mirror of
https://github.com/langgenius/dify.git
synced 2026-02-07 08:33:55 +00:00
Compare commits
9 Commits
1.12.1-ote
...
deploy/dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
02206cc64b | ||
|
|
cba157038f | ||
|
|
d606de26f1 | ||
|
|
49f46a05e7 | ||
|
|
0d74ac634b | ||
|
|
30b73f2765 | ||
|
|
468990cc39 | ||
|
|
64e769f96e | ||
|
|
778aabb485 |
@@ -106,10 +106,10 @@ ignore = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
"PT019", # @patch-injected params look like unused fixtures
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
]
|
||||
"controllers/console/explore/trial.py" = ["TID251"]
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
|
||||
@@ -122,8 +122,7 @@ These commands assume you start from the repository root.
|
||||
|
||||
```bash
|
||||
cd api
|
||||
# Note: enterprise_telemetry queue is only used in Enterprise Edition
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,enterprise_telemetry
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
@@ -81,7 +81,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_enterprise_telemetry,
|
||||
ext_fastopenapi,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
@@ -132,7 +131,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_fastopenapi,
|
||||
ext_otel,
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
]
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings
|
||||
from libs.file_utils import search_file_upwards
|
||||
|
||||
from .deploy import DeploymentConfig
|
||||
from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig
|
||||
from .enterprise import EnterpriseFeatureConfig
|
||||
from .extra import ExtraServiceConfig
|
||||
from .feature import FeatureConfig
|
||||
from .middleware import MiddlewareConfig
|
||||
@@ -73,8 +73,6 @@ class DifyConfig(
|
||||
# Enterprise feature configs
|
||||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
# Enterprise telemetry configs
|
||||
EnterpriseTelemetryConfig,
|
||||
):
|
||||
model_config = SettingsConfigDict(
|
||||
# read from dotenv format config file
|
||||
|
||||
@@ -18,44 +18,3 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
description="Allow customization of the enterprise logo.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseTelemetryConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for enterprise telemetry.
|
||||
"""
|
||||
|
||||
ENTERPRISE_TELEMETRY_ENABLED: bool = Field(
|
||||
description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_ENDPOINT: str = Field(
|
||||
description="Enterprise OTEL collector endpoint.",
|
||||
default="",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_HEADERS: str = Field(
|
||||
description="Auth headers for OTLP export (key=value,key2=value2).",
|
||||
default="",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_PROTOCOL: str = Field(
|
||||
description="OTLP protocol: 'http' or 'grpc' (default: http).",
|
||||
default="http",
|
||||
)
|
||||
|
||||
ENTERPRISE_INCLUDE_CONTENT: bool = Field(
|
||||
description="Include input/output content in traces (privacy toggle).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
ENTERPRISE_SERVICE_NAME: str = Field(
|
||||
description="Service name for OTEL resource.",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE: float = Field(
|
||||
description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -12,10 +11,12 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
@@ -26,32 +27,14 @@ from services.workflow_service import WorkflowService
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
code_language: str = Field(default="javascript", description="Programming language for code generation")
|
||||
|
||||
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
node_id: str = Field(default="", description="Node ID for workflow context")
|
||||
current: str = Field(default="", description="Current instruction text")
|
||||
language: str = Field(default="javascript", description="Programming language (javascript/python)")
|
||||
instruction: str = Field(..., description="Instruction for generation")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
|
||||
class InstructionTemplatePayload(BaseModel):
|
||||
@@ -67,6 +50,7 @@ reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
reg(ModelConfig)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
@@ -82,17 +66,10 @@ class RuleGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=args.no_variable,
|
||||
user_id=account.id,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
@@ -118,16 +95,12 @@ class RuleCodeGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.code_language,
|
||||
user_id=account.id,
|
||||
app_id=args.app_id,
|
||||
args=args,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -154,15 +127,12 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
user_id=account.id,
|
||||
app_id=args.app_id,
|
||||
args=args,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -189,14 +159,14 @@ class InstructionGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
app_id = args.app_id or args.flow_id
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
if not app:
|
||||
@@ -213,33 +183,33 @@ class InstructionGenerateApi(Resource):
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
args=RuleGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
),
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
args=RuleGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
),
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
args=RuleCodeGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
),
|
||||
)
|
||||
case _:
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args.node_id == "" and args.current != "":
|
||||
if args.node_id == "" and args.current != "": # For legacy app without a workflow
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
@@ -247,10 +217,8 @@ class InstructionGenerateApi(Resource):
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
if args.node_id != "" and args.current != "":
|
||||
if args.node_id != "" and args.current != "": # For workflow node
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
@@ -260,8 +228,6 @@ class InstructionGenerateApi(Resource):
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
workflow_service=WorkflowService(),
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest
|
||||
@@ -78,10 +77,7 @@ class TraceAppConfigApi(Resource):
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=app_id,
|
||||
tracing_provider=args.tracing_provider,
|
||||
tracing_config=args.tracing_config,
|
||||
account_id=current_user.id,
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
@@ -106,10 +102,7 @@ class TraceAppConfigApi(Resource):
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=app_id,
|
||||
tracing_provider=args.tracing_provider,
|
||||
tracing_config=args.tracing_config,
|
||||
account_id=current_user.id,
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
@@ -131,9 +124,7 @@ class TraceAppConfigApi(Resource):
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, account_id=current_user.id
|
||||
)
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
@@ -117,7 +118,56 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel
|
||||
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
|
||||
|
||||
|
||||
# Pydantic models for request validation
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunRequest(BaseModel):
|
||||
inputs: dict
|
||||
files: list | None = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str
|
||||
files: list | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
class TextToSpeechRequest(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str = ""
|
||||
files: list | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
# Register schemas for Swagger documentation
|
||||
console_ns.schema_model(
|
||||
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
@@ -129,10 +179,8 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
@@ -183,6 +231,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
@@ -190,14 +239,14 @@ class TrialChatApi(TrialAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = ChatRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
# Validate UUID values if provided
|
||||
if args.get("conversation_id"):
|
||||
args["conversation_id"] = uuid_value(args["conversation_id"])
|
||||
if args.get("parent_message_id"):
|
||||
args["parent_message_id"] = uuid_value(args["parent_message_id"])
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@@ -320,20 +369,16 @@ class TrialChatAudioApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = request_data.message_id
|
||||
text = request_data.text
|
||||
voice = request_data.voice
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@@ -371,19 +416,15 @@ class TrialChatTextApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = CompletionRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@@ -79,7 +79,7 @@ class BaseAgentRunner(AppRunner):
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
self.agent_callback = DifyAgentCallbackHandler(tenant_id=tenant_id)
|
||||
self.agent_callback = DifyAgentCallbackHandler()
|
||||
# init dataset tools
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=queue_manager,
|
||||
|
||||
@@ -63,8 +63,6 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
@@ -566,6 +564,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle stop events."""
|
||||
_ = trace_manager
|
||||
resolved_state = None
|
||||
if self._workflow_run_id:
|
||||
resolved_state = self._resolve_graph_runtime_state(graph_runtime_state)
|
||||
@@ -580,7 +579,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
# Save message
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif event.stopped_by in (
|
||||
@@ -589,7 +589,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
):
|
||||
# When hitting input-moderation or annotation-reply, the workflow will not start
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
# Save message
|
||||
self._save_message(session=session)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@@ -598,7 +599,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
event: QueueAdvancedChatMessageEndEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
@@ -616,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@@ -770,13 +770,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if self._conversation_name_generate_thread:
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
):
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
message = self._get_message(session=session)
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
@@ -832,22 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
if trace_manager:
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
),
|
||||
payload={
|
||||
"conversation_id": str(message.conversation_id),
|
||||
"message_id": str(message.id),
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@@ -147,12 +147,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
|
||||
extras: dict[str, Any] = {
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
parent_trace_context = args.get("_parent_trace_context")
|
||||
if parent_trace_context:
|
||||
extras["parent_trace_context"] = parent_trace_context
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
|
||||
@@ -52,11 +52,10 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -410,19 +409,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
),
|
||||
payload={
|
||||
"conversation_id": self._conversation_id,
|
||||
"message_id": self._message_id,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
|
||||
)
|
||||
)
|
||||
|
||||
message_was_created.send(
|
||||
|
||||
@@ -15,7 +15,8 @@ from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import (
|
||||
@@ -372,7 +373,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
self._enqueue_node_trace_task(domain_execution)
|
||||
|
||||
def _fail_running_node_executions(self, *, error_message: str) -> None:
|
||||
now = naive_utc_now()
|
||||
@@ -390,131 +390,17 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
external_trace_id = None
|
||||
parent_trace_context = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
|
||||
parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context")
|
||||
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
),
|
||||
payload={
|
||||
"workflow_execution": execution,
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": self._trace_manager.user_id,
|
||||
"external_trace_id": external_trace_id,
|
||||
"parent_trace_context": parent_trace_context,
|
||||
},
|
||||
),
|
||||
trace_manager=self._trace_manager,
|
||||
)
|
||||
|
||||
def _enqueue_node_trace_task(self, domain_execution: WorkflowNodeExecution) -> None:
|
||||
if not self._trace_manager:
|
||||
return
|
||||
|
||||
execution = self._get_workflow_execution()
|
||||
meta = domain_execution.metadata or {}
|
||||
|
||||
parent_trace_context = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context")
|
||||
|
||||
node_data: dict[str, Any] = {
|
||||
"workflow_id": domain_execution.workflow_id,
|
||||
"workflow_execution_id": execution.id_,
|
||||
"tenant_id": self._application_generate_entity.app_config.tenant_id,
|
||||
"app_id": self._application_generate_entity.app_config.app_id,
|
||||
"node_execution_id": domain_execution.id,
|
||||
"node_id": domain_execution.node_id,
|
||||
"node_type": str(domain_execution.node_type.value),
|
||||
"title": domain_execution.title,
|
||||
"status": str(domain_execution.status.value),
|
||||
"error": domain_execution.error,
|
||||
"elapsed_time": domain_execution.elapsed_time,
|
||||
"index": domain_execution.index,
|
||||
"predecessor_node_id": domain_execution.predecessor_node_id,
|
||||
"created_at": domain_execution.created_at,
|
||||
"finished_at": domain_execution.finished_at,
|
||||
"total_tokens": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
|
||||
"prompt_tokens": meta.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS),
|
||||
"completion_tokens": meta.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS),
|
||||
"total_price": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
|
||||
"currency": meta.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
|
||||
"tool_name": (meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
|
||||
if isinstance(meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
|
||||
else None,
|
||||
"iteration_id": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
|
||||
"iteration_index": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
|
||||
"loop_id": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
|
||||
"loop_index": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
|
||||
"parallel_id": meta.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
|
||||
"node_inputs": dict(domain_execution.inputs) if domain_execution.inputs else None,
|
||||
"node_outputs": dict(domain_execution.outputs) if domain_execution.outputs else None,
|
||||
"process_data": dict(domain_execution.process_data) if domain_execution.process_data else None,
|
||||
}
|
||||
node_data["invoke_from"] = self._application_generate_entity.invoke_from.value
|
||||
node_data["user_id"] = self._system_variables().get(SystemVariableKey.USER_ID.value)
|
||||
|
||||
if domain_execution.node_type.value == "knowledge-retrieval" and domain_execution.outputs:
|
||||
results = domain_execution.outputs.get("result") or []
|
||||
dataset_ids: list[str] = []
|
||||
dataset_names: list[str] = []
|
||||
for doc in results:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
doc_meta = doc.get("metadata") or {}
|
||||
did = doc_meta.get("dataset_id")
|
||||
dname = doc_meta.get("dataset_name")
|
||||
if did and did not in dataset_ids:
|
||||
dataset_ids.append(did)
|
||||
if dname and dname not in dataset_names:
|
||||
dataset_names.append(dname)
|
||||
if dataset_ids:
|
||||
node_data["dataset_ids"] = dataset_ids
|
||||
if dataset_names:
|
||||
node_data["dataset_names"] = dataset_names
|
||||
|
||||
tool_info = meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO)
|
||||
if isinstance(tool_info, dict):
|
||||
plugin_id = tool_info.get("plugin_unique_identifier")
|
||||
if plugin_id:
|
||||
node_data["plugin_name"] = plugin_id
|
||||
credential_id = tool_info.get("credential_id")
|
||||
if credential_id:
|
||||
node_data["credential_id"] = credential_id
|
||||
node_data["credential_provider_type"] = tool_info.get("provider_type")
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
if conversation_id:
|
||||
node_data["conversation_id"] = conversation_id
|
||||
|
||||
if parent_trace_context:
|
||||
node_data["parent_trace_context"] = parent_trace_context
|
||||
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=node_data.get("tenant_id"),
|
||||
user_id=node_data.get("user_id"),
|
||||
app_id=node_data.get("app_id"),
|
||||
),
|
||||
payload={"node_execution_data": node_data},
|
||||
),
|
||||
trace_manager=self._trace_manager,
|
||||
trace_task = TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
external_trace_id=external_trace_id,
|
||||
)
|
||||
self._trace_manager.add_trace_task(trace_task)
|
||||
|
||||
def _system_variables(self) -> Mapping[str, Any]:
|
||||
runtime_state = self.graph_runtime_state
|
||||
|
||||
@@ -4,9 +4,8 @@ from typing import Any, TextIO, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
@@ -37,15 +36,13 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
|
||||
color: str | None = ""
|
||||
current_loop: int = 1
|
||||
tenant_id: str | None = None
|
||||
|
||||
def __init__(self, color: str | None = None, tenant_id: str | None = None):
|
||||
def __init__(self, color: str | None = None):
|
||||
super().__init__()
|
||||
"""Initialize callback handler."""
|
||||
# use a specific color is not specified
|
||||
self.color = color or "green"
|
||||
self.current_loop = 1
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
@@ -74,23 +71,15 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
print_text("\n")
|
||||
|
||||
if trace_manager:
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.TOOL_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=trace_manager.app_id,
|
||||
user_id=trace_manager.user_id,
|
||||
),
|
||||
payload={
|
||||
"message_id": message_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_inputs": tool_inputs,
|
||||
"tool_outputs": tool_outputs,
|
||||
"timer": timer,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.TOOL_TRACE,
|
||||
message_id=message_id,
|
||||
tool_name=tool_name,
|
||||
tool_inputs=tool_inputs,
|
||||
tool_outputs=tool_outputs,
|
||||
timer=timer,
|
||||
)
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any):
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Protocol, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.llm_generator.prompts import (
|
||||
@@ -25,11 +27,10 @@ from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.entities.trace_entity import OperationType
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
@@ -72,8 +73,8 @@ class LLMGenerator:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
if answer is None:
|
||||
answer = response.message.get_text_content()
|
||||
if answer == "":
|
||||
return ""
|
||||
try:
|
||||
result_dict = json.loads(answer)
|
||||
@@ -95,17 +96,15 @@ class LLMGenerator:
|
||||
name = name[:75] + "..."
|
||||
|
||||
# get tracing instance
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
context=TelemetryContext(tenant_id=tenant_id, app_id=app_id),
|
||||
payload={
|
||||
"conversation_id": conversation_id,
|
||||
"generate_conversation_name": name,
|
||||
"inputs": prompt,
|
||||
"timer": timer,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
trace_manager = TraceQueueManager(app_id=app_id)
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.GENERATE_NAME_TRACE,
|
||||
conversation_id=conversation_id,
|
||||
generate_conversation_name=name,
|
||||
inputs=prompt,
|
||||
timer=timer,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -154,27 +153,19 @@ class LLMGenerator:
|
||||
return questions
|
||||
|
||||
@classmethod
|
||||
def generate_rule_config(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
no_variable: bool,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload):
|
||||
output_parser = RuleConfigGeneratorOutputParser()
|
||||
|
||||
error = ""
|
||||
error_step = ""
|
||||
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
if no_variable:
|
||||
model_parameters = args.model_config_data.completion_params
|
||||
if args.no_variable:
|
||||
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
||||
|
||||
prompt_generate = prompt_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": instruction,
|
||||
"TASK_DESCRIPTION": args.instruction,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
@@ -186,45 +177,26 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
)
|
||||
|
||||
llm_result = None
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
rule_config["prompt"] = cast(str, llm_result.message.content)
|
||||
rule_config["prompt"] = response.message.get_text_content()
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
if user_id:
|
||||
prompt_value = rule_config.get("prompt", "")
|
||||
generated_output = str(prompt_value) if prompt_value else ""
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.RULE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output=generated_output,
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error or None,
|
||||
)
|
||||
|
||||
return rule_config
|
||||
|
||||
# get rule config prompt, parameter and statement
|
||||
@@ -239,7 +211,7 @@ class LLMGenerator:
|
||||
# format the prompt_generate_prompt
|
||||
prompt_generate_prompt = prompt_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": instruction,
|
||||
"TASK_DESCRIPTION": args.instruction,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
@@ -250,125 +222,84 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
)
|
||||
|
||||
llm_result = None
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
try:
|
||||
try:
|
||||
# the first step to generate the task prompt
|
||||
prompt_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
llm_result = prompt_content
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate prefix prompt"
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
if user_id:
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.RULE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output="",
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
)
|
||||
|
||||
return rule_config
|
||||
|
||||
rule_config["prompt"] = cast(str, prompt_content.message.content)
|
||||
|
||||
if not isinstance(prompt_content.message.content, str):
|
||||
raise NotImplementedError("prompt content is not a string")
|
||||
parameter_generate_prompt = parameter_template.format(
|
||||
inputs={
|
||||
"INPUT_TEXT": prompt_content.message.content,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
# the first step to generate the task prompt
|
||||
prompt_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
|
||||
|
||||
# the second step to generate the task_parameter and task_statement
|
||||
statement_generate_prompt = statement_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": instruction,
|
||||
"INPUT_TEXT": prompt_content.message.content,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["variables"] = re.findall(
|
||||
r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate variables"
|
||||
|
||||
try:
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = cast(str, statement_content.message.content)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate conversation opener"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
rule_config["error"] = str(e)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate prefix prompt"
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
return rule_config
|
||||
|
||||
rule_config["prompt"] = prompt_content.message.get_text_content()
|
||||
|
||||
parameter_generate_prompt = parameter_template.format(
|
||||
inputs={
|
||||
"INPUT_TEXT": prompt_content.message.get_text_content(),
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
|
||||
|
||||
# the second step to generate the task_parameter and task_statement
|
||||
statement_generate_prompt = statement_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": args.instruction,
|
||||
"INPUT_TEXT": prompt_content.message.get_text_content(),
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate variables"
|
||||
|
||||
try:
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = statement_content.message.get_text_content()
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate conversation opener"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
if user_id:
|
||||
generated_output = rule_config.get("prompt", "")
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.RULE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output=str(generated_output) if generated_output else "",
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error or None,
|
||||
)
|
||||
|
||||
return rule_config
|
||||
|
||||
@classmethod
|
||||
def generate_code(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
code_language: str = "javascript",
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
args: RuleCodeGeneratePayload,
|
||||
):
|
||||
if code_language == "python":
|
||||
if args.code_language == "python":
|
||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
else:
|
||||
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
"INSTRUCTION": instruction,
|
||||
"CODE_LANGUAGE": code_language,
|
||||
"INSTRUCTION": args.instruction,
|
||||
"CODE_LANGUAGE": args.code_language,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
@@ -377,49 +308,28 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
|
||||
llm_result = None
|
||||
error = None
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_code = cast(str, llm_result.message.content)
|
||||
result = {"code": generated_code, "language": code_language, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
result = {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language
|
||||
)
|
||||
error = str(e)
|
||||
result = {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
if user_id:
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.CODE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output=result.get("code", ""),
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
model_parameters = args.model_config_data.completion_params
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
return result
|
||||
generated_code = response.message.get_text_content()
|
||||
return {"code": generated_code, "language": args.code_language, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language
|
||||
)
|
||||
return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
@@ -445,76 +355,49 @@ class LLMGenerator:
|
||||
raise TypeError("Expected LLMResult when stream=False")
|
||||
response = result
|
||||
|
||||
answer = cast(str, response.message.content)
|
||||
answer = response.message.get_text_content()
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_structured_output(
|
||||
cls, tenant_id: str, instruction: str, model_config: dict, user_id: str | None = None, app_id: str | None = None
|
||||
):
|
||||
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
)
|
||||
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
|
||||
UserPromptMessage(content=instruction),
|
||||
UserPromptMessage(content=args.instruction),
|
||||
]
|
||||
model_parameters = model_config.get("model_parameters", {})
|
||||
model_parameters = args.model_config_data.completion_params
|
||||
|
||||
llm_result = None
|
||||
error = None
|
||||
result = {"output": "", "error": ""}
|
||||
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
raw_content = llm_result.message.content
|
||||
|
||||
if not isinstance(raw_content, str):
|
||||
raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
|
||||
|
||||
try:
|
||||
parsed_content = json.loads(raw_content)
|
||||
except json.JSONDecodeError:
|
||||
parsed_content = json_repair.loads(raw_content)
|
||||
|
||||
if not isinstance(parsed_content, dict | list):
|
||||
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
|
||||
|
||||
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
|
||||
result = {"output": generated_json_schema, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
result = {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
error = str(e)
|
||||
result = {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
if user_id:
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.STRUCTURED_OUTPUT,
|
||||
instruction=instruction,
|
||||
generated_output=result.get("output", ""),
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
return result
|
||||
raw_content = response.message.get_text_content()
|
||||
|
||||
try:
|
||||
parsed_content = json.loads(raw_content)
|
||||
except json.JSONDecodeError:
|
||||
parsed_content = json_repair.loads(raw_content)
|
||||
|
||||
if not isinstance(parsed_content, dict | list):
|
||||
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
|
||||
|
||||
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
|
||||
return {"output": generated_json_schema, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name)
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_legacy(
|
||||
@@ -522,16 +405,14 @@ class LLMGenerator:
|
||||
flow_id: str,
|
||||
current: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
model_config: ModelConfig,
|
||||
ideal_output: str | None,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
last_run: Message | None = (
|
||||
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
|
||||
)
|
||||
if not last_run:
|
||||
result = LLMGenerator.__instruction_modify_common(
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=None,
|
||||
@@ -540,28 +421,22 @@ class LLMGenerator:
|
||||
instruction=instruction,
|
||||
node_type="llm",
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
else:
|
||||
last_run_dict = {
|
||||
"query": last_run.query,
|
||||
"answer": last_run.answer,
|
||||
"error": last_run.error,
|
||||
}
|
||||
result = LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=last_run_dict,
|
||||
current=current,
|
||||
error_message=str(last_run.error),
|
||||
instruction=instruction,
|
||||
node_type="llm",
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
return result
|
||||
last_run_dict = {
|
||||
"query": last_run.query,
|
||||
"answer": last_run.answer,
|
||||
"error": last_run.error,
|
||||
}
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=last_run_dict,
|
||||
current=current,
|
||||
error_message=str(last_run.error),
|
||||
instruction=instruction,
|
||||
node_type="llm",
|
||||
ideal_output=ideal_output,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_workflow(
|
||||
@@ -570,11 +445,9 @@ class LLMGenerator:
|
||||
node_id: str,
|
||||
current: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
model_config: ModelConfig,
|
||||
ideal_output: str | None,
|
||||
workflow_service: WorkflowServiceInterface,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
session = db.session()
|
||||
|
||||
@@ -605,8 +478,6 @@ class LLMGenerator:
|
||||
instruction=instruction,
|
||||
node_type=node_type,
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
|
||||
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
|
||||
@@ -640,22 +511,18 @@ class LLMGenerator:
|
||||
instruction=instruction,
|
||||
node_type=last_run.node_type,
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __instruction_modify_common(
|
||||
tenant_id: str,
|
||||
model_config: dict,
|
||||
model_config: ModelConfig,
|
||||
last_run: dict | None,
|
||||
current: str | None,
|
||||
error_message: str | None,
|
||||
instruction: str,
|
||||
node_type: str,
|
||||
ideal_output: str | None,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
LAST_RUN = "{{#last_run#}}"
|
||||
CURRENT = "{{#current#}}"
|
||||
@@ -670,8 +537,8 @@ class LLMGenerator:
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
match node_type:
|
||||
case "llm" | "agent":
|
||||
@@ -695,122 +562,24 @@ class LLMGenerator:
|
||||
]
|
||||
model_parameters = {"temperature": 0.4}
|
||||
|
||||
llm_result = None
|
||||
error = None
|
||||
result = {}
|
||||
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_raw = llm_result.message.get_text_content()
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
result = data
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
result = {"error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
|
||||
)
|
||||
error = str(e)
|
||||
result = {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
if user_id:
|
||||
generated_output = ""
|
||||
if isinstance(result, dict):
|
||||
for key in ["prompt", "code", "output", "modified"]:
|
||||
if result.get(key):
|
||||
generated_output = str(result[key])
|
||||
break
|
||||
|
||||
LLMGenerator._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.INSTRUCTION_MODIFY,
|
||||
instruction=instruction,
|
||||
generated_output=generated_output,
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _emit_prompt_generation_trace(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
app_id: str | None,
|
||||
operation_type: OperationType,
|
||||
instruction: str,
|
||||
generated_output: str,
|
||||
llm_result: LLMResult | None,
|
||||
model_config: dict | None = None,
|
||||
timer=None,
|
||||
error: str | None = None,
|
||||
):
|
||||
if llm_result:
|
||||
prompt_tokens = llm_result.usage.prompt_tokens
|
||||
completion_tokens = llm_result.usage.completion_tokens
|
||||
total_tokens = llm_result.usage.total_tokens
|
||||
model_name = llm_result.model
|
||||
# Extract provider from model_config if available, otherwise fall back to parsing model name
|
||||
if model_config and model_config.get("provider"):
|
||||
model_provider = model_config.get("provider", "")
|
||||
else:
|
||||
model_provider = model_name.split("/")[0] if "/" in model_name else ""
|
||||
latency = llm_result.usage.latency
|
||||
total_price = float(llm_result.usage.total_price) if llm_result.usage.total_price else None
|
||||
currency = llm_result.usage.currency
|
||||
else:
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
model_provider = model_config.get("provider", "") if model_config else ""
|
||||
model_name = model_config.get("name", "") if model_config else ""
|
||||
latency = 0.0
|
||||
if timer:
|
||||
start_time = timer.get("start")
|
||||
end_time = timer.get("end")
|
||||
if start_time and end_time:
|
||||
latency = (end_time - start_time).total_seconds()
|
||||
total_price = None
|
||||
currency = None
|
||||
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.PROMPT_GENERATION_TRACE,
|
||||
context=TelemetryContext(tenant_id=tenant_id, user_id=user_id, app_id=app_id),
|
||||
payload={
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"app_id": app_id,
|
||||
"operation_type": operation_type,
|
||||
"instruction": instruction,
|
||||
"generated_output": generated_output,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"model_provider": model_provider,
|
||||
"model_name": model_name,
|
||||
"latency": latency,
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"timer": timer,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
)
|
||||
generated_raw = response.message.get_text_content()
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
return data
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
|
||||
return {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@@ -15,23 +15,16 @@ class TraceContextFilter(logging.Filter):
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Preserve explicit trace_id set by the caller (e.g. emit_metric_only_event)
|
||||
existing_trace_id = getattr(record, "trace_id", "")
|
||||
if not existing_trace_id:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
|
||||
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
else:
|
||||
record.trace_id = get_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
else:
|
||||
# Keep existing trace_id; only fill span_id if missing
|
||||
if not getattr(record, "span_id", ""):
|
||||
record.span_id = ""
|
||||
record.trace_id = get_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
|
||||
# For backward compatibility, also set req_id
|
||||
record.req_id = get_request_id()
|
||||
@@ -62,12 +55,9 @@ class IdentityContextFilter(logging.Filter):
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
if not getattr(record, "tenant_id", ""):
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
if not getattr(record, "user_id", ""):
|
||||
record.user_id = identity.get("user_id", "")
|
||||
if not getattr(record, "user_type", ""):
|
||||
record.user_type = identity.get("user_type", "")
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
record.user_id = identity.get("user_id", "")
|
||||
record.user_type = identity.get("user_type", "")
|
||||
return True
|
||||
|
||||
def _extract_identity(self) -> dict[str, str]:
|
||||
|
||||
@@ -5,10 +5,9 @@ from typing import Any
|
||||
from core.app.app_config.entities import AppConfig
|
||||
from core.moderation.base import ModerationAction, ModerationError
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,18 +49,14 @@ class InputModeration:
|
||||
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
|
||||
|
||||
if trace_manager:
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.MODERATION_TRACE,
|
||||
context=TelemetryContext(tenant_id=tenant_id, app_id=app_id),
|
||||
payload={
|
||||
"message_id": message_id,
|
||||
"moderation_result": moderation_result,
|
||||
"inputs": inputs,
|
||||
"timer": timer,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MODERATION_TRACE,
|
||||
message_id=message_id,
|
||||
moderation_result=moderation_result,
|
||||
inputs=inputs,
|
||||
timer=timer,
|
||||
)
|
||||
)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
|
||||
@@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
class BaseTraceInfo(BaseModel):
|
||||
message_id: str | None = None
|
||||
message_data: Any | None = None
|
||||
inputs: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
outputs: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
inputs: Union[str, dict[str, Any], list] | None = None
|
||||
outputs: Union[str, dict[str, Any], list] | None = None
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
metadata: dict[str, Any]
|
||||
@@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel):
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None:
|
||||
def ensure_type(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str | dict | list):
|
||||
@@ -48,14 +48,10 @@ class WorkflowTraceInfo(BaseTraceInfo):
|
||||
workflow_run_version: str
|
||||
error: str | None = None
|
||||
total_tokens: int
|
||||
prompt_tokens: int | None = None
|
||||
completion_tokens: int | None = None
|
||||
file_list: list[str]
|
||||
query: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
invoked_by: str | None = None
|
||||
|
||||
|
||||
class MessageTraceInfo(BaseTraceInfo):
|
||||
conversation_model: str
|
||||
@@ -63,7 +59,7 @@ class MessageTraceInfo(BaseTraceInfo):
|
||||
answer_tokens: int
|
||||
total_tokens: int
|
||||
error: str | None = None
|
||||
file_list: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
file_list: Union[str, dict[str, Any], list] | None = None
|
||||
message_file_data: Any | None = None
|
||||
conversation_mode: str
|
||||
gen_ai_server_time_to_first_token: float | None = None
|
||||
@@ -110,7 +106,7 @@ class ToolTraceInfo(BaseTraceInfo):
|
||||
tool_config: dict[str, Any]
|
||||
time_cost: Union[int, float]
|
||||
tool_parameters: dict[str, Any]
|
||||
file_url: Union[str, None, list[str]] = None
|
||||
file_url: Union[str, None, list] = None
|
||||
|
||||
|
||||
class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
@@ -118,79 +114,6 @@ class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class PromptGenerationTraceInfo(BaseTraceInfo):
|
||||
"""Trace information for prompt generation operations (rule-generate, code-generate, etc.)."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
app_id: str | None = None
|
||||
|
||||
operation_type: str
|
||||
instruction: str
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
model_provider: str
|
||||
model_name: str
|
||||
|
||||
latency: float
|
||||
|
||||
total_price: float | None = None
|
||||
currency: str | None = None
|
||||
|
||||
error: str | None = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class WorkflowNodeTraceInfo(BaseTraceInfo):
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
tenant_id: str
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
|
||||
status: str
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
|
||||
index: int
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
total_tokens: int = 0
|
||||
total_price: float = 0.0
|
||||
currency: str | None = None
|
||||
|
||||
model_provider: str | None = None
|
||||
model_name: str | None = None
|
||||
prompt_tokens: int | None = None
|
||||
completion_tokens: int | None = None
|
||||
|
||||
tool_name: str | None = None
|
||||
|
||||
iteration_id: str | None = None
|
||||
iteration_index: int | None = None
|
||||
loop_id: str | None = None
|
||||
loop_index: int | None = None
|
||||
parallel_id: str | None = None
|
||||
|
||||
node_inputs: Mapping[str, Any] | None = None
|
||||
node_outputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
|
||||
invoked_by: str | None = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class DraftNodeExecutionTrace(WorkflowNodeTraceInfo):
|
||||
pass
|
||||
|
||||
|
||||
class TaskData(BaseModel):
|
||||
app_id: str
|
||||
trace_info_type: str
|
||||
@@ -205,38 +128,16 @@ trace_info_info_map = {
|
||||
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
|
||||
"ToolTraceInfo": ToolTraceInfo,
|
||||
"GenerateNameTraceInfo": GenerateNameTraceInfo,
|
||||
"PromptGenerationTraceInfo": PromptGenerationTraceInfo,
|
||||
"WorkflowNodeTraceInfo": WorkflowNodeTraceInfo,
|
||||
"DraftNodeExecutionTrace": DraftNodeExecutionTrace,
|
||||
}
|
||||
|
||||
|
||||
class OperationType(StrEnum):
|
||||
"""Operation type for token metric labels.
|
||||
|
||||
Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output``
|
||||
counters so consumers can break down token usage by operation.
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
NODE_EXECUTION = "node_execution"
|
||||
MESSAGE = "message"
|
||||
RULE_GENERATE = "rule_generate"
|
||||
CODE_GENERATE = "code_generate"
|
||||
STRUCTURED_OUTPUT = "structured_output"
|
||||
INSTRUCTION_MODIFY = "instruction_modify"
|
||||
|
||||
|
||||
class TraceTaskName(StrEnum):
|
||||
CONVERSATION_TRACE = "conversation"
|
||||
WORKFLOW_TRACE = "workflow"
|
||||
DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution"
|
||||
MESSAGE_TRACE = "message"
|
||||
MODERATION_TRACE = "moderation"
|
||||
SUGGESTED_QUESTION_TRACE = "suggested_question"
|
||||
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
||||
TOOL_TRACE = "tool"
|
||||
GENERATE_NAME_TRACE = "generate_conversation_name"
|
||||
PROMPT_GENERATION_TRACE = "prompt_generation"
|
||||
DATASOURCE_TRACE = "datasource"
|
||||
NODE_EXECUTION_TRACE = "node_execution"
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from langfuse import Langfuse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
@@ -31,7 +30,7 @@ from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, Message, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -72,50 +71,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
metadata = trace_info.metadata
|
||||
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
# Check for parent_trace_context to detect nested workflow
|
||||
parent_trace_context = trace_info.metadata.get("parent_trace_context")
|
||||
|
||||
if parent_trace_context:
|
||||
# Nested workflow: create span under outer trace
|
||||
outer_trace_id = parent_trace_context.get("trace_id")
|
||||
parent_node_execution_id = parent_trace_context.get("parent_node_execution_id")
|
||||
parent_conversation_id = parent_trace_context.get("parent_conversation_id")
|
||||
parent_workflow_run_id = parent_trace_context.get("parent_workflow_run_id")
|
||||
|
||||
# Resolve outer trace_id: try message_id lookup first, fallback to workflow_run_id
|
||||
if parent_conversation_id:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
with session_factory() as session:
|
||||
message_data_stmt = select(Message.id).where(
|
||||
Message.conversation_id == parent_conversation_id,
|
||||
Message.workflow_run_id == parent_workflow_run_id,
|
||||
)
|
||||
resolved_message_id = session.scalar(message_data_stmt)
|
||||
if resolved_message_id:
|
||||
outer_trace_id = resolved_message_id
|
||||
else:
|
||||
outer_trace_id = parent_workflow_run_id
|
||||
else:
|
||||
outer_trace_id = parent_workflow_run_id
|
||||
|
||||
# Create inner workflow span under outer trace
|
||||
workflow_span_data = LangfuseSpan(
|
||||
id=trace_info.workflow_run_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
input=dict(trace_info.workflow_run_inputs),
|
||||
output=dict(trace_info.workflow_run_outputs),
|
||||
trace_id=outer_trace_id,
|
||||
parent_observation_id=parent_node_execution_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=metadata,
|
||||
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
||||
status_message=trace_info.error or "",
|
||||
)
|
||||
self.add_span(langfuse_span_data=workflow_span_data)
|
||||
# Use outer_trace_id for all node spans/generations
|
||||
trace_id = outer_trace_id
|
||||
elif trace_info.message_id:
|
||||
if trace_info.message_id:
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
name = TraceTaskName.MESSAGE_TRACE
|
||||
trace_data = LangfuseTrace(
|
||||
@@ -218,11 +174,6 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
# Determine parent_observation_id for nested workflows
|
||||
node_parent_observation_id = None
|
||||
if parent_trace_context or trace_info.message_id:
|
||||
node_parent_observation_id = trace_info.workflow_run_id
|
||||
|
||||
# add generation span
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
@@ -255,7 +206,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=node_parent_observation_id,
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
@@ -274,7 +225,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=node_parent_observation_id,
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import cast
|
||||
|
||||
from langsmith import Client
|
||||
from langsmith.schemas import RunBase
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
@@ -31,7 +30,7 @@ from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, Message, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -65,35 +64,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
# Check for parent_trace_context for cross-workflow linking
|
||||
parent_trace_context = trace_info.metadata.get("parent_trace_context")
|
||||
|
||||
if parent_trace_context:
|
||||
# Inner workflow: resolve outer trace_id and link to parent node
|
||||
outer_trace_id = parent_trace_context.get("parent_workflow_run_id")
|
||||
|
||||
# Try to resolve message_id from conversation_id if available
|
||||
if parent_trace_context.get("parent_conversation_id"):
|
||||
try:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
with session_factory() as session:
|
||||
message_data_stmt = select(Message.id).where(
|
||||
Message.conversation_id == parent_trace_context["parent_conversation_id"],
|
||||
Message.workflow_run_id == parent_trace_context["parent_workflow_run_id"],
|
||||
)
|
||||
resolved_message_id = session.scalar(message_data_stmt)
|
||||
if resolved_message_id:
|
||||
outer_trace_id = resolved_message_id
|
||||
except Exception as e:
|
||||
logger.debug("Failed to resolve message_id from conversation_id: %s", str(e))
|
||||
|
||||
trace_id = outer_trace_id
|
||||
parent_run_id = parent_trace_context.get("parent_node_execution_id")
|
||||
else:
|
||||
# Outer workflow: existing behavior
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
parent_run_id = trace_info.message_id or None
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
message_dotted_order = (
|
||||
@@ -107,8 +78,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
metadata = trace_info.metadata
|
||||
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
# Only create message_run for outer workflows (no parent_trace_context)
|
||||
if trace_info.message_id and not parent_trace_context:
|
||||
if trace_info.message_id:
|
||||
message_run = LangSmithRunModel(
|
||||
id=trace_info.message_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
@@ -151,9 +121,9 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
},
|
||||
error=trace_info.error,
|
||||
tags=["workflow"],
|
||||
parent_run_id=parent_run_id,
|
||||
parent_run_id=trace_info.message_id or None,
|
||||
trace_id=trace_id,
|
||||
dotted_order=None if parent_trace_context else workflow_dotted_order,
|
||||
dotted_order=workflow_dotted_order,
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
|
||||
@@ -21,25 +21,19 @@ from core.ops.entities.config_entity import (
|
||||
)
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
DraftNodeExecutionTrace,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
PromptGenerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
TaskData,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowNodeTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from models.workflow import WorkflowAppLog
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
@@ -49,44 +43,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
|
||||
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
|
||||
app_name = ""
|
||||
workspace_name = ""
|
||||
if not app_id and not tenant_id:
|
||||
return app_name, workspace_name
|
||||
with Session(db.engine) as session:
|
||||
if app_id:
|
||||
name = session.scalar(select(App.name).where(App.id == app_id))
|
||||
if name:
|
||||
app_name = name
|
||||
if tenant_id:
|
||||
name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id))
|
||||
if name:
|
||||
workspace_name = name
|
||||
return app_name, workspace_name
|
||||
|
||||
|
||||
_PROVIDER_TYPE_TO_MODEL: dict[str, type] = {
|
||||
"builtin": BuiltinToolProvider,
|
||||
"plugin": BuiltinToolProvider,
|
||||
"api": ApiToolProvider,
|
||||
"workflow": WorkflowToolProvider,
|
||||
"mcp": MCPToolProvider,
|
||||
}
|
||||
|
||||
|
||||
def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str:
|
||||
if not credential_id:
|
||||
return ""
|
||||
model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "")
|
||||
if not model_cls:
|
||||
return ""
|
||||
with Session(db.engine) as session:
|
||||
name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id))
|
||||
return str(name) if name else ""
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||
match provider:
|
||||
@@ -361,10 +317,6 @@ class OpsTraceManager:
|
||||
if app_id is None:
|
||||
return None
|
||||
|
||||
# Handle storage_id format (tenant-{uuid}) - not a real app_id
|
||||
if isinstance(app_id, str) and app_id.startswith("tenant-"):
|
||||
return None
|
||||
|
||||
app: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
|
||||
if app is None:
|
||||
@@ -527,56 +479,6 @@ class TraceTask:
|
||||
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return cls._workflow_run_repo
|
||||
|
||||
@classmethod
|
||||
def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str:
|
||||
"""Extract user ID from metadata, prioritizing end_user over account.
|
||||
|
||||
Returns the actual user ID (end_user or account) who invoked the workflow,
|
||||
regardless of invoke_from context.
|
||||
"""
|
||||
# Priority 1: End user (external users via API/WebApp)
|
||||
if user_id := metadata.get("from_end_user_id"):
|
||||
return f"end_user:{user_id}"
|
||||
|
||||
# Priority 2: Account user (internal users via console/debugger)
|
||||
if user_id := metadata.get("from_account_id"):
|
||||
return f"account:{user_id}"
|
||||
|
||||
# Priority 3: User (internal users via console/debugger)
|
||||
if user_id := metadata.get("user_id"):
|
||||
return f"user:{user_id}"
|
||||
|
||||
return "anonymous"
|
||||
|
||||
@classmethod
|
||||
def _calculate_workflow_token_split(cls, workflow_run_id: str, tenant_id: str) -> tuple[int, int]:
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
with Session(db.engine) as session:
|
||||
node_executions = session.scalars(
|
||||
select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
total_prompt = 0
|
||||
total_completion = 0
|
||||
|
||||
for node_exec in node_executions:
|
||||
metadata = node_exec.execution_metadata_dict
|
||||
|
||||
prompt = metadata.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS)
|
||||
if prompt is not None:
|
||||
total_prompt += prompt
|
||||
|
||||
completion = metadata.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS)
|
||||
if completion is not None:
|
||||
total_completion += completion
|
||||
|
||||
return (total_prompt, total_completion)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_type: Any,
|
||||
@@ -597,8 +499,6 @@ class TraceTask:
|
||||
self.app_id = None
|
||||
self.trace_id = None
|
||||
self.kwargs = kwargs
|
||||
if user_id is not None and "user_id" not in self.kwargs:
|
||||
self.kwargs["user_id"] = user_id
|
||||
external_trace_id = kwargs.get("external_trace_id")
|
||||
if external_trace_id:
|
||||
self.trace_id = external_trace_id
|
||||
@@ -612,7 +512,7 @@ class TraceTask:
|
||||
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
|
||||
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
|
||||
),
|
||||
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs),
|
||||
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
|
||||
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
|
||||
message_id=self.message_id, timer=self.timer, **self.kwargs
|
||||
),
|
||||
@@ -628,9 +528,6 @@ class TraceTask:
|
||||
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
|
||||
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
|
||||
),
|
||||
TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs),
|
||||
TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs),
|
||||
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs),
|
||||
}
|
||||
|
||||
return preprocess_map.get(self.trace_type, lambda: None)()
|
||||
@@ -666,10 +563,6 @@ class TraceTask:
|
||||
|
||||
total_tokens = workflow_run.total_tokens
|
||||
|
||||
prompt_tokens, completion_tokens = self._calculate_workflow_token_split(
|
||||
workflow_run_id=workflow_run_id, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
file_list = workflow_run_inputs.get("sys.file") or []
|
||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||
|
||||
@@ -690,9 +583,7 @@ class TraceTask:
|
||||
)
|
||||
message_id = session.scalar(message_data_stmt)
|
||||
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id)
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
metadata = {
|
||||
"workflow_id": workflow_id,
|
||||
"conversation_id": conversation_id,
|
||||
"workflow_run_id": workflow_run_id,
|
||||
@@ -705,14 +596,8 @@ class TraceTask:
|
||||
"triggered_from": workflow_run.triggered_from,
|
||||
"user_id": user_id,
|
||||
"app_id": workflow_run.app_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
}
|
||||
|
||||
parent_trace_context = self.kwargs.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
workflow_data=workflow_run.to_dict(),
|
||||
@@ -727,8 +612,6 @@ class TraceTask:
|
||||
workflow_run_version=workflow_run_version,
|
||||
error=error,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
file_list=file_list,
|
||||
query=query,
|
||||
metadata=metadata,
|
||||
@@ -736,11 +619,10 @@ class TraceTask:
|
||||
message_id=message_id,
|
||||
start_time=workflow_run.created_at,
|
||||
end_time=workflow_run.finished_at,
|
||||
invoked_by=self._get_user_id_from_metadata(metadata),
|
||||
)
|
||||
return workflow_trace_info
|
||||
|
||||
def message_trace(self, message_id: str | None, **kwargs):
|
||||
def message_trace(self, message_id: str | None):
|
||||
if not message_id:
|
||||
return {}
|
||||
message_data = get_message_data(message_id)
|
||||
@@ -763,14 +645,6 @@ class TraceTask:
|
||||
|
||||
streaming_metrics = self._extract_streaming_metrics(message_data)
|
||||
|
||||
tenant_id = ""
|
||||
with Session(db.engine) as session:
|
||||
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
|
||||
if tid:
|
||||
tenant_id = str(tid)
|
||||
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
|
||||
metadata = {
|
||||
"conversation_id": message_data.conversation_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
@@ -782,14 +656,7 @@ class TraceTask:
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
"message_id": message_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": message_data.app_id,
|
||||
"user_id": message_data.from_end_user_id or message_data.from_account_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
message_tokens = message_data.message_tokens
|
||||
|
||||
@@ -831,8 +698,6 @@ class TraceTask:
|
||||
"preset_response": moderation_result.preset_response,
|
||||
"query": moderation_result.query,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
@@ -874,8 +739,6 @@ class TraceTask:
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
@@ -915,36 +778,6 @@ class TraceTask:
|
||||
if not message_data:
|
||||
return {}
|
||||
|
||||
tenant_id = ""
|
||||
with Session(db.engine) as session:
|
||||
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
|
||||
if tid:
|
||||
tenant_id = str(tid)
|
||||
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
|
||||
doc_list = [doc.model_dump() for doc in documents] if documents else []
|
||||
dataset_ids: set[str] = set()
|
||||
for doc in doc_list:
|
||||
doc_meta = doc.get("metadata") or {}
|
||||
did = doc_meta.get("dataset_id")
|
||||
if did:
|
||||
dataset_ids.add(did)
|
||||
|
||||
embedding_models: dict[str, dict[str, str]] = {}
|
||||
if dataset_ids:
|
||||
with Session(db.engine) as session:
|
||||
rows = session.execute(
|
||||
select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where(
|
||||
Dataset.id.in_(list(dataset_ids))
|
||||
)
|
||||
).all()
|
||||
for row in rows:
|
||||
embedding_models[str(row[0])] = {
|
||||
"embedding_model": row[1] or "",
|
||||
"embedding_model_provider": row[2] or "",
|
||||
}
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
@@ -955,21 +788,13 @@ class TraceTask:
|
||||
"agent_based": message_data.agent_based,
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": message_data.app_id,
|
||||
"user_id": message_data.from_end_user_id or message_data.from_account_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"embedding_models": embedding_models,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=message_id,
|
||||
inputs=message_data.query or message_data.inputs,
|
||||
documents=doc_list,
|
||||
documents=[doc.model_dump() for doc in documents] if documents else [],
|
||||
start_time=timer.get("start"),
|
||||
end_time=timer.get("end"),
|
||||
metadata=metadata,
|
||||
@@ -1012,10 +837,6 @@ class TraceTask:
|
||||
"error": error,
|
||||
"tool_parameters": tool_parameters,
|
||||
}
|
||||
if message_data.workflow_run_id:
|
||||
metadata["workflow_run_id"] = message_data.workflow_run_id
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
file_url = ""
|
||||
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
|
||||
@@ -1070,8 +891,6 @@ class TraceTask:
|
||||
"conversation_id": conversation_id,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
generate_name_trace_info = GenerateNameTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
@@ -1086,158 +905,6 @@ class TraceTask:
|
||||
|
||||
return generate_name_trace_info
|
||||
|
||||
def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict:
|
||||
tenant_id = kwargs.get("tenant_id", "")
|
||||
user_id = kwargs.get("user_id", "")
|
||||
app_id = kwargs.get("app_id")
|
||||
operation_type = kwargs.get("operation_type", "")
|
||||
instruction = kwargs.get("instruction", "")
|
||||
generated_output = kwargs.get("generated_output", "")
|
||||
|
||||
prompt_tokens = kwargs.get("prompt_tokens", 0)
|
||||
completion_tokens = kwargs.get("completion_tokens", 0)
|
||||
total_tokens = kwargs.get("total_tokens", 0)
|
||||
|
||||
model_provider = kwargs.get("model_provider", "")
|
||||
model_name = kwargs.get("model_name", "")
|
||||
|
||||
latency = kwargs.get("latency", 0.0)
|
||||
|
||||
timer = kwargs.get("timer")
|
||||
start_time = timer.get("start") if timer else None
|
||||
end_time = timer.get("end") if timer else None
|
||||
|
||||
total_price = kwargs.get("total_price")
|
||||
currency = kwargs.get("currency")
|
||||
|
||||
error = kwargs.get("error")
|
||||
|
||||
app_name = None
|
||||
workspace_name = None
|
||||
if app_id:
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id)
|
||||
|
||||
metadata = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"app_id": app_id or "",
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"operation_type": operation_type,
|
||||
"model_provider": model_provider,
|
||||
"model_name": model_name,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
return PromptGenerationTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
inputs=instruction,
|
||||
outputs=generated_output,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
metadata=metadata,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=operation_type,
|
||||
instruction=instruction,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
latency=latency,
|
||||
total_price=total_price,
|
||||
currency=currency,
|
||||
error=error,
|
||||
)
|
||||
|
||||
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
|
||||
node_data: dict = kwargs.get("node_execution_data", {})
|
||||
if not node_data:
|
||||
return {}
|
||||
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(node_data.get("app_id"), node_data.get("tenant_id"))
|
||||
|
||||
credential_name = _lookup_credential_name(
|
||||
node_data.get("credential_id"), node_data.get("credential_provider_type")
|
||||
)
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"tenant_id": node_data.get("tenant_id"),
|
||||
"app_id": node_data.get("app_id"),
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"user_id": node_data.get("user_id"),
|
||||
"dataset_ids": node_data.get("dataset_ids"),
|
||||
"dataset_names": node_data.get("dataset_names"),
|
||||
"plugin_name": node_data.get("plugin_name"),
|
||||
"credential_name": credential_name,
|
||||
}
|
||||
|
||||
parent_trace_context = node_data.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
|
||||
message_id: str | None = None
|
||||
conversation_id = node_data.get("conversation_id")
|
||||
workflow_execution_id = node_data.get("workflow_execution_id")
|
||||
if conversation_id and workflow_execution_id and not parent_trace_context:
|
||||
with Session(db.engine) as session:
|
||||
msg_id = session.scalar(
|
||||
select(Message.id).where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.workflow_run_id == workflow_execution_id,
|
||||
)
|
||||
)
|
||||
if msg_id:
|
||||
message_id = str(msg_id)
|
||||
metadata["message_id"] = message_id
|
||||
|
||||
return WorkflowNodeTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=message_id,
|
||||
start_time=node_data.get("created_at"),
|
||||
end_time=node_data.get("finished_at"),
|
||||
metadata=metadata,
|
||||
workflow_id=node_data.get("workflow_id", ""),
|
||||
workflow_run_id=node_data.get("workflow_execution_id", ""),
|
||||
tenant_id=node_data.get("tenant_id", ""),
|
||||
node_execution_id=node_data.get("node_execution_id", ""),
|
||||
node_id=node_data.get("node_id", ""),
|
||||
node_type=node_data.get("node_type", ""),
|
||||
title=node_data.get("title", ""),
|
||||
status=node_data.get("status", ""),
|
||||
error=node_data.get("error"),
|
||||
elapsed_time=node_data.get("elapsed_time", 0.0),
|
||||
index=node_data.get("index", 0),
|
||||
predecessor_node_id=node_data.get("predecessor_node_id"),
|
||||
total_tokens=node_data.get("total_tokens", 0),
|
||||
total_price=node_data.get("total_price", 0.0),
|
||||
currency=node_data.get("currency"),
|
||||
model_provider=node_data.get("model_provider"),
|
||||
model_name=node_data.get("model_name"),
|
||||
prompt_tokens=node_data.get("prompt_tokens"),
|
||||
completion_tokens=node_data.get("completion_tokens"),
|
||||
tool_name=node_data.get("tool_name"),
|
||||
iteration_id=node_data.get("iteration_id"),
|
||||
iteration_index=node_data.get("iteration_index"),
|
||||
loop_id=node_data.get("loop_id"),
|
||||
loop_index=node_data.get("loop_index"),
|
||||
parallel_id=node_data.get("parallel_id"),
|
||||
node_inputs=node_data.get("node_inputs"),
|
||||
node_outputs=node_data.get("node_outputs"),
|
||||
process_data=node_data.get("process_data"),
|
||||
invoked_by=self._get_user_id_from_metadata(metadata),
|
||||
)
|
||||
|
||||
def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict:
|
||||
node_trace = self.node_execution_trace(**kwargs)
|
||||
if not node_trace or not isinstance(node_trace, WorkflowNodeTraceInfo):
|
||||
return node_trace
|
||||
return DraftNodeExecutionTrace(**node_trace.model_dump())
|
||||
|
||||
def _extract_streaming_metrics(self, message_data) -> dict:
|
||||
if not message_data.message_metadata:
|
||||
return {}
|
||||
@@ -1271,17 +938,13 @@ class TraceQueueManager:
|
||||
self.user_id = user_id
|
||||
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
|
||||
self.flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
from core.telemetry import is_enterprise_telemetry_enabled
|
||||
|
||||
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
|
||||
if trace_manager_timer is None:
|
||||
self.start_timer()
|
||||
|
||||
def add_trace_task(self, trace_task: TraceTask):
|
||||
global trace_manager_timer, trace_manager_queue
|
||||
try:
|
||||
if self._enterprise_telemetry_enabled or self.trace_instance:
|
||||
if self.trace_instance:
|
||||
trace_task.app_id = self.app_id
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception:
|
||||
@@ -1317,27 +980,20 @@ class TraceQueueManager:
|
||||
def send_to_celery(self, tasks: list[TraceTask]):
|
||||
with self.flask_app.app_context():
|
||||
for task in tasks:
|
||||
storage_id = task.app_id
|
||||
if storage_id is None:
|
||||
tenant_id = task.kwargs.get("tenant_id")
|
||||
if tenant_id:
|
||||
storage_id = f"tenant-{tenant_id}"
|
||||
else:
|
||||
logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type)
|
||||
continue
|
||||
|
||||
if task.app_id is None:
|
||||
continue
|
||||
file_id = uuid4().hex
|
||||
trace_info = task.execute()
|
||||
|
||||
task_data = TaskData(
|
||||
app_id=storage_id,
|
||||
app_id=task.app_id,
|
||||
trace_info_type=type(trace_info).__name__,
|
||||
trace_info=trace_info.model_dump() if trace_info else None,
|
||||
)
|
||||
file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json"
|
||||
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
|
||||
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
|
||||
file_info = {
|
||||
"file_id": file_id,
|
||||
"app_id": storage_id,
|
||||
"app_id": task.app_id,
|
||||
}
|
||||
process_trace_tasks.delay(file_info) # type: ignore
|
||||
|
||||
@@ -27,7 +27,8 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
@@ -55,8 +56,6 @@ from core.rag.retrieval.template_prompts import (
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
@@ -729,21 +728,10 @@ class DatasetRetrieval:
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
app_config = self.application_generate_entity.app_config if self.application_generate_entity else None
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=app_config.tenant_id if app_config else None,
|
||||
app_id=app_config.app_id if app_config else None,
|
||||
),
|
||||
payload={
|
||||
"message_id": message_id,
|
||||
"documents": documents,
|
||||
"timer": timer,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
def _on_query(
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
"""Community telemetry helpers.
|
||||
|
||||
Provides ``emit()`` which enqueues trace events into the CE trace pipeline
|
||||
(``TraceQueueManager`` → ``ops_trace`` Celery queue → Langfuse / LangSmith / etc.).
|
||||
|
||||
Enterprise-only traces (node execution, draft node execution, prompt generation)
|
||||
are silently dropped when enterprise telemetry is disabled.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.telemetry.events import TelemetryContext, TelemetryEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
_ENTERPRISE_ONLY_TRACES: frozenset[TraceTaskName] = frozenset(
|
||||
{
|
||||
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
TraceTaskName.PROMPT_GENERATION_TRACE,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_enterprise_telemetry_enabled() -> bool:
|
||||
try:
|
||||
from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled
|
||||
|
||||
return is_enterprise_telemetry_enabled()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager
|
||||
from core.ops.ops_trace_manager import TraceTask
|
||||
|
||||
if event.name in _ENTERPRISE_ONLY_TRACES and not _is_enterprise_telemetry_enabled():
|
||||
return
|
||||
|
||||
queue_manager = trace_manager or LocalTraceQueueManager(
|
||||
app_id=event.context.app_id,
|
||||
user_id=event.context.user_id,
|
||||
)
|
||||
queue_manager.add_trace_task(TraceTask(event.name, **event.payload))
|
||||
|
||||
|
||||
is_enterprise_telemetry_enabled = _is_enterprise_telemetry_enabled
|
||||
|
||||
__all__ = [
|
||||
"TelemetryContext",
|
||||
"TelemetryEvent",
|
||||
"TraceTaskName",
|
||||
"emit",
|
||||
"is_enterprise_telemetry_enabled",
|
||||
]
|
||||
@@ -1,21 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TelemetryContext:
|
||||
tenant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TelemetryEvent:
|
||||
name: TraceTaskName
|
||||
context: TelemetryContext
|
||||
payload: dict[str, Any]
|
||||
@@ -50,7 +50,6 @@ class WorkflowTool(Tool):
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.label = label
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
self.parent_trace_context: dict[str, str] | None = None
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
||||
@@ -91,15 +90,11 @@ class WorkflowTool(Tool):
|
||||
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
args: dict[str, Any] = {"inputs": tool_parameters, "files": files}
|
||||
if self.parent_trace_context:
|
||||
args["_parent_trace_context"] = self.parent_trace_context
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
args={"inputs": tool_parameters, "files": files},
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
|
||||
@@ -232,8 +232,6 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROMPT_TOKENS = "prompt_tokens"
|
||||
COMPLETION_TOKENS = "completion_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
|
||||
@@ -322,8 +322,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS: usage.prompt_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS: usage.completion_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
|
||||
@@ -61,7 +61,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
"provider_type": self.node_data.provider_type.value,
|
||||
"provider_id": self.node_data.provider_id,
|
||||
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
||||
"credential_id": self.node_data.credential_id,
|
||||
}
|
||||
|
||||
# get tool runtime
|
||||
@@ -106,20 +105,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
|
||||
if isinstance(tool_runtime, WorkflowTool):
|
||||
workflow_run_id_var = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID]
|
||||
)
|
||||
tool_runtime.parent_trace_context = {
|
||||
"trace_id": str(workflow_run_id_var.text) if workflow_run_id_var else "",
|
||||
"parent_node_execution_id": self.execution_id,
|
||||
"parent_workflow_run_id": str(workflow_run_id_var.text) if workflow_run_id_var else "",
|
||||
"parent_app_id": self.app_id,
|
||||
"parent_conversation_id": conversation_id.text if conversation_id else None,
|
||||
}
|
||||
|
||||
try:
|
||||
message_stream = ToolEngine.generic_invoke(
|
||||
tool=tool_runtime,
|
||||
@@ -446,8 +431,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
}
|
||||
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS] = usage.prompt_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS] = usage.completion_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Telemetry gateway contracts and data structures.
|
||||
|
||||
This module defines the envelope format for telemetry events and the routing
|
||||
configuration that determines how each event type is processed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class TelemetryCase(StrEnum):
|
||||
"""Enumeration of all known telemetry event cases."""
|
||||
|
||||
WORKFLOW_RUN = "workflow_run"
|
||||
NODE_EXECUTION = "node_execution"
|
||||
DRAFT_NODE_EXECUTION = "draft_node_execution"
|
||||
MESSAGE_RUN = "message_run"
|
||||
TOOL_EXECUTION = "tool_execution"
|
||||
MODERATION_CHECK = "moderation_check"
|
||||
SUGGESTED_QUESTION = "suggested_question"
|
||||
DATASET_RETRIEVAL = "dataset_retrieval"
|
||||
GENERATE_NAME = "generate_name"
|
||||
PROMPT_GENERATION = "prompt_generation"
|
||||
APP_CREATED = "app_created"
|
||||
APP_UPDATED = "app_updated"
|
||||
APP_DELETED = "app_deleted"
|
||||
FEEDBACK_CREATED = "feedback_created"
|
||||
|
||||
|
||||
class SignalType(StrEnum):
|
||||
"""Signal routing type for telemetry cases."""
|
||||
|
||||
TRACE = "trace"
|
||||
METRIC_LOG = "metric_log"
|
||||
|
||||
|
||||
class CaseRoute(BaseModel):
|
||||
"""Routing configuration for a telemetry case.
|
||||
|
||||
Attributes:
|
||||
signal_type: The type of signal (trace or metric_log).
|
||||
ce_eligible: Whether this case is eligible for community edition tracing.
|
||||
"""
|
||||
|
||||
signal_type: SignalType
|
||||
ce_eligible: bool
|
||||
|
||||
|
||||
class TelemetryEnvelope(BaseModel):
|
||||
"""Envelope for telemetry events.
|
||||
|
||||
Attributes:
|
||||
case: The telemetry case type.
|
||||
tenant_id: The tenant identifier.
|
||||
event_id: Unique event identifier for deduplication.
|
||||
payload: The main event payload.
|
||||
payload_fallback: Fallback payload (max 64KB).
|
||||
metadata: Optional metadata dictionary.
|
||||
"""
|
||||
|
||||
case: TelemetryCase
|
||||
tenant_id: str
|
||||
event_id: str
|
||||
payload: dict[str, Any]
|
||||
payload_fallback: bytes | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("payload_fallback")
|
||||
@classmethod
|
||||
def validate_payload_fallback_size(cls, v: bytes | None) -> bytes | None:
|
||||
"""Validate that payload_fallback does not exceed 64KB."""
|
||||
if v is not None and len(v) > 65536: # 64 * 1024
|
||||
raise ValueError("payload_fallback must not exceed 64KB")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
"""Pydantic configuration."""
|
||||
|
||||
use_enum_values = False
|
||||
@@ -1,77 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
|
||||
def enqueue_draft_node_execution_trace(
|
||||
*,
|
||||
execution: WorkflowNodeExecutionModel,
|
||||
outputs: Mapping[str, Any] | None,
|
||||
workflow_execution_id: str | None,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
node_data = _build_node_execution_data(
|
||||
execution=execution,
|
||||
outputs=outputs,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
)
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=execution.tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=execution.app_id,
|
||||
),
|
||||
payload={"node_execution_data": node_data},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _build_node_execution_data(
|
||||
*,
|
||||
execution: WorkflowNodeExecutionModel,
|
||||
outputs: Mapping[str, Any] | None,
|
||||
workflow_execution_id: str | None,
|
||||
) -> dict[str, Any]:
|
||||
metadata = execution.execution_metadata_dict
|
||||
node_outputs = outputs if outputs is not None else execution.outputs_dict
|
||||
execution_id = workflow_execution_id or execution.workflow_run_id or execution.id
|
||||
|
||||
return {
|
||||
"workflow_id": execution.workflow_id,
|
||||
"workflow_execution_id": execution_id,
|
||||
"tenant_id": execution.tenant_id,
|
||||
"app_id": execution.app_id,
|
||||
"node_execution_id": execution.id,
|
||||
"node_id": execution.node_id,
|
||||
"node_type": execution.node_type,
|
||||
"title": execution.title,
|
||||
"status": execution.status,
|
||||
"error": execution.error,
|
||||
"elapsed_time": execution.elapsed_time,
|
||||
"index": execution.index,
|
||||
"predecessor_node_id": execution.predecessor_node_id,
|
||||
"created_at": execution.created_at,
|
||||
"finished_at": execution.finished_at,
|
||||
"total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
|
||||
"total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
|
||||
"currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
|
||||
"tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
|
||||
if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
|
||||
else None,
|
||||
"iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
|
||||
"iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
|
||||
"loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
|
||||
"loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
|
||||
"parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
|
||||
"node_inputs": execution.inputs_dict,
|
||||
"node_outputs": node_outputs,
|
||||
"process_data": execution.process_data_dict,
|
||||
}
|
||||
@@ -1,903 +0,0 @@
|
||||
"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass.
|
||||
|
||||
Invoked directly in the Celery task, not through OpsTraceManager dispatch.
|
||||
Only requires a matching ``trace(trace_info)`` method signature.
|
||||
|
||||
Signal strategy:
|
||||
- **Traces (spans)**: workflow run, node execution, draft node execution only.
|
||||
- **Metrics + structured logs**: all other event types.
|
||||
|
||||
Token metric labels (unified structure):
|
||||
All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the
|
||||
same label set for consistent filtering and aggregation:
|
||||
- tenant_id: Tenant identifier
|
||||
- app_id: Application identifier
|
||||
- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.)
|
||||
- model_provider: LLM provider name (empty string if not applicable)
|
||||
- model_name: LLM model name (empty string if not applicable)
|
||||
- node_type: Workflow node type (empty string if not node_execution)
|
||||
|
||||
This unified structure allows filtering by operation_type to separate:
|
||||
- Workflow-level aggregates (operation_type=workflow)
|
||||
- Individual node executions (operation_type=node_execution)
|
||||
- Direct message calls (operation_type=message)
|
||||
- Prompt generation operations (operation_type=rule_generate, code_generate, etc.)
|
||||
|
||||
Without this, tokens are double-counted when querying totals (workflow totals include
|
||||
node totals, since workflow.total_tokens is the sum of all node tokens).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
DraftNodeExecutionTrace,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
OperationType,
|
||||
PromptGenerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowNodeTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from enterprise.telemetry.entities import (
|
||||
EnterpriseTelemetryCounter,
|
||||
EnterpriseTelemetryEvent,
|
||||
EnterpriseTelemetryHistogram,
|
||||
EnterpriseTelemetrySpan,
|
||||
TokenMetricLabels,
|
||||
)
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnterpriseOtelTrace:
|
||||
"""Duck-typed enterprise trace handler.
|
||||
|
||||
``*_trace`` methods emit spans (workflow/node only) or structured logs
|
||||
(all other events), plus metrics at 100 % accuracy.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if exporter is None:
|
||||
raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized")
|
||||
self._exporter = exporter
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo) -> None:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self._workflow_trace(trace_info)
|
||||
elif isinstance(trace_info, MessageTraceInfo):
|
||||
self._message_trace(trace_info)
|
||||
elif isinstance(trace_info, ToolTraceInfo):
|
||||
self._tool_trace(trace_info)
|
||||
elif isinstance(trace_info, DraftNodeExecutionTrace):
|
||||
self._draft_node_execution_trace(trace_info)
|
||||
elif isinstance(trace_info, WorkflowNodeTraceInfo):
|
||||
self._node_execution_trace(trace_info)
|
||||
elif isinstance(trace_info, ModerationTraceInfo):
|
||||
self._moderation_trace(trace_info)
|
||||
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self._suggested_question_trace(trace_info)
|
||||
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self._dataset_retrieval_trace(trace_info)
|
||||
elif isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self._generate_name_trace(trace_info)
|
||||
elif isinstance(trace_info, PromptGenerationTraceInfo):
|
||||
self._prompt_generation_trace(trace_info)
|
||||
|
||||
def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
|
||||
metadata = self._metadata(trace_info)
|
||||
tenant_id, app_id, user_id = self._context_ids(trace_info, metadata)
|
||||
return {
|
||||
"dify.trace_id": trace_info.trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"dify.app_id": app_id,
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"gen_ai.user.id": user_id,
|
||||
"dify.message.id": trace_info.message_id,
|
||||
}
|
||||
|
||||
def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
|
||||
return trace_info.metadata
|
||||
|
||||
def _context_ids(
|
||||
self,
|
||||
trace_info: BaseTraceInfo,
|
||||
metadata: dict[str, Any],
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id")
|
||||
app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id")
|
||||
user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id")
|
||||
return tenant_id, app_id, user_id
|
||||
|
||||
def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]:
|
||||
return dict(values)
|
||||
|
||||
def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return cast(dict[str, Any], value)
|
||||
if isinstance(value, list):
|
||||
items: list[object] = []
|
||||
for item in cast(list[object], value):
|
||||
items.append(item)
|
||||
return items
|
||||
return None
|
||||
|
||||
def _content_or_ref(self, value: Any, ref: str) -> Any:
|
||||
if self._exporter.include_content:
|
||||
return self._maybe_json(value)
|
||||
return ref
|
||||
|
||||
def _maybe_json(self, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.dumps(value, default=str)
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SPAN-emitting handlers (workflow, node execution, draft node)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _workflow_trace(self, info: WorkflowTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
# -- Slim span attrs: identity + structure + status + timing only --
|
||||
span_attrs: dict[str, Any] = {
|
||||
"dify.trace_id": info.trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"dify.app_id": app_id,
|
||||
"dify.workflow.id": info.workflow_id,
|
||||
"dify.workflow.run_id": info.workflow_run_id,
|
||||
"dify.workflow.status": info.workflow_run_status,
|
||||
"dify.workflow.error": info.error,
|
||||
"dify.workflow.elapsed_time": info.workflow_run_elapsed_time,
|
||||
"dify.invoke_from": metadata.get("triggered_from"),
|
||||
"dify.conversation.id": info.conversation_id,
|
||||
"dify.message.id": info.message_id,
|
||||
"dify.invoked_by": info.invoked_by,
|
||||
}
|
||||
|
||||
trace_correlation_override: str | None = None
|
||||
parent_span_id_source: str | None = None
|
||||
|
||||
parent_ctx = metadata.get("parent_trace_context")
|
||||
if isinstance(parent_ctx, dict):
|
||||
parent_ctx_dict = cast(dict[str, Any], parent_ctx)
|
||||
span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id")
|
||||
span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id")
|
||||
span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id")
|
||||
span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id")
|
||||
|
||||
trace_override_value = parent_ctx_dict.get("parent_workflow_run_id")
|
||||
if isinstance(trace_override_value, str):
|
||||
trace_correlation_override = trace_override_value
|
||||
parent_span_value = parent_ctx_dict.get("parent_node_execution_id")
|
||||
if isinstance(parent_span_value, str):
|
||||
parent_span_id_source = parent_span_value
|
||||
|
||||
self._exporter.export_span(
|
||||
EnterpriseTelemetrySpan.WORKFLOW_RUN,
|
||||
span_attrs,
|
||||
correlation_id=info.workflow_run_id,
|
||||
span_id_source=info.workflow_run_id,
|
||||
start_time=info.start_time,
|
||||
end_time=info.end_time,
|
||||
trace_correlation_override=trace_correlation_override,
|
||||
parent_span_id_source=parent_span_id_source,
|
||||
)
|
||||
|
||||
# -- Companion log: ALL attrs (span + detail) for full picture --
|
||||
log_attrs: dict[str, Any] = {**span_attrs}
|
||||
log_attrs.update(
|
||||
{
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"gen_ai.user.id": user_id,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.workflow.version": info.workflow_run_version,
|
||||
}
|
||||
)
|
||||
|
||||
ref = f"ref:workflow_run_id={info.workflow_run_id}"
|
||||
log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref)
|
||||
log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref)
|
||||
log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref)
|
||||
|
||||
emit_telemetry_log(
|
||||
event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN,
|
||||
attributes=log_attrs,
|
||||
signal="span_detail",
|
||||
trace_id_source=info.workflow_run_id,
|
||||
span_id_source=info.workflow_run_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# -- Metrics --
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=OperationType.WORKFLOW,
|
||||
model_provider="",
|
||||
model_name="",
|
||||
node_type="",
|
||||
).to_dict()
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.prompt_tokens is not None and info.prompt_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
|
||||
if info.completion_tokens is not None and info.completion_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
|
||||
)
|
||||
invoke_from = metadata.get("triggered_from", "")
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="workflow",
|
||||
status=info.workflow_run_status,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
)
|
||||
self._exporter.record_histogram(
|
||||
EnterpriseTelemetryHistogram.WORKFLOW_DURATION,
|
||||
float(info.workflow_run_elapsed_time),
|
||||
self._labels(
|
||||
**labels,
|
||||
status=info.workflow_run_status,
|
||||
),
|
||||
)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="workflow",
|
||||
),
|
||||
)
|
||||
|
||||
def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None:
|
||||
self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node")
|
||||
|
||||
def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None:
|
||||
self._emit_node_execution_trace(
|
||||
info,
|
||||
EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION,
|
||||
"draft_node",
|
||||
correlation_id_override=info.node_execution_id,
|
||||
trace_correlation_override_param=info.workflow_run_id,
|
||||
)
|
||||
|
||||
def _emit_node_execution_trace(
|
||||
self,
|
||||
info: WorkflowNodeTraceInfo,
|
||||
span_name: EnterpriseTelemetrySpan,
|
||||
request_type: str,
|
||||
correlation_id_override: str | None = None,
|
||||
trace_correlation_override_param: str | None = None,
|
||||
) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
# -- Slim span attrs: identity + structure + status + timing --
|
||||
span_attrs: dict[str, Any] = {
|
||||
"dify.trace_id": info.trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"dify.app_id": app_id,
|
||||
"dify.workflow.id": info.workflow_id,
|
||||
"dify.workflow.run_id": info.workflow_run_id,
|
||||
"dify.message.id": info.message_id,
|
||||
"dify.conversation.id": metadata.get("conversation_id"),
|
||||
"dify.node.execution_id": info.node_execution_id,
|
||||
"dify.node.id": info.node_id,
|
||||
"dify.node.type": info.node_type,
|
||||
"dify.node.title": info.title,
|
||||
"dify.node.status": info.status,
|
||||
"dify.node.error": info.error,
|
||||
"dify.node.elapsed_time": info.elapsed_time,
|
||||
"dify.node.index": info.index,
|
||||
"dify.node.predecessor_node_id": info.predecessor_node_id,
|
||||
"dify.node.iteration_id": info.iteration_id,
|
||||
"dify.node.loop_id": info.loop_id,
|
||||
"dify.node.parallel_id": info.parallel_id,
|
||||
"dify.node.invoked_by": info.invoked_by,
|
||||
}
|
||||
|
||||
trace_correlation_override = trace_correlation_override_param
|
||||
parent_ctx = metadata.get("parent_trace_context")
|
||||
if isinstance(parent_ctx, dict):
|
||||
parent_ctx_dict = cast(dict[str, Any], parent_ctx)
|
||||
override_value = parent_ctx_dict.get("parent_workflow_run_id")
|
||||
if isinstance(override_value, str):
|
||||
trace_correlation_override = override_value
|
||||
|
||||
effective_correlation_id = correlation_id_override or info.workflow_run_id
|
||||
self._exporter.export_span(
|
||||
span_name,
|
||||
span_attrs,
|
||||
correlation_id=effective_correlation_id,
|
||||
span_id_source=info.node_execution_id,
|
||||
start_time=info.start_time,
|
||||
end_time=info.end_time,
|
||||
trace_correlation_override=trace_correlation_override,
|
||||
)
|
||||
|
||||
# -- Companion log: ALL attrs (span + detail) --
|
||||
log_attrs: dict[str, Any] = {**span_attrs}
|
||||
log_attrs.update(
|
||||
{
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"dify.invoke_from": metadata.get("invoke_from"),
|
||||
"gen_ai.user.id": user_id,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.node.total_price": info.total_price,
|
||||
"dify.node.currency": info.currency,
|
||||
"gen_ai.provider.name": info.model_provider,
|
||||
"gen_ai.request.model": info.model_name,
|
||||
"gen_ai.tool.name": info.tool_name,
|
||||
"dify.node.iteration_index": info.iteration_index,
|
||||
"dify.node.loop_index": info.loop_index,
|
||||
"dify.plugin.name": metadata.get("plugin_name"),
|
||||
"dify.credential.name": metadata.get("credential_name"),
|
||||
"dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")),
|
||||
"dify.dataset.names": self._maybe_json(metadata.get("dataset_names")),
|
||||
}
|
||||
)
|
||||
|
||||
ref = f"ref:node_execution_id={info.node_execution_id}"
|
||||
log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref)
|
||||
log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref)
|
||||
log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref)
|
||||
|
||||
emit_telemetry_log(
|
||||
event_name=span_name.value,
|
||||
attributes=log_attrs,
|
||||
signal="span_detail",
|
||||
trace_id_source=info.workflow_run_id,
|
||||
span_id_source=info.node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# -- Metrics --
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
node_type=info.node_type,
|
||||
model_provider=info.model_provider or "",
|
||||
)
|
||||
if info.total_tokens:
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=OperationType.NODE_EXECUTION,
|
||||
model_provider=info.model_provider or "",
|
||||
model_name=info.model_name or "",
|
||||
node_type=info.node_type,
|
||||
).to_dict()
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.prompt_tokens is not None and info.prompt_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels
|
||||
)
|
||||
if info.completion_tokens is not None and info.completion_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type=request_type,
|
||||
status=info.status,
|
||||
),
|
||||
)
|
||||
duration_labels = dict(labels)
|
||||
plugin_name = metadata.get("plugin_name")
|
||||
if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}:
|
||||
duration_labels["plugin_name"] = plugin_name
|
||||
self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type=request_type,
|
||||
),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# METRIC-ONLY handlers (structured log + counters/histograms)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _message_trace(self, info: MessageTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs.update(
|
||||
{
|
||||
"dify.invoke_from": metadata.get("from_source"),
|
||||
"dify.conversation.id": metadata.get("conversation_id"),
|
||||
"dify.conversation.mode": info.conversation_mode,
|
||||
"gen_ai.provider.name": metadata.get("ls_provider"),
|
||||
"gen_ai.request.model": metadata.get("ls_model_name"),
|
||||
"gen_ai.usage.input_tokens": info.message_tokens,
|
||||
"gen_ai.usage.output_tokens": info.answer_tokens,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.message.status": metadata.get("status"),
|
||||
"dify.message.error": info.error,
|
||||
"dify.message.from_source": metadata.get("from_source"),
|
||||
"dify.message.from_end_user_id": metadata.get("from_end_user_id"),
|
||||
"dify.message.from_account_id": metadata.get("from_account_id"),
|
||||
"dify.streaming": info.is_streaming_request,
|
||||
"dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token,
|
||||
"dify.message.streaming_duration": info.llm_streaming_time_to_generate,
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
ref = f"ref:message_id={info.message_id}"
|
||||
inputs = self._safe_payload_value(info.inputs)
|
||||
outputs = self._safe_payload_value(info.outputs)
|
||||
attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref)
|
||||
attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.MESSAGE_RUN,
|
||||
attributes=attrs,
|
||||
trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
model_provider=metadata.get("ls_provider", ""),
|
||||
model_name=metadata.get("ls_model_name", ""),
|
||||
)
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=OperationType.MESSAGE,
|
||||
model_provider=metadata.get("ls_provider", ""),
|
||||
model_name=metadata.get("ls_model_name", ""),
|
||||
node_type="",
|
||||
).to_dict()
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.message_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels)
|
||||
if info.answer_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels)
|
||||
invoke_from = metadata.get("from_source", "")
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="message",
|
||||
status=metadata.get("status", ""),
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
)
|
||||
|
||||
if info.start_time and info.end_time:
|
||||
duration = (info.end_time - info.start_time).total_seconds()
|
||||
self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels)
|
||||
|
||||
if info.gen_ai_server_time_to_first_token is not None:
|
||||
self._exporter.record_histogram(
|
||||
EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels
|
||||
)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="message",
|
||||
),
|
||||
)
|
||||
|
||||
def _tool_trace(self, info: ToolTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs.update(
|
||||
{
|
||||
"gen_ai.tool.name": info.tool_name,
|
||||
"dify.tool.time_cost": info.time_cost,
|
||||
"dify.tool.error": info.error,
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
ref = f"ref:message_id={info.message_id}"
|
||||
attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref)
|
||||
attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref)
|
||||
attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref)
|
||||
attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION,
|
||||
attributes=attrs,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
tool_name=info.tool_name,
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="tool",
|
||||
),
|
||||
)
|
||||
self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="tool",
|
||||
),
|
||||
)
|
||||
|
||||
def _moderation_trace(self, info: ModerationTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs.update(
|
||||
{
|
||||
"dify.moderation.flagged": info.flagged,
|
||||
"dify.moderation.action": info.action,
|
||||
"dify.moderation.preset_response": info.preset_response,
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
attrs["dify.moderation.query"] = self._content_or_ref(
|
||||
info.query,
|
||||
f"ref:message_id={info.message_id}",
|
||||
)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.MODERATION_CHECK,
|
||||
attributes=attrs,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="moderation",
|
||||
),
|
||||
)
|
||||
|
||||
def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs.update(
|
||||
{
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.suggested_question.status": info.status,
|
||||
"dify.suggested_question.error": info.error,
|
||||
"gen_ai.provider.name": info.model_provider,
|
||||
"gen_ai.request.model": info.model_id,
|
||||
"dify.suggested_question.count": len(info.suggested_question),
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
attrs["dify.suggested_question.questions"] = self._content_or_ref(
|
||||
info.suggested_question,
|
||||
f"ref:message_id={info.message_id}",
|
||||
)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION,
|
||||
attributes=attrs,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="suggested_question",
|
||||
),
|
||||
)
|
||||
|
||||
def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs["dify.dataset.error"] = info.error
|
||||
attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id")
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
docs: list[dict[str, Any]] = []
|
||||
documents_any: Any = info.documents
|
||||
documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else []
|
||||
for entry in documents_list:
|
||||
if isinstance(entry, dict):
|
||||
entry_dict: dict[str, Any] = cast(dict[str, Any], entry)
|
||||
docs.append(entry_dict)
|
||||
dataset_ids: list[str] = []
|
||||
dataset_names: list[str] = []
|
||||
structured_docs: list[dict[str, Any]] = []
|
||||
for doc in docs:
|
||||
meta_raw = doc.get("metadata")
|
||||
meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {}
|
||||
did = meta.get("dataset_id")
|
||||
dname = meta.get("dataset_name")
|
||||
if did and did not in dataset_ids:
|
||||
dataset_ids.append(did)
|
||||
if dname and dname not in dataset_names:
|
||||
dataset_names.append(dname)
|
||||
structured_docs.append(
|
||||
{
|
||||
"dataset_id": did,
|
||||
"document_id": meta.get("document_id"),
|
||||
"segment_id": meta.get("segment_id"),
|
||||
"score": meta.get("score"),
|
||||
}
|
||||
)
|
||||
|
||||
attrs["dify.dataset.ids"] = self._maybe_json(dataset_ids)
|
||||
attrs["dify.dataset.names"] = self._maybe_json(dataset_names)
|
||||
attrs["dify.retrieval.document_count"] = len(docs)
|
||||
|
||||
embedding_models_raw: Any = metadata.get("embedding_models")
|
||||
embedding_models: dict[str, Any] = (
|
||||
cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {}
|
||||
)
|
||||
if embedding_models:
|
||||
providers: list[str] = []
|
||||
models: list[str] = []
|
||||
for ds_info in embedding_models.values():
|
||||
if isinstance(ds_info, dict):
|
||||
ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info)
|
||||
p = ds_info_dict.get("embedding_model_provider", "")
|
||||
m = ds_info_dict.get("embedding_model", "")
|
||||
if p and p not in providers:
|
||||
providers.append(p)
|
||||
if m and m not in models:
|
||||
models.append(m)
|
||||
attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers)
|
||||
attrs["dify.dataset.embedding_models"] = self._maybe_json(models)
|
||||
|
||||
ref = f"ref:message_id={info.message_id}"
|
||||
retrieval_inputs = self._safe_payload_value(info.inputs)
|
||||
attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref)
|
||||
attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL,
|
||||
attributes=attrs,
|
||||
trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None,
|
||||
span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None),
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="dataset_retrieval",
|
||||
),
|
||||
)
|
||||
|
||||
for did in dataset_ids:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.DATASET_RETRIEVALS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
dataset_id=did,
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs["dify.conversation.id"] = info.conversation_id
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
ref = f"ref:conversation_id={info.conversation_id}"
|
||||
inputs = self._safe_payload_value(info.inputs)
|
||||
outputs = self._safe_payload_value(info.outputs)
|
||||
attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref)
|
||||
attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION,
|
||||
attributes=attrs,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="generate_name",
|
||||
),
|
||||
)
|
||||
|
||||
def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = {
|
||||
"dify.trace_id": info.trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"dify.user.id": user_id,
|
||||
"dify.app.id": app_id or "",
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"dify.operation.type": info.operation_type,
|
||||
"gen_ai.provider.name": info.model_provider,
|
||||
"gen_ai.request.model": info.model_name,
|
||||
"gen_ai.usage.input_tokens": info.prompt_tokens,
|
||||
"gen_ai.usage.output_tokens": info.completion_tokens,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.prompt_generation.latency": info.latency,
|
||||
"dify.prompt_generation.error": info.error,
|
||||
}
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
if info.total_price is not None:
|
||||
attrs["dify.prompt_generation.total_price"] = info.total_price
|
||||
attrs["dify.prompt_generation.currency"] = info.currency
|
||||
|
||||
ref = f"ref:trace_id={info.trace_id}"
|
||||
outputs = self._safe_payload_value(info.outputs)
|
||||
attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref)
|
||||
attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION,
|
||||
attributes=attrs,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=info.operation_type,
|
||||
model_provider=info.model_provider,
|
||||
model_name=info.model_name,
|
||||
node_type="",
|
||||
).to_dict()
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=info.operation_type,
|
||||
model_provider=info.model_provider,
|
||||
model_name=info.model_name,
|
||||
)
|
||||
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.prompt_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
|
||||
if info.completion_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
|
||||
)
|
||||
|
||||
status = "failed" if info.error else "success"
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="prompt_generation",
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
|
||||
self._exporter.record_histogram(
|
||||
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION,
|
||||
info.latency,
|
||||
labels,
|
||||
)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="prompt_generation",
|
||||
),
|
||||
)
|
||||
@@ -1,121 +0,0 @@
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class EnterpriseTelemetrySpan(StrEnum):
|
||||
WORKFLOW_RUN = "dify.workflow.run"
|
||||
NODE_EXECUTION = "dify.node.execution"
|
||||
DRAFT_NODE_EXECUTION = "dify.node.execution.draft"
|
||||
|
||||
|
||||
class EnterpriseTelemetryEvent(StrEnum):
|
||||
"""Event names for enterprise telemetry logs."""
|
||||
|
||||
APP_CREATED = "dify.app.created"
|
||||
APP_UPDATED = "dify.app.updated"
|
||||
APP_DELETED = "dify.app.deleted"
|
||||
FEEDBACK_CREATED = "dify.feedback.created"
|
||||
WORKFLOW_RUN = "dify.workflow.run"
|
||||
MESSAGE_RUN = "dify.message.run"
|
||||
TOOL_EXECUTION = "dify.tool.execution"
|
||||
MODERATION_CHECK = "dify.moderation.check"
|
||||
SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation"
|
||||
DATASET_RETRIEVAL = "dify.dataset.retrieval"
|
||||
GENERATE_NAME_EXECUTION = "dify.generate_name.execution"
|
||||
PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution"
|
||||
REHYDRATION_FAILED = "dify.telemetry.rehydration_failed"
|
||||
|
||||
|
||||
class EnterpriseTelemetryCounter(StrEnum):
|
||||
TOKENS = "tokens"
|
||||
INPUT_TOKENS = "input_tokens"
|
||||
OUTPUT_TOKENS = "output_tokens"
|
||||
REQUESTS = "requests"
|
||||
ERRORS = "errors"
|
||||
FEEDBACK = "feedback"
|
||||
DATASET_RETRIEVALS = "dataset_retrievals"
|
||||
APP_CREATED = "app_created"
|
||||
APP_UPDATED = "app_updated"
|
||||
APP_DELETED = "app_deleted"
|
||||
|
||||
|
||||
class EnterpriseTelemetryHistogram(StrEnum):
|
||||
WORKFLOW_DURATION = "workflow_duration"
|
||||
NODE_DURATION = "node_duration"
|
||||
MESSAGE_DURATION = "message_duration"
|
||||
MESSAGE_TTFT = "message_ttft"
|
||||
TOOL_DURATION = "tool_duration"
|
||||
PROMPT_GENERATION_DURATION = "prompt_generation_duration"
|
||||
|
||||
|
||||
class TokenMetricLabels(BaseModel):
|
||||
"""Unified label structure for all dify.token.* metrics.
|
||||
|
||||
All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST
|
||||
use this exact label set to ensure consistent filtering and aggregation across
|
||||
different operation types.
|
||||
|
||||
Attributes:
|
||||
tenant_id: Tenant identifier.
|
||||
app_id: Application identifier.
|
||||
operation_type: Source of token usage (workflow | node_execution | message |
|
||||
rule_generate | code_generate | structured_output | instruction_modify).
|
||||
model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level).
|
||||
model_name: LLM model name. Empty string if not applicable (e.g., workflow-level).
|
||||
node_type: Workflow node type. Empty string unless operation_type=node_execution.
|
||||
|
||||
Usage:
|
||||
labels = TokenMetricLabels(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
operation_type=OperationType.WORKFLOW,
|
||||
model_provider="",
|
||||
model_name="",
|
||||
node_type="",
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.INPUT_TOKENS,
|
||||
100,
|
||||
labels.to_dict()
|
||||
)
|
||||
|
||||
Design rationale:
|
||||
Without this unified structure, tokens get double-counted when querying totals
|
||||
because workflow.total_tokens is already the sum of all node tokens. The
|
||||
operation_type label allows filtering to separate workflow-level aggregates from
|
||||
node-level detail, while keeping the same label cardinality for consistent queries.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
operation_type: str
|
||||
model_provider: str
|
||||
model_name: str
|
||||
node_type: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
def to_dict(self) -> dict[str, AttributeValue]:
|
||||
return cast(
|
||||
dict[str, AttributeValue],
|
||||
{
|
||||
"tenant_id": self.tenant_id,
|
||||
"app_id": self.app_id,
|
||||
"operation_type": self.operation_type,
|
||||
"model_provider": self.model_provider,
|
||||
"model_name": self.model_name,
|
||||
"node_type": self.node_type,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EnterpriseTelemetryCounter",
|
||||
"EnterpriseTelemetryEvent",
|
||||
"EnterpriseTelemetryHistogram",
|
||||
"EnterpriseTelemetrySpan",
|
||||
"TokenMetricLabels",
|
||||
]
|
||||
@@ -1,130 +0,0 @@
|
||||
"""Blinker signal handlers for enterprise telemetry.
|
||||
|
||||
Registered at import time via ``@signal.connect`` decorators.
|
||||
Import must happen during ``ext_enterprise_telemetry.init_app()`` to ensure handlers fire.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from events.app_event import app_was_created, app_was_deleted, app_was_updated
|
||||
from events.feedback_event import feedback_was_created
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"_handle_app_created",
|
||||
"_handle_app_deleted",
|
||||
"_handle_app_updated",
|
||||
"_handle_feedback_created",
|
||||
]
|
||||
|
||||
|
||||
@app_was_created.connect
|
||||
def _handle_app_created(sender: object, **kwargs: object) -> None:
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
return
|
||||
|
||||
tenant_id = str(getattr(sender, "tenant_id", "") or "")
|
||||
payload = {
|
||||
"app_id": getattr(sender, "id", None),
|
||||
"mode": getattr(sender, "mode", None),
|
||||
}
|
||||
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id=tenant_id,
|
||||
event_id=str(uuid.uuid4()),
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
process_enterprise_telemetry.delay(envelope.model_dump_json())
|
||||
|
||||
|
||||
@app_was_deleted.connect
|
||||
def _handle_app_deleted(sender: object, **kwargs: object) -> None:
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
return
|
||||
|
||||
tenant_id = str(getattr(sender, "tenant_id", "") or "")
|
||||
payload = {
|
||||
"app_id": getattr(sender, "id", None),
|
||||
}
|
||||
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_DELETED,
|
||||
tenant_id=tenant_id,
|
||||
event_id=str(uuid.uuid4()),
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
process_enterprise_telemetry.delay(envelope.model_dump_json())
|
||||
|
||||
|
||||
@app_was_updated.connect
|
||||
def _handle_app_updated(sender: object, **kwargs: object) -> None:
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
return
|
||||
|
||||
tenant_id = str(getattr(sender, "tenant_id", "") or "")
|
||||
payload = {
|
||||
"app_id": getattr(sender, "id", None),
|
||||
}
|
||||
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_UPDATED,
|
||||
tenant_id=tenant_id,
|
||||
event_id=str(uuid.uuid4()),
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
process_enterprise_telemetry.delay(envelope.model_dump_json())
|
||||
|
||||
|
||||
@feedback_was_created.connect
|
||||
def _handle_feedback_created(sender: object, **kwargs: object) -> None:
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
return
|
||||
|
||||
tenant_id = str(kwargs.get("tenant_id", "") or "")
|
||||
payload = {
|
||||
"message_id": getattr(sender, "message_id", None),
|
||||
"app_id": getattr(sender, "app_id", None),
|
||||
"conversation_id": getattr(sender, "conversation_id", None),
|
||||
"from_end_user_id": getattr(sender, "from_end_user_id", None),
|
||||
"from_account_id": getattr(sender, "from_account_id", None),
|
||||
"rating": getattr(sender, "rating", None),
|
||||
"from_source": getattr(sender, "from_source", None),
|
||||
"content": getattr(sender, "content", None),
|
||||
}
|
||||
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.FEEDBACK_CREATED,
|
||||
tenant_id=tenant_id,
|
||||
event_id=str(uuid.uuid4()),
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
process_enterprise_telemetry.delay(envelope.model_dump_json())
|
||||
@@ -1,255 +0,0 @@
|
||||
"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation.
|
||||
|
||||
Uses dedicated TracerProvider and MeterProvider instances (configurable sampling,
|
||||
independent from ext_otel.py infrastructure).
|
||||
|
||||
Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py).
|
||||
Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace import SpanContext, TraceFlags
|
||||
from opentelemetry.util.types import Attributes, AttributeValue
|
||||
|
||||
from configs import dify_config
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
|
||||
from enterprise.telemetry.id_generator import (
|
||||
CorrelationIdGenerator,
|
||||
compute_deterministic_span_id,
|
||||
set_correlation_id,
|
||||
set_span_id_source,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_enterprise_telemetry_enabled() -> bool:
|
||||
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
|
||||
|
||||
|
||||
def _parse_otlp_headers(raw: str) -> dict[str, str]:
|
||||
"""Parse ``key=value,key2=value2`` into a dict."""
|
||||
if not raw:
|
||||
return {}
|
||||
headers: dict[str, str] = {}
|
||||
for pair in raw.split(","):
|
||||
if "=" not in pair:
|
||||
continue
|
||||
k, v = pair.split("=", 1)
|
||||
headers[k.strip()] = v.strip()
|
||||
return headers
|
||||
|
||||
|
||||
def _datetime_to_ns(dt: datetime) -> int:
|
||||
"""Convert a datetime to nanoseconds since epoch (OTEL convention)."""
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
class _ExporterFactory:
|
||||
def __init__(self, protocol: str, endpoint: str, headers: dict[str, str]):
|
||||
self._protocol = protocol
|
||||
self._endpoint = endpoint
|
||||
self._headers = headers
|
||||
self._grpc_headers = tuple(headers.items()) if headers else None
|
||||
self._http_headers = headers or None
|
||||
|
||||
def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter:
|
||||
if self._protocol == "grpc":
|
||||
return GRPCSpanExporter(
|
||||
endpoint=self._endpoint or None,
|
||||
headers=self._grpc_headers,
|
||||
insecure=True,
|
||||
)
|
||||
trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else ""
|
||||
return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers)
|
||||
|
||||
def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter:
|
||||
if self._protocol == "grpc":
|
||||
return GRPCMetricExporter(
|
||||
endpoint=self._endpoint or None,
|
||||
headers=self._grpc_headers,
|
||||
insecure=True,
|
||||
)
|
||||
metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else ""
|
||||
return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers)
|
||||
|
||||
|
||||
class EnterpriseExporter:
|
||||
"""Shared OTEL exporter for all enterprise telemetry.
|
||||
|
||||
``export_span`` creates spans with optional real timestamps, deterministic
|
||||
span/trace IDs, and cross-workflow parent linking.
|
||||
``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy.
|
||||
"""
|
||||
|
||||
def __init__(self, config: object) -> None:
|
||||
endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "")
|
||||
headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "")
|
||||
protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower()
|
||||
service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify")
|
||||
sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0)
|
||||
self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True)
|
||||
|
||||
resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
}
|
||||
)
|
||||
sampler = ParentBasedTraceIdRatio(sampling_rate)
|
||||
id_generator = CorrelationIdGenerator()
|
||||
self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator)
|
||||
|
||||
headers = _parse_otlp_headers(headers_raw)
|
||||
factory = _ExporterFactory(protocol, endpoint, headers)
|
||||
|
||||
trace_exporter = factory.create_trace_exporter()
|
||||
self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
|
||||
self._tracer = self._tracer_provider.get_tracer("dify.enterprise")
|
||||
|
||||
metric_exporter = factory.create_metric_exporter()
|
||||
self._meter_provider = MeterProvider(
|
||||
resource=resource,
|
||||
metric_readers=[PeriodicExportingMetricReader(metric_exporter)],
|
||||
)
|
||||
meter = self._meter_provider.get_meter("dify.enterprise")
|
||||
self._counters = {
|
||||
EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"),
|
||||
EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"),
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"),
|
||||
EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"),
|
||||
EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"),
|
||||
EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"),
|
||||
EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter(
|
||||
"dify.dataset.retrievals.total", unit="{retrieval}"
|
||||
),
|
||||
EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"),
|
||||
EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"),
|
||||
EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"),
|
||||
}
|
||||
self._histograms = {
|
||||
EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram(
|
||||
"dify.message.time_to_first_token", unit="s"
|
||||
),
|
||||
EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram(
|
||||
"dify.prompt_generation.duration", unit="s"
|
||||
),
|
||||
}
|
||||
|
||||
def export_span(
|
||||
self,
|
||||
name: str,
|
||||
attributes: dict[str, Any],
|
||||
correlation_id: str | None = None,
|
||||
span_id_source: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
trace_correlation_override: str | None = None,
|
||||
parent_span_id_source: str | None = None,
|
||||
) -> None:
|
||||
"""Export an OTEL span with optional deterministic IDs and real timestamps.
|
||||
|
||||
Args:
|
||||
name: Span operation name.
|
||||
attributes: Span attributes dict.
|
||||
correlation_id: Source for trace_id derivation (groups spans in one trace).
|
||||
span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id).
|
||||
start_time: Real span start time. When None, uses current time.
|
||||
end_time: Real span end time. When None, span ends immediately.
|
||||
trace_correlation_override: Override trace_id source (for cross-workflow linking).
|
||||
When set, trace_id is derived from this instead of ``correlation_id``.
|
||||
parent_span_id_source: Override parent span_id source (for cross-workflow linking).
|
||||
When set, parent span_id is derived from this value. When None and
|
||||
``correlation_id`` is set, parent is the workflow root span.
|
||||
"""
|
||||
effective_trace_correlation = trace_correlation_override or correlation_id
|
||||
set_correlation_id(effective_trace_correlation)
|
||||
set_span_id_source(span_id_source)
|
||||
|
||||
try:
|
||||
parent_context: Context | None = None
|
||||
# A span is the "root" of its correlation group when span_id_source == correlation_id
|
||||
# (i.e. a workflow root span). All other spans are children.
|
||||
if parent_span_id_source:
|
||||
# Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow)
|
||||
parent_span_id = compute_deterministic_span_id(parent_span_id_source)
|
||||
parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0
|
||||
if parent_trace_id:
|
||||
parent_span_context = SpanContext(
|
||||
trace_id=parent_trace_id,
|
||||
span_id=parent_span_id,
|
||||
is_remote=True,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
|
||||
elif correlation_id and correlation_id != span_id_source:
|
||||
# Child span: parent is the correlation-group root (workflow root span)
|
||||
parent_span_id = compute_deterministic_span_id(correlation_id)
|
||||
parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id))
|
||||
parent_span_context = SpanContext(
|
||||
trace_id=parent_trace_id,
|
||||
span_id=parent_span_id,
|
||||
is_remote=True,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
|
||||
|
||||
span_start_time = _datetime_to_ns(start_time) if start_time is not None else None
|
||||
span_end_on_exit = end_time is None
|
||||
|
||||
with self._tracer.start_as_current_span(
|
||||
name,
|
||||
context=parent_context,
|
||||
start_time=span_start_time,
|
||||
end_on_exit=span_end_on_exit,
|
||||
) as span:
|
||||
for key, value in attributes.items():
|
||||
if value is not None:
|
||||
span.set_attribute(key, value)
|
||||
if end_time is not None:
|
||||
span.end(end_time=_datetime_to_ns(end_time))
|
||||
except Exception:
|
||||
logger.exception("Failed to export span %s", name)
|
||||
finally:
|
||||
set_correlation_id(None)
|
||||
set_span_id_source(None)
|
||||
|
||||
def increment_counter(
|
||||
self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue]
|
||||
) -> None:
|
||||
counter = self._counters.get(name)
|
||||
if counter:
|
||||
counter.add(value, cast(Attributes, labels))
|
||||
|
||||
def record_histogram(
|
||||
self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue]
|
||||
) -> None:
|
||||
histogram = self._histograms.get(name)
|
||||
if histogram:
|
||||
histogram.record(value, cast(Attributes, labels))
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._tracer_provider.shutdown()
|
||||
self._meter_provider.shutdown()
|
||||
@@ -1,199 +0,0 @@
|
||||
"""Telemetry gateway routing and dispatch.
|
||||
|
||||
Maps ``TelemetryCase`` → ``CaseRoute`` (signal type + CE eligibility)
|
||||
and dispatches events to either the trace pipeline or the metric/log
|
||||
Celery queue.
|
||||
|
||||
Singleton lifecycle is managed by ``ext_enterprise_telemetry.init_app()``
|
||||
which creates the instance during single-threaded Flask app startup.
|
||||
Access via ``ext_enterprise_telemetry.get_gateway()``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024
|
||||
|
||||
CASE_TO_TRACE_TASK: dict[TelemetryCase, TraceTaskName] = {
|
||||
TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE,
|
||||
TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE,
|
||||
TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE,
|
||||
}
|
||||
|
||||
CASE_ROUTING: dict[TelemetryCase, CaseRoute] = {
|
||||
TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
}
|
||||
|
||||
|
||||
def _is_enterprise_telemetry_enabled() -> bool:
|
||||
try:
|
||||
from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled
|
||||
|
||||
return is_enterprise_telemetry_enabled()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _should_drop_ee_only_event(route: CaseRoute) -> bool:
|
||||
"""Return True when the event is enterprise-only and EE telemetry is disabled."""
|
||||
return not route.ce_eligible and not _is_enterprise_telemetry_enabled()
|
||||
|
||||
|
||||
class TelemetryGateway:
|
||||
"""Routes telemetry events to the trace pipeline or the metric/log Celery queue.
|
||||
|
||||
Stateless — instantiated once during ``ext_enterprise_telemetry.init_app()``
|
||||
and shared for the lifetime of the process.
|
||||
"""
|
||||
|
||||
def emit(
|
||||
self,
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> None:
|
||||
route = CASE_ROUTING.get(case)
|
||||
if route is None:
|
||||
logger.warning("Unknown telemetry case: %s, dropping event", case)
|
||||
return
|
||||
|
||||
if _should_drop_ee_only_event(route):
|
||||
logger.debug("Dropping EE-only event: case=%s (EE disabled)", case)
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"Gateway routing: case=%s, signal_type=%s, ce_eligible=%s",
|
||||
case,
|
||||
route.signal_type,
|
||||
route.ce_eligible,
|
||||
)
|
||||
|
||||
if route.signal_type is SignalType.TRACE:
|
||||
self._emit_trace(case, context, payload, route, trace_manager)
|
||||
else:
|
||||
self._emit_metric_log(case, context, payload)
|
||||
|
||||
def _emit_trace(
|
||||
self,
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
route: CaseRoute,
|
||||
trace_manager: TraceQueueManager | None,
|
||||
) -> None:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager
|
||||
from core.ops.ops_trace_manager import TraceTask
|
||||
|
||||
trace_task_name = CASE_TO_TRACE_TASK.get(case)
|
||||
if trace_task_name is None:
|
||||
logger.warning("No TraceTaskName mapping for case: %s", case)
|
||||
return
|
||||
|
||||
queue_manager = trace_manager or LocalTraceQueueManager(
|
||||
app_id=context.get("app_id"),
|
||||
user_id=context.get("user_id"),
|
||||
)
|
||||
|
||||
queue_manager.add_trace_task(TraceTask(trace_task_name, **payload))
|
||||
logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id"))
|
||||
|
||||
def _emit_metric_log(
|
||||
self,
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
|
||||
|
||||
tenant_id = context.get("tenant_id", "")
|
||||
event_id = str(uuid.uuid4())
|
||||
|
||||
payload_for_envelope, payload_ref = self._handle_payload_sizing(payload, tenant_id, event_id)
|
||||
|
||||
envelope = TelemetryEnvelope(
|
||||
case=case,
|
||||
tenant_id=tenant_id,
|
||||
event_id=event_id,
|
||||
payload=payload_for_envelope,
|
||||
metadata={"payload_ref": payload_ref} if payload_ref else None,
|
||||
)
|
||||
|
||||
process_enterprise_telemetry.delay(envelope.model_dump_json())
|
||||
logger.debug(
|
||||
"Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s",
|
||||
case,
|
||||
tenant_id,
|
||||
event_id,
|
||||
)
|
||||
|
||||
def _handle_payload_sizing(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
tenant_id: str,
|
||||
event_id: str,
|
||||
) -> tuple[dict[str, Any], str | None]:
|
||||
try:
|
||||
payload_json = json.dumps(payload)
|
||||
payload_size = len(payload_json.encode("utf-8"))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id)
|
||||
return payload, None
|
||||
|
||||
if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES:
|
||||
return payload, None
|
||||
|
||||
storage_key = f"telemetry/{tenant_id}/{event_id}.json"
|
||||
try:
|
||||
storage.save(storage_key, payload_json.encode("utf-8"))
|
||||
logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size)
|
||||
return {}, storage_key
|
||||
except Exception:
|
||||
logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True)
|
||||
return payload, None
|
||||
|
||||
|
||||
def emit(
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> None:
|
||||
"""Module-level convenience wrapper.
|
||||
|
||||
Fetches the gateway singleton from the extension; no-ops when
|
||||
enterprise telemetry is disabled (gateway is ``None``).
|
||||
"""
|
||||
from extensions.ext_enterprise_telemetry import get_gateway
|
||||
|
||||
gateway = get_gateway()
|
||||
if gateway is not None:
|
||||
gateway.emit(case, context, payload, trace_manager)
|
||||
@@ -1,76 +0,0 @@
|
||||
"""Custom OTEL ID Generator for correlation-based trace/span ID derivation.
|
||||
|
||||
Uses contextvars for thread-safe correlation_id -> trace_id mapping.
|
||||
When a span_id_source is set, the span_id is derived deterministically
|
||||
from that value, enabling any span to reference another as parent
|
||||
without depending on span creation order.
|
||||
"""
|
||||
|
||||
import random
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import cast
|
||||
|
||||
from opentelemetry.sdk.trace.id_generator import IdGenerator
|
||||
|
||||
_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None)
|
||||
_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None)
|
||||
|
||||
|
||||
def set_correlation_id(correlation_id: str | None) -> None:
|
||||
_correlation_id_context.set(correlation_id)
|
||||
|
||||
|
||||
def get_correlation_id() -> str | None:
|
||||
return _correlation_id_context.get()
|
||||
|
||||
|
||||
def set_span_id_source(source_id: str | None) -> None:
|
||||
"""Set the source for deterministic span_id generation.
|
||||
|
||||
When set, ``generate_span_id()`` derives the span_id from this value
|
||||
(lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow
|
||||
root spans or ``node_execution_id`` for node spans.
|
||||
"""
|
||||
_span_id_source_context.set(source_id)
|
||||
|
||||
|
||||
def compute_deterministic_span_id(source_id: str) -> int:
|
||||
"""Derive a deterministic span_id from any UUID string.
|
||||
|
||||
Uses the lower 64 bits of the UUID, guaranteeing non-zero output
|
||||
(OTEL requires span_id != 0).
|
||||
"""
|
||||
span_id = cast(int, uuid.UUID(source_id).int) & ((1 << 64) - 1)
|
||||
return span_id if span_id != 0 else 1
|
||||
|
||||
|
||||
class CorrelationIdGenerator(IdGenerator):
|
||||
"""ID generator that derives trace_id and optionally span_id from context.
|
||||
|
||||
- trace_id: always derived from correlation_id (groups all spans in one trace)
|
||||
- span_id: derived from span_id_source when set (enables deterministic
|
||||
parent-child linking), otherwise random
|
||||
"""
|
||||
|
||||
def generate_trace_id(self) -> int:
|
||||
correlation_id = _correlation_id_context.get()
|
||||
if correlation_id:
|
||||
try:
|
||||
return cast(int, uuid.UUID(correlation_id).int)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
return random.getrandbits(128)
|
||||
|
||||
def generate_span_id(self) -> int:
|
||||
source = _span_id_source_context.get()
|
||||
if source:
|
||||
try:
|
||||
return compute_deterministic_span_id(source)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == 0:
|
||||
span_id = random.getrandbits(64)
|
||||
return span_id
|
||||
@@ -1,373 +0,0 @@
|
||||
"""Enterprise metric/log event handler.
|
||||
|
||||
This module processes metric and log telemetry events after they've been
|
||||
dequeued from the enterprise_telemetry Celery queue. It handles case routing,
|
||||
idempotency checking, and payload rehydration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnterpriseMetricHandler:
|
||||
"""Handler for enterprise metric and log telemetry events.
|
||||
|
||||
Processes envelopes from the enterprise_telemetry queue, routing each
|
||||
case to the appropriate handler method. Implements idempotency checking
|
||||
and payload rehydration with fallback.
|
||||
"""
|
||||
|
||||
def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None:
|
||||
"""Increment a diagnostic counter for operational monitoring.
|
||||
|
||||
Args:
|
||||
counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total').
|
||||
labels: Optional labels for the counter.
|
||||
"""
|
||||
try:
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
return
|
||||
|
||||
full_counter_name = f"enterprise_telemetry.handler.{counter_name}"
|
||||
logger.debug(
|
||||
"Diagnostic counter: %s, labels=%s",
|
||||
full_counter_name,
|
||||
labels or {},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True)
|
||||
|
||||
def handle(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Main entry point for processing telemetry envelopes.
|
||||
|
||||
Args:
|
||||
envelope: The telemetry envelope to process.
|
||||
"""
|
||||
# Check for duplicate events
|
||||
if self._is_duplicate(envelope):
|
||||
logger.debug(
|
||||
"Skipping duplicate event: tenant_id=%s, event_id=%s",
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
self._increment_diagnostic_counter("deduped_total")
|
||||
return
|
||||
|
||||
# Route to appropriate handler based on case
|
||||
case = envelope.case
|
||||
if case == TelemetryCase.APP_CREATED:
|
||||
self._on_app_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
|
||||
elif case == TelemetryCase.APP_UPDATED:
|
||||
self._on_app_updated(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
|
||||
elif case == TelemetryCase.APP_DELETED:
|
||||
self._on_app_deleted(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
|
||||
elif case == TelemetryCase.FEEDBACK_CREATED:
|
||||
self._on_feedback_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
|
||||
elif case == TelemetryCase.MESSAGE_RUN:
|
||||
self._on_message_run(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
|
||||
elif case == TelemetryCase.TOOL_EXECUTION:
|
||||
self._on_tool_execution(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
|
||||
elif case == TelemetryCase.MODERATION_CHECK:
|
||||
self._on_moderation_check(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
|
||||
elif case == TelemetryCase.SUGGESTED_QUESTION:
|
||||
self._on_suggested_question(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
|
||||
elif case == TelemetryCase.DATASET_RETRIEVAL:
|
||||
self._on_dataset_retrieval(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
|
||||
elif case == TelemetryCase.GENERATE_NAME:
|
||||
self._on_generate_name(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
|
||||
elif case == TelemetryCase.PROMPT_GENERATION:
|
||||
self._on_prompt_generation(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
|
||||
case,
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
|
||||
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
|
||||
"""Check if this event has already been processed.
|
||||
|
||||
Uses Redis with TTL for deduplication. Returns True if duplicate,
|
||||
False if first time seeing this event.
|
||||
|
||||
Args:
|
||||
envelope: The telemetry envelope to check.
|
||||
|
||||
Returns:
|
||||
True if this event_id has been seen before, False otherwise.
|
||||
"""
|
||||
dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}"
|
||||
|
||||
try:
|
||||
# Atomic set-if-not-exists with 1h TTL
|
||||
# Returns True if key was set (first time), None if already exists (duplicate)
|
||||
was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600)
|
||||
return was_set is None
|
||||
except Exception:
|
||||
# Fail open: if Redis is unavailable, process the event
|
||||
# (prefer occasional duplicate over lost data)
|
||||
logger.warning(
|
||||
"Redis unavailable for deduplication check, processing event anyway: %s",
|
||||
envelope.event_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]:
|
||||
"""Rehydrate payload from reference or fallback.
|
||||
|
||||
Attempts to resolve payload_ref to full data. If that fails,
|
||||
falls back to payload_fallback. If both fail, emits a degraded
|
||||
event marker.
|
||||
|
||||
Args:
|
||||
envelope: The telemetry envelope containing payload data.
|
||||
|
||||
Returns:
|
||||
The rehydrated payload dictionary.
|
||||
"""
|
||||
# For now, payload is directly in the envelope
|
||||
# Future: implement payload_ref resolution from storage
|
||||
payload = envelope.payload
|
||||
|
||||
if not payload and envelope.payload_fallback:
|
||||
import pickle
|
||||
|
||||
try:
|
||||
payload = pickle.loads(envelope.payload_fallback) # noqa: S301
|
||||
logger.debug("Used payload_fallback for event_id=%s", envelope.event_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to deserialize payload_fallback for event_id=%s",
|
||||
envelope.event_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if not payload:
|
||||
# Both ref and fallback failed - emit degraded event
|
||||
logger.error(
|
||||
"Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s",
|
||||
envelope.event_id,
|
||||
envelope.tenant_id,
|
||||
envelope.case,
|
||||
)
|
||||
# Emit degraded event marker
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED,
|
||||
attributes={
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event_id": envelope.event_id,
|
||||
"dify.case": envelope.case,
|
||||
"rehydration_failed": True,
|
||||
},
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
self._increment_diagnostic_counter("rehydration_failed_total")
|
||||
return {}
|
||||
|
||||
return payload
|
||||
|
||||
# Stub methods for each metric/log case
|
||||
# These will be implemented in later tasks with actual emission logic
|
||||
|
||||
def _on_app_created(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle app created event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
attrs = {
|
||||
"dify.app.id": payload.get("app_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.app.mode": payload.get("mode"),
|
||||
}
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.APP_CREATED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.APP_CREATED,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
"mode": str(payload.get("mode", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_app_updated(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle app updated event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
attrs = {
|
||||
"dify.app.id": payload.get("app_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
}
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.APP_UPDATED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.APP_UPDATED,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle app deleted event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
attrs = {
|
||||
"dify.app.id": payload.get("app_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
}
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.APP_DELETED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.APP_DELETED,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle feedback created event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
include_content = exporter.include_content
|
||||
attrs: dict = {
|
||||
"dify.message.id": payload.get("message_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.app_id": payload.get("app_id"),
|
||||
"dify.conversation.id": payload.get("conversation_id"),
|
||||
"gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"),
|
||||
"dify.feedback.rating": payload.get("rating"),
|
||||
"dify.feedback.from_source": payload.get("from_source"),
|
||||
}
|
||||
if include_content:
|
||||
attrs["dify.feedback.content"] = payload.get("content")
|
||||
|
||||
user_id = payload.get("from_end_user_id") or payload.get("from_account_id")
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
user_id=str(user_id or ""),
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.FEEDBACK,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
"rating": str(payload.get("rating", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_message_run(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle message run event (stub)."""
|
||||
logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle tool execution event (stub)."""
|
||||
logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle moderation check event (stub)."""
|
||||
logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle suggested question event (stub)."""
|
||||
logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle dataset retrieval event (stub)."""
|
||||
logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_generate_name(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle generate name event (stub)."""
|
||||
logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle prompt generation event (stub)."""
|
||||
logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id)
|
||||
@@ -1,122 +0,0 @@
|
||||
"""Structured-log emitter for enterprise telemetry events.
|
||||
|
||||
Emits structured JSON log lines correlated with OTEL traces via trace_id.
|
||||
Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
logger = logging.getLogger("dify.telemetry")
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def compute_trace_id_hex(uuid_str: str | None) -> str:
|
||||
"""Convert a business UUID string to a 32-hex OTEL-compatible trace_id.
|
||||
|
||||
Returns empty string when *uuid_str* is ``None`` or invalid.
|
||||
"""
|
||||
if not uuid_str:
|
||||
return ""
|
||||
normalized = uuid_str.strip().lower()
|
||||
if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized):
|
||||
return normalized
|
||||
try:
|
||||
return f"{uuid.UUID(normalized).int:032x}"
|
||||
except (ValueError, AttributeError):
|
||||
return ""
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def compute_span_id_hex(uuid_str: str | None) -> str:
|
||||
if not uuid_str:
|
||||
return ""
|
||||
normalized = uuid_str.strip().lower()
|
||||
if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized):
|
||||
return normalized
|
||||
try:
|
||||
from enterprise.telemetry.id_generator import compute_deterministic_span_id
|
||||
|
||||
return f"{compute_deterministic_span_id(normalized):016x}"
|
||||
except (ValueError, AttributeError):
|
||||
return ""
|
||||
|
||||
|
||||
def emit_telemetry_log(
|
||||
*,
|
||||
event_name: str | EnterpriseTelemetryEvent,
|
||||
attributes: dict[str, Any],
|
||||
signal: str = "metric_only",
|
||||
trace_id_source: str | None = None,
|
||||
span_id_source: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit a structured log line for a telemetry event.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
event_name:
|
||||
Canonical event name, e.g. ``"dify.workflow.run"``.
|
||||
attributes:
|
||||
All event-specific attributes (already built by the caller).
|
||||
signal:
|
||||
``"metric_only"`` for events with no span, ``"span_detail"``
|
||||
for detail logs accompanying a slim span.
|
||||
trace_id_source:
|
||||
A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex
|
||||
trace_id for cross-signal correlation.
|
||||
tenant_id:
|
||||
Tenant identifier (for the ``IdentityContextFilter``).
|
||||
user_id:
|
||||
User identifier (for the ``IdentityContextFilter``).
|
||||
"""
|
||||
if not logger.isEnabledFor(logging.INFO):
|
||||
return
|
||||
attrs = {
|
||||
"dify.event.name": event_name,
|
||||
"dify.event.signal": signal,
|
||||
**attributes,
|
||||
}
|
||||
|
||||
extra: dict[str, Any] = {"attributes": attrs}
|
||||
|
||||
trace_id_hex = compute_trace_id_hex(trace_id_source)
|
||||
if trace_id_hex:
|
||||
extra["trace_id"] = trace_id_hex
|
||||
span_id_hex = compute_span_id_hex(span_id_source)
|
||||
if span_id_hex:
|
||||
extra["span_id"] = span_id_hex
|
||||
if tenant_id:
|
||||
extra["tenant_id"] = tenant_id
|
||||
if user_id:
|
||||
extra["user_id"] = user_id
|
||||
|
||||
logger.info("telemetry.%s", signal, extra=extra)
|
||||
|
||||
|
||||
def emit_metric_only_event(
|
||||
*,
|
||||
event_name: str | EnterpriseTelemetryEvent,
|
||||
attributes: dict[str, Any],
|
||||
trace_id_source: str | None = None,
|
||||
span_id_source: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
emit_telemetry_log(
|
||||
event_name=event_name,
|
||||
attributes=attributes,
|
||||
signal="metric_only",
|
||||
trace_id_source=trace_id_source,
|
||||
span_id_source=span_id_source,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -3,12 +3,6 @@ from blinker import signal
|
||||
# sender: app
|
||||
app_was_created = signal("app-was-created")
|
||||
|
||||
# sender: app
|
||||
app_was_deleted = signal("app-was-deleted")
|
||||
|
||||
# sender: app
|
||||
app_was_updated = signal("app-was-updated")
|
||||
|
||||
# sender: app, kwargs: app_model_config
|
||||
app_model_config_was_updated = signal("app-model-config-was-updated")
|
||||
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from blinker import signal
|
||||
|
||||
# sender: MessageFeedback, kwargs: tenant_id
|
||||
feedback_was_created = signal("feedback-was-created")
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Flask extension for enterprise telemetry lifecycle management.
|
||||
|
||||
Initializes the EnterpriseExporter and TelemetryGateway singletons during
|
||||
``create_app()`` (single-threaded), registers blinker event handlers,
|
||||
and hooks atexit for graceful shutdown.
|
||||
|
||||
Skipped entirely when ``ENTERPRISE_ENABLED`` and ``ENTERPRISE_TELEMETRY_ENABLED``
|
||||
are false (``is_enabled()`` gate).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_app import DifyApp
|
||||
from enterprise.telemetry.exporter import EnterpriseExporter
|
||||
from enterprise.telemetry.gateway import TelemetryGateway
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_exporter: EnterpriseExporter | None = None
|
||||
_gateway: TelemetryGateway | None = None
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
global _exporter, _gateway
|
||||
|
||||
if not is_enabled():
|
||||
return
|
||||
|
||||
from enterprise.telemetry.exporter import EnterpriseExporter
|
||||
from enterprise.telemetry.gateway import TelemetryGateway
|
||||
|
||||
_exporter = EnterpriseExporter(dify_config)
|
||||
_gateway = TelemetryGateway()
|
||||
atexit.register(_exporter.shutdown)
|
||||
|
||||
# Import to trigger @signal.connect decorator registration
|
||||
import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport]
|
||||
|
||||
logger.info("Enterprise telemetry initialized")
|
||||
|
||||
|
||||
def get_enterprise_exporter() -> EnterpriseExporter | None:
|
||||
return _exporter
|
||||
|
||||
|
||||
def get_gateway() -> TelemetryGateway | None:
|
||||
return _gateway
|
||||
@@ -21,15 +21,3 @@ class DifySpanAttributes:
|
||||
|
||||
INVOKE_FROM = "dify.invoke_from"
|
||||
"""Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER."""
|
||||
|
||||
INVOKED_BY = "dify.invoked_by"
|
||||
"""Invoked by, e.g. end_user, account, user."""
|
||||
|
||||
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
"""Number of input tokens (prompt tokens) used."""
|
||||
|
||||
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
"""Number of output tokens (completion tokens) generated."""
|
||||
|
||||
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||
"""Total number of tokens used."""
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
@@ -340,8 +340,6 @@ class AppService:
|
||||
db.session.delete(app)
|
||||
db.session.commit()
|
||||
|
||||
app_was_deleted.send(app)
|
||||
|
||||
# clean up web app settings
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
EnterpriseService.WebAppAuth.cleanup_webapp(app.id)
|
||||
|
||||
@@ -7,10 +7,9 @@ from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from events.feedback_event import feedback_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
@@ -180,9 +179,6 @@ class MessageService:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if feedback and rating:
|
||||
feedback_was_created.send(feedback, tenant_id=app_model.tenant_id)
|
||||
|
||||
return feedback
|
||||
|
||||
@classmethod
|
||||
@@ -298,15 +294,10 @@ class MessageService:
|
||||
questions: list[str] = list(questions_sequence)
|
||||
|
||||
# get tracing instance
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
context=TelemetryContext(tenant_id=app_model.tenant_id, app_id=app_model.id),
|
||||
payload={
|
||||
"message_id": message_id,
|
||||
"suggested_question": questions,
|
||||
"timer": timer,
|
||||
},
|
||||
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
@@ -6,8 +5,6 @@ from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, TraceAppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpsService:
|
||||
@classmethod
|
||||
@@ -138,13 +135,12 @@ class OpsService:
|
||||
return trace_config_data.to_dict()
|
||||
|
||||
@classmethod
|
||||
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str):
|
||||
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
|
||||
"""
|
||||
Create tracing app config
|
||||
:param app_id: app id
|
||||
:param tracing_provider: tracing provider
|
||||
:param tracing_config: tracing config
|
||||
:param account_id: account id of the user creating the config
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
@@ -211,19 +207,15 @@ class OpsService:
|
||||
db.session.add(trace_config_data)
|
||||
db.session.commit()
|
||||
|
||||
# Log the creation with modifier information
|
||||
logger.info("Trace config created: app_id=%s, provider=%s, created_by=%s", app_id, tracing_provider, account_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str):
|
||||
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
|
||||
"""
|
||||
Update tracing app config
|
||||
:param app_id: app id
|
||||
:param tracing_provider: tracing provider
|
||||
:param tracing_config: tracing config
|
||||
:param account_id: account id of the user updating the config
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
@@ -259,18 +251,14 @@ class OpsService:
|
||||
current_trace_config.tracing_config = tracing_config
|
||||
db.session.commit()
|
||||
|
||||
# Log the update with modifier information
|
||||
logger.info("Trace config updated: app_id=%s, provider=%s, updated_by=%s", app_id, tracing_provider, account_id)
|
||||
|
||||
return current_trace_config.to_dict()
|
||||
|
||||
@classmethod
|
||||
def delete_tracing_app_config(cls, app_id: str, tracing_provider: str, account_id: str):
|
||||
def delete_tracing_app_config(cls, app_id: str, tracing_provider: str):
|
||||
"""
|
||||
Delete tracing app config
|
||||
:param app_id: app id
|
||||
:param tracing_provider: tracing provider
|
||||
:param account_id: account id of the user deleting the config
|
||||
:return:
|
||||
"""
|
||||
trace_config = (
|
||||
@@ -282,9 +270,6 @@ class OpsService:
|
||||
if not trace_config:
|
||||
return None
|
||||
|
||||
# Log the deletion with modifier information
|
||||
logger.info("Trace config deleted: app_id=%s, provider=%s, deleted_by=%s", app_id, tracing_provider, account_id)
|
||||
|
||||
db.session.delete(trace_config)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
@@ -648,7 +647,6 @@ class WorkflowService:
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id)
|
||||
node_type = Workflow.get_node_type_from_node_config(node_config)
|
||||
node_data = node_config.get("data", {})
|
||||
workflow_execution_id: str | None = None
|
||||
if node_type.is_start_node:
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_srv = WorkflowDraftVariableService(session)
|
||||
@@ -674,13 +672,10 @@ class WorkflowService:
|
||||
node_type=node_type,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
workflow_execution_id = variable_pool.system_variables.workflow_execution_id
|
||||
|
||||
else:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
system_variable = SystemVariable(workflow_execution_id=workflow_execution_id)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_variable,
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
@@ -734,13 +729,6 @@ class WorkflowService:
|
||||
with Session(db.engine) as session:
|
||||
outputs = workflow_node_execution.load_full_outputs(session, storage)
|
||||
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution,
|
||||
outputs=outputs,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
user_id=account.id,
|
||||
)
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
@@ -796,20 +784,19 @@ class WorkflowService:
|
||||
Returns:
|
||||
WorkflowNodeExecution: The execution result
|
||||
"""
|
||||
created_at = naive_utc_now()
|
||||
node, node_run_result, run_succeeded, error = self._execute_node_safely(invoke_node_fn)
|
||||
finished_at = naive_utc_now()
|
||||
|
||||
# Create base node execution
|
||||
node_execution = WorkflowNodeExecution(
|
||||
id=node.execution_id or str(uuid.uuid4()),
|
||||
id=str(uuid.uuid4()),
|
||||
workflow_id="", # Single-step execution has no workflow ID
|
||||
index=1,
|
||||
node_id=node_id,
|
||||
node_type=node.node_type,
|
||||
title=node.title,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
created_at=naive_utc_now(),
|
||||
finished_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Populate execution result data
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Celery worker for enterprise metric/log telemetry events.
|
||||
|
||||
This module defines the Celery task that processes telemetry envelopes
|
||||
from the enterprise_telemetry queue. It deserializes envelopes and
|
||||
dispatches them to the EnterpriseMetricHandler.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryEnvelope
|
||||
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="enterprise_telemetry")
|
||||
def process_enterprise_telemetry(envelope_json: str) -> None:
|
||||
"""Process enterprise metric/log telemetry envelope.
|
||||
|
||||
This task is enqueued by the TelemetryGateway for metric/log-only
|
||||
events. It deserializes the envelope and dispatches to the handler.
|
||||
|
||||
Best-effort processing: logs errors but never raises, to avoid
|
||||
failing user requests due to telemetry issues.
|
||||
|
||||
Args:
|
||||
envelope_json: JSON-serialized TelemetryEnvelope.
|
||||
"""
|
||||
try:
|
||||
# Deserialize envelope
|
||||
envelope_dict = json.loads(envelope_json)
|
||||
envelope = TelemetryEnvelope.model_validate(envelope_dict)
|
||||
|
||||
# Process through handler
|
||||
handler = EnterpriseMetricHandler()
|
||||
handler.handle(envelope)
|
||||
|
||||
logger.debug(
|
||||
"Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s",
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
envelope.case,
|
||||
)
|
||||
except Exception:
|
||||
# Best-effort: log and drop on error, never fail user request
|
||||
logger.warning(
|
||||
"Failed to process enterprise telemetry envelope, dropping event",
|
||||
exc_info=True,
|
||||
)
|
||||
@@ -39,24 +39,12 @@ def process_trace_tasks(file_info):
|
||||
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
|
||||
|
||||
try:
|
||||
trace_type = trace_info_info_map.get(trace_info_type)
|
||||
if trace_type:
|
||||
trace_info = trace_type(**trace_info)
|
||||
|
||||
from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled
|
||||
|
||||
if is_ee_telemetry_enabled():
|
||||
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
|
||||
|
||||
try:
|
||||
EnterpriseOtelTrace().trace(trace_info)
|
||||
except Exception:
|
||||
logger.warning("Enterprise trace failed for app_id: %s", app_id, exc_info=True)
|
||||
|
||||
if trace_instance:
|
||||
with current_app.app_context():
|
||||
trace_type = trace_info_info_map.get(trace_info_type)
|
||||
if trace_type:
|
||||
trace_info = trace_type(**trace_info)
|
||||
trace_instance.trace(trace_info)
|
||||
|
||||
logger.info("Processing trace tasks success, app_id: %s", app_id)
|
||||
except Exception as e:
|
||||
logger.info("error:\n\n\n%s\n\n\n\n", e)
|
||||
@@ -64,12 +52,4 @@ def process_trace_tasks(file_info):
|
||||
redis_client.incr(failed_key)
|
||||
logger.info("Processing trace tasks failed, app_id: %s", app_id)
|
||||
finally:
|
||||
try:
|
||||
storage.delete(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete trace file %s for app_id %s: %s",
|
||||
file_path,
|
||||
app_id,
|
||||
e,
|
||||
)
|
||||
storage.delete(file_path)
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
"""Unit tests for TraceQueueManager telemetry guard.
|
||||
|
||||
This test suite verifies that TraceQueueManager correctly drops trace tasks
|
||||
when telemetry is disabled, proving Bug 1 from code review is a false positive.
|
||||
|
||||
The guard logic moved from persistence.py to TraceQueueManager.add_trace_task()
|
||||
at line 1282 of ops_trace_manager.py:
|
||||
if self._enterprise_telemetry_enabled or self.trace_instance:
|
||||
trace_task.app_id = self.app_id
|
||||
trace_manager_queue.put(trace_task)
|
||||
|
||||
Tasks are only enqueued if EITHER:
|
||||
- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR
|
||||
- A third-party trace instance (Langfuse, etc.) is configured
|
||||
|
||||
When BOTH are false, tasks are silently dropped (correct behavior).
|
||||
"""
|
||||
|
||||
import queue
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_queue_manager_and_task(monkeypatch):
|
||||
"""Fixture to provide TraceQueueManager and TraceTask with delayed imports."""
|
||||
module_name = "core.ops.ops_trace_manager"
|
||||
if module_name not in sys.modules:
|
||||
ops_stub = types.ModuleType(module_name)
|
||||
|
||||
class StubTraceTask:
|
||||
def __init__(self, trace_type):
|
||||
self.trace_type = trace_type
|
||||
self.app_id = None
|
||||
|
||||
class StubTraceQueueManager:
|
||||
def __init__(self, app_id=None):
|
||||
self.app_id = app_id
|
||||
from core.telemetry import is_enterprise_telemetry_enabled
|
||||
|
||||
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
|
||||
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
|
||||
|
||||
def add_trace_task(self, trace_task):
|
||||
if self._enterprise_telemetry_enabled or self.trace_instance:
|
||||
trace_task.app_id = self.app_id
|
||||
from core.ops.ops_trace_manager import trace_manager_queue
|
||||
|
||||
trace_manager_queue.put(trace_task)
|
||||
|
||||
class StubOpsTraceManager:
|
||||
@staticmethod
|
||||
def get_ops_trace_instance(app_id):
|
||||
return None
|
||||
|
||||
ops_stub.TraceQueueManager = StubTraceQueueManager
|
||||
ops_stub.TraceTask = StubTraceTask
|
||||
ops_stub.OpsTraceManager = StubOpsTraceManager
|
||||
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
|
||||
monkeypatch.setitem(sys.modules, module_name, ops_stub)
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
|
||||
ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"])
|
||||
TraceQueueManager = ops_module.TraceQueueManager
|
||||
TraceTask = ops_module.TraceTask
|
||||
|
||||
return TraceQueueManager, TraceTask, TraceTaskName
|
||||
|
||||
|
||||
class TestTraceQueueManagerTelemetryGuard:
|
||||
"""Test TraceQueueManager's telemetry guard in add_trace_task()."""
|
||||
|
||||
def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task):
|
||||
"""Verify task is NOT enqueued when telemetry disabled and no trace instance.
|
||||
|
||||
This is the core guard: when _enterprise_telemetry_enabled=False AND
|
||||
trace_instance=None, the task should be silently dropped.
|
||||
"""
|
||||
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
|
||||
|
||||
mock_queue = MagicMock(spec=queue.Queue)
|
||||
|
||||
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
|
||||
|
||||
with (
|
||||
patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=False),
|
||||
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
|
||||
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
|
||||
):
|
||||
manager = TraceQueueManager(app_id="test-app-id")
|
||||
manager.add_trace_task(trace_task)
|
||||
|
||||
mock_queue.put.assert_not_called()
|
||||
|
||||
def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task):
|
||||
"""Verify task IS enqueued when enterprise telemetry is enabled.
|
||||
|
||||
When _enterprise_telemetry_enabled=True, the task should be enqueued
|
||||
regardless of trace_instance state.
|
||||
"""
|
||||
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
|
||||
|
||||
mock_queue = MagicMock(spec=queue.Queue)
|
||||
|
||||
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
|
||||
|
||||
with (
|
||||
patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=True),
|
||||
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
|
||||
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
|
||||
):
|
||||
manager = TraceQueueManager(app_id="test-app-id")
|
||||
manager.add_trace_task(trace_task)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.app_id == "test-app-id"
|
||||
|
||||
def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task):
|
||||
"""Verify task IS enqueued when third-party trace instance is configured.
|
||||
|
||||
When trace_instance is not None (e.g., Langfuse configured), the task
|
||||
should be enqueued even if enterprise telemetry is disabled.
|
||||
"""
|
||||
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
|
||||
|
||||
mock_queue = MagicMock(spec=queue.Queue)
|
||||
|
||||
mock_trace_instance = MagicMock()
|
||||
|
||||
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
|
||||
|
||||
with (
|
||||
patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=False),
|
||||
patch(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
|
||||
),
|
||||
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
|
||||
):
|
||||
manager = TraceQueueManager(app_id="test-app-id")
|
||||
manager.add_trace_task(trace_task)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.app_id == "test-app-id"
|
||||
|
||||
def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task):
|
||||
"""Verify task IS enqueued when both telemetry and trace instance are enabled.
|
||||
|
||||
When both _enterprise_telemetry_enabled=True AND trace_instance is set,
|
||||
the task should definitely be enqueued.
|
||||
"""
|
||||
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
|
||||
|
||||
mock_queue = MagicMock(spec=queue.Queue)
|
||||
|
||||
mock_trace_instance = MagicMock()
|
||||
|
||||
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
|
||||
|
||||
with (
|
||||
patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=True),
|
||||
patch(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
|
||||
),
|
||||
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
|
||||
):
|
||||
manager = TraceQueueManager(app_id="test-app-id")
|
||||
manager.add_trace_task(trace_task)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.app_id == "test-app-id"
|
||||
|
||||
def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task):
|
||||
"""Verify app_id is set on the task before enqueuing.
|
||||
|
||||
The guard logic sets trace_task.app_id = self.app_id before calling
|
||||
trace_manager_queue.put(trace_task). This test verifies that behavior.
|
||||
"""
|
||||
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
|
||||
|
||||
mock_queue = MagicMock(spec=queue.Queue)
|
||||
|
||||
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
|
||||
|
||||
with (
|
||||
patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=True),
|
||||
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
|
||||
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
|
||||
):
|
||||
manager = TraceQueueManager(app_id="expected-app-id")
|
||||
manager.add_trace_task(trace_task)
|
||||
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.app_id == "expected-app-id"
|
||||
@@ -1,181 +0,0 @@
|
||||
"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.telemetry.events import TelemetryContext, TelemetryEvent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def telemetry_test_setup(monkeypatch):
|
||||
module_name = "core.ops.ops_trace_manager"
|
||||
ops_stub = types.ModuleType(module_name)
|
||||
|
||||
class StubTraceTask:
|
||||
def __init__(self, trace_type, **kwargs):
|
||||
self.trace_type = trace_type
|
||||
self.app_id = None
|
||||
self.kwargs = kwargs
|
||||
|
||||
class StubTraceQueueManager:
|
||||
def __init__(self, app_id=None, user_id=None):
|
||||
self.app_id = app_id
|
||||
self.user_id = user_id
|
||||
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
|
||||
|
||||
def add_trace_task(self, trace_task):
|
||||
trace_task.app_id = self.app_id
|
||||
from core.ops.ops_trace_manager import trace_manager_queue
|
||||
|
||||
trace_manager_queue.put(trace_task)
|
||||
|
||||
class StubOpsTraceManager:
|
||||
@staticmethod
|
||||
def get_ops_trace_instance(app_id):
|
||||
return None
|
||||
|
||||
ops_stub.TraceQueueManager = StubTraceQueueManager
|
||||
ops_stub.TraceTask = StubTraceTask
|
||||
ops_stub.OpsTraceManager = StubOpsTraceManager
|
||||
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
|
||||
monkeypatch.setitem(sys.modules, module_name, ops_stub)
|
||||
|
||||
from core.telemetry import emit
|
||||
|
||||
return emit, ops_stub.trace_manager_queue
|
||||
|
||||
|
||||
class TestTelemetryEmit:
|
||||
@patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_emit_enterprise_trace_creates_trace_task(self, _mock_ee, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={"key": "value"},
|
||||
)
|
||||
|
||||
emit_fn(event)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
|
||||
|
||||
def test_emit_community_trace_enqueued(self, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={},
|
||||
)
|
||||
|
||||
emit_fn(event)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
|
||||
def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={},
|
||||
)
|
||||
|
||||
emit_fn(event)
|
||||
|
||||
mock_queue.put.assert_not_called()
|
||||
|
||||
@patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, _mock_ee, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
enterprise_only_traces = [
|
||||
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
TraceTaskName.PROMPT_GENERATION_TRACE,
|
||||
]
|
||||
|
||||
for trace_name in enterprise_only_traces:
|
||||
mock_queue.reset_mock()
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=trace_name,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={},
|
||||
)
|
||||
|
||||
emit_fn(event)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.trace_type == trace_name
|
||||
|
||||
@patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_emit_passes_name_directly_to_trace_task(self, _mock_ee, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={"extra": "data"},
|
||||
)
|
||||
|
||||
emit_fn(event)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
|
||||
assert isinstance(called_task.trace_type, TraceTaskName)
|
||||
|
||||
@patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_emit_with_provided_trace_manager(self, _mock_ee, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
mock_trace_manager = MagicMock()
|
||||
mock_trace_manager.add_trace_task = MagicMock()
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={},
|
||||
)
|
||||
|
||||
emit_fn(event, trace_manager=mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
called_task = mock_trace_manager.add_trace_task.call_args[0][0]
|
||||
assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE
|
||||
@@ -1,252 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.telemetry import is_enterprise_telemetry_enabled
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
from enterprise.telemetry.gateway import TelemetryGateway
|
||||
|
||||
|
||||
class TestTelemetryCoreExports:
|
||||
def test_is_enterprise_telemetry_enabled_exported(self) -> None:
|
||||
from core.telemetry import is_enterprise_telemetry_enabled as exported_func
|
||||
|
||||
assert callable(exported_func)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ops_trace_manager():
|
||||
mock_module = MagicMock()
|
||||
mock_trace_task_class = MagicMock()
|
||||
mock_trace_task_class.return_value = MagicMock()
|
||||
mock_module.TraceTask = mock_trace_task_class
|
||||
mock_module.TraceQueueManager = MagicMock()
|
||||
|
||||
mock_trace_entity = MagicMock()
|
||||
mock_trace_task_name = MagicMock()
|
||||
mock_trace_task_name.return_value = "workflow"
|
||||
mock_trace_entity.TraceTaskName = mock_trace_task_name
|
||||
|
||||
with (
|
||||
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
|
||||
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
|
||||
):
|
||||
yield mock_module, mock_trace_entity
|
||||
|
||||
|
||||
class TestGatewayIntegrationTraceRouting:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> TelemetryGateway:
|
||||
return TelemetryGateway()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trace_manager(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_ce_eligible_trace_routed_to_trace_manager(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True):
|
||||
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_ce_eligible_trace_routed_when_ee_disabled(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_enterprise_only_trace_dropped_when_ee_disabled(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_enterprise_only_trace_routed_when_ee_enabled(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
|
||||
class TestGatewayIntegrationMetricRouting:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> TelemetryGateway:
|
||||
return TelemetryGateway()
|
||||
|
||||
def test_metric_case_routes_to_celery_task(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
from enterprise.telemetry.contracts import TelemetryEnvelope
|
||||
|
||||
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"app_id": "app-abc", "name": "My App"}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
mock_delay.assert_called_once()
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.case == TelemetryCase.APP_CREATED
|
||||
assert envelope.tenant_id == "tenant-123"
|
||||
assert envelope.payload["app_id"] == "app-abc"
|
||||
|
||||
def test_tool_execution_metric_routed(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
from enterprise.telemetry.contracts import TelemetryEnvelope
|
||||
|
||||
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay:
|
||||
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
|
||||
payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"}
|
||||
|
||||
gateway.emit(TelemetryCase.TOOL_EXECUTION, context, payload)
|
||||
|
||||
mock_delay.assert_called_once()
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.case == TelemetryCase.TOOL_EXECUTION
|
||||
|
||||
def test_moderation_check_metric_routed(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
from enterprise.telemetry.contracts import TelemetryEnvelope
|
||||
|
||||
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay:
|
||||
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
|
||||
payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}}
|
||||
|
||||
gateway.emit(TelemetryCase.MODERATION_CHECK, context, payload)
|
||||
|
||||
mock_delay.assert_called_once()
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.case == TelemetryCase.MODERATION_CHECK
|
||||
|
||||
|
||||
class TestGatewayIntegrationCEEligibility:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> TelemetryGateway:
|
||||
return TelemetryGateway()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trace_manager(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_workflow_run_is_ce_eligible(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_message_run_is_ce_eligible(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"message_id": "msg-abc", "conversation_id": "conv-123"}
|
||||
|
||||
gateway.emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_node_execution_not_ce_eligible(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_draft_node_execution_not_ce_eligible(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_execution_data": {}}
|
||||
|
||||
gateway.emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_prompt_generation_not_ce_eligible(
|
||||
self,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
|
||||
payload = {"operation_type": "generate", "instruction": "test"}
|
||||
|
||||
gateway.emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
|
||||
class TestIsEnterpriseTelemetryEnabled:
|
||||
def test_returns_false_when_exporter_import_fails(self) -> None:
|
||||
with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}):
|
||||
result = is_enterprise_telemetry_enabled()
|
||||
assert result is False
|
||||
|
||||
def test_function_is_callable(self) -> None:
|
||||
assert callable(is_enterprise_telemetry_enabled)
|
||||
@@ -1,264 +0,0 @@
|
||||
"""Unit tests for telemetry gateway contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope
|
||||
from enterprise.telemetry.gateway import CASE_ROUTING
|
||||
|
||||
|
||||
class TestTelemetryCase:
|
||||
"""Tests for TelemetryCase enum."""
|
||||
|
||||
def test_all_cases_defined(self) -> None:
|
||||
"""Verify all 14 telemetry cases are defined."""
|
||||
expected_cases = {
|
||||
"WORKFLOW_RUN",
|
||||
"NODE_EXECUTION",
|
||||
"DRAFT_NODE_EXECUTION",
|
||||
"MESSAGE_RUN",
|
||||
"TOOL_EXECUTION",
|
||||
"MODERATION_CHECK",
|
||||
"SUGGESTED_QUESTION",
|
||||
"DATASET_RETRIEVAL",
|
||||
"GENERATE_NAME",
|
||||
"PROMPT_GENERATION",
|
||||
"APP_CREATED",
|
||||
"APP_UPDATED",
|
||||
"APP_DELETED",
|
||||
"FEEDBACK_CREATED",
|
||||
}
|
||||
actual_cases = {case.name for case in TelemetryCase}
|
||||
assert actual_cases == expected_cases
|
||||
|
||||
def test_case_values(self) -> None:
|
||||
"""Verify case enum values are correct."""
|
||||
assert TelemetryCase.WORKFLOW_RUN.value == "workflow_run"
|
||||
assert TelemetryCase.NODE_EXECUTION.value == "node_execution"
|
||||
assert TelemetryCase.DRAFT_NODE_EXECUTION.value == "draft_node_execution"
|
||||
assert TelemetryCase.MESSAGE_RUN.value == "message_run"
|
||||
assert TelemetryCase.TOOL_EXECUTION.value == "tool_execution"
|
||||
assert TelemetryCase.MODERATION_CHECK.value == "moderation_check"
|
||||
assert TelemetryCase.SUGGESTED_QUESTION.value == "suggested_question"
|
||||
assert TelemetryCase.DATASET_RETRIEVAL.value == "dataset_retrieval"
|
||||
assert TelemetryCase.GENERATE_NAME.value == "generate_name"
|
||||
assert TelemetryCase.PROMPT_GENERATION.value == "prompt_generation"
|
||||
assert TelemetryCase.APP_CREATED.value == "app_created"
|
||||
assert TelemetryCase.APP_UPDATED.value == "app_updated"
|
||||
assert TelemetryCase.APP_DELETED.value == "app_deleted"
|
||||
assert TelemetryCase.FEEDBACK_CREATED.value == "feedback_created"
|
||||
|
||||
|
||||
class TestCaseRoute:
|
||||
"""Tests for CaseRoute model."""
|
||||
|
||||
def test_valid_trace_route(self) -> None:
|
||||
"""Verify valid trace route creation."""
|
||||
route = CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True)
|
||||
assert route.signal_type == SignalType.TRACE
|
||||
assert route.ce_eligible is True
|
||||
|
||||
def test_valid_metric_log_route(self) -> None:
|
||||
"""Verify valid metric_log route creation."""
|
||||
route = CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False)
|
||||
assert route.signal_type == SignalType.METRIC_LOG
|
||||
assert route.ce_eligible is False
|
||||
|
||||
def test_invalid_signal_type(self) -> None:
|
||||
"""Verify invalid signal_type is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
CaseRoute(signal_type="invalid", ce_eligible=True)
|
||||
|
||||
|
||||
class TestTelemetryEnvelope:
|
||||
"""Tests for TelemetryEnvelope model."""
|
||||
|
||||
def test_valid_envelope_minimal(self) -> None:
|
||||
"""Verify valid minimal envelope creation."""
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"key": "value"},
|
||||
)
|
||||
assert envelope.case == TelemetryCase.WORKFLOW_RUN
|
||||
assert envelope.tenant_id == "tenant-123"
|
||||
assert envelope.event_id == "event-456"
|
||||
assert envelope.payload == {"key": "value"}
|
||||
assert envelope.payload_fallback is None
|
||||
assert envelope.metadata is None
|
||||
|
||||
def test_valid_envelope_full(self) -> None:
|
||||
"""Verify valid envelope with all fields."""
|
||||
metadata = {"source": "api"}
|
||||
fallback = b"fallback data"
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.MESSAGE_RUN,
|
||||
tenant_id="tenant-789",
|
||||
event_id="event-012",
|
||||
payload={"message": "hello"},
|
||||
payload_fallback=fallback,
|
||||
metadata=metadata,
|
||||
)
|
||||
assert envelope.case == TelemetryCase.MESSAGE_RUN
|
||||
assert envelope.tenant_id == "tenant-789"
|
||||
assert envelope.event_id == "event-012"
|
||||
assert envelope.payload == {"message": "hello"}
|
||||
assert envelope.payload_fallback == fallback
|
||||
assert envelope.metadata == metadata
|
||||
|
||||
def test_missing_required_case(self) -> None:
|
||||
"""Verify missing case field is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
TelemetryEnvelope(
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"key": "value"},
|
||||
)
|
||||
|
||||
def test_missing_required_tenant_id(self) -> None:
|
||||
"""Verify missing tenant_id field is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
event_id="event-456",
|
||||
payload={"key": "value"},
|
||||
)
|
||||
|
||||
def test_missing_required_event_id(self) -> None:
|
||||
"""Verify missing event_id field is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
tenant_id="tenant-123",
|
||||
payload={"key": "value"},
|
||||
)
|
||||
|
||||
def test_missing_required_payload(self) -> None:
|
||||
"""Verify missing payload field is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
)
|
||||
|
||||
def test_payload_fallback_within_limit(self) -> None:
|
||||
"""Verify payload_fallback within 64KB limit is accepted."""
|
||||
fallback = b"x" * 65536
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"key": "value"},
|
||||
payload_fallback=fallback,
|
||||
)
|
||||
assert envelope.payload_fallback == fallback
|
||||
|
||||
def test_payload_fallback_exceeds_limit(self) -> None:
|
||||
"""Verify payload_fallback exceeding 64KB is rejected."""
|
||||
fallback = b"x" * 65537
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"key": "value"},
|
||||
payload_fallback=fallback,
|
||||
)
|
||||
assert "64KB" in str(exc_info.value)
|
||||
|
||||
def test_payload_fallback_none(self) -> None:
|
||||
"""Verify payload_fallback can be None."""
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"key": "value"},
|
||||
payload_fallback=None,
|
||||
)
|
||||
assert envelope.payload_fallback is None
|
||||
|
||||
|
||||
class TestCaseRouting:
|
||||
"""Tests for CASE_ROUTING table."""
|
||||
|
||||
def test_all_cases_routed(self) -> None:
|
||||
"""Verify all 14 cases have routing entries."""
|
||||
assert len(CASE_ROUTING) == 14
|
||||
for case in TelemetryCase:
|
||||
assert case in CASE_ROUTING
|
||||
|
||||
def test_trace_ce_eligible_cases(self) -> None:
|
||||
"""Verify trace cases with CE eligibility."""
|
||||
ce_eligible_trace_cases = {
|
||||
TelemetryCase.WORKFLOW_RUN,
|
||||
TelemetryCase.MESSAGE_RUN,
|
||||
}
|
||||
for case in ce_eligible_trace_cases:
|
||||
route = CASE_ROUTING[case]
|
||||
assert route.signal_type == SignalType.TRACE
|
||||
assert route.ce_eligible is True
|
||||
|
||||
def test_trace_enterprise_only_cases(self) -> None:
|
||||
"""Verify trace cases that are enterprise-only."""
|
||||
enterprise_only_trace_cases = {
|
||||
TelemetryCase.NODE_EXECUTION,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION,
|
||||
TelemetryCase.PROMPT_GENERATION,
|
||||
}
|
||||
for case in enterprise_only_trace_cases:
|
||||
route = CASE_ROUTING[case]
|
||||
assert route.signal_type == SignalType.TRACE
|
||||
assert route.ce_eligible is False
|
||||
|
||||
def test_metric_log_cases(self) -> None:
|
||||
"""Verify metric/log-only cases."""
|
||||
metric_log_cases = {
|
||||
TelemetryCase.APP_CREATED,
|
||||
TelemetryCase.APP_UPDATED,
|
||||
TelemetryCase.APP_DELETED,
|
||||
TelemetryCase.FEEDBACK_CREATED,
|
||||
TelemetryCase.TOOL_EXECUTION,
|
||||
TelemetryCase.MODERATION_CHECK,
|
||||
TelemetryCase.SUGGESTED_QUESTION,
|
||||
TelemetryCase.DATASET_RETRIEVAL,
|
||||
TelemetryCase.GENERATE_NAME,
|
||||
}
|
||||
for case in metric_log_cases:
|
||||
route = CASE_ROUTING[case]
|
||||
assert route.signal_type == SignalType.METRIC_LOG
|
||||
assert route.ce_eligible is False
|
||||
|
||||
def test_routing_table_completeness(self) -> None:
|
||||
"""Verify routing table covers all cases with correct types."""
|
||||
trace_cases = {
|
||||
TelemetryCase.WORKFLOW_RUN,
|
||||
TelemetryCase.MESSAGE_RUN,
|
||||
TelemetryCase.NODE_EXECUTION,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION,
|
||||
TelemetryCase.PROMPT_GENERATION,
|
||||
}
|
||||
metric_log_cases = {
|
||||
TelemetryCase.APP_CREATED,
|
||||
TelemetryCase.APP_UPDATED,
|
||||
TelemetryCase.APP_DELETED,
|
||||
TelemetryCase.FEEDBACK_CREATED,
|
||||
TelemetryCase.TOOL_EXECUTION,
|
||||
TelemetryCase.MODERATION_CHECK,
|
||||
TelemetryCase.SUGGESTED_QUESTION,
|
||||
TelemetryCase.DATASET_RETRIEVAL,
|
||||
TelemetryCase.GENERATE_NAME,
|
||||
}
|
||||
|
||||
all_cases = trace_cases | metric_log_cases
|
||||
assert len(all_cases) == 14
|
||||
assert all_cases == set(TelemetryCase)
|
||||
|
||||
for case in trace_cases:
|
||||
assert CASE_ROUTING[case].signal_type == SignalType.TRACE
|
||||
|
||||
for case in metric_log_cases:
|
||||
assert CASE_ROUTING[case].signal_type == SignalType.METRIC_LOG
|
||||
@@ -1,134 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enterprise.telemetry import event_handlers
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_exporter():
|
||||
with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock:
|
||||
exporter = MagicMock()
|
||||
mock.return_value = exporter
|
||||
yield exporter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def test_handle_app_created_calls_task(mock_exporter, mock_task):
|
||||
sender = MagicMock()
|
||||
sender.id = "app-123"
|
||||
sender.tenant_id = "tenant-456"
|
||||
sender.mode = "chat"
|
||||
|
||||
event_handlers._handle_app_created(sender)
|
||||
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[0][0]
|
||||
assert "app_created" in call_args
|
||||
assert "tenant-456" in call_args
|
||||
assert "app-123" in call_args
|
||||
assert "chat" in call_args
|
||||
|
||||
|
||||
def test_handle_app_created_no_exporter(mock_task):
|
||||
with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=None):
|
||||
sender = MagicMock()
|
||||
sender.id = "app-123"
|
||||
sender.tenant_id = "tenant-456"
|
||||
|
||||
event_handlers._handle_app_created(sender)
|
||||
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
|
||||
def test_handle_app_updated_calls_task(mock_exporter, mock_task):
|
||||
sender = MagicMock()
|
||||
sender.id = "app-123"
|
||||
sender.tenant_id = "tenant-456"
|
||||
|
||||
event_handlers._handle_app_updated(sender)
|
||||
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[0][0]
|
||||
assert "app_updated" in call_args
|
||||
assert "tenant-456" in call_args
|
||||
assert "app-123" in call_args
|
||||
|
||||
|
||||
def test_handle_app_deleted_calls_task(mock_exporter, mock_task):
|
||||
sender = MagicMock()
|
||||
sender.id = "app-123"
|
||||
sender.tenant_id = "tenant-456"
|
||||
|
||||
event_handlers._handle_app_deleted(sender)
|
||||
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[0][0]
|
||||
assert "app_deleted" in call_args
|
||||
assert "tenant-456" in call_args
|
||||
assert "app-123" in call_args
|
||||
|
||||
|
||||
def test_handle_feedback_created_calls_task(mock_exporter, mock_task):
|
||||
sender = MagicMock()
|
||||
sender.message_id = "msg-123"
|
||||
sender.app_id = "app-456"
|
||||
sender.conversation_id = "conv-789"
|
||||
sender.from_end_user_id = "user-001"
|
||||
sender.from_account_id = None
|
||||
sender.rating = "like"
|
||||
sender.from_source = "api"
|
||||
sender.content = "Great response!"
|
||||
|
||||
event_handlers._handle_feedback_created(sender, tenant_id="tenant-456")
|
||||
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[0][0]
|
||||
assert "feedback_created" in call_args
|
||||
assert "tenant-456" in call_args
|
||||
assert "msg-123" in call_args
|
||||
assert "app-456" in call_args
|
||||
assert "conv-789" in call_args
|
||||
assert "user-001" in call_args
|
||||
assert "like" in call_args
|
||||
assert "api" in call_args
|
||||
assert "Great response!" in call_args
|
||||
|
||||
|
||||
def test_handle_feedback_created_no_exporter(mock_task):
|
||||
with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=None):
|
||||
sender = MagicMock()
|
||||
sender.message_id = "msg-123"
|
||||
|
||||
event_handlers._handle_feedback_created(sender, tenant_id="tenant-456")
|
||||
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
|
||||
def test_handlers_create_valid_envelopes(mock_exporter, mock_task):
|
||||
import json
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryEnvelope
|
||||
|
||||
sender = MagicMock()
|
||||
sender.id = "app-123"
|
||||
sender.tenant_id = "tenant-456"
|
||||
sender.mode = "chat"
|
||||
|
||||
event_handlers._handle_app_created(sender)
|
||||
|
||||
call_args = mock_task.delay.call_args[0][0]
|
||||
envelope_dict = json.loads(call_args)
|
||||
envelope = TelemetryEnvelope(**envelope_dict)
|
||||
|
||||
assert envelope.case == TelemetryCase.APP_CREATED
|
||||
assert envelope.tenant_id == "tenant-456"
|
||||
assert envelope.event_id
|
||||
assert envelope.payload["app_id"] == "app-123"
|
||||
assert envelope.payload["mode"] == "chat"
|
||||
@@ -1,301 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from enterprise.telemetry.contracts import SignalType, TelemetryCase, TelemetryEnvelope
|
||||
from enterprise.telemetry.gateway import (
|
||||
CASE_ROUTING,
|
||||
CASE_TO_TRACE_TASK,
|
||||
PAYLOAD_SIZE_THRESHOLD_BYTES,
|
||||
TelemetryGateway,
|
||||
emit,
|
||||
)
|
||||
|
||||
|
||||
class TestCaseRoutingTable:
|
||||
def test_all_cases_have_routing(self) -> None:
|
||||
for case in TelemetryCase:
|
||||
assert case in CASE_ROUTING, f"Missing routing for {case}"
|
||||
|
||||
def test_trace_cases(self) -> None:
|
||||
trace_cases = [
|
||||
TelemetryCase.WORKFLOW_RUN,
|
||||
TelemetryCase.MESSAGE_RUN,
|
||||
TelemetryCase.NODE_EXECUTION,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION,
|
||||
TelemetryCase.PROMPT_GENERATION,
|
||||
]
|
||||
for case in trace_cases:
|
||||
assert CASE_ROUTING[case].signal_type is SignalType.TRACE, f"{case} should be trace"
|
||||
|
||||
def test_metric_log_cases(self) -> None:
|
||||
metric_log_cases = [
|
||||
TelemetryCase.APP_CREATED,
|
||||
TelemetryCase.APP_UPDATED,
|
||||
TelemetryCase.APP_DELETED,
|
||||
TelemetryCase.FEEDBACK_CREATED,
|
||||
TelemetryCase.TOOL_EXECUTION,
|
||||
TelemetryCase.MODERATION_CHECK,
|
||||
TelemetryCase.SUGGESTED_QUESTION,
|
||||
TelemetryCase.DATASET_RETRIEVAL,
|
||||
TelemetryCase.GENERATE_NAME,
|
||||
]
|
||||
for case in metric_log_cases:
|
||||
assert CASE_ROUTING[case].signal_type is SignalType.METRIC_LOG, f"{case} should be metric_log"
|
||||
|
||||
def test_ce_eligible_cases(self) -> None:
|
||||
ce_eligible_cases = [TelemetryCase.WORKFLOW_RUN, TelemetryCase.MESSAGE_RUN]
|
||||
for case in ce_eligible_cases:
|
||||
assert CASE_ROUTING[case].ce_eligible is True, f"{case} should be CE eligible"
|
||||
|
||||
def test_enterprise_only_cases(self) -> None:
|
||||
enterprise_only_cases = [
|
||||
TelemetryCase.NODE_EXECUTION,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION,
|
||||
TelemetryCase.PROMPT_GENERATION,
|
||||
]
|
||||
for case in enterprise_only_cases:
|
||||
assert CASE_ROUTING[case].ce_eligible is False, f"{case} should be enterprise-only"
|
||||
|
||||
def test_trace_cases_have_task_name_mapping(self) -> None:
|
||||
trace_cases = [c for c in TelemetryCase if CASE_ROUTING[c].signal_type is SignalType.TRACE]
|
||||
for case in trace_cases:
|
||||
assert case in CASE_TO_TRACE_TASK, f"Missing TraceTaskName mapping for {case}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ops_trace_manager():
|
||||
mock_module = MagicMock()
|
||||
mock_trace_task_class = MagicMock()
|
||||
mock_trace_task_class.return_value = MagicMock()
|
||||
mock_module.TraceTask = mock_trace_task_class
|
||||
mock_module.TraceQueueManager = MagicMock()
|
||||
|
||||
mock_trace_entity = MagicMock()
|
||||
mock_trace_task_name = MagicMock()
|
||||
mock_trace_task_name.return_value = "workflow"
|
||||
mock_trace_entity.TraceTaskName = mock_trace_task_name
|
||||
|
||||
with (
|
||||
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
|
||||
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
|
||||
):
|
||||
yield mock_module, mock_trace_entity
|
||||
|
||||
|
||||
class TestTelemetryGatewayTraceRouting:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> TelemetryGateway:
|
||||
return TelemetryGateway()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trace_manager(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_trace_case_routes_to_trace_manager(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False)
|
||||
def test_ce_eligible_trace_enqueued_when_ee_disabled(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False)
|
||||
def test_enterprise_only_trace_dropped_when_ee_disabled(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_enterprise_only_trace_enqueued_when_ee_enabled(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
|
||||
class TestTelemetryGatewayMetricLogRouting:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> TelemetryGateway:
|
||||
return TelemetryGateway()
|
||||
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_metric_case_routes_to_celery_task(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"app_id": "app-abc", "name": "My App"}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
mock_delay.assert_called_once()
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.case == TelemetryCase.APP_CREATED
|
||||
assert envelope.tenant_id == "tenant-123"
|
||||
assert envelope.payload["app_id"] == "app-abc"
|
||||
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_envelope_has_unique_event_id(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"app_id": "app-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
assert mock_delay.call_count == 2
|
||||
envelope1 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[0][0][0])
|
||||
envelope2 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[1][0][0])
|
||||
assert envelope1.event_id != envelope2.event_id
|
||||
|
||||
|
||||
class TestTelemetryGatewayPayloadSizing:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> TelemetryGateway:
|
||||
return TelemetryGateway()
|
||||
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_small_payload_inlined(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"key": "small_value"}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.payload == payload
|
||||
assert envelope.metadata is None
|
||||
|
||||
@patch("enterprise.telemetry.gateway.storage")
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_large_payload_stored(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
mock_storage: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
|
||||
payload = {"key": large_value}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
mock_storage.save.assert_called_once()
|
||||
storage_key = mock_storage.save.call_args[0][0]
|
||||
assert storage_key.startswith("telemetry/tenant-123/")
|
||||
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.payload == {}
|
||||
assert envelope.metadata is not None
|
||||
assert envelope.metadata["payload_ref"] == storage_key
|
||||
|
||||
@patch("enterprise.telemetry.gateway.storage")
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_large_payload_fallback_on_storage_error(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
mock_storage: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
mock_storage.save.side_effect = Exception("Storage failure")
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
|
||||
payload = {"key": large_value}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.payload == payload
|
||||
assert envelope.metadata is None
|
||||
|
||||
|
||||
class TestModuleLevelFunctions:
|
||||
@patch("extensions.ext_enterprise_telemetry.get_gateway")
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_emit_function_uses_gateway(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
mock_get_gateway: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
mock_gateway = TelemetryGateway()
|
||||
mock_get_gateway.return_value = mock_gateway
|
||||
mock_trace_manager = MagicMock()
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
with patch.object(mock_gateway, "emit") as mock_emit:
|
||||
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
mock_emit.assert_called_once_with(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
|
||||
class TestTraceTaskNameMapping:
|
||||
def test_workflow_run_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK[TelemetryCase.WORKFLOW_RUN] is TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
def test_message_run_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK[TelemetryCase.MESSAGE_RUN] is TraceTaskName.MESSAGE_TRACE
|
||||
|
||||
def test_node_execution_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK[TelemetryCase.NODE_EXECUTION] is TraceTaskName.NODE_EXECUTION_TRACE
|
||||
|
||||
def test_draft_node_execution_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK[TelemetryCase.DRAFT_NODE_EXECUTION] is TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
|
||||
|
||||
def test_prompt_generation_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK[TelemetryCase.PROMPT_GENERATION] is TraceTaskName.PROMPT_GENERATION_TRACE
|
||||
@@ -1,474 +0,0 @@
|
||||
"""Unit tests for EnterpriseMetricHandler."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
with patch("enterprise.telemetry.metric_handler.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_envelope():
|
||||
return TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-123",
|
||||
payload={"app_id": "app-123", "name": "Test App"},
|
||||
)
|
||||
|
||||
|
||||
def test_dispatch_app_created(sample_envelope, mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_app_created") as mock_handler:
|
||||
handler.handle(sample_envelope)
|
||||
mock_handler.assert_called_once_with(sample_envelope)
|
||||
|
||||
|
||||
def test_dispatch_app_updated(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_UPDATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-456",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_app_updated") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_app_deleted(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_DELETED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-789",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_app_deleted") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_feedback_created(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.FEEDBACK_CREATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-abc",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_feedback_created") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_message_run(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.MESSAGE_RUN,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-msg",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_message_run") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_tool_execution(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.TOOL_EXECUTION,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-tool",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_tool_execution") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_moderation_check(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.MODERATION_CHECK,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-mod",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_moderation_check") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_suggested_question(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.SUGGESTED_QUESTION,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-sq",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_suggested_question") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_dataset_retrieval(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.DATASET_RETRIEVAL,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-ds",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_dataset_retrieval") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_generate_name(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.GENERATE_NAME,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-gn",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_generate_name") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_prompt_generation(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.PROMPT_GENERATION,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-pg",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_prompt_generation") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_all_known_cases_have_handlers(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
handler = EnterpriseMetricHandler()
|
||||
|
||||
for case in TelemetryCase:
|
||||
envelope = TelemetryEnvelope(
|
||||
case=case,
|
||||
tenant_id="test-tenant",
|
||||
event_id=f"test-{case.value}",
|
||||
payload={},
|
||||
)
|
||||
handler.handle(envelope)
|
||||
|
||||
|
||||
def test_idempotency_duplicate(sample_envelope, mock_redis):
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_app_created") as mock_handler:
|
||||
handler.handle(sample_envelope)
|
||||
mock_handler.assert_not_called()
|
||||
|
||||
|
||||
def test_idempotency_first_seen(sample_envelope, mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
is_dup = handler._is_duplicate(sample_envelope)
|
||||
|
||||
assert is_dup is False
|
||||
mock_redis.set.assert_called_once_with(
|
||||
"telemetry:dedup:test-tenant:test-event-123",
|
||||
b"1",
|
||||
nx=True,
|
||||
ex=3600,
|
||||
)
|
||||
|
||||
|
||||
def test_idempotency_redis_failure_fails_open(sample_envelope, mock_redis, caplog):
|
||||
mock_redis.set.side_effect = Exception("Redis unavailable")
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
is_dup = handler._is_duplicate(sample_envelope)
|
||||
|
||||
assert is_dup is False
|
||||
assert "Redis unavailable for deduplication check" in caplog.text
|
||||
|
||||
|
||||
def test_rehydration_uses_payload(sample_envelope):
|
||||
handler = EnterpriseMetricHandler()
|
||||
payload = handler._rehydrate(sample_envelope)
|
||||
|
||||
assert payload == {"app_id": "app-123", "name": "Test App"}
|
||||
|
||||
|
||||
def test_rehydration_fallback():
|
||||
import pickle
|
||||
|
||||
fallback_data = {"fallback": "data"}
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-fb",
|
||||
payload={},
|
||||
payload_fallback=pickle.dumps(fallback_data),
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
payload = handler._rehydrate(envelope)
|
||||
|
||||
assert payload == fallback_data
|
||||
|
||||
|
||||
def test_rehydration_emits_degraded_event_on_failure():
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-fail",
|
||||
payload={},
|
||||
payload_fallback=None,
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit:
|
||||
payload = handler._rehydrate(envelope)
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
assert payload == {}
|
||||
mock_emit.assert_called_once()
|
||||
call_args = mock_emit.call_args
|
||||
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED
|
||||
assert call_args[1]["attributes"]["rehydration_failed"] is True
|
||||
|
||||
|
||||
def test_on_app_created_emits_correct_event(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"app_id": "app-789", "mode": "chat"},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with (
|
||||
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
|
||||
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
|
||||
):
|
||||
mock_exporter = MagicMock()
|
||||
mock_get_exporter.return_value = mock_exporter
|
||||
|
||||
handler._on_app_created(envelope)
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
mock_emit.assert_called_once_with(
|
||||
event_name=EnterpriseTelemetryEvent.APP_CREATED,
|
||||
attributes={
|
||||
"dify.app.id": "app-789",
|
||||
"dify.tenant_id": "tenant-123",
|
||||
"dify.app.mode": "chat",
|
||||
},
|
||||
tenant_id="tenant-123",
|
||||
)
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
|
||||
|
||||
mock_exporter.increment_counter.assert_called_once()
|
||||
call_args = mock_exporter.increment_counter.call_args
|
||||
assert call_args[0][0] == EnterpriseTelemetryCounter.APP_CREATED
|
||||
assert call_args[0][1] == 1
|
||||
assert call_args[0][2]["tenant_id"] == "tenant-123"
|
||||
assert call_args[0][2]["app_id"] == "app-789"
|
||||
assert call_args[0][2]["mode"] == "chat"
|
||||
|
||||
|
||||
def test_on_app_updated_emits_correct_event(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_UPDATED,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"app_id": "app-789"},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with (
|
||||
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
|
||||
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
|
||||
):
|
||||
mock_exporter = MagicMock()
|
||||
mock_get_exporter.return_value = mock_exporter
|
||||
|
||||
handler._on_app_updated(envelope)
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
mock_emit.assert_called_once_with(
|
||||
event_name=EnterpriseTelemetryEvent.APP_UPDATED,
|
||||
attributes={
|
||||
"dify.app.id": "app-789",
|
||||
"dify.tenant_id": "tenant-123",
|
||||
},
|
||||
tenant_id="tenant-123",
|
||||
)
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
|
||||
|
||||
mock_exporter.increment_counter.assert_called_once()
|
||||
call_args = mock_exporter.increment_counter.call_args
|
||||
assert call_args[0][0] == EnterpriseTelemetryCounter.APP_UPDATED
|
||||
assert call_args[0][1] == 1
|
||||
assert call_args[0][2]["tenant_id"] == "tenant-123"
|
||||
assert call_args[0][2]["app_id"] == "app-789"
|
||||
|
||||
|
||||
def test_on_app_deleted_emits_correct_event(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_DELETED,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"app_id": "app-789"},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with (
|
||||
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
|
||||
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
|
||||
):
|
||||
mock_exporter = MagicMock()
|
||||
mock_get_exporter.return_value = mock_exporter
|
||||
|
||||
handler._on_app_deleted(envelope)
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
mock_emit.assert_called_once_with(
|
||||
event_name=EnterpriseTelemetryEvent.APP_DELETED,
|
||||
attributes={
|
||||
"dify.app.id": "app-789",
|
||||
"dify.tenant_id": "tenant-123",
|
||||
},
|
||||
tenant_id="tenant-123",
|
||||
)
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
|
||||
|
||||
mock_exporter.increment_counter.assert_called_once()
|
||||
call_args = mock_exporter.increment_counter.call_args
|
||||
assert call_args[0][0] == EnterpriseTelemetryCounter.APP_DELETED
|
||||
assert call_args[0][1] == 1
|
||||
assert call_args[0][2]["tenant_id"] == "tenant-123"
|
||||
assert call_args[0][2]["app_id"] == "app-789"
|
||||
|
||||
|
||||
def test_on_feedback_created_emits_correct_event(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.FEEDBACK_CREATED,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={
|
||||
"message_id": "msg-001",
|
||||
"app_id": "app-789",
|
||||
"conversation_id": "conv-123",
|
||||
"from_end_user_id": "user-456",
|
||||
"from_account_id": None,
|
||||
"rating": "like",
|
||||
"from_source": "api",
|
||||
"content": "Great!",
|
||||
},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with (
|
||||
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
|
||||
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
|
||||
):
|
||||
mock_exporter = MagicMock()
|
||||
mock_exporter.include_content = True
|
||||
mock_get_exporter.return_value = mock_exporter
|
||||
|
||||
handler._on_feedback_created(envelope)
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
call_args = mock_emit.call_args
|
||||
assert call_args[1]["event_name"] == "dify.feedback.created"
|
||||
assert call_args[1]["attributes"]["dify.message.id"] == "msg-001"
|
||||
assert call_args[1]["attributes"]["dify.feedback.content"] == "Great!"
|
||||
assert call_args[1]["tenant_id"] == "tenant-123"
|
||||
assert call_args[1]["user_id"] == "user-456"
|
||||
|
||||
mock_exporter.increment_counter.assert_called_once()
|
||||
counter_args = mock_exporter.increment_counter.call_args
|
||||
assert counter_args[0][2]["app_id"] == "app-789"
|
||||
assert counter_args[0][2]["rating"] == "like"
|
||||
|
||||
|
||||
def test_on_feedback_created_without_content(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.FEEDBACK_CREATED,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={
|
||||
"message_id": "msg-001",
|
||||
"app_id": "app-789",
|
||||
"conversation_id": "conv-123",
|
||||
"from_end_user_id": "user-456",
|
||||
"from_account_id": None,
|
||||
"rating": "like",
|
||||
"from_source": "api",
|
||||
"content": "Great!",
|
||||
},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with (
|
||||
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
|
||||
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
|
||||
):
|
||||
mock_exporter = MagicMock()
|
||||
mock_exporter.include_content = False
|
||||
mock_get_exporter.return_value = mock_exporter
|
||||
|
||||
handler._on_feedback_created(envelope)
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
call_args = mock_emit.call_args
|
||||
assert "dify.feedback.content" not in call_args[1]["attributes"]
|
||||
@@ -1,69 +0,0 @@
|
||||
"""Unit tests for enterprise telemetry Celery task."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_envelope_json():
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-123",
|
||||
payload={"app_id": "app-123"},
|
||||
)
|
||||
return envelope.model_dump_json()
|
||||
|
||||
|
||||
def test_process_enterprise_telemetry_success(sample_envelope_json):
|
||||
with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class:
|
||||
mock_handler = MagicMock()
|
||||
mock_handler_class.return_value = mock_handler
|
||||
|
||||
process_enterprise_telemetry(sample_envelope_json)
|
||||
|
||||
mock_handler.handle.assert_called_once()
|
||||
call_args = mock_handler.handle.call_args[0][0]
|
||||
assert isinstance(call_args, TelemetryEnvelope)
|
||||
assert call_args.case == TelemetryCase.APP_CREATED
|
||||
assert call_args.tenant_id == "test-tenant"
|
||||
assert call_args.event_id == "test-event-123"
|
||||
|
||||
|
||||
def test_process_enterprise_telemetry_invalid_json(caplog):
|
||||
invalid_json = "not valid json"
|
||||
|
||||
process_enterprise_telemetry(invalid_json)
|
||||
|
||||
assert "Failed to process enterprise telemetry envelope" in caplog.text
|
||||
|
||||
|
||||
def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog):
|
||||
with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class:
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.handle.side_effect = Exception("Handler error")
|
||||
mock_handler_class.return_value = mock_handler
|
||||
|
||||
process_enterprise_telemetry(sample_envelope_json)
|
||||
|
||||
assert "Failed to process enterprise telemetry envelope" in caplog.text
|
||||
|
||||
|
||||
def test_process_enterprise_telemetry_validation_error(caplog):
|
||||
invalid_envelope = json.dumps(
|
||||
{
|
||||
"case": "INVALID_CASE",
|
||||
"tenant_id": "test-tenant",
|
||||
"event_id": "test-event",
|
||||
"payload": {},
|
||||
}
|
||||
)
|
||||
|
||||
process_enterprise_telemetry(invalid_envelope)
|
||||
|
||||
assert "Failed to process enterprise telemetry envelope" in caplog.text
|
||||
@@ -154,7 +154,7 @@ export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
|
||||
</div>
|
||||
))}
|
||||
{
|
||||
showSummaryIndexSetting && (
|
||||
showSummaryIndexSetting && IS_CE_EDITION && (
|
||||
<div className="mt-3">
|
||||
<SummaryIndexSetting
|
||||
entry="create-document"
|
||||
|
||||
@@ -12,6 +12,7 @@ import Divider from '@/app/components/base/divider'
|
||||
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
|
||||
import RadioCard from '@/app/components/base/radio-card'
|
||||
import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import FileList from '../../assets/file-list-3-fill.svg'
|
||||
import Note from '../../assets/note-mod.svg'
|
||||
@@ -191,7 +192,7 @@ export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
|
||||
</div>
|
||||
))}
|
||||
{
|
||||
showSummaryIndexSetting && (
|
||||
showSummaryIndexSetting && IS_CE_EDITION && (
|
||||
<div className="mt-3">
|
||||
<SummaryIndexSetting
|
||||
entry="create-document"
|
||||
|
||||
@@ -26,6 +26,7 @@ import CustomPopover from '@/app/components/base/popover'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { DataSourceType, DocumentActionType } from '@/models/datasets'
|
||||
import {
|
||||
useDocumentArchive,
|
||||
@@ -263,10 +264,14 @@ const Operations = ({
|
||||
<span className={s.actionName}>{t('list.action.sync', { ns: 'datasetDocuments' })}</span>
|
||||
</div>
|
||||
)}
|
||||
<div className={s.actionItem} onClick={() => onOperate('summary')}>
|
||||
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
|
||||
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
|
||||
</div>
|
||||
{
|
||||
IS_CE_EDITION && (
|
||||
<div className={s.actionItem} onClick={() => onOperate('summary')}>
|
||||
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
|
||||
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<Divider className="my-1" />
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -7,6 +7,7 @@ import Button from '@/app/components/base/button'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
const i18nPrefix = 'batchAction'
|
||||
@@ -87,7 +88,7 @@ const BatchAction: FC<IBatchActionProps> = ({
|
||||
<span className="px-0.5">{t('metadata.metadata', { ns: 'dataset' })}</span>
|
||||
</Button>
|
||||
)}
|
||||
{onBatchSummary && (
|
||||
{onBatchSummary && IS_CE_EDITION && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="gap-x-0.5 px-3"
|
||||
|
||||
@@ -21,6 +21,7 @@ import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-me
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import { useDocLink } from '@/context/i18n'
|
||||
@@ -359,7 +360,7 @@ const Form = () => {
|
||||
{
|
||||
indexMethod === IndexingType.QUALIFIED
|
||||
&& [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode)
|
||||
&& (
|
||||
&& IS_CE_EDITION && (
|
||||
<>
|
||||
<Divider
|
||||
type="horizontal"
|
||||
|
||||
@@ -1,27 +1,19 @@
|
||||
'use client'
|
||||
import type { FileUpload } from '@/app/components/base/features/types'
|
||||
import type { App } from '@/types/app'
|
||||
import * as React from 'react'
|
||||
import { useMemo, useRef } from 'react'
|
||||
import { useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
|
||||
import AppInputsForm from '@/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-form'
|
||||
import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
import { useAppDetail } from '@/service/use-apps'
|
||||
import { useFileUploadConfig } from '@/service/use-common'
|
||||
import { useAppWorkflow } from '@/service/use-workflow'
|
||||
import { AppModeEnum, Resolution } from '@/types/app'
|
||||
|
||||
import { useAppInputsFormSchema } from '@/app/components/plugins/plugin-detail-panel/app-selector/hooks/use-app-inputs-form-schema'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type Props = {
|
||||
value?: {
|
||||
app_id: string
|
||||
inputs: Record<string, any>
|
||||
inputs: Record<string, unknown>
|
||||
}
|
||||
appDetail: App
|
||||
onFormChange: (value: Record<string, any>) => void
|
||||
onFormChange: (value: Record<string, unknown>) => void
|
||||
}
|
||||
|
||||
const AppInputsPanel = ({
|
||||
@@ -30,155 +22,33 @@ const AppInputsPanel = ({
|
||||
onFormChange,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const inputsRef = useRef<any>(value?.inputs || {})
|
||||
const isBasicApp = appDetail.mode !== AppModeEnum.ADVANCED_CHAT && appDetail.mode !== AppModeEnum.WORKFLOW
|
||||
const { data: fileUploadConfig } = useFileUploadConfig()
|
||||
const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id)
|
||||
const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow(isBasicApp ? '' : appDetail.id)
|
||||
const isLoading = isAppLoading || isWorkflowLoading
|
||||
const inputsRef = useRef<Record<string, unknown>>(value?.inputs || {})
|
||||
|
||||
const basicAppFileConfig = useMemo(() => {
|
||||
let fileConfig: FileUpload
|
||||
if (isBasicApp)
|
||||
fileConfig = currentApp?.model_config?.file_upload as FileUpload
|
||||
else
|
||||
fileConfig = currentWorkflow?.features?.file_upload as FileUpload
|
||||
return {
|
||||
image: {
|
||||
detail: fileConfig?.image?.detail || Resolution.high,
|
||||
enabled: !!fileConfig?.image?.enabled,
|
||||
number_limits: fileConfig?.image?.number_limits || 3,
|
||||
transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
},
|
||||
enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled),
|
||||
allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image],
|
||||
allowed_file_extensions: fileConfig?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`),
|
||||
allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods || fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3,
|
||||
}
|
||||
}, [currentApp?.model_config?.file_upload, currentWorkflow?.features?.file_upload, isBasicApp])
|
||||
const { inputFormSchema, isLoading } = useAppInputsFormSchema({ appDetail })
|
||||
|
||||
const inputFormSchema = useMemo(() => {
|
||||
if (!currentApp)
|
||||
return []
|
||||
let inputFormSchema = []
|
||||
if (isBasicApp) {
|
||||
inputFormSchema = currentApp.model_config?.user_input_form?.filter((item: any) => !item.external_data_tool).map((item: any) => {
|
||||
if (item.paragraph) {
|
||||
return {
|
||||
...item.paragraph,
|
||||
type: 'paragraph',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
if (item.number) {
|
||||
return {
|
||||
...item.number,
|
||||
type: 'number',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
if (item.checkbox) {
|
||||
return {
|
||||
...item.checkbox,
|
||||
type: 'checkbox',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
if (item.select) {
|
||||
return {
|
||||
...item.select,
|
||||
type: 'select',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
|
||||
if (item['file-list']) {
|
||||
return {
|
||||
...item['file-list'],
|
||||
type: 'file-list',
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
if (item.file) {
|
||||
return {
|
||||
...item.file,
|
||||
type: 'file',
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
if (item.json_object) {
|
||||
return {
|
||||
...item.json_object,
|
||||
type: 'json_object',
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
...item['text-input'],
|
||||
type: 'text-input',
|
||||
required: false,
|
||||
}
|
||||
}) || []
|
||||
}
|
||||
else {
|
||||
const startNode = currentWorkflow?.graph?.nodes.find(node => node.data.type === BlockEnum.Start) as any
|
||||
inputFormSchema = startNode?.data.variables.map((variable: any) => {
|
||||
if (variable.type === InputVarType.multiFiles) {
|
||||
return {
|
||||
...variable,
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
if (variable.type === InputVarType.singleFile) {
|
||||
return {
|
||||
...variable,
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
return {
|
||||
...variable,
|
||||
required: false,
|
||||
}
|
||||
}) || []
|
||||
}
|
||||
if ((currentApp.mode === AppModeEnum.COMPLETION || currentApp.mode === AppModeEnum.WORKFLOW) && basicAppFileConfig.enabled) {
|
||||
inputFormSchema.push({
|
||||
label: 'Image Upload',
|
||||
variable: '#image#',
|
||||
type: InputVarType.singleFile,
|
||||
required: false,
|
||||
...basicAppFileConfig,
|
||||
fileUploadConfig,
|
||||
})
|
||||
}
|
||||
return inputFormSchema || []
|
||||
}, [basicAppFileConfig, currentApp, currentWorkflow, fileUploadConfig, isBasicApp])
|
||||
|
||||
const handleFormChange = (value: Record<string, any>) => {
|
||||
inputsRef.current = value
|
||||
onFormChange(value)
|
||||
const handleFormChange = (newValue: Record<string, unknown>) => {
|
||||
inputsRef.current = newValue
|
||||
onFormChange(newValue)
|
||||
}
|
||||
|
||||
const hasInputs = inputFormSchema.length > 0
|
||||
|
||||
return (
|
||||
<div className={cn('flex max-h-[240px] flex-col rounded-b-2xl border-t border-divider-subtle pb-4')}>
|
||||
{isLoading && <div className="pt-3"><Loading type="app" /></div>}
|
||||
{!isLoading && (
|
||||
<div className="system-sm-semibold mb-2 mt-3 flex h-6 shrink-0 items-center px-4 text-text-secondary">{t('appSelector.params', { ns: 'app' })}</div>
|
||||
)}
|
||||
{!isLoading && !inputFormSchema.length && (
|
||||
<div className="flex h-16 flex-col items-center justify-center">
|
||||
<div className="system-sm-regular text-text-tertiary">{t('appSelector.noParams', { ns: 'app' })}</div>
|
||||
<div className="system-sm-semibold mb-2 mt-3 flex h-6 shrink-0 items-center px-4 text-text-secondary">
|
||||
{t('appSelector.params', { ns: 'app' })}
|
||||
</div>
|
||||
)}
|
||||
{!isLoading && !!inputFormSchema.length && (
|
||||
{!isLoading && !hasInputs && (
|
||||
<div className="flex h-16 flex-col items-center justify-center">
|
||||
<div className="system-sm-regular text-text-tertiary">
|
||||
{t('appSelector.noParams', { ns: 'app' })}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{!isLoading && hasInputs && (
|
||||
<div className="grow overflow-y-auto">
|
||||
<AppInputsForm
|
||||
inputs={value?.inputs || {}}
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
'use client'
|
||||
import type { FileUpload } from '@/app/components/base/features/types'
|
||||
import type { FileUploadConfigResponse } from '@/models/common'
|
||||
import type { App } from '@/types/app'
|
||||
import type { FetchWorkflowDraftResponse } from '@/types/workflow'
|
||||
import { useMemo } from 'react'
|
||||
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
|
||||
import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
import { useAppDetail } from '@/service/use-apps'
|
||||
import { useFileUploadConfig } from '@/service/use-common'
|
||||
import { useAppWorkflow } from '@/service/use-workflow'
|
||||
import { AppModeEnum, Resolution } from '@/types/app'
|
||||
|
||||
const BASIC_INPUT_TYPE_MAP: Record<string, string> = {
|
||||
'paragraph': 'paragraph',
|
||||
'number': 'number',
|
||||
'checkbox': 'checkbox',
|
||||
'select': 'select',
|
||||
'file-list': 'file-list',
|
||||
'file': 'file',
|
||||
'json_object': 'json_object',
|
||||
}
|
||||
|
||||
const FILE_INPUT_TYPES = new Set(['file-list', 'file'])
|
||||
|
||||
const WORKFLOW_FILE_VAR_TYPES = new Set([InputVarType.multiFiles, InputVarType.singleFile])
|
||||
|
||||
type InputSchemaItem = {
|
||||
label?: string
|
||||
variable?: string
|
||||
type: string
|
||||
required: boolean
|
||||
fileUploadConfig?: FileUploadConfigResponse
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
function isBasicAppMode(mode: string): boolean {
|
||||
return mode !== AppModeEnum.ADVANCED_CHAT && mode !== AppModeEnum.WORKFLOW
|
||||
}
|
||||
|
||||
function supportsImageUpload(mode: string): boolean {
|
||||
return mode === AppModeEnum.COMPLETION || mode === AppModeEnum.WORKFLOW
|
||||
}
|
||||
|
||||
function buildFileConfig(fileConfig: FileUpload | undefined) {
|
||||
return {
|
||||
image: {
|
||||
detail: fileConfig?.image?.detail || Resolution.high,
|
||||
enabled: !!fileConfig?.image?.enabled,
|
||||
number_limits: fileConfig?.image?.number_limits || 3,
|
||||
transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
},
|
||||
enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled),
|
||||
allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image],
|
||||
allowed_file_extensions: fileConfig?.allowed_file_extensions
|
||||
|| [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`),
|
||||
allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods
|
||||
|| fileConfig?.image?.transfer_methods
|
||||
|| ['local_file', 'remote_url'],
|
||||
number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3,
|
||||
}
|
||||
}
|
||||
|
||||
function mapBasicAppInputItem(
|
||||
item: Record<string, unknown>,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem | null {
|
||||
for (const [key, type] of Object.entries(BASIC_INPUT_TYPE_MAP)) {
|
||||
if (!item[key])
|
||||
continue
|
||||
|
||||
const inputData = item[key] as Record<string, unknown>
|
||||
const needsFileConfig = FILE_INPUT_TYPES.has(key)
|
||||
|
||||
return {
|
||||
...inputData,
|
||||
type,
|
||||
required: false,
|
||||
...(needsFileConfig && { fileUploadConfig }),
|
||||
}
|
||||
}
|
||||
|
||||
const textInput = item['text-input'] as Record<string, unknown> | undefined
|
||||
if (!textInput)
|
||||
return null
|
||||
|
||||
return {
|
||||
...textInput,
|
||||
type: 'text-input',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
|
||||
function mapWorkflowVariable(
|
||||
variable: Record<string, unknown>,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem {
|
||||
const needsFileConfig = WORKFLOW_FILE_VAR_TYPES.has(variable.type as InputVarType)
|
||||
|
||||
return {
|
||||
...variable,
|
||||
type: variable.type as string,
|
||||
required: false,
|
||||
...(needsFileConfig && { fileUploadConfig }),
|
||||
}
|
||||
}
|
||||
|
||||
function createImageUploadSchema(
|
||||
basicFileConfig: ReturnType<typeof buildFileConfig>,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem {
|
||||
return {
|
||||
label: 'Image Upload',
|
||||
variable: '#image#',
|
||||
type: InputVarType.singleFile,
|
||||
required: false,
|
||||
...basicFileConfig,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
function buildBasicAppSchema(
|
||||
currentApp: App,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem[] {
|
||||
const userInputForm = currentApp.model_config?.user_input_form as Array<Record<string, unknown>> | undefined
|
||||
if (!userInputForm)
|
||||
return []
|
||||
|
||||
return userInputForm
|
||||
.filter((item: Record<string, unknown>) => !item.external_data_tool)
|
||||
.map((item: Record<string, unknown>) => mapBasicAppInputItem(item, fileUploadConfig))
|
||||
.filter((item): item is InputSchemaItem => item !== null)
|
||||
}
|
||||
|
||||
function buildWorkflowSchema(
|
||||
workflow: FetchWorkflowDraftResponse,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem[] {
|
||||
const startNode = workflow.graph?.nodes.find(
|
||||
node => node.data.type === BlockEnum.Start,
|
||||
) as { data: { variables: Array<Record<string, unknown>> } } | undefined
|
||||
|
||||
if (!startNode?.data.variables)
|
||||
return []
|
||||
|
||||
return startNode.data.variables.map(
|
||||
variable => mapWorkflowVariable(variable, fileUploadConfig),
|
||||
)
|
||||
}
|
||||
|
||||
type UseAppInputsFormSchemaParams = {
|
||||
appDetail: App
|
||||
}
|
||||
|
||||
type UseAppInputsFormSchemaResult = {
|
||||
inputFormSchema: InputSchemaItem[]
|
||||
isLoading: boolean
|
||||
fileUploadConfig?: FileUploadConfigResponse
|
||||
}
|
||||
|
||||
export function useAppInputsFormSchema({
|
||||
appDetail,
|
||||
}: UseAppInputsFormSchemaParams): UseAppInputsFormSchemaResult {
|
||||
const isBasicApp = isBasicAppMode(appDetail.mode)
|
||||
|
||||
const { data: fileUploadConfig } = useFileUploadConfig()
|
||||
const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id)
|
||||
const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow(
|
||||
isBasicApp ? '' : appDetail.id,
|
||||
)
|
||||
|
||||
const isLoading = isAppLoading || isWorkflowLoading
|
||||
|
||||
const inputFormSchema = useMemo(() => {
|
||||
if (!currentApp)
|
||||
return []
|
||||
|
||||
if (!isBasicApp && !currentWorkflow)
|
||||
return []
|
||||
|
||||
// Build base schema based on app type
|
||||
// Note: currentWorkflow is guaranteed to be defined here due to the early return above
|
||||
const baseSchema = isBasicApp
|
||||
? buildBasicAppSchema(currentApp, fileUploadConfig)
|
||||
: buildWorkflowSchema(currentWorkflow!, fileUploadConfig)
|
||||
|
||||
if (!supportsImageUpload(currentApp.mode))
|
||||
return baseSchema
|
||||
|
||||
const rawFileConfig = isBasicApp
|
||||
? currentApp.model_config?.file_upload as FileUpload
|
||||
: currentWorkflow?.features?.file_upload as FileUpload
|
||||
|
||||
const basicFileConfig = buildFileConfig(rawFileConfig)
|
||||
|
||||
if (!basicFileConfig.enabled)
|
||||
return baseSchema
|
||||
|
||||
return [
|
||||
...baseSchema,
|
||||
createImageUploadSchema(basicFileConfig, fileUploadConfig),
|
||||
]
|
||||
}, [currentApp, currentWorkflow, fileUploadConfig, isBasicApp])
|
||||
|
||||
return {
|
||||
inputFormSchema,
|
||||
isLoading,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import Toast from '@/app/components/base/toast'
|
||||
import { PluginSource } from '../types'
|
||||
import DetailHeader from './detail-header'
|
||||
|
||||
// Use vi.hoisted for mock functions used in vi.mock factories
|
||||
const {
|
||||
mockSetShowUpdatePluginModal,
|
||||
mockRefreshModelProviders,
|
||||
|
||||
@@ -1,416 +1,2 @@
|
||||
import type { PluginDetail } from '../types'
|
||||
import {
|
||||
RiArrowLeftRightLine,
|
||||
RiBugLine,
|
||||
RiCloseLine,
|
||||
RiHardDrive3Line,
|
||||
} from '@remixicon/react'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { Github } from '@/app/components/base/icons/src/public/common'
|
||||
import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { AuthCategory, PluginAuth } from '@/app/components/plugins/plugin-auth'
|
||||
import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown'
|
||||
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
|
||||
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
|
||||
import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker'
|
||||
import { API_PREFIX } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useGetLanguage, useLocale } from '@/context/i18n'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { uninstallPlugin } from '@/service/plugins'
|
||||
import { useAllToolProviders, useInvalidateAllToolProviders } from '@/service/use-tools'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
import { AutoUpdateLine } from '../../base/icons/src/vender/system'
|
||||
import Verified from '../base/badges/verified'
|
||||
import DeprecationNotice from '../base/deprecation-notice'
|
||||
import Icon from '../card/base/card-icon'
|
||||
import Description from '../card/base/description'
|
||||
import OrgInfo from '../card/base/org-info'
|
||||
import Title from '../card/base/title'
|
||||
import { useGitHubReleases } from '../install-plugin/hooks'
|
||||
import useReferenceSetting from '../plugin-page/use-reference-setting'
|
||||
import { AUTO_UPDATE_MODE } from '../reference-setting-modal/auto-update-setting/types'
|
||||
import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../reference-setting-modal/auto-update-setting/utils'
|
||||
import { PluginCategoryEnum, PluginSource } from '../types'
|
||||
|
||||
const i18nPrefix = 'action'
|
||||
|
||||
type Props = {
|
||||
detail: PluginDetail
|
||||
isReadmeView?: boolean
|
||||
onHide?: () => void
|
||||
onUpdate?: (isDelete?: boolean) => void
|
||||
}
|
||||
|
||||
const DetailHeader = ({
|
||||
detail,
|
||||
isReadmeView = false,
|
||||
onHide,
|
||||
onUpdate,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { userProfile: { timezone } } = useAppContext()
|
||||
|
||||
const { theme } = useTheme()
|
||||
const locale = useGetLanguage()
|
||||
const currentLocale = useLocale()
|
||||
const { checkForUpdates, fetchReleases } = useGitHubReleases()
|
||||
const { setShowUpdatePluginModal } = useModalContext()
|
||||
const { refreshModelProviders } = useProviderContext()
|
||||
const invalidateAllToolProviders = useInvalidateAllToolProviders()
|
||||
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
|
||||
|
||||
const {
|
||||
id,
|
||||
source,
|
||||
tenant_id,
|
||||
version,
|
||||
latest_unique_identifier,
|
||||
latest_version,
|
||||
meta,
|
||||
plugin_id,
|
||||
status,
|
||||
deprecated_reason,
|
||||
alternative_plugin_id,
|
||||
} = detail
|
||||
|
||||
const { author, category, name, label, description, icon, icon_dark, verified, tool } = detail.declaration || detail
|
||||
const isTool = category === PluginCategoryEnum.tool
|
||||
const providerBriefInfo = tool?.identity
|
||||
const providerKey = `${plugin_id}/${providerBriefInfo?.name}`
|
||||
const { data: collectionList = [] } = useAllToolProviders(isTool)
|
||||
const provider = useMemo(() => {
|
||||
return collectionList.find(collection => collection.name === providerKey)
|
||||
}, [collectionList, providerKey])
|
||||
const isFromGitHub = source === PluginSource.github
|
||||
const isFromMarketplace = source === PluginSource.marketplace
|
||||
|
||||
const [isShow, setIsShow] = useState(false)
|
||||
const [targetVersion, setTargetVersion] = useState({
|
||||
version: latest_version,
|
||||
unique_identifier: latest_unique_identifier,
|
||||
})
|
||||
const hasNewVersion = useMemo(() => {
|
||||
if (isFromMarketplace)
|
||||
return !!latest_version && latest_version !== version
|
||||
|
||||
return false
|
||||
}, [isFromMarketplace, latest_version, version])
|
||||
|
||||
const iconFileName = theme === 'dark' && icon_dark ? icon_dark : icon
|
||||
const iconSrc = iconFileName
|
||||
? (iconFileName.startsWith('http') ? iconFileName : `${API_PREFIX}/workspaces/current/plugin/icon?tenant_id=${tenant_id}&filename=${iconFileName}`)
|
||||
: ''
|
||||
|
||||
const detailUrl = useMemo(() => {
|
||||
if (isFromGitHub)
|
||||
return `https://github.com/${meta!.repo}`
|
||||
if (isFromMarketplace)
|
||||
return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: currentLocale, theme })
|
||||
return ''
|
||||
}, [author, isFromGitHub, isFromMarketplace, meta, name, theme])
|
||||
|
||||
const [isShowUpdateModal, {
|
||||
setTrue: showUpdateModal,
|
||||
setFalse: hideUpdateModal,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const { referenceSetting } = useReferenceSetting()
|
||||
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
|
||||
const isAutoUpgradeEnabled = useMemo(() => {
|
||||
if (!enable_marketplace)
|
||||
return false
|
||||
if (!autoUpgradeInfo || !isFromMarketplace)
|
||||
return false
|
||||
if (autoUpgradeInfo.strategy_setting === 'disabled')
|
||||
return false
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all)
|
||||
return true
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id))
|
||||
return true
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id))
|
||||
return true
|
||||
return false
|
||||
}, [autoUpgradeInfo, plugin_id, isFromMarketplace])
|
||||
|
||||
const [isDowngrade, setIsDowngrade] = useState(false)
|
||||
const handleUpdate = async (isDowngrade?: boolean) => {
|
||||
if (isFromMarketplace) {
|
||||
setIsDowngrade(!!isDowngrade)
|
||||
showUpdateModal()
|
||||
return
|
||||
}
|
||||
|
||||
const owner = meta!.repo.split('/')[0] || author
|
||||
const repo = meta!.repo.split('/')[1] || name
|
||||
const fetchedReleases = await fetchReleases(owner, repo)
|
||||
if (fetchedReleases.length === 0)
|
||||
return
|
||||
const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta!.version)
|
||||
Toast.notify(toastProps)
|
||||
if (needUpdate) {
|
||||
setShowUpdatePluginModal({
|
||||
onSaveCallback: () => {
|
||||
onUpdate?.()
|
||||
},
|
||||
payload: {
|
||||
type: PluginSource.github,
|
||||
category: detail.declaration.category,
|
||||
github: {
|
||||
originalPackageInfo: {
|
||||
id: detail.plugin_unique_identifier,
|
||||
repo: meta!.repo,
|
||||
version: meta!.version,
|
||||
package: meta!.package,
|
||||
releases: fetchedReleases,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const handleUpdatedFromMarketplace = () => {
|
||||
onUpdate?.()
|
||||
hideUpdateModal()
|
||||
}
|
||||
|
||||
const [isShowPluginInfo, {
|
||||
setTrue: showPluginInfo,
|
||||
setFalse: hidePluginInfo,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const [isShowDeleteConfirm, {
|
||||
setTrue: showDeleteConfirm,
|
||||
setFalse: hideDeleteConfirm,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const [deleting, {
|
||||
setTrue: showDeleting,
|
||||
setFalse: hideDeleting,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const handleDelete = useCallback(async () => {
|
||||
showDeleting()
|
||||
const res = await uninstallPlugin(id)
|
||||
hideDeleting()
|
||||
if (res.success) {
|
||||
hideDeleteConfirm()
|
||||
onUpdate?.(true)
|
||||
if (PluginCategoryEnum.model.includes(category))
|
||||
refreshModelProviders()
|
||||
if (PluginCategoryEnum.tool.includes(category))
|
||||
invalidateAllToolProviders()
|
||||
trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name })
|
||||
}
|
||||
}, [showDeleting, id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders, plugin_id, name])
|
||||
|
||||
return (
|
||||
<div className={cn('shrink-0 border-b border-divider-subtle bg-components-panel-bg p-4 pb-3', isReadmeView && 'border-b-0 bg-transparent p-0')}>
|
||||
<div className="flex">
|
||||
<div className={cn('overflow-hidden rounded-xl border border-components-panel-border-subtle', isReadmeView && 'bg-components-panel-bg')}>
|
||||
<Icon src={iconSrc} />
|
||||
</div>
|
||||
<div className="ml-3 w-0 grow">
|
||||
<div className="flex h-5 items-center">
|
||||
<Title title={label[locale]} />
|
||||
{verified && !isReadmeView && <Verified className="ml-0.5 h-4 w-4" text={t('marketplace.verifiedTip', { ns: 'plugin' })} />}
|
||||
{!!version && (
|
||||
<PluginVersionPicker
|
||||
disabled={!isFromMarketplace || isReadmeView}
|
||||
isShow={isShow}
|
||||
onShowChange={setIsShow}
|
||||
pluginID={plugin_id}
|
||||
currentVersion={version}
|
||||
onSelect={(state) => {
|
||||
setTargetVersion(state)
|
||||
handleUpdate(state.isDowngrade)
|
||||
}}
|
||||
trigger={(
|
||||
<Badge
|
||||
className={cn(
|
||||
'mx-1',
|
||||
isShow && 'bg-state-base-hover',
|
||||
(isShow || isFromMarketplace) && 'hover:bg-state-base-hover',
|
||||
)}
|
||||
uppercase={false}
|
||||
text={(
|
||||
<>
|
||||
<div>{isFromGitHub ? meta!.version : version}</div>
|
||||
{isFromMarketplace && !isReadmeView && <RiArrowLeftRightLine className="ml-1 h-3 w-3 text-text-tertiary" />}
|
||||
</>
|
||||
)}
|
||||
hasRedCornerMark={hasNewVersion}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
{/* Auto update info */}
|
||||
{isAutoUpgradeEnabled && !isReadmeView && (
|
||||
<Tooltip popupContent={t('autoUpdate.nextUpdateTime', { ns: 'plugin', time: timeOfDayToDayjs(convertUTCDaySecondsToLocalSeconds(autoUpgradeInfo?.upgrade_time_of_day || 0, timezone!)).format('hh:mm A') })}>
|
||||
{/* add a a div to fix tooltip hover not show problem */}
|
||||
<div>
|
||||
<Badge className="mr-1 cursor-pointer px-1">
|
||||
<AutoUpdateLine className="size-3" />
|
||||
</Badge>
|
||||
</div>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{(hasNewVersion || isFromGitHub) && (
|
||||
<Button
|
||||
variant="secondary-accent"
|
||||
size="small"
|
||||
className="!h-5"
|
||||
onClick={() => {
|
||||
if (isFromMarketplace) {
|
||||
setTargetVersion({
|
||||
version: latest_version,
|
||||
unique_identifier: latest_unique_identifier,
|
||||
})
|
||||
}
|
||||
handleUpdate()
|
||||
}}
|
||||
>
|
||||
{t('detailPanel.operation.update', { ns: 'plugin' })}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<div className="mb-1 flex h-4 items-center justify-between">
|
||||
<div className="mt-0.5 flex items-center">
|
||||
<OrgInfo
|
||||
packageNameClassName="w-auto"
|
||||
orgName={author}
|
||||
packageName={name?.includes('/') ? (name.split('/').pop() || '') : name}
|
||||
/>
|
||||
{!!source && (
|
||||
<>
|
||||
<div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">·</div>
|
||||
{source === PluginSource.marketplace && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.marketplace', { ns: 'plugin' })}>
|
||||
<div><BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
{source === PluginSource.github && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.github', { ns: 'plugin' })}>
|
||||
<div><Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
{source === PluginSource.local && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.local', { ns: 'plugin' })}>
|
||||
<div><RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
{source === PluginSource.debugging && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.debugging', { ns: 'plugin' })}>
|
||||
<div><RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{!isReadmeView && (
|
||||
<div className="flex gap-1">
|
||||
<OperationDropdown
|
||||
source={source}
|
||||
onInfo={showPluginInfo}
|
||||
onCheckVersion={handleUpdate}
|
||||
onRemove={showDeleteConfirm}
|
||||
detailUrl={detailUrl}
|
||||
/>
|
||||
<ActionButton onClick={onHide}>
|
||||
<RiCloseLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{isFromMarketplace && (
|
||||
<DeprecationNotice
|
||||
status={status}
|
||||
deprecatedReason={deprecated_reason}
|
||||
alternativePluginId={alternative_plugin_id}
|
||||
alternativePluginURL={getMarketplaceUrl(`/plugins/${alternative_plugin_id}`, { language: currentLocale, theme })}
|
||||
className="mt-3"
|
||||
/>
|
||||
)}
|
||||
{!isReadmeView && <Description className="mb-2 mt-3 h-auto" text={description[locale]} descriptionLineRows={2}></Description>}
|
||||
{
|
||||
category === PluginCategoryEnum.tool && !isReadmeView && (
|
||||
<PluginAuth
|
||||
pluginPayload={{
|
||||
provider: provider?.name || '',
|
||||
category: AuthCategory.tool,
|
||||
providerType: provider?.type || '',
|
||||
detail,
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{isShowPluginInfo && (
|
||||
<PluginInfo
|
||||
repository={isFromGitHub ? meta?.repo : ''}
|
||||
release={version}
|
||||
packageName={meta?.package || ''}
|
||||
onHide={hidePluginInfo}
|
||||
/>
|
||||
)}
|
||||
{isShowDeleteConfirm && (
|
||||
<Confirm
|
||||
isShow
|
||||
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
|
||||
content={(
|
||||
<div>
|
||||
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
|
||||
<span className="system-md-semibold">{label[locale]}</span>
|
||||
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
|
||||
<br />
|
||||
</div>
|
||||
)}
|
||||
onCancel={hideDeleteConfirm}
|
||||
onConfirm={handleDelete}
|
||||
isLoading={deleting}
|
||||
isDisabled={deleting}
|
||||
/>
|
||||
)}
|
||||
{
|
||||
isShowUpdateModal && (
|
||||
<UpdateFromMarketplace
|
||||
pluginId={plugin_id}
|
||||
payload={{
|
||||
category: detail.declaration.category,
|
||||
originalPackageInfo: {
|
||||
id: detail.plugin_unique_identifier,
|
||||
payload: detail.declaration,
|
||||
},
|
||||
targetPackageInfo: {
|
||||
id: targetVersion.unique_identifier,
|
||||
version: targetVersion.version,
|
||||
},
|
||||
}}
|
||||
onCancel={hideUpdateModal}
|
||||
onSave={handleUpdatedFromMarketplace}
|
||||
isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default DetailHeader
|
||||
// Re-export from refactored module for backward compatibility
|
||||
export { default } from './detail-header/index'
|
||||
|
||||
@@ -0,0 +1,539 @@
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from '../hooks'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PluginSource } from '../../../types'
|
||||
import HeaderModals from './header-modals'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: () => 'en_US',
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/confirm', () => ({
|
||||
default: ({ isShow, title, onCancel, onConfirm, isLoading }: {
|
||||
isShow: boolean
|
||||
title: string
|
||||
onCancel: () => void
|
||||
onConfirm: () => void
|
||||
isLoading: boolean
|
||||
}) => isShow
|
||||
? (
|
||||
<div data-testid="delete-confirm">
|
||||
<div data-testid="delete-title">{title}</div>
|
||||
<button data-testid="confirm-cancel" onClick={onCancel}>Cancel</button>
|
||||
<button data-testid="confirm-ok" onClick={onConfirm} disabled={isLoading}>Confirm</button>
|
||||
</div>
|
||||
)
|
||||
: null,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/plugin-page/plugin-info', () => ({
|
||||
default: ({ repository, release, packageName, onHide }: {
|
||||
repository: string
|
||||
release: string
|
||||
packageName: string
|
||||
onHide: () => void
|
||||
}) => (
|
||||
<div data-testid="plugin-info">
|
||||
<div data-testid="plugin-info-repo">{repository}</div>
|
||||
<div data-testid="plugin-info-release">{release}</div>
|
||||
<div data-testid="plugin-info-package">{packageName}</div>
|
||||
<button data-testid="plugin-info-close" onClick={onHide}>Close</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/update-plugin/from-market-place', () => ({
|
||||
default: ({ pluginId, onSave, onCancel, isShowDowngradeWarningModal }: {
|
||||
pluginId: string
|
||||
onSave: () => void
|
||||
onCancel: () => void
|
||||
isShowDowngradeWarningModal: boolean
|
||||
}) => (
|
||||
<div data-testid="update-modal">
|
||||
<div data-testid="update-plugin-id">{pluginId}</div>
|
||||
<div data-testid="update-downgrade-warning">{String(isShowDowngradeWarningModal)}</div>
|
||||
<button data-testid="update-modal-save" onClick={onSave}>Save</button>
|
||||
<button data-testid="update-modal-cancel" onClick={onCancel}>Cancel</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
|
||||
id: 'test-id',
|
||||
created_at: '2024-01-01',
|
||||
updated_at: '2024-01-02',
|
||||
name: 'Test Plugin',
|
||||
plugin_id: 'test-plugin',
|
||||
plugin_unique_identifier: 'test-uid',
|
||||
declaration: {
|
||||
author: 'test-author',
|
||||
name: 'test-plugin-name',
|
||||
category: 'tool',
|
||||
label: { en_US: 'Test Plugin Label' },
|
||||
description: { en_US: 'Test description' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
installation_id: 'install-1',
|
||||
tenant_id: 'tenant-1',
|
||||
endpoints_setups: 0,
|
||||
endpoints_active: 0,
|
||||
version: '1.0.0',
|
||||
latest_version: '2.0.0',
|
||||
latest_unique_identifier: 'new-uid',
|
||||
source: PluginSource.marketplace,
|
||||
meta: undefined,
|
||||
status: 'active',
|
||||
deprecated_reason: '',
|
||||
alternative_plugin_id: '',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createModalStatesMock = (overrides: Partial<ModalStates> = {}): ModalStates => ({
|
||||
isShowUpdateModal: false,
|
||||
showUpdateModal: vi.fn<() => void>(),
|
||||
hideUpdateModal: vi.fn<() => void>(),
|
||||
isShowPluginInfo: false,
|
||||
showPluginInfo: vi.fn<() => void>(),
|
||||
hidePluginInfo: vi.fn<() => void>(),
|
||||
isShowDeleteConfirm: false,
|
||||
showDeleteConfirm: vi.fn<() => void>(),
|
||||
hideDeleteConfirm: vi.fn<() => void>(),
|
||||
deleting: false,
|
||||
showDeleting: vi.fn<() => void>(),
|
||||
hideDeleting: vi.fn<() => void>(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createTargetVersion = (overrides: Partial<VersionTarget> = {}): VersionTarget => ({
|
||||
version: '2.0.0',
|
||||
unique_identifier: 'new-uid',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('HeaderModals', () => {
|
||||
let mockOnUpdatedFromMarketplace: () => void
|
||||
let mockOnDelete: () => void
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockOnUpdatedFromMarketplace = vi.fn<() => void>()
|
||||
mockOnDelete = vi.fn<() => void>()
|
||||
})
|
||||
|
||||
describe('Plugin Info Modal', () => {
|
||||
it('should not render plugin info modal when isShowPluginInfo is false', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: false })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('plugin-info')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render plugin info modal when isShowPluginInfo is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass GitHub repo to plugin info for GitHub source', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'test-pkg' },
|
||||
})
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={detail}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('owner/repo')
|
||||
})
|
||||
|
||||
it('should pass empty string for repo for non-GitHub source', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail({ source: PluginSource.marketplace })}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('')
|
||||
})
|
||||
|
||||
it('should call hidePluginInfo when close button is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('plugin-info-close'))
|
||||
|
||||
expect(modalStates.hidePluginInfo).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Delete Confirm Modal', () => {
|
||||
it('should not render delete confirm when isShowDeleteConfirm is false', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: false })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('delete-confirm')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render delete confirm when isShowDeleteConfirm is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show correct delete title', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('delete-title')).toHaveTextContent('action.delete')
|
||||
})
|
||||
|
||||
it('should call hideDeleteConfirm when cancel is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-cancel'))
|
||||
|
||||
expect(modalStates.hideDeleteConfirm).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onDelete when confirm is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-ok'))
|
||||
|
||||
expect(mockOnDelete).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should disable confirm button when deleting', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true, deleting: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('confirm-ok')).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Update Modal', () => {
|
||||
it('should not render update modal when isShowUpdateModal is false', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: false })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('update-modal')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render update modal when isShowUpdateModal is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass plugin id to update modal', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail({ plugin_id: 'my-plugin-id' })}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-plugin-id')).toHaveTextContent('my-plugin-id')
|
||||
})
|
||||
|
||||
it('should call onUpdatedFromMarketplace when save is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('update-modal-save'))
|
||||
|
||||
expect(mockOnUpdatedFromMarketplace).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call hideUpdateModal when cancel is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('update-modal-cancel'))
|
||||
|
||||
expect(modalStates.hideUpdateModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show downgrade warning when isDowngrade and isAutoUpgradeEnabled are true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={true}
|
||||
isAutoUpgradeEnabled={true}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('true')
|
||||
})
|
||||
|
||||
it('should not show downgrade warning when only isDowngrade is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={true}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false')
|
||||
})
|
||||
|
||||
it('should not show downgrade warning when only isAutoUpgradeEnabled is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={true}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multiple Modals', () => {
|
||||
it('should render multiple modals when multiple are open', () => {
|
||||
const modalStates = createModalStatesMock({
|
||||
isShowPluginInfo: true,
|
||||
isShowDeleteConfirm: true,
|
||||
isShowUpdateModal: true,
|
||||
})
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle undefined target version values', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={{ version: undefined, unique_identifier: undefined }}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty meta for GitHub source', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: undefined,
|
||||
})
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={detail}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('')
|
||||
expect(screen.getByTestId('plugin-info-package')).toHaveTextContent('')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,107 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from '../hooks'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
|
||||
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import { PluginSource } from '../../../types'
|
||||
|
||||
const i18nPrefix = 'action'
|
||||
|
||||
type HeaderModalsProps = {
|
||||
detail: PluginDetail
|
||||
modalStates: ModalStates
|
||||
targetVersion: VersionTarget
|
||||
isDowngrade: boolean
|
||||
isAutoUpgradeEnabled: boolean
|
||||
onUpdatedFromMarketplace: () => void
|
||||
onDelete: () => void
|
||||
}
|
||||
|
||||
const HeaderModals: FC<HeaderModalsProps> = ({
|
||||
detail,
|
||||
modalStates,
|
||||
targetVersion,
|
||||
isDowngrade,
|
||||
isAutoUpgradeEnabled,
|
||||
onUpdatedFromMarketplace,
|
||||
onDelete,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const locale = useGetLanguage()
|
||||
|
||||
const { source, version, meta } = detail
|
||||
const { label } = detail.declaration || detail
|
||||
const isFromGitHub = source === PluginSource.github
|
||||
|
||||
const {
|
||||
isShowUpdateModal,
|
||||
hideUpdateModal,
|
||||
isShowPluginInfo,
|
||||
hidePluginInfo,
|
||||
isShowDeleteConfirm,
|
||||
hideDeleteConfirm,
|
||||
deleting,
|
||||
} = modalStates
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Plugin Info Modal */}
|
||||
{isShowPluginInfo && (
|
||||
<PluginInfo
|
||||
repository={isFromGitHub ? meta?.repo : ''}
|
||||
release={version}
|
||||
packageName={meta?.package || ''}
|
||||
onHide={hidePluginInfo}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Delete Confirm Modal */}
|
||||
{isShowDeleteConfirm && (
|
||||
<Confirm
|
||||
isShow
|
||||
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
|
||||
content={(
|
||||
<div>
|
||||
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
|
||||
<span className="system-md-semibold">{label[locale]}</span>
|
||||
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
|
||||
<br />
|
||||
</div>
|
||||
)}
|
||||
onCancel={hideDeleteConfirm}
|
||||
onConfirm={onDelete}
|
||||
isLoading={deleting}
|
||||
isDisabled={deleting}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Update from Marketplace Modal */}
|
||||
{isShowUpdateModal && (
|
||||
<UpdateFromMarketplace
|
||||
pluginId={detail.plugin_id}
|
||||
payload={{
|
||||
category: detail.declaration?.category ?? '',
|
||||
originalPackageInfo: {
|
||||
id: detail.plugin_unique_identifier,
|
||||
payload: detail.declaration ?? undefined,
|
||||
},
|
||||
targetPackageInfo: {
|
||||
id: targetVersion.unique_identifier || '',
|
||||
version: targetVersion.version || '',
|
||||
},
|
||||
}}
|
||||
onCancel={hideUpdateModal}
|
||||
onSave={onUpdatedFromMarketplace}
|
||||
isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default HeaderModals
|
||||
@@ -0,0 +1,2 @@
|
||||
export { default as HeaderModals } from './header-modals'
|
||||
export { default as PluginSourceBadge } from './plugin-source-badge'
|
||||
@@ -0,0 +1,200 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PluginSource } from '../../../types'
|
||||
import PluginSourceBadge from './plugin-source-badge'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/tooltip', () => ({
|
||||
default: ({ children, popupContent }: { children: React.ReactNode, popupContent: string }) => (
|
||||
<div data-testid="tooltip" data-content={popupContent}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('PluginSourceBadge', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Source Icon Rendering', () => {
|
||||
it('should render marketplace source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.marketplace')
|
||||
})
|
||||
|
||||
it('should render github source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.github')
|
||||
})
|
||||
|
||||
it('should render local source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.local')
|
||||
})
|
||||
|
||||
it('should render debugging source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.debugging')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Separator Rendering', () => {
|
||||
it('should render separator dot before marketplace badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
expect(separator?.textContent).toBe('·')
|
||||
})
|
||||
|
||||
it('should render separator dot before github badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
expect(separator?.textContent).toBe('·')
|
||||
})
|
||||
|
||||
it('should render separator dot before local badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render separator dot before debugging badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Tooltip Content', () => {
|
||||
it('should show marketplace tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.marketplace',
|
||||
)
|
||||
})
|
||||
|
||||
it('should show github tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.github',
|
||||
)
|
||||
})
|
||||
|
||||
it('should show local tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.local',
|
||||
)
|
||||
})
|
||||
|
||||
it('should show debugging tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.debugging',
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Icon Element Structure', () => {
|
||||
it('should render icon inside tooltip for marketplace', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render icon inside tooltip for github', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render icon inside tooltip for local', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render icon inside tooltip for debugging', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Lookup Table Coverage', () => {
|
||||
it('should handle all PluginSource enum values', () => {
|
||||
const allSources = Object.values(PluginSource)
|
||||
|
||||
allSources.forEach((source) => {
|
||||
const { container } = render(<PluginSourceBadge source={source} />)
|
||||
// Should render either tooltip or nothing
|
||||
expect(container).toBeTruthy()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Invalid Source Handling', () => {
|
||||
it('should return null for unknown source type', () => {
|
||||
// Use type assertion to test invalid source value
|
||||
const invalidSource = 'unknown_source' as PluginSource
|
||||
const { container } = render(<PluginSourceBadge source={invalidSource} />)
|
||||
|
||||
// Should render nothing (empty container)
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
|
||||
it('should not render separator for invalid source', () => {
|
||||
const invalidSource = 'invalid' as PluginSource
|
||||
const { container } = render(<PluginSourceBadge source={invalidSource} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render tooltip for invalid source', () => {
|
||||
const invalidSource = '' as PluginSource
|
||||
render(<PluginSourceBadge source={invalidSource} />)
|
||||
|
||||
expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,59 @@
|
||||
'use client'
|
||||
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import {
|
||||
RiBugLine,
|
||||
RiHardDrive3Line,
|
||||
} from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Github } from '@/app/components/base/icons/src/public/common'
|
||||
import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { PluginSource } from '../../../types'
|
||||
|
||||
type SourceConfig = {
|
||||
icon: ReactNode
|
||||
tipKey: string
|
||||
}
|
||||
|
||||
type PluginSourceBadgeProps = {
|
||||
source: PluginSource
|
||||
}
|
||||
|
||||
const SOURCE_CONFIG_MAP: Record<PluginSource, SourceConfig | null> = {
|
||||
[PluginSource.marketplace]: {
|
||||
icon: <BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" />,
|
||||
tipKey: 'detailPanel.categoryTip.marketplace',
|
||||
},
|
||||
[PluginSource.github]: {
|
||||
icon: <Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" />,
|
||||
tipKey: 'detailPanel.categoryTip.github',
|
||||
},
|
||||
[PluginSource.local]: {
|
||||
icon: <RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" />,
|
||||
tipKey: 'detailPanel.categoryTip.local',
|
||||
},
|
||||
[PluginSource.debugging]: {
|
||||
icon: <RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" />,
|
||||
tipKey: 'detailPanel.categoryTip.debugging',
|
||||
},
|
||||
}
|
||||
|
||||
const PluginSourceBadge: FC<PluginSourceBadgeProps> = ({ source }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const config = SOURCE_CONFIG_MAP[source]
|
||||
if (!config)
|
||||
return null
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">·</div>
|
||||
<Tooltip popupContent={t(config.tipKey as never, { ns: 'plugin' })}>
|
||||
<div>{config.icon}</div>
|
||||
</Tooltip>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default PluginSourceBadge
|
||||
@@ -0,0 +1,3 @@
|
||||
export { useDetailHeaderState } from './use-detail-header-state'
|
||||
export type { ModalStates, UseDetailHeaderStateReturn, VersionPickerState, VersionTarget } from './use-detail-header-state'
|
||||
export { usePluginOperations } from './use-plugin-operations'
|
||||
@@ -0,0 +1,409 @@
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PluginSource } from '../../../types'
|
||||
import { useDetailHeaderState } from './use-detail-header-state'
|
||||
|
||||
let mockEnableMarketplace = true
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) =>
|
||||
selector({ systemFeatures: { enable_marketplace: mockEnableMarketplace } }),
|
||||
}))
|
||||
|
||||
let mockAutoUpgradeInfo: {
|
||||
strategy_setting: string
|
||||
upgrade_mode: string
|
||||
include_plugins: string[]
|
||||
exclude_plugins: string[]
|
||||
upgrade_time_of_day: number
|
||||
} | null = null
|
||||
|
||||
vi.mock('../../../plugin-page/use-reference-setting', () => ({
|
||||
default: () => ({
|
||||
referenceSetting: mockAutoUpgradeInfo ? { auto_upgrade: mockAutoUpgradeInfo } : null,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../../reference-setting-modal/auto-update-setting/types', () => ({
|
||||
AUTO_UPDATE_MODE: {
|
||||
update_all: 'update_all',
|
||||
partial: 'partial',
|
||||
exclude: 'exclude',
|
||||
},
|
||||
}))
|
||||
|
||||
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
|
||||
id: 'test-id',
|
||||
created_at: '2024-01-01',
|
||||
updated_at: '2024-01-02',
|
||||
name: 'Test Plugin',
|
||||
plugin_id: 'test-plugin',
|
||||
plugin_unique_identifier: 'test-uid',
|
||||
declaration: {
|
||||
author: 'test-author',
|
||||
name: 'test-plugin-name',
|
||||
category: 'tool',
|
||||
label: { en_US: 'Test Plugin Label' },
|
||||
description: { en_US: 'Test description' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
installation_id: 'install-1',
|
||||
tenant_id: 'tenant-1',
|
||||
endpoints_setups: 0,
|
||||
endpoints_active: 0,
|
||||
version: '1.0.0',
|
||||
latest_version: '1.0.0',
|
||||
latest_unique_identifier: 'test-uid',
|
||||
source: PluginSource.marketplace,
|
||||
meta: undefined,
|
||||
status: 'active',
|
||||
deprecated_reason: '',
|
||||
alternative_plugin_id: '',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('useDetailHeaderState', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockAutoUpgradeInfo = null
|
||||
mockEnableMarketplace = true
|
||||
})
|
||||
|
||||
describe('Source Type Detection', () => {
|
||||
it('should detect marketplace source', () => {
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isFromMarketplace).toBe(true)
|
||||
expect(result.current.isFromGitHub).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect GitHub source', () => {
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isFromGitHub).toBe(true)
|
||||
expect(result.current.isFromMarketplace).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect local source', () => {
|
||||
const detail = createPluginDetail({ source: PluginSource.local })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isFromGitHub).toBe(false)
|
||||
expect(result.current.isFromMarketplace).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Version State', () => {
|
||||
it('should detect new version available for marketplace plugin', () => {
|
||||
const detail = createPluginDetail({
|
||||
version: '1.0.0',
|
||||
latest_version: '2.0.0',
|
||||
source: PluginSource.marketplace,
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.hasNewVersion).toBe(true)
|
||||
})
|
||||
|
||||
it('should not detect new version when versions match', () => {
|
||||
const detail = createPluginDetail({
|
||||
version: '1.0.0',
|
||||
latest_version: '1.0.0',
|
||||
source: PluginSource.marketplace,
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.hasNewVersion).toBe(false)
|
||||
})
|
||||
|
||||
it('should not detect new version for non-marketplace source', () => {
|
||||
const detail = createPluginDetail({
|
||||
version: '1.0.0',
|
||||
latest_version: '2.0.0',
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.hasNewVersion).toBe(false)
|
||||
})
|
||||
|
||||
it('should not detect new version when latest_version is empty', () => {
|
||||
const detail = createPluginDetail({
|
||||
version: '1.0.0',
|
||||
latest_version: '',
|
||||
source: PluginSource.marketplace,
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.hasNewVersion).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Version Picker State', () => {
|
||||
it('should initialize version picker as hidden', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.versionPicker.isShow).toBe(false)
|
||||
})
|
||||
|
||||
it('should toggle version picker visibility', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.versionPicker.setIsShow(true)
|
||||
})
|
||||
expect(result.current.versionPicker.isShow).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.versionPicker.setIsShow(false)
|
||||
})
|
||||
expect(result.current.versionPicker.isShow).toBe(false)
|
||||
})
|
||||
|
||||
it('should update target version', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.versionPicker.setTargetVersion({
|
||||
version: '2.0.0',
|
||||
unique_identifier: 'new-uid',
|
||||
})
|
||||
})
|
||||
|
||||
expect(result.current.versionPicker.targetVersion.version).toBe('2.0.0')
|
||||
expect(result.current.versionPicker.targetVersion.unique_identifier).toBe('new-uid')
|
||||
})
|
||||
|
||||
it('should set isDowngrade when provided in target version', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.versionPicker.setTargetVersion({
|
||||
version: '0.5.0',
|
||||
unique_identifier: 'old-uid',
|
||||
isDowngrade: true,
|
||||
})
|
||||
})
|
||||
|
||||
expect(result.current.versionPicker.isDowngrade).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Modal States', () => {
|
||||
it('should initialize all modals as hidden', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.modalStates.isShowUpdateModal).toBe(false)
|
||||
expect(result.current.modalStates.isShowPluginInfo).toBe(false)
|
||||
expect(result.current.modalStates.isShowDeleteConfirm).toBe(false)
|
||||
expect(result.current.modalStates.deleting).toBe(false)
|
||||
})
|
||||
|
||||
it('should toggle update modal', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.showUpdateModal()
|
||||
})
|
||||
expect(result.current.modalStates.isShowUpdateModal).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.hideUpdateModal()
|
||||
})
|
||||
expect(result.current.modalStates.isShowUpdateModal).toBe(false)
|
||||
})
|
||||
|
||||
it('should toggle plugin info modal', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.showPluginInfo()
|
||||
})
|
||||
expect(result.current.modalStates.isShowPluginInfo).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.hidePluginInfo()
|
||||
})
|
||||
expect(result.current.modalStates.isShowPluginInfo).toBe(false)
|
||||
})
|
||||
|
||||
it('should toggle delete confirm modal', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.showDeleteConfirm()
|
||||
})
|
||||
expect(result.current.modalStates.isShowDeleteConfirm).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.hideDeleteConfirm()
|
||||
})
|
||||
expect(result.current.modalStates.isShowDeleteConfirm).toBe(false)
|
||||
})
|
||||
|
||||
it('should toggle deleting state', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.showDeleting()
|
||||
})
|
||||
expect(result.current.modalStates.deleting).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.hideDeleting()
|
||||
})
|
||||
expect(result.current.modalStates.deleting).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Auto Upgrade Detection', () => {
|
||||
it('should disable auto upgrade when marketplace is disabled', () => {
|
||||
mockEnableMarketplace = false
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'update_all',
|
||||
include_plugins: [],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('should disable auto upgrade when strategy is disabled', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'disabled',
|
||||
upgrade_mode: 'update_all',
|
||||
include_plugins: [],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('should enable auto upgrade for update_all mode', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'update_all',
|
||||
include_plugins: [],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('should enable auto upgrade for partial mode when plugin is included', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'partial',
|
||||
include_plugins: ['test-plugin'],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('should disable auto upgrade for partial mode when plugin is not included', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'partial',
|
||||
include_plugins: ['other-plugin'],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('should enable auto upgrade for exclude mode when plugin is not excluded', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'exclude',
|
||||
include_plugins: [],
|
||||
exclude_plugins: ['other-plugin'],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('should disable auto upgrade for exclude mode when plugin is excluded', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'exclude',
|
||||
include_plugins: [],
|
||||
exclude_plugins: ['test-plugin'],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('should disable auto upgrade for non-marketplace source', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'update_all',
|
||||
include_plugins: [],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('should disable auto upgrade when no auto upgrade info', () => {
|
||||
mockAutoUpgradeInfo = null
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,132 @@
|
||||
'use client'
|
||||
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import useReferenceSetting from '../../../plugin-page/use-reference-setting'
|
||||
import { AUTO_UPDATE_MODE } from '../../../reference-setting-modal/auto-update-setting/types'
|
||||
import { PluginSource } from '../../../types'
|
||||
|
||||
export type VersionTarget = {
|
||||
version: string | undefined
|
||||
unique_identifier: string | undefined
|
||||
isDowngrade?: boolean
|
||||
}
|
||||
|
||||
export type ModalStates = {
|
||||
isShowUpdateModal: boolean
|
||||
showUpdateModal: () => void
|
||||
hideUpdateModal: () => void
|
||||
isShowPluginInfo: boolean
|
||||
showPluginInfo: () => void
|
||||
hidePluginInfo: () => void
|
||||
isShowDeleteConfirm: boolean
|
||||
showDeleteConfirm: () => void
|
||||
hideDeleteConfirm: () => void
|
||||
deleting: boolean
|
||||
showDeleting: () => void
|
||||
hideDeleting: () => void
|
||||
}
|
||||
|
||||
export type VersionPickerState = {
|
||||
isShow: boolean
|
||||
setIsShow: (show: boolean) => void
|
||||
targetVersion: VersionTarget
|
||||
setTargetVersion: (version: VersionTarget) => void
|
||||
isDowngrade: boolean
|
||||
setIsDowngrade: (downgrade: boolean) => void
|
||||
}
|
||||
|
||||
export type UseDetailHeaderStateReturn = {
|
||||
modalStates: ModalStates
|
||||
versionPicker: VersionPickerState
|
||||
hasNewVersion: boolean
|
||||
isAutoUpgradeEnabled: boolean
|
||||
isFromGitHub: boolean
|
||||
isFromMarketplace: boolean
|
||||
}
|
||||
|
||||
export const useDetailHeaderState = (detail: PluginDetail): UseDetailHeaderStateReturn => {
|
||||
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const { referenceSetting } = useReferenceSetting()
|
||||
|
||||
const {
|
||||
source,
|
||||
version,
|
||||
latest_version,
|
||||
latest_unique_identifier,
|
||||
plugin_id,
|
||||
} = detail
|
||||
|
||||
const isFromGitHub = source === PluginSource.github
|
||||
const isFromMarketplace = source === PluginSource.marketplace
|
||||
const [isShow, setIsShow] = useState(false)
|
||||
const [targetVersion, setTargetVersion] = useState<VersionTarget>({
|
||||
version: latest_version,
|
||||
unique_identifier: latest_unique_identifier,
|
||||
})
|
||||
const [isDowngrade, setIsDowngrade] = useState(false)
|
||||
|
||||
const [isShowUpdateModal, { setTrue: showUpdateModal, setFalse: hideUpdateModal }] = useBoolean(false)
|
||||
const [isShowPluginInfo, { setTrue: showPluginInfo, setFalse: hidePluginInfo }] = useBoolean(false)
|
||||
const [isShowDeleteConfirm, { setTrue: showDeleteConfirm, setFalse: hideDeleteConfirm }] = useBoolean(false)
|
||||
const [deleting, { setTrue: showDeleting, setFalse: hideDeleting }] = useBoolean(false)
|
||||
|
||||
const hasNewVersion = useMemo(() => {
|
||||
if (isFromMarketplace)
|
||||
return !!latest_version && latest_version !== version
|
||||
return false
|
||||
}, [isFromMarketplace, latest_version, version])
|
||||
|
||||
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
|
||||
|
||||
const isAutoUpgradeEnabled = useMemo(() => {
|
||||
if (!enable_marketplace || !autoUpgradeInfo || !isFromMarketplace)
|
||||
return false
|
||||
if (autoUpgradeInfo.strategy_setting === 'disabled')
|
||||
return false
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all)
|
||||
return true
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id))
|
||||
return true
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id))
|
||||
return true
|
||||
return false
|
||||
}, [autoUpgradeInfo, plugin_id, isFromMarketplace, enable_marketplace])
|
||||
|
||||
const handleSetTargetVersion = useCallback((version: VersionTarget) => {
|
||||
setTargetVersion(version)
|
||||
if (version.isDowngrade !== undefined)
|
||||
setIsDowngrade(version.isDowngrade)
|
||||
}, [])
|
||||
|
||||
return {
|
||||
modalStates: {
|
||||
isShowUpdateModal,
|
||||
showUpdateModal,
|
||||
hideUpdateModal,
|
||||
isShowPluginInfo,
|
||||
showPluginInfo,
|
||||
hidePluginInfo,
|
||||
isShowDeleteConfirm,
|
||||
showDeleteConfirm,
|
||||
hideDeleteConfirm,
|
||||
deleting,
|
||||
showDeleting,
|
||||
hideDeleting,
|
||||
},
|
||||
versionPicker: {
|
||||
isShow,
|
||||
setIsShow,
|
||||
targetVersion,
|
||||
setTargetVersion: handleSetTargetVersion,
|
||||
isDowngrade,
|
||||
setIsDowngrade,
|
||||
},
|
||||
hasNewVersion,
|
||||
isAutoUpgradeEnabled,
|
||||
isFromGitHub,
|
||||
isFromMarketplace,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,549 @@
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from './use-detail-header-state'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import * as amplitude from '@/app/components/base/amplitude'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { PluginSource } from '../../../types'
|
||||
import { usePluginOperations } from './use-plugin-operations'
|
||||
|
||||
type VersionPickerMock = {
|
||||
setTargetVersion: (version: VersionTarget) => void
|
||||
setIsDowngrade: (downgrade: boolean) => void
|
||||
}
|
||||
|
||||
const {
|
||||
mockSetShowUpdatePluginModal,
|
||||
mockRefreshModelProviders,
|
||||
mockInvalidateAllToolProviders,
|
||||
mockUninstallPlugin,
|
||||
mockFetchReleases,
|
||||
mockCheckForUpdates,
|
||||
} = vi.hoisted(() => {
|
||||
return {
|
||||
mockSetShowUpdatePluginModal: vi.fn(),
|
||||
mockRefreshModelProviders: vi.fn(),
|
||||
mockInvalidateAllToolProviders: vi.fn(),
|
||||
mockUninstallPlugin: vi.fn(() => Promise.resolve({ success: true })),
|
||||
mockFetchReleases: vi.fn(() => Promise.resolve([{ tag_name: 'v2.0.0' }])),
|
||||
mockCheckForUpdates: vi.fn(() => ({ needUpdate: true, toastProps: { type: 'success', message: 'Update available' } })),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/modal-context', () => ({
|
||||
useModalContext: () => ({
|
||||
setShowUpdatePluginModal: mockSetShowUpdatePluginModal,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => ({
|
||||
refreshModelProviders: mockRefreshModelProviders,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/plugins', () => ({
|
||||
uninstallPlugin: mockUninstallPlugin,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-tools', () => ({
|
||||
useInvalidateAllToolProviders: () => mockInvalidateAllToolProviders,
|
||||
}))
|
||||
|
||||
vi.mock('../../../install-plugin/hooks', () => ({
|
||||
useGitHubReleases: () => ({
|
||||
checkForUpdates: mockCheckForUpdates,
|
||||
fetchReleases: mockFetchReleases,
|
||||
}),
|
||||
}))
|
||||
|
||||
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
|
||||
id: 'test-id',
|
||||
created_at: '2024-01-01',
|
||||
updated_at: '2024-01-02',
|
||||
name: 'Test Plugin',
|
||||
plugin_id: 'test-plugin',
|
||||
plugin_unique_identifier: 'test-uid',
|
||||
declaration: {
|
||||
author: 'test-author',
|
||||
name: 'test-plugin-name',
|
||||
category: 'tool',
|
||||
label: { en_US: 'Test Plugin Label' },
|
||||
description: { en_US: 'Test description' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
installation_id: 'install-1',
|
||||
tenant_id: 'tenant-1',
|
||||
endpoints_setups: 0,
|
||||
endpoints_active: 0,
|
||||
version: '1.0.0',
|
||||
latest_version: '2.0.0',
|
||||
latest_unique_identifier: 'new-uid',
|
||||
source: PluginSource.marketplace,
|
||||
meta: undefined,
|
||||
status: 'active',
|
||||
deprecated_reason: '',
|
||||
alternative_plugin_id: '',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createModalStatesMock = (): ModalStates => ({
|
||||
isShowUpdateModal: false,
|
||||
showUpdateModal: vi.fn(),
|
||||
hideUpdateModal: vi.fn(),
|
||||
isShowPluginInfo: false,
|
||||
showPluginInfo: vi.fn(),
|
||||
hidePluginInfo: vi.fn(),
|
||||
isShowDeleteConfirm: false,
|
||||
showDeleteConfirm: vi.fn(),
|
||||
hideDeleteConfirm: vi.fn(),
|
||||
deleting: false,
|
||||
showDeleting: vi.fn(),
|
||||
hideDeleting: vi.fn(),
|
||||
})
|
||||
|
||||
const createVersionPickerMock = (): VersionPickerMock => ({
|
||||
setTargetVersion: vi.fn<(version: VersionTarget) => void>(),
|
||||
setIsDowngrade: vi.fn<(downgrade: boolean) => void>(),
|
||||
})
|
||||
|
||||
describe('usePluginOperations', () => {
|
||||
let modalStates: ModalStates
|
||||
let versionPicker: VersionPickerMock
|
||||
let mockOnUpdate: (isDelete?: boolean) => void
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
modalStates = createModalStatesMock()
|
||||
versionPicker = createVersionPickerMock()
|
||||
mockOnUpdate = vi.fn()
|
||||
vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() }))
|
||||
vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {})
|
||||
})
|
||||
|
||||
describe('Marketplace Update Flow', () => {
|
||||
it('should show update modal for marketplace plugin', async () => {
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(modalStates.showUpdateModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should set isDowngrade when downgrading', async () => {
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate(true)
|
||||
})
|
||||
|
||||
expect(versionPicker.setIsDowngrade).toHaveBeenCalledWith(true)
|
||||
expect(modalStates.showUpdateModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onUpdate and hide modal on successful marketplace update', () => {
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleUpdatedFromMarketplace()
|
||||
})
|
||||
|
||||
expect(mockOnUpdate).toHaveBeenCalled()
|
||||
expect(modalStates.hideUpdateModal).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('GitHub Update Flow', () => {
|
||||
it('should fetch releases from GitHub', async () => {
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockFetchReleases).toHaveBeenCalledWith('owner', 'repo')
|
||||
})
|
||||
|
||||
it('should check for updates after fetching releases', async () => {
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockCheckForUpdates).toHaveBeenCalled()
|
||||
expect(Toast.notify).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show update plugin modal when update is needed', async () => {
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockSetShowUpdatePluginModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not show modal when no releases found', async () => {
|
||||
mockFetchReleases.mockResolvedValueOnce([])
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockSetShowUpdatePluginModal).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not show modal when no update needed', async () => {
|
||||
mockCheckForUpdates.mockReturnValueOnce({
|
||||
needUpdate: false,
|
||||
toastProps: { type: 'info', message: 'Already up to date' },
|
||||
})
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockSetShowUpdatePluginModal).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should use author and name as fallback for repo parsing', async () => {
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: '/', version: 'v1.0.0', package: 'pkg' },
|
||||
declaration: {
|
||||
author: 'fallback-author',
|
||||
name: 'fallback-name',
|
||||
category: 'tool',
|
||||
label: { en_US: 'Test' },
|
||||
description: { en_US: 'Test' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockFetchReleases).toHaveBeenCalledWith('fallback-author', 'fallback-name')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Delete Flow', () => {
|
||||
it('should call uninstallPlugin with correct id', async () => {
|
||||
const detail = createPluginDetail({ id: 'plugin-to-delete' })
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(mockUninstallPlugin).toHaveBeenCalledWith('plugin-to-delete')
|
||||
})
|
||||
|
||||
it('should show and hide deleting state during delete', async () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(modalStates.showDeleting).toHaveBeenCalled()
|
||||
expect(modalStates.hideDeleting).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onUpdate with true after successful delete', async () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(mockOnUpdate).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
it('should hide delete confirm after successful delete', async () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(modalStates.hideDeleteConfirm).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should refresh model providers when deleting model plugin', async () => {
|
||||
const detail = createPluginDetail({
|
||||
declaration: {
|
||||
author: 'test-author',
|
||||
name: 'test-plugin-name',
|
||||
category: 'model',
|
||||
label: { en_US: 'Test' },
|
||||
description: { en_US: 'Test' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(mockRefreshModelProviders).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should invalidate tool providers when deleting tool plugin', async () => {
|
||||
const detail = createPluginDetail({
|
||||
declaration: {
|
||||
author: 'test-author',
|
||||
name: 'test-plugin-name',
|
||||
category: 'tool',
|
||||
label: { en_US: 'Test' },
|
||||
description: { en_US: 'Test' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(mockInvalidateAllToolProviders).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should track plugin uninstalled event', async () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(amplitude.trackEvent).toHaveBeenCalledWith('plugin_uninstalled', expect.objectContaining({
|
||||
plugin_id: 'test-plugin',
|
||||
plugin_name: 'test-plugin-name',
|
||||
}))
|
||||
})
|
||||
|
||||
it('should not call onUpdate when delete fails', async () => {
|
||||
mockUninstallPlugin.mockResolvedValueOnce({ success: false })
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(mockOnUpdate).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Optional onUpdate Callback', () => {
|
||||
it('should not throw when onUpdate is not provided for marketplace update', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
}),
|
||||
)
|
||||
|
||||
expect(() => {
|
||||
result.current.handleUpdatedFromMarketplace()
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
||||
it('should not throw when onUpdate is not provided for delete', async () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
}),
|
||||
)
|
||||
|
||||
await expect(
|
||||
act(async () => {
|
||||
await result.current.handleDelete()
|
||||
}),
|
||||
).resolves.not.toThrow()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,143 @@
|
||||
'use client'
|
||||
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from './use-detail-header-state'
|
||||
import { useCallback } from 'react'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { uninstallPlugin } from '@/service/plugins'
|
||||
import { useInvalidateAllToolProviders } from '@/service/use-tools'
|
||||
import { useGitHubReleases } from '../../../install-plugin/hooks'
|
||||
import { PluginCategoryEnum, PluginSource } from '../../../types'
|
||||
|
||||
type UsePluginOperationsParams = {
|
||||
detail: PluginDetail
|
||||
modalStates: ModalStates
|
||||
versionPicker: {
|
||||
setTargetVersion: (version: VersionTarget) => void
|
||||
setIsDowngrade: (downgrade: boolean) => void
|
||||
}
|
||||
isFromMarketplace: boolean
|
||||
onUpdate?: (isDelete?: boolean) => void
|
||||
}
|
||||
|
||||
type UsePluginOperationsReturn = {
|
||||
handleUpdate: (isDowngrade?: boolean) => Promise<void>
|
||||
handleUpdatedFromMarketplace: () => void
|
||||
handleDelete: () => Promise<void>
|
||||
}
|
||||
|
||||
export const usePluginOperations = ({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace,
|
||||
onUpdate,
|
||||
}: UsePluginOperationsParams): UsePluginOperationsReturn => {
|
||||
const { checkForUpdates, fetchReleases } = useGitHubReleases()
|
||||
const { setShowUpdatePluginModal } = useModalContext()
|
||||
const { refreshModelProviders } = useProviderContext()
|
||||
const invalidateAllToolProviders = useInvalidateAllToolProviders()
|
||||
|
||||
const { id, meta, plugin_id } = detail
|
||||
const { author, category, name } = detail.declaration || detail
|
||||
|
||||
const handleUpdate = useCallback(async (isDowngrade?: boolean) => {
|
||||
if (isFromMarketplace) {
|
||||
versionPicker.setIsDowngrade(!!isDowngrade)
|
||||
modalStates.showUpdateModal()
|
||||
return
|
||||
}
|
||||
|
||||
if (!meta?.repo || !meta?.version || !meta?.package) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: 'Missing plugin metadata for GitHub update',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const owner = meta.repo.split('/')[0] || author
|
||||
const repo = meta.repo.split('/')[1] || name
|
||||
const fetchedReleases = await fetchReleases(owner, repo)
|
||||
if (fetchedReleases.length === 0)
|
||||
return
|
||||
|
||||
const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta.version)
|
||||
Toast.notify(toastProps)
|
||||
|
||||
if (needUpdate) {
|
||||
setShowUpdatePluginModal({
|
||||
onSaveCallback: () => {
|
||||
onUpdate?.()
|
||||
},
|
||||
payload: {
|
||||
type: PluginSource.github,
|
||||
category,
|
||||
github: {
|
||||
originalPackageInfo: {
|
||||
id: detail.plugin_unique_identifier,
|
||||
repo: meta.repo,
|
||||
version: meta.version,
|
||||
package: meta.package,
|
||||
releases: fetchedReleases,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}, [
|
||||
isFromMarketplace,
|
||||
meta,
|
||||
author,
|
||||
name,
|
||||
fetchReleases,
|
||||
checkForUpdates,
|
||||
setShowUpdatePluginModal,
|
||||
detail,
|
||||
onUpdate,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
])
|
||||
|
||||
const handleUpdatedFromMarketplace = useCallback(() => {
|
||||
onUpdate?.()
|
||||
modalStates.hideUpdateModal()
|
||||
}, [onUpdate, modalStates])
|
||||
|
||||
const handleDelete = useCallback(async () => {
|
||||
modalStates.showDeleting()
|
||||
const res = await uninstallPlugin(id)
|
||||
modalStates.hideDeleting()
|
||||
|
||||
if (res.success) {
|
||||
modalStates.hideDeleteConfirm()
|
||||
onUpdate?.(true)
|
||||
|
||||
if (PluginCategoryEnum.model.includes(category))
|
||||
refreshModelProviders()
|
||||
|
||||
if (PluginCategoryEnum.tool.includes(category))
|
||||
invalidateAllToolProviders()
|
||||
|
||||
trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name })
|
||||
}
|
||||
}, [
|
||||
id,
|
||||
category,
|
||||
plugin_id,
|
||||
name,
|
||||
modalStates,
|
||||
onUpdate,
|
||||
refreshModelProviders,
|
||||
invalidateAllToolProviders,
|
||||
])
|
||||
|
||||
return {
|
||||
handleUpdate,
|
||||
handleUpdatedFromMarketplace,
|
||||
handleDelete,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,286 @@
|
||||
'use client'
|
||||
|
||||
import type { PluginDetail } from '../../types'
|
||||
import {
|
||||
RiArrowLeftRightLine,
|
||||
RiCloseLine,
|
||||
} from '@remixicon/react'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { AuthCategory, PluginAuth } from '@/app/components/plugins/plugin-auth'
|
||||
import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown'
|
||||
import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker'
|
||||
import { API_PREFIX } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGetLanguage, useLocale } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { useAllToolProviders } from '@/service/use-tools'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
import { AutoUpdateLine } from '../../../base/icons/src/vender/system'
|
||||
import Verified from '../../base/badges/verified'
|
||||
import DeprecationNotice from '../../base/deprecation-notice'
|
||||
import Icon from '../../card/base/card-icon'
|
||||
import Description from '../../card/base/description'
|
||||
import OrgInfo from '../../card/base/org-info'
|
||||
import Title from '../../card/base/title'
|
||||
import useReferenceSetting from '../../plugin-page/use-reference-setting'
|
||||
import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../../reference-setting-modal/auto-update-setting/utils'
|
||||
import { PluginCategoryEnum, PluginSource } from '../../types'
|
||||
import { HeaderModals, PluginSourceBadge } from './components'
|
||||
import { useDetailHeaderState, usePluginOperations } from './hooks'
|
||||
|
||||
type Props = {
|
||||
detail: PluginDetail
|
||||
isReadmeView?: boolean
|
||||
onHide?: () => void
|
||||
onUpdate?: (isDelete?: boolean) => void
|
||||
}
|
||||
|
||||
const getIconSrc = (icon: string | undefined, iconDark: string | undefined, theme: string, tenantId: string): string => {
|
||||
const iconFileName = theme === 'dark' && iconDark ? iconDark : icon
|
||||
if (!iconFileName)
|
||||
return ''
|
||||
return iconFileName.startsWith('http')
|
||||
? iconFileName
|
||||
: `${API_PREFIX}/workspaces/current/plugin/icon?tenant_id=${tenantId}&filename=${iconFileName}`
|
||||
}
|
||||
|
||||
const getDetailUrl = (
|
||||
source: PluginSource,
|
||||
meta: PluginDetail['meta'],
|
||||
author: string,
|
||||
name: string,
|
||||
locale: string,
|
||||
theme: string,
|
||||
): string => {
|
||||
if (source === PluginSource.github) {
|
||||
const repo = meta?.repo
|
||||
if (!repo)
|
||||
return ''
|
||||
return `https://github.com/${repo}`
|
||||
}
|
||||
if (source === PluginSource.marketplace)
|
||||
return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: locale, theme })
|
||||
return ''
|
||||
}
|
||||
|
||||
const DetailHeader = ({
|
||||
detail,
|
||||
isReadmeView = false,
|
||||
onHide,
|
||||
onUpdate,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { userProfile: { timezone } } = useAppContext()
|
||||
const { theme } = useTheme()
|
||||
const locale = useGetLanguage()
|
||||
const currentLocale = useLocale()
|
||||
const { referenceSetting } = useReferenceSetting()
|
||||
|
||||
const {
|
||||
source,
|
||||
tenant_id,
|
||||
version,
|
||||
latest_version,
|
||||
latest_unique_identifier,
|
||||
meta,
|
||||
plugin_id,
|
||||
status,
|
||||
deprecated_reason,
|
||||
alternative_plugin_id,
|
||||
} = detail
|
||||
|
||||
const { author, category, name, label, description, icon, icon_dark, verified, tool } = detail.declaration || detail
|
||||
|
||||
const {
|
||||
modalStates,
|
||||
versionPicker,
|
||||
hasNewVersion,
|
||||
isAutoUpgradeEnabled,
|
||||
isFromGitHub,
|
||||
isFromMarketplace,
|
||||
} = useDetailHeaderState(detail)
|
||||
|
||||
const {
|
||||
handleUpdate,
|
||||
handleUpdatedFromMarketplace,
|
||||
handleDelete,
|
||||
} = usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace,
|
||||
onUpdate,
|
||||
})
|
||||
|
||||
const isTool = category === PluginCategoryEnum.tool
|
||||
const providerBriefInfo = tool?.identity
|
||||
const providerKey = `${plugin_id}/${providerBriefInfo?.name}`
|
||||
const { data: collectionList = [] } = useAllToolProviders(isTool)
|
||||
const provider = useMemo(() => {
|
||||
return collectionList.find(collection => collection.name === providerKey)
|
||||
}, [collectionList, providerKey])
|
||||
|
||||
const iconSrc = getIconSrc(icon, icon_dark, theme, tenant_id)
|
||||
const detailUrl = getDetailUrl(source, meta, author, name, currentLocale, theme)
|
||||
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
|
||||
|
||||
const handleVersionSelect = (state: { version: string, unique_identifier: string, isDowngrade?: boolean }) => {
|
||||
versionPicker.setTargetVersion(state)
|
||||
handleUpdate(state.isDowngrade)
|
||||
}
|
||||
|
||||
const handleTriggerLatestUpdate = () => {
|
||||
if (isFromMarketplace) {
|
||||
versionPicker.setTargetVersion({
|
||||
version: latest_version,
|
||||
unique_identifier: latest_unique_identifier,
|
||||
})
|
||||
}
|
||||
handleUpdate()
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn('shrink-0 border-b border-divider-subtle bg-components-panel-bg p-4 pb-3', isReadmeView && 'border-b-0 bg-transparent p-0')}>
|
||||
<div className="flex">
|
||||
{/* Plugin Icon */}
|
||||
<div className={cn('overflow-hidden rounded-xl border border-components-panel-border-subtle', isReadmeView && 'bg-components-panel-bg')}>
|
||||
<Icon src={iconSrc} />
|
||||
</div>
|
||||
|
||||
{/* Plugin Info */}
|
||||
<div className="ml-3 w-0 grow">
|
||||
{/* Title Row */}
|
||||
<div className="flex h-5 items-center">
|
||||
<Title title={label[locale]} />
|
||||
{verified && !isReadmeView && <Verified className="ml-0.5 h-4 w-4" text={t('marketplace.verifiedTip', { ns: 'plugin' })} />}
|
||||
|
||||
{/* Version Picker */}
|
||||
{!!version && (
|
||||
<PluginVersionPicker
|
||||
disabled={!isFromMarketplace || isReadmeView}
|
||||
isShow={versionPicker.isShow}
|
||||
onShowChange={versionPicker.setIsShow}
|
||||
pluginID={plugin_id}
|
||||
currentVersion={version}
|
||||
onSelect={handleVersionSelect}
|
||||
trigger={(
|
||||
<Badge
|
||||
className={cn(
|
||||
'mx-1',
|
||||
versionPicker.isShow && 'bg-state-base-hover',
|
||||
(versionPicker.isShow || isFromMarketplace) && 'hover:bg-state-base-hover',
|
||||
)}
|
||||
uppercase={false}
|
||||
text={(
|
||||
<>
|
||||
<div>{isFromGitHub ? (meta?.version ?? version ?? '') : version}</div>
|
||||
{isFromMarketplace && !isReadmeView && <RiArrowLeftRightLine className="ml-1 h-3 w-3 text-text-tertiary" />}
|
||||
</>
|
||||
)}
|
||||
hasRedCornerMark={hasNewVersion}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Auto Update Badge */}
|
||||
{isAutoUpgradeEnabled && !isReadmeView && (
|
||||
<Tooltip popupContent={t('autoUpdate.nextUpdateTime', { ns: 'plugin', time: timeOfDayToDayjs(convertUTCDaySecondsToLocalSeconds(autoUpgradeInfo?.upgrade_time_of_day || 0, timezone!)).format('hh:mm A') })}>
|
||||
<div>
|
||||
<Badge className="mr-1 cursor-pointer px-1">
|
||||
<AutoUpdateLine className="size-3" />
|
||||
</Badge>
|
||||
</div>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{/* Update Button */}
|
||||
{(hasNewVersion || isFromGitHub) && (
|
||||
<Button
|
||||
variant="secondary-accent"
|
||||
size="small"
|
||||
className="!h-5"
|
||||
onClick={handleTriggerLatestUpdate}
|
||||
>
|
||||
{t('detailPanel.operation.update', { ns: 'plugin' })}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Org Info Row */}
|
||||
<div className="mb-1 flex h-4 items-center justify-between">
|
||||
<div className="mt-0.5 flex items-center">
|
||||
<OrgInfo
|
||||
packageNameClassName="w-auto"
|
||||
orgName={author}
|
||||
packageName={name?.includes('/') ? (name.split('/').pop() || '') : name}
|
||||
/>
|
||||
{!!source && <PluginSourceBadge source={source} />}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Action Buttons */}
|
||||
{!isReadmeView && (
|
||||
<div className="flex gap-1">
|
||||
<OperationDropdown
|
||||
source={source}
|
||||
onInfo={modalStates.showPluginInfo}
|
||||
onCheckVersion={handleUpdate}
|
||||
onRemove={modalStates.showDeleteConfirm}
|
||||
detailUrl={detailUrl}
|
||||
/>
|
||||
<ActionButton onClick={onHide}>
|
||||
<RiCloseLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Deprecation Notice */}
|
||||
{isFromMarketplace && (
|
||||
<DeprecationNotice
|
||||
status={status}
|
||||
deprecatedReason={deprecated_reason}
|
||||
alternativePluginId={alternative_plugin_id}
|
||||
alternativePluginURL={getMarketplaceUrl(`/plugins/${alternative_plugin_id}`, { language: currentLocale, theme })}
|
||||
className="mt-3"
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Description */}
|
||||
{!isReadmeView && <Description className="mb-2 mt-3 h-auto" text={description[locale]} descriptionLineRows={2} />}
|
||||
|
||||
{/* Plugin Auth for Tools */}
|
||||
{category === PluginCategoryEnum.tool && !isReadmeView && (
|
||||
<PluginAuth
|
||||
pluginPayload={{
|
||||
provider: provider?.name || '',
|
||||
category: AuthCategory.tool,
|
||||
providerType: provider?.type || '',
|
||||
detail,
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Modals */}
|
||||
<HeaderModals
|
||||
detail={detail}
|
||||
modalStates={modalStates}
|
||||
targetVersion={versionPicker.targetVersion}
|
||||
isDowngrade={versionPicker.isDowngrade}
|
||||
isAutoUpgradeEnabled={isAutoUpgradeEnabled}
|
||||
onUpdatedFromMarketplace={handleUpdatedFromMarketplace}
|
||||
onDelete={handleDelete}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default DetailHeader
|
||||
@@ -2,15 +2,10 @@ import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
// Import after mocks
|
||||
import { SupportedCreationMethods } from '@/app/components/plugins/types'
|
||||
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
|
||||
import { CommonCreateModal } from './common-modal'
|
||||
|
||||
// ============================================================================
|
||||
// Type Definitions
|
||||
// ============================================================================
|
||||
|
||||
type PluginDetail = {
|
||||
plugin_id: string
|
||||
provider: string
|
||||
@@ -33,10 +28,6 @@ type TriggerLogEntity = {
|
||||
level: 'info' | 'warn' | 'error'
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Factory Functions
|
||||
// ============================================================================
|
||||
|
||||
function createMockPluginDetail(overrides: Partial<PluginDetail> = {}): PluginDetail {
|
||||
return {
|
||||
plugin_id: 'test-plugin-id',
|
||||
@@ -74,18 +65,12 @@ function createMockLogData(logs: TriggerLogEntity[] = []): { logs: TriggerLogEnt
|
||||
return { logs }
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Setup
|
||||
// ============================================================================
|
||||
|
||||
// Mock plugin store
|
||||
const mockPluginDetail = createMockPluginDetail()
|
||||
const mockUsePluginStore = vi.fn(() => mockPluginDetail)
|
||||
vi.mock('../../store', () => ({
|
||||
usePluginStore: () => mockUsePluginStore(),
|
||||
}))
|
||||
|
||||
// Mock subscription list hook
|
||||
const mockRefetch = vi.fn()
|
||||
vi.mock('../use-subscription-list', () => ({
|
||||
useSubscriptionList: () => ({
|
||||
@@ -93,13 +78,11 @@ vi.mock('../use-subscription-list', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock service hooks
|
||||
const mockVerifyCredentials = vi.fn()
|
||||
const mockCreateBuilder = vi.fn()
|
||||
const mockBuildSubscription = vi.fn()
|
||||
const mockUpdateBuilder = vi.fn()
|
||||
|
||||
// Configurable pending states
|
||||
let mockIsVerifyingCredentials = false
|
||||
let mockIsBuilding = false
|
||||
const setMockPendingStates = (verifying: boolean, building: boolean) => {
|
||||
@@ -129,18 +112,15 @@ vi.mock('@/service/use-triggers', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock error parser
|
||||
const mockParsePluginErrorMessage = vi.fn().mockResolvedValue(null)
|
||||
vi.mock('@/utils/error-parser', () => ({
|
||||
parsePluginErrorMessage: (...args: unknown[]) => mockParsePluginErrorMessage(...args),
|
||||
}))
|
||||
|
||||
// Mock URL validation
|
||||
vi.mock('@/utils/urlValidation', () => ({
|
||||
isPrivateOrLocalAddress: vi.fn().mockReturnValue(false),
|
||||
}))
|
||||
|
||||
// Mock toast
|
||||
const mockToastNotify = vi.fn()
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
@@ -148,7 +128,6 @@ vi.mock('@/app/components/base/toast', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock Modal component
|
||||
vi.mock('@/app/components/base/modal/modal', () => ({
|
||||
default: ({
|
||||
children,
|
||||
@@ -179,7 +158,6 @@ vi.mock('@/app/components/base/modal/modal', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
// Configurable form mock values
|
||||
type MockFormValuesConfig = {
|
||||
values: Record<string, unknown>
|
||||
isCheckValidated: boolean
|
||||
@@ -190,7 +168,6 @@ let mockFormValuesConfig: MockFormValuesConfig = {
|
||||
}
|
||||
let mockGetFormReturnsNull = false
|
||||
|
||||
// Separate validation configs for different forms
|
||||
let mockSubscriptionFormValidated = true
|
||||
let mockAutoParamsFormValidated = true
|
||||
let mockManualPropsFormValidated = true
|
||||
@@ -207,7 +184,6 @@ const setMockFormValidation = (subscription: boolean, autoParams: boolean, manua
|
||||
mockManualPropsFormValidated = manualProps
|
||||
}
|
||||
|
||||
// Mock BaseForm component with ref support
|
||||
vi.mock('@/app/components/base/form/components/base', async () => {
|
||||
const React = await import('react')
|
||||
|
||||
@@ -219,7 +195,6 @@ vi.mock('@/app/components/base/form/components/base', async () => {
|
||||
type MockBaseFormProps = { formSchemas: Array<{ name: string }>, onChange?: () => void }
|
||||
|
||||
function MockBaseFormInner({ formSchemas, onChange }: MockBaseFormProps, ref: React.ForwardedRef<MockFormRef>) {
|
||||
// Determine which form this is based on schema
|
||||
const isSubscriptionForm = formSchemas.some((s: { name: string }) => s.name === 'subscription_name')
|
||||
const isAutoParamsForm = formSchemas.some((s: { name: string }) =>
|
||||
['repo_name', 'branch', 'repo', 'text_field', 'dynamic_field', 'bool_field', 'text_input_field', 'unknown_field', 'count'].includes(s.name),
|
||||
@@ -265,12 +240,10 @@ vi.mock('@/app/components/base/form/components/base', async () => {
|
||||
}
|
||||
})
|
||||
|
||||
// Mock EncryptedBottom component
|
||||
vi.mock('@/app/components/base/encrypted-bottom', () => ({
|
||||
EncryptedBottom: () => <div data-testid="encrypted-bottom">Encrypted</div>,
|
||||
}))
|
||||
|
||||
// Mock LogViewer component
|
||||
vi.mock('../log-viewer', () => ({
|
||||
default: ({ logs }: { logs: TriggerLogEntity[] }) => (
|
||||
<div data-testid="log-viewer">
|
||||
@@ -281,7 +254,6 @@ vi.mock('../log-viewer', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
// Mock debounce
|
||||
vi.mock('es-toolkit/compat', () => ({
|
||||
debounce: (fn: (...args: unknown[]) => unknown) => {
|
||||
const debouncedFn = (...args: unknown[]) => fn(...args)
|
||||
@@ -290,10 +262,6 @@ vi.mock('es-toolkit/compat', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
// ============================================================================
|
||||
// Test Suites
|
||||
// ============================================================================
|
||||
|
||||
describe('CommonCreateModal', () => {
|
||||
const defaultProps = {
|
||||
onClose: vi.fn(),
|
||||
@@ -441,7 +409,8 @@ describe('CommonCreateModal', () => {
|
||||
})
|
||||
|
||||
it('should call onConfirm handler when confirm button is clicked', () => {
|
||||
render(<CommonCreateModal {...defaultProps} />)
|
||||
// Provide builder so the guard passes and credentials check is reached
|
||||
render(<CommonCreateModal {...defaultProps} builder={createMockSubscriptionBuilder()} />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('modal-confirm'))
|
||||
|
||||
@@ -1243,13 +1212,22 @@ describe('CommonCreateModal', () => {
|
||||
|
||||
render(<CommonCreateModal {...defaultProps} createType={SupportedCreationMethods.MANUAL} />)
|
||||
|
||||
// Wait for createBuilder to complete and state to update
|
||||
await waitFor(() => {
|
||||
expect(mockCreateBuilder).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Allow React to process the state update from createBuilder
|
||||
await act(async () => {})
|
||||
|
||||
const input = screen.getByTestId('form-field-webhook_url')
|
||||
fireEvent.change(input, { target: { value: 'https://example.com/webhook' } })
|
||||
|
||||
// Wait for updateBuilder to be called, then check the toast
|
||||
await waitFor(() => {
|
||||
expect(mockUpdateBuilder).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
@@ -1450,7 +1428,8 @@ describe('CommonCreateModal', () => {
|
||||
})
|
||||
mockUsePluginStore.mockReturnValue(detailWithCredentials)
|
||||
|
||||
render(<CommonCreateModal {...defaultProps} />)
|
||||
// Provide builder so the guard passes and credentials check is reached
|
||||
render(<CommonCreateModal {...defaultProps} builder={createMockSubscriptionBuilder()} />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('modal-confirm'))
|
||||
|
||||
|
||||
@@ -1,32 +1,19 @@
|
||||
'use client'
|
||||
import type { FormRefObject } from '@/app/components/base/form/types'
|
||||
import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
|
||||
import type { BuildTriggerSubscriptionPayload } from '@/service/use-triggers'
|
||||
import { RiLoader2Line } from '@remixicon/react'
|
||||
import { debounce } from 'es-toolkit/compat'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
// import { CopyFeedbackNew } from '@/app/components/base/copy-feedback'
|
||||
import { EncryptedBottom } from '@/app/components/base/encrypted-bottom'
|
||||
import { BaseForm } from '@/app/components/base/form/components/base'
|
||||
import { FormTypeEnum } from '@/app/components/base/form/types'
|
||||
import Modal from '@/app/components/base/modal/modal'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { SupportedCreationMethods } from '@/app/components/plugins/types'
|
||||
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
|
||||
import {
|
||||
useBuildTriggerSubscription,
|
||||
useCreateTriggerSubscriptionBuilder,
|
||||
useTriggerSubscriptionBuilderLogs,
|
||||
useUpdateTriggerSubscriptionBuilder,
|
||||
useVerifyAndUpdateTriggerSubscriptionBuilder,
|
||||
} from '@/service/use-triggers'
|
||||
import { parsePluginErrorMessage } from '@/utils/error-parser'
|
||||
import { isPrivateOrLocalAddress } from '@/utils/urlValidation'
|
||||
import { usePluginStore } from '../../store'
|
||||
import LogViewer from '../log-viewer'
|
||||
import { useSubscriptionList } from '../use-subscription-list'
|
||||
ConfigurationStepContent,
|
||||
MultiSteps,
|
||||
VerifyStepContent,
|
||||
} from './components/modal-steps'
|
||||
import {
|
||||
ApiKeyStep,
|
||||
MODAL_TITLE_KEY_MAP,
|
||||
useCommonModalState,
|
||||
} from './hooks/use-common-modal-state'
|
||||
|
||||
type Props = {
|
||||
onClose: () => void
|
||||
@@ -34,316 +21,33 @@ type Props = {
|
||||
builder?: TriggerSubscriptionBuilder
|
||||
}
|
||||
|
||||
const CREDENTIAL_TYPE_MAP: Record<SupportedCreationMethods, TriggerCredentialTypeEnum> = {
|
||||
[SupportedCreationMethods.APIKEY]: TriggerCredentialTypeEnum.ApiKey,
|
||||
[SupportedCreationMethods.OAUTH]: TriggerCredentialTypeEnum.Oauth2,
|
||||
[SupportedCreationMethods.MANUAL]: TriggerCredentialTypeEnum.Unauthorized,
|
||||
}
|
||||
|
||||
const MODAL_TITLE_KEY_MAP: Record<
|
||||
SupportedCreationMethods,
|
||||
'modal.apiKey.title' | 'modal.oauth.title' | 'modal.manual.title'
|
||||
> = {
|
||||
[SupportedCreationMethods.APIKEY]: 'modal.apiKey.title',
|
||||
[SupportedCreationMethods.OAUTH]: 'modal.oauth.title',
|
||||
[SupportedCreationMethods.MANUAL]: 'modal.manual.title',
|
||||
}
|
||||
|
||||
enum ApiKeyStep {
|
||||
Verify = 'verify',
|
||||
Configuration = 'configuration',
|
||||
}
|
||||
|
||||
const defaultFormValues = { values: {}, isCheckValidated: false }
|
||||
|
||||
const normalizeFormType = (type: FormTypeEnum | string): FormTypeEnum => {
|
||||
if (Object.values(FormTypeEnum).includes(type as FormTypeEnum))
|
||||
return type as FormTypeEnum
|
||||
|
||||
switch (type) {
|
||||
case 'string':
|
||||
case 'text':
|
||||
return FormTypeEnum.textInput
|
||||
case 'password':
|
||||
case 'secret':
|
||||
return FormTypeEnum.secretInput
|
||||
case 'number':
|
||||
case 'integer':
|
||||
return FormTypeEnum.textNumber
|
||||
case 'boolean':
|
||||
return FormTypeEnum.boolean
|
||||
default:
|
||||
return FormTypeEnum.textInput
|
||||
}
|
||||
}
|
||||
|
||||
const StatusStep = ({ isActive, text }: { isActive: boolean, text: string }) => {
|
||||
return (
|
||||
<div className={`system-2xs-semibold-uppercase flex items-center gap-1 ${isActive
|
||||
? 'text-state-accent-solid'
|
||||
: 'text-text-tertiary'}`}
|
||||
>
|
||||
{/* Active indicator dot */}
|
||||
{isActive && (
|
||||
<div className="h-1 w-1 rounded-full bg-state-accent-solid"></div>
|
||||
)}
|
||||
{text}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const MultiSteps = ({ currentStep }: { currentStep: ApiKeyStep }) => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
<div className="mb-6 flex w-1/3 items-center gap-2">
|
||||
<StatusStep isActive={currentStep === ApiKeyStep.Verify} text={t('modal.steps.verify', { ns: 'pluginTrigger' })} />
|
||||
<div className="h-px w-3 shrink-0 bg-divider-deep"></div>
|
||||
<StatusStep isActive={currentStep === ApiKeyStep.Configuration} text={t('modal.steps.configuration', { ns: 'pluginTrigger' })} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export const CommonCreateModal = ({ onClose, createType, builder }: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const detail = usePluginStore(state => state.detail)
|
||||
const { refetch } = useSubscriptionList()
|
||||
|
||||
const [currentStep, setCurrentStep] = useState<ApiKeyStep>(createType === SupportedCreationMethods.APIKEY ? ApiKeyStep.Verify : ApiKeyStep.Configuration)
|
||||
const {
|
||||
currentStep,
|
||||
subscriptionBuilder,
|
||||
isVerifyingCredentials,
|
||||
isBuilding,
|
||||
formRefs,
|
||||
detail,
|
||||
manualPropertiesSchema,
|
||||
autoCommonParametersSchema,
|
||||
apiKeyCredentialsSchema,
|
||||
logData,
|
||||
confirmButtonText,
|
||||
handleConfirm,
|
||||
handleManualPropertiesChange,
|
||||
handleApiKeyCredentialsChange,
|
||||
} = useCommonModalState({
|
||||
createType,
|
||||
builder,
|
||||
onClose,
|
||||
})
|
||||
|
||||
const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>(builder)
|
||||
const isInitializedRef = useRef(false)
|
||||
|
||||
const { mutate: verifyCredentials, isPending: isVerifyingCredentials } = useVerifyAndUpdateTriggerSubscriptionBuilder()
|
||||
const { mutateAsync: createBuilder /* isPending: isCreatingBuilder */ } = useCreateTriggerSubscriptionBuilder()
|
||||
const { mutate: buildSubscription, isPending: isBuilding } = useBuildTriggerSubscription()
|
||||
const { mutate: updateBuilder } = useUpdateTriggerSubscriptionBuilder()
|
||||
|
||||
const manualPropertiesSchema = detail?.declaration?.trigger?.subscription_schema || [] // manual
|
||||
const manualPropertiesFormRef = React.useRef<FormRefObject>(null)
|
||||
|
||||
const subscriptionFormRef = React.useRef<FormRefObject>(null)
|
||||
|
||||
const autoCommonParametersSchema = detail?.declaration.trigger?.subscription_constructor?.parameters || [] // apikey and oauth
|
||||
const autoCommonParametersFormRef = React.useRef<FormRefObject>(null)
|
||||
|
||||
const apiKeyCredentialsSchema = useMemo(() => {
|
||||
const rawSchema = detail?.declaration?.trigger?.subscription_constructor?.credentials_schema || []
|
||||
return rawSchema.map(schema => ({
|
||||
...schema,
|
||||
tooltip: schema.help,
|
||||
}))
|
||||
}, [detail?.declaration?.trigger?.subscription_constructor?.credentials_schema])
|
||||
const apiKeyCredentialsFormRef = React.useRef<FormRefObject>(null)
|
||||
|
||||
const { data: logData } = useTriggerSubscriptionBuilderLogs(
|
||||
detail?.provider || '',
|
||||
subscriptionBuilder?.id || '',
|
||||
{
|
||||
enabled: createType === SupportedCreationMethods.MANUAL,
|
||||
refetchInterval: 3000,
|
||||
},
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
const initializeBuilder = async () => {
|
||||
isInitializedRef.current = true
|
||||
try {
|
||||
const response = await createBuilder({
|
||||
provider: detail?.provider || '',
|
||||
credential_type: CREDENTIAL_TYPE_MAP[createType],
|
||||
})
|
||||
setSubscriptionBuilder(response.subscription_builder)
|
||||
}
|
||||
catch (error) {
|
||||
console.error('createBuilder error:', error)
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('modal.errors.createFailed', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
}
|
||||
}
|
||||
if (!isInitializedRef.current && !subscriptionBuilder && detail?.provider)
|
||||
initializeBuilder()
|
||||
}, [subscriptionBuilder, detail?.provider, createType, createBuilder, t])
|
||||
|
||||
useEffect(() => {
|
||||
if (subscriptionBuilder?.endpoint && subscriptionFormRef.current && currentStep === ApiKeyStep.Configuration) {
|
||||
const form = subscriptionFormRef.current.getForm()
|
||||
if (form)
|
||||
form.setFieldValue('callback_url', subscriptionBuilder.endpoint)
|
||||
if (isPrivateOrLocalAddress(subscriptionBuilder.endpoint)) {
|
||||
console.warn('callback_url is private or local address', subscriptionBuilder.endpoint)
|
||||
subscriptionFormRef.current?.setFields([{
|
||||
name: 'callback_url',
|
||||
warnings: [t('modal.form.callbackUrl.privateAddressWarning', { ns: 'pluginTrigger' })],
|
||||
}])
|
||||
}
|
||||
else {
|
||||
subscriptionFormRef.current?.setFields([{
|
||||
name: 'callback_url',
|
||||
warnings: [],
|
||||
}])
|
||||
}
|
||||
}
|
||||
}, [subscriptionBuilder?.endpoint, currentStep, t])
|
||||
|
||||
const debouncedUpdate = useMemo(
|
||||
() => debounce((provider: string, builderId: string, properties: Record<string, unknown>) => {
|
||||
updateBuilder(
|
||||
{
|
||||
provider,
|
||||
subscriptionBuilderId: builderId,
|
||||
properties,
|
||||
},
|
||||
{
|
||||
onError: async (error: unknown) => {
|
||||
const errorMessage = await parsePluginErrorMessage(error) || t('modal.errors.updateFailed', { ns: 'pluginTrigger' })
|
||||
console.error('Failed to update subscription builder:', error)
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: errorMessage,
|
||||
})
|
||||
},
|
||||
},
|
||||
)
|
||||
}, 500),
|
||||
[updateBuilder, t],
|
||||
)
|
||||
|
||||
const handleManualPropertiesChange = useCallback(() => {
|
||||
if (!subscriptionBuilder || !detail?.provider)
|
||||
return
|
||||
|
||||
const formValues = manualPropertiesFormRef.current?.getFormValues({ needCheckValidatedValues: false }) || { values: {}, isCheckValidated: true }
|
||||
|
||||
debouncedUpdate(detail.provider, subscriptionBuilder.id, formValues.values)
|
||||
}, [subscriptionBuilder, detail?.provider, debouncedUpdate])
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
debouncedUpdate.cancel()
|
||||
}
|
||||
}, [debouncedUpdate])
|
||||
|
||||
const handleVerify = () => {
|
||||
const apiKeyCredentialsFormValues = apiKeyCredentialsFormRef.current?.getFormValues({}) || defaultFormValues
|
||||
const credentials = apiKeyCredentialsFormValues.values
|
||||
|
||||
if (!Object.keys(credentials).length) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: 'Please fill in all required credentials',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyCredentialsFormRef.current?.setFields([{
|
||||
name: Object.keys(credentials)[0],
|
||||
errors: [],
|
||||
}])
|
||||
|
||||
verifyCredentials(
|
||||
{
|
||||
provider: detail?.provider || '',
|
||||
subscriptionBuilderId: subscriptionBuilder?.id || '',
|
||||
credentials,
|
||||
},
|
||||
{
|
||||
onSuccess: () => {
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('modal.apiKey.verify.success', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
setCurrentStep(ApiKeyStep.Configuration)
|
||||
},
|
||||
onError: async (error: unknown) => {
|
||||
const errorMessage = await parsePluginErrorMessage(error) || t('modal.apiKey.verify.error', { ns: 'pluginTrigger' })
|
||||
apiKeyCredentialsFormRef.current?.setFields([{
|
||||
name: Object.keys(credentials)[0],
|
||||
errors: [errorMessage],
|
||||
}])
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
const handleCreate = () => {
|
||||
if (!subscriptionBuilder) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: 'Subscription builder not found',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const subscriptionFormValues = subscriptionFormRef.current?.getFormValues({})
|
||||
if (!subscriptionFormValues?.isCheckValidated)
|
||||
return
|
||||
|
||||
const subscriptionNameValue = subscriptionFormValues?.values?.subscription_name as string
|
||||
|
||||
const params: BuildTriggerSubscriptionPayload = {
|
||||
provider: detail?.provider || '',
|
||||
subscriptionBuilderId: subscriptionBuilder.id,
|
||||
name: subscriptionNameValue,
|
||||
}
|
||||
|
||||
if (createType !== SupportedCreationMethods.MANUAL) {
|
||||
if (autoCommonParametersSchema.length > 0) {
|
||||
const autoCommonParametersFormValues = autoCommonParametersFormRef.current?.getFormValues({}) || defaultFormValues
|
||||
if (!autoCommonParametersFormValues?.isCheckValidated)
|
||||
return
|
||||
params.parameters = autoCommonParametersFormValues.values
|
||||
}
|
||||
}
|
||||
else if (manualPropertiesSchema.length > 0) {
|
||||
const manualFormValues = manualPropertiesFormRef.current?.getFormValues({}) || defaultFormValues
|
||||
if (!manualFormValues?.isCheckValidated)
|
||||
return
|
||||
}
|
||||
|
||||
buildSubscription(
|
||||
params,
|
||||
{
|
||||
onSuccess: () => {
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('subscription.createSuccess', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
onClose()
|
||||
refetch?.()
|
||||
},
|
||||
onError: async (error: unknown) => {
|
||||
const errorMessage = await parsePluginErrorMessage(error) || t('subscription.createFailed', { ns: 'pluginTrigger' })
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: errorMessage,
|
||||
})
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
const handleConfirm = () => {
|
||||
if (currentStep === ApiKeyStep.Verify)
|
||||
handleVerify()
|
||||
else
|
||||
handleCreate()
|
||||
}
|
||||
|
||||
const handleApiKeyCredentialsChange = () => {
|
||||
apiKeyCredentialsFormRef.current?.setFields([{
|
||||
name: apiKeyCredentialsSchema[0].name,
|
||||
errors: [],
|
||||
}])
|
||||
}
|
||||
|
||||
const confirmButtonText = useMemo(() => {
|
||||
if (currentStep === ApiKeyStep.Verify)
|
||||
return isVerifyingCredentials ? t('modal.common.verifying', { ns: 'pluginTrigger' }) : t('modal.common.verify', { ns: 'pluginTrigger' })
|
||||
|
||||
return isBuilding ? t('modal.common.creating', { ns: 'pluginTrigger' }) : t('modal.common.create', { ns: 'pluginTrigger' })
|
||||
}, [currentStep, isVerifyingCredentials, isBuilding, t])
|
||||
const isApiKeyType = createType === SupportedCreationMethods.APIKEY
|
||||
const isVerifyStep = currentStep === ApiKeyStep.Verify
|
||||
const isConfigurationStep = currentStep === ApiKeyStep.Configuration
|
||||
|
||||
return (
|
||||
<Modal
|
||||
@@ -353,121 +57,36 @@ export const CommonCreateModal = ({ onClose, createType, builder }: Props) => {
|
||||
onCancel={onClose}
|
||||
onConfirm={handleConfirm}
|
||||
disabled={isVerifyingCredentials || isBuilding}
|
||||
bottomSlot={currentStep === ApiKeyStep.Verify ? <EncryptedBottom /> : null}
|
||||
bottomSlot={isVerifyStep ? <EncryptedBottom /> : null}
|
||||
size={createType === SupportedCreationMethods.MANUAL ? 'md' : 'sm'}
|
||||
containerClassName="min-h-[360px]"
|
||||
clickOutsideNotClose
|
||||
>
|
||||
{createType === SupportedCreationMethods.APIKEY && <MultiSteps currentStep={currentStep} />}
|
||||
{currentStep === ApiKeyStep.Verify && (
|
||||
<>
|
||||
{apiKeyCredentialsSchema.length > 0 && (
|
||||
<div className="mb-4">
|
||||
<BaseForm
|
||||
formSchemas={apiKeyCredentialsSchema}
|
||||
ref={apiKeyCredentialsFormRef}
|
||||
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
|
||||
preventDefaultSubmit={true}
|
||||
formClassName="space-y-4"
|
||||
onChange={handleApiKeyCredentialsChange}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{currentStep === ApiKeyStep.Configuration && (
|
||||
<div className="max-h-[70vh]">
|
||||
<BaseForm
|
||||
formSchemas={[
|
||||
{
|
||||
name: 'subscription_name',
|
||||
label: t('modal.form.subscriptionName.label', { ns: 'pluginTrigger' }),
|
||||
placeholder: t('modal.form.subscriptionName.placeholder', { ns: 'pluginTrigger' }),
|
||||
type: FormTypeEnum.textInput,
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
name: 'callback_url',
|
||||
label: t('modal.form.callbackUrl.label', { ns: 'pluginTrigger' }),
|
||||
placeholder: t('modal.form.callbackUrl.placeholder', { ns: 'pluginTrigger' }),
|
||||
type: FormTypeEnum.textInput,
|
||||
required: false,
|
||||
default: subscriptionBuilder?.endpoint || '',
|
||||
disabled: true,
|
||||
tooltip: t('modal.form.callbackUrl.tooltip', { ns: 'pluginTrigger' }),
|
||||
showCopy: true,
|
||||
},
|
||||
]}
|
||||
ref={subscriptionFormRef}
|
||||
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
|
||||
formClassName="space-y-4 mb-4"
|
||||
/>
|
||||
{/* <div className='system-xs-regular mb-6 mt-[-1rem] text-text-tertiary'>
|
||||
{t('pluginTrigger.modal.form.callbackUrl.description')}
|
||||
</div> */}
|
||||
{createType !== SupportedCreationMethods.MANUAL && autoCommonParametersSchema.length > 0 && (
|
||||
<BaseForm
|
||||
formSchemas={autoCommonParametersSchema.map((schema) => {
|
||||
const normalizedType = normalizeFormType(schema.type as FormTypeEnum | string)
|
||||
return {
|
||||
...schema,
|
||||
tooltip: schema.description,
|
||||
type: normalizedType,
|
||||
dynamicSelectParams: normalizedType === FormTypeEnum.dynamicSelect
|
||||
? {
|
||||
plugin_id: detail?.plugin_id || '',
|
||||
provider: detail?.provider || '',
|
||||
action: 'provider',
|
||||
parameter: schema.name,
|
||||
credential_id: subscriptionBuilder?.id || '',
|
||||
}
|
||||
: undefined,
|
||||
fieldClassName: schema.type === FormTypeEnum.boolean ? 'flex items-center justify-between' : undefined,
|
||||
labelClassName: schema.type === FormTypeEnum.boolean ? 'mb-0' : undefined,
|
||||
}
|
||||
})}
|
||||
ref={autoCommonParametersFormRef}
|
||||
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
|
||||
formClassName="space-y-4"
|
||||
/>
|
||||
)}
|
||||
{createType === SupportedCreationMethods.MANUAL && (
|
||||
<>
|
||||
{manualPropertiesSchema.length > 0 && (
|
||||
<div className="mb-6">
|
||||
<BaseForm
|
||||
formSchemas={manualPropertiesSchema.map(schema => ({
|
||||
...schema,
|
||||
tooltip: schema.description,
|
||||
}))}
|
||||
ref={manualPropertiesFormRef}
|
||||
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
|
||||
formClassName="space-y-4"
|
||||
onChange={handleManualPropertiesChange}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div className="mb-6">
|
||||
<div className="mb-3 flex items-center gap-2">
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">
|
||||
{t('modal.manual.logs.title', { ns: 'pluginTrigger' })}
|
||||
</div>
|
||||
<div className="h-px flex-1 bg-gradient-to-r from-divider-regular to-transparent" />
|
||||
</div>
|
||||
{isApiKeyType && <MultiSteps currentStep={currentStep} />}
|
||||
|
||||
<div className="mb-1 flex items-center justify-center gap-1 rounded-lg bg-background-section p-3">
|
||||
<div className="h-3.5 w-3.5">
|
||||
<RiLoader2Line className="h-full w-full animate-spin" />
|
||||
</div>
|
||||
<div className="system-xs-regular text-text-tertiary">
|
||||
{t('modal.manual.logs.loading', { ns: 'pluginTrigger', pluginName: detail?.name || '' })}
|
||||
</div>
|
||||
</div>
|
||||
<LogViewer logs={logData?.logs || []} />
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
{isVerifyStep && (
|
||||
<VerifyStepContent
|
||||
apiKeyCredentialsSchema={apiKeyCredentialsSchema}
|
||||
apiKeyCredentialsFormRef={formRefs.apiKeyCredentialsFormRef}
|
||||
onChange={handleApiKeyCredentialsChange}
|
||||
/>
|
||||
)}
|
||||
|
||||
{isConfigurationStep && (
|
||||
<ConfigurationStepContent
|
||||
createType={createType}
|
||||
subscriptionBuilder={subscriptionBuilder}
|
||||
subscriptionFormRef={formRefs.subscriptionFormRef}
|
||||
autoCommonParametersSchema={autoCommonParametersSchema}
|
||||
autoCommonParametersFormRef={formRefs.autoCommonParametersFormRef}
|
||||
manualPropertiesSchema={manualPropertiesSchema}
|
||||
manualPropertiesFormRef={formRefs.manualPropertiesFormRef}
|
||||
onManualPropertiesChange={handleManualPropertiesChange}
|
||||
logs={logData?.logs || []}
|
||||
pluginId={detail?.plugin_id || ''}
|
||||
pluginName={detail?.name || ''}
|
||||
provider={detail?.provider || ''}
|
||||
/>
|
||||
)}
|
||||
</Modal>
|
||||
)
|
||||
|
||||
@@ -0,0 +1,304 @@
|
||||
'use client'
|
||||
import type { FormRefObject, FormSchema } from '@/app/components/base/form/types'
|
||||
import type { TriggerLogEntity, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
|
||||
import { RiLoader2Line } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { BaseForm } from '@/app/components/base/form/components/base'
|
||||
import { FormTypeEnum } from '@/app/components/base/form/types'
|
||||
import { SupportedCreationMethods } from '@/app/components/plugins/types'
|
||||
import LogViewer from '../../log-viewer'
|
||||
import { ApiKeyStep } from '../hooks/use-common-modal-state'
|
||||
|
||||
export type SchemaItem = Partial<FormSchema> & Record<string, unknown> & {
|
||||
name: string
|
||||
}
|
||||
|
||||
type StatusStepProps = {
|
||||
isActive: boolean
|
||||
text: string
|
||||
}
|
||||
|
||||
export const StatusStep = ({ isActive, text }: StatusStepProps) => {
|
||||
return (
|
||||
<div className={`system-2xs-semibold-uppercase flex items-center gap-1 ${isActive
|
||||
? 'text-state-accent-solid'
|
||||
: 'text-text-tertiary'}`}
|
||||
>
|
||||
{isActive && (
|
||||
<div className="h-1 w-1 rounded-full bg-state-accent-solid"></div>
|
||||
)}
|
||||
{text}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
type MultiStepsProps = {
|
||||
currentStep: ApiKeyStep
|
||||
}
|
||||
|
||||
export const MultiSteps = ({ currentStep }: MultiStepsProps) => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
<div className="mb-6 flex w-1/3 items-center gap-2">
|
||||
<StatusStep isActive={currentStep === ApiKeyStep.Verify} text={t('modal.steps.verify', { ns: 'pluginTrigger' })} />
|
||||
<div className="h-px w-3 shrink-0 bg-divider-deep"></div>
|
||||
<StatusStep isActive={currentStep === ApiKeyStep.Configuration} text={t('modal.steps.configuration', { ns: 'pluginTrigger' })} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
type VerifyStepContentProps = {
|
||||
apiKeyCredentialsSchema: SchemaItem[]
|
||||
apiKeyCredentialsFormRef: React.RefObject<FormRefObject | null>
|
||||
onChange: () => void
|
||||
}
|
||||
|
||||
export const VerifyStepContent = ({
|
||||
apiKeyCredentialsSchema,
|
||||
apiKeyCredentialsFormRef,
|
||||
onChange,
|
||||
}: VerifyStepContentProps) => {
|
||||
if (!apiKeyCredentialsSchema.length)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div className="mb-4">
|
||||
<BaseForm
|
||||
formSchemas={apiKeyCredentialsSchema as FormSchema[]}
|
||||
ref={apiKeyCredentialsFormRef}
|
||||
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
|
||||
preventDefaultSubmit={true}
|
||||
formClassName="space-y-4"
|
||||
onChange={onChange}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
type SubscriptionFormProps = {
|
||||
subscriptionFormRef: React.RefObject<FormRefObject | null>
|
||||
endpoint?: string
|
||||
}
|
||||
|
||||
export const SubscriptionForm = ({
|
||||
subscriptionFormRef,
|
||||
endpoint,
|
||||
}: SubscriptionFormProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const formSchemas = React.useMemo(() => [
|
||||
{
|
||||
name: 'subscription_name',
|
||||
label: t('modal.form.subscriptionName.label', { ns: 'pluginTrigger' }),
|
||||
placeholder: t('modal.form.subscriptionName.placeholder', { ns: 'pluginTrigger' }),
|
||||
type: FormTypeEnum.textInput,
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
name: 'callback_url',
|
||||
label: t('modal.form.callbackUrl.label', { ns: 'pluginTrigger' }),
|
||||
placeholder: t('modal.form.callbackUrl.placeholder', { ns: 'pluginTrigger' }),
|
||||
type: FormTypeEnum.textInput,
|
||||
required: false,
|
||||
default: endpoint || '',
|
||||
disabled: true,
|
||||
tooltip: t('modal.form.callbackUrl.tooltip', { ns: 'pluginTrigger' }),
|
||||
showCopy: true,
|
||||
},
|
||||
], [endpoint, t])
|
||||
|
||||
return (
|
||||
<BaseForm
|
||||
formSchemas={formSchemas}
|
||||
ref={subscriptionFormRef}
|
||||
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
|
||||
formClassName="space-y-4 mb-4"
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
const normalizeFormType = (type: FormTypeEnum | string): FormTypeEnum => {
|
||||
if (Object.values(FormTypeEnum).includes(type as FormTypeEnum))
|
||||
return type as FormTypeEnum
|
||||
|
||||
const TYPE_MAP: Record<string, FormTypeEnum> = {
|
||||
string: FormTypeEnum.textInput,
|
||||
text: FormTypeEnum.textInput,
|
||||
password: FormTypeEnum.secretInput,
|
||||
secret: FormTypeEnum.secretInput,
|
||||
number: FormTypeEnum.textNumber,
|
||||
integer: FormTypeEnum.textNumber,
|
||||
boolean: FormTypeEnum.boolean,
|
||||
}
|
||||
|
||||
return TYPE_MAP[type] || FormTypeEnum.textInput
|
||||
}
|
||||
|
||||
type AutoParametersFormProps = {
|
||||
schemas: SchemaItem[]
|
||||
formRef: React.RefObject<FormRefObject | null>
|
||||
pluginId: string
|
||||
provider: string
|
||||
credentialId: string
|
||||
}
|
||||
|
||||
export const AutoParametersForm = ({
|
||||
schemas,
|
||||
formRef,
|
||||
pluginId,
|
||||
provider,
|
||||
credentialId,
|
||||
}: AutoParametersFormProps) => {
|
||||
const formSchemas = React.useMemo(() =>
|
||||
schemas.map((schema) => {
|
||||
const normalizedType = normalizeFormType((schema.type || FormTypeEnum.textInput) as FormTypeEnum | string)
|
||||
return {
|
||||
...schema,
|
||||
tooltip: schema.description,
|
||||
type: normalizedType,
|
||||
dynamicSelectParams: normalizedType === FormTypeEnum.dynamicSelect
|
||||
? {
|
||||
plugin_id: pluginId,
|
||||
provider,
|
||||
action: 'provider',
|
||||
parameter: schema.name,
|
||||
credential_id: credentialId,
|
||||
}
|
||||
: undefined,
|
||||
fieldClassName: normalizedType === FormTypeEnum.boolean ? 'flex items-center justify-between' : undefined,
|
||||
labelClassName: normalizedType === FormTypeEnum.boolean ? 'mb-0' : undefined,
|
||||
}
|
||||
}) as FormSchema[], [schemas, pluginId, provider, credentialId])
|
||||
|
||||
if (!schemas.length)
|
||||
return null
|
||||
|
||||
return (
|
||||
<BaseForm
|
||||
formSchemas={formSchemas}
|
||||
ref={formRef}
|
||||
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
|
||||
formClassName="space-y-4"
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
type ManualPropertiesSectionProps = {
|
||||
schemas: SchemaItem[]
|
||||
formRef: React.RefObject<FormRefObject | null>
|
||||
onChange: () => void
|
||||
logs: TriggerLogEntity[]
|
||||
pluginName: string
|
||||
}
|
||||
|
||||
export const ManualPropertiesSection = ({
|
||||
schemas,
|
||||
formRef,
|
||||
onChange,
|
||||
logs,
|
||||
pluginName,
|
||||
}: ManualPropertiesSectionProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const formSchemas = React.useMemo(() =>
|
||||
schemas.map(schema => ({
|
||||
...schema,
|
||||
tooltip: schema.description,
|
||||
})) as FormSchema[], [schemas])
|
||||
|
||||
return (
|
||||
<>
|
||||
{schemas.length > 0 && (
|
||||
<div className="mb-6">
|
||||
<BaseForm
|
||||
formSchemas={formSchemas}
|
||||
ref={formRef}
|
||||
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
|
||||
formClassName="space-y-4"
|
||||
onChange={onChange}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div className="mb-6">
|
||||
<div className="mb-3 flex items-center gap-2">
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">
|
||||
{t('modal.manual.logs.title', { ns: 'pluginTrigger' })}
|
||||
</div>
|
||||
<div className="h-px flex-1 bg-gradient-to-r from-divider-regular to-transparent" />
|
||||
</div>
|
||||
|
||||
<div className="mb-1 flex items-center justify-center gap-1 rounded-lg bg-background-section p-3">
|
||||
<div className="h-3.5 w-3.5">
|
||||
<RiLoader2Line className="h-full w-full animate-spin" />
|
||||
</div>
|
||||
<div className="system-xs-regular text-text-tertiary">
|
||||
{t('modal.manual.logs.loading', { ns: 'pluginTrigger', pluginName })}
|
||||
</div>
|
||||
</div>
|
||||
<LogViewer logs={logs} />
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
type ConfigurationStepContentProps = {
|
||||
createType: SupportedCreationMethods
|
||||
subscriptionBuilder?: TriggerSubscriptionBuilder
|
||||
subscriptionFormRef: React.RefObject<FormRefObject | null>
|
||||
autoCommonParametersSchema: SchemaItem[]
|
||||
autoCommonParametersFormRef: React.RefObject<FormRefObject | null>
|
||||
manualPropertiesSchema: SchemaItem[]
|
||||
manualPropertiesFormRef: React.RefObject<FormRefObject | null>
|
||||
onManualPropertiesChange: () => void
|
||||
logs: TriggerLogEntity[]
|
||||
pluginId: string
|
||||
pluginName: string
|
||||
provider: string
|
||||
}
|
||||
|
||||
export const ConfigurationStepContent = ({
|
||||
createType,
|
||||
subscriptionBuilder,
|
||||
subscriptionFormRef,
|
||||
autoCommonParametersSchema,
|
||||
autoCommonParametersFormRef,
|
||||
manualPropertiesSchema,
|
||||
manualPropertiesFormRef,
|
||||
onManualPropertiesChange,
|
||||
logs,
|
||||
pluginId,
|
||||
pluginName,
|
||||
provider,
|
||||
}: ConfigurationStepContentProps) => {
|
||||
const isManualType = createType === SupportedCreationMethods.MANUAL
|
||||
|
||||
return (
|
||||
<div className="max-h-[70vh]">
|
||||
<SubscriptionForm
|
||||
subscriptionFormRef={subscriptionFormRef}
|
||||
endpoint={subscriptionBuilder?.endpoint}
|
||||
/>
|
||||
|
||||
{!isManualType && autoCommonParametersSchema.length > 0 && (
|
||||
<AutoParametersForm
|
||||
schemas={autoCommonParametersSchema}
|
||||
formRef={autoCommonParametersFormRef}
|
||||
pluginId={pluginId}
|
||||
provider={provider}
|
||||
credentialId={subscriptionBuilder?.id || ''}
|
||||
/>
|
||||
)}
|
||||
|
||||
{isManualType && (
|
||||
<ManualPropertiesSection
|
||||
schemas={manualPropertiesSchema}
|
||||
formRef={manualPropertiesFormRef}
|
||||
onChange={onManualPropertiesChange}
|
||||
logs={logs}
|
||||
pluginName={pluginName}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,401 @@
|
||||
'use client'
|
||||
import type { SimpleDetail } from '../../../store'
|
||||
import type { SchemaItem } from '../components/modal-steps'
|
||||
import type { FormRefObject } from '@/app/components/base/form/types'
|
||||
import type { TriggerLogEntity, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
|
||||
import type { BuildTriggerSubscriptionPayload } from '@/service/use-triggers'
|
||||
import { debounce } from 'es-toolkit/compat'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { SupportedCreationMethods } from '@/app/components/plugins/types'
|
||||
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
|
||||
import {
|
||||
useBuildTriggerSubscription,
|
||||
useCreateTriggerSubscriptionBuilder,
|
||||
useTriggerSubscriptionBuilderLogs,
|
||||
useUpdateTriggerSubscriptionBuilder,
|
||||
useVerifyAndUpdateTriggerSubscriptionBuilder,
|
||||
} from '@/service/use-triggers'
|
||||
import { parsePluginErrorMessage } from '@/utils/error-parser'
|
||||
import { isPrivateOrLocalAddress } from '@/utils/urlValidation'
|
||||
import { usePluginStore } from '../../../store'
|
||||
import { useSubscriptionList } from '../../use-subscription-list'
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export enum ApiKeyStep {
|
||||
Verify = 'verify',
|
||||
Configuration = 'configuration',
|
||||
}
|
||||
|
||||
export const CREDENTIAL_TYPE_MAP: Record<SupportedCreationMethods, TriggerCredentialTypeEnum> = {
|
||||
[SupportedCreationMethods.APIKEY]: TriggerCredentialTypeEnum.ApiKey,
|
||||
[SupportedCreationMethods.OAUTH]: TriggerCredentialTypeEnum.Oauth2,
|
||||
[SupportedCreationMethods.MANUAL]: TriggerCredentialTypeEnum.Unauthorized,
|
||||
}
|
||||
|
||||
export const MODAL_TITLE_KEY_MAP: Record<
|
||||
SupportedCreationMethods,
|
||||
'modal.apiKey.title' | 'modal.oauth.title' | 'modal.manual.title'
|
||||
> = {
|
||||
[SupportedCreationMethods.APIKEY]: 'modal.apiKey.title',
|
||||
[SupportedCreationMethods.OAUTH]: 'modal.oauth.title',
|
||||
[SupportedCreationMethods.MANUAL]: 'modal.manual.title',
|
||||
}
|
||||
|
||||
type UseCommonModalStateParams = {
|
||||
createType: SupportedCreationMethods
|
||||
builder?: TriggerSubscriptionBuilder
|
||||
onClose: () => void
|
||||
}
|
||||
|
||||
type FormRefs = {
|
||||
manualPropertiesFormRef: React.RefObject<FormRefObject | null>
|
||||
subscriptionFormRef: React.RefObject<FormRefObject | null>
|
||||
autoCommonParametersFormRef: React.RefObject<FormRefObject | null>
|
||||
apiKeyCredentialsFormRef: React.RefObject<FormRefObject | null>
|
||||
}
|
||||
|
||||
type UseCommonModalStateReturn = {
|
||||
// State
|
||||
currentStep: ApiKeyStep
|
||||
subscriptionBuilder: TriggerSubscriptionBuilder | undefined
|
||||
isVerifyingCredentials: boolean
|
||||
isBuilding: boolean
|
||||
|
||||
// Form refs
|
||||
formRefs: FormRefs
|
||||
|
||||
// Computed values
|
||||
detail: SimpleDetail | undefined
|
||||
manualPropertiesSchema: SchemaItem[]
|
||||
autoCommonParametersSchema: SchemaItem[]
|
||||
apiKeyCredentialsSchema: SchemaItem[]
|
||||
logData: { logs: TriggerLogEntity[] } | undefined
|
||||
confirmButtonText: string
|
||||
|
||||
// Handlers
|
||||
handleVerify: () => void
|
||||
handleCreate: () => void
|
||||
handleConfirm: () => void
|
||||
handleManualPropertiesChange: () => void
|
||||
handleApiKeyCredentialsChange: () => void
|
||||
}
|
||||
|
||||
const DEFAULT_FORM_VALUES = { values: {}, isCheckValidated: false }
|
||||
|
||||
// ============================================================================
|
||||
// Hook Implementation
|
||||
// ============================================================================
|
||||
|
||||
export const useCommonModalState = ({
|
||||
createType,
|
||||
builder,
|
||||
onClose,
|
||||
}: UseCommonModalStateParams): UseCommonModalStateReturn => {
|
||||
const { t } = useTranslation()
|
||||
const detail = usePluginStore(state => state.detail)
|
||||
const { refetch } = useSubscriptionList()
|
||||
|
||||
// State
|
||||
const [currentStep, setCurrentStep] = useState<ApiKeyStep>(
|
||||
createType === SupportedCreationMethods.APIKEY ? ApiKeyStep.Verify : ApiKeyStep.Configuration,
|
||||
)
|
||||
const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>(builder)
|
||||
const isInitializedRef = useRef(false)
|
||||
|
||||
// Form refs
|
||||
const manualPropertiesFormRef = useRef<FormRefObject>(null)
|
||||
const subscriptionFormRef = useRef<FormRefObject>(null)
|
||||
const autoCommonParametersFormRef = useRef<FormRefObject>(null)
|
||||
const apiKeyCredentialsFormRef = useRef<FormRefObject>(null)
|
||||
|
||||
// Mutations
|
||||
const { mutate: verifyCredentials, isPending: isVerifyingCredentials } = useVerifyAndUpdateTriggerSubscriptionBuilder()
|
||||
const { mutateAsync: createBuilder } = useCreateTriggerSubscriptionBuilder()
|
||||
const { mutate: buildSubscription, isPending: isBuilding } = useBuildTriggerSubscription()
|
||||
const { mutate: updateBuilder } = useUpdateTriggerSubscriptionBuilder()
|
||||
|
||||
// Schemas
|
||||
const manualPropertiesSchema = detail?.declaration?.trigger?.subscription_schema || []
|
||||
const autoCommonParametersSchema = detail?.declaration.trigger?.subscription_constructor?.parameters || []
|
||||
|
||||
const apiKeyCredentialsSchema = useMemo(() => {
|
||||
const rawSchema = detail?.declaration?.trigger?.subscription_constructor?.credentials_schema || []
|
||||
return rawSchema.map(schema => ({
|
||||
...schema,
|
||||
tooltip: schema.help,
|
||||
}))
|
||||
}, [detail?.declaration?.trigger?.subscription_constructor?.credentials_schema])
|
||||
|
||||
// Log data for manual mode
|
||||
const { data: logData } = useTriggerSubscriptionBuilderLogs(
|
||||
detail?.provider || '',
|
||||
subscriptionBuilder?.id || '',
|
||||
{
|
||||
enabled: createType === SupportedCreationMethods.MANUAL,
|
||||
refetchInterval: 3000,
|
||||
},
|
||||
)
|
||||
|
||||
// Debounced update for manual properties
|
||||
const debouncedUpdate = useMemo(
|
||||
() => debounce((provider: string, builderId: string, properties: Record<string, unknown>) => {
|
||||
updateBuilder(
|
||||
{
|
||||
provider,
|
||||
subscriptionBuilderId: builderId,
|
||||
properties,
|
||||
},
|
||||
{
|
||||
onError: async (error: unknown) => {
|
||||
const errorMessage = await parsePluginErrorMessage(error) || t('modal.errors.updateFailed', { ns: 'pluginTrigger' })
|
||||
console.error('Failed to update subscription builder:', error)
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: errorMessage,
|
||||
})
|
||||
},
|
||||
},
|
||||
)
|
||||
}, 500),
|
||||
[updateBuilder, t],
|
||||
)
|
||||
|
||||
// Initialize builder
|
||||
useEffect(() => {
|
||||
const initializeBuilder = async () => {
|
||||
isInitializedRef.current = true
|
||||
try {
|
||||
const response = await createBuilder({
|
||||
provider: detail?.provider || '',
|
||||
credential_type: CREDENTIAL_TYPE_MAP[createType],
|
||||
})
|
||||
setSubscriptionBuilder(response.subscription_builder)
|
||||
}
|
||||
catch (error) {
|
||||
console.error('createBuilder error:', error)
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('modal.errors.createFailed', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
}
|
||||
}
|
||||
if (!isInitializedRef.current && !subscriptionBuilder && detail?.provider)
|
||||
initializeBuilder()
|
||||
}, [subscriptionBuilder, detail?.provider, createType, createBuilder, t])
|
||||
|
||||
// Cleanup debounced function
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
debouncedUpdate.cancel()
|
||||
}
|
||||
}, [debouncedUpdate])
|
||||
|
||||
// Update endpoint in form when endpoint changes
|
||||
useEffect(() => {
|
||||
if (!subscriptionBuilder?.endpoint || !subscriptionFormRef.current || currentStep !== ApiKeyStep.Configuration)
|
||||
return
|
||||
|
||||
const form = subscriptionFormRef.current.getForm()
|
||||
if (form)
|
||||
form.setFieldValue('callback_url', subscriptionBuilder.endpoint)
|
||||
|
||||
const warnings = isPrivateOrLocalAddress(subscriptionBuilder.endpoint)
|
||||
? [t('modal.form.callbackUrl.privateAddressWarning', { ns: 'pluginTrigger' })]
|
||||
: []
|
||||
|
||||
subscriptionFormRef.current?.setFields([{
|
||||
name: 'callback_url',
|
||||
warnings,
|
||||
}])
|
||||
}, [subscriptionBuilder?.endpoint, currentStep, t])
|
||||
|
||||
// Handle manual properties change
|
||||
const handleManualPropertiesChange = useCallback(() => {
|
||||
if (!subscriptionBuilder || !detail?.provider)
|
||||
return
|
||||
|
||||
const formValues = manualPropertiesFormRef.current?.getFormValues({ needCheckValidatedValues: false })
|
||||
|| { values: {}, isCheckValidated: true }
|
||||
|
||||
debouncedUpdate(detail.provider, subscriptionBuilder.id, formValues.values)
|
||||
}, [subscriptionBuilder, detail?.provider, debouncedUpdate])
|
||||
|
||||
// Handle API key credentials change
|
||||
const handleApiKeyCredentialsChange = useCallback(() => {
|
||||
if (!apiKeyCredentialsSchema.length)
|
||||
return
|
||||
apiKeyCredentialsFormRef.current?.setFields([{
|
||||
name: apiKeyCredentialsSchema[0].name,
|
||||
errors: [],
|
||||
}])
|
||||
}, [apiKeyCredentialsSchema])
|
||||
|
||||
// Handle verify
|
||||
const handleVerify = useCallback(() => {
|
||||
// Guard against uninitialized state
|
||||
if (!detail?.provider || !subscriptionBuilder?.id) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: 'Subscription builder not initialized',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const apiKeyCredentialsFormValues = apiKeyCredentialsFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES
|
||||
const credentials = apiKeyCredentialsFormValues.values
|
||||
|
||||
if (!Object.keys(credentials).length) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: 'Please fill in all required credentials',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyCredentialsFormRef.current?.setFields([{
|
||||
name: Object.keys(credentials)[0],
|
||||
errors: [],
|
||||
}])
|
||||
|
||||
verifyCredentials(
|
||||
{
|
||||
provider: detail.provider,
|
||||
subscriptionBuilderId: subscriptionBuilder.id,
|
||||
credentials,
|
||||
},
|
||||
{
|
||||
onSuccess: () => {
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('modal.apiKey.verify.success', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
setCurrentStep(ApiKeyStep.Configuration)
|
||||
},
|
||||
onError: async (error: unknown) => {
|
||||
const errorMessage = await parsePluginErrorMessage(error) || t('modal.apiKey.verify.error', { ns: 'pluginTrigger' })
|
||||
apiKeyCredentialsFormRef.current?.setFields([{
|
||||
name: Object.keys(credentials)[0],
|
||||
errors: [errorMessage],
|
||||
}])
|
||||
},
|
||||
},
|
||||
)
|
||||
}, [detail?.provider, subscriptionBuilder?.id, verifyCredentials, t])
|
||||
|
||||
// Handle create
|
||||
const handleCreate = useCallback(() => {
|
||||
if (!subscriptionBuilder) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: 'Subscription builder not found',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const subscriptionFormValues = subscriptionFormRef.current?.getFormValues({})
|
||||
if (!subscriptionFormValues?.isCheckValidated)
|
||||
return
|
||||
|
||||
const subscriptionNameValue = subscriptionFormValues?.values?.subscription_name as string
|
||||
|
||||
const params: BuildTriggerSubscriptionPayload = {
|
||||
provider: detail?.provider || '',
|
||||
subscriptionBuilderId: subscriptionBuilder.id,
|
||||
name: subscriptionNameValue,
|
||||
}
|
||||
|
||||
if (createType !== SupportedCreationMethods.MANUAL) {
|
||||
if (autoCommonParametersSchema.length > 0) {
|
||||
const autoCommonParametersFormValues = autoCommonParametersFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES
|
||||
if (!autoCommonParametersFormValues?.isCheckValidated)
|
||||
return
|
||||
params.parameters = autoCommonParametersFormValues.values
|
||||
}
|
||||
}
|
||||
else if (manualPropertiesSchema.length > 0) {
|
||||
const manualFormValues = manualPropertiesFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES
|
||||
if (!manualFormValues?.isCheckValidated)
|
||||
return
|
||||
}
|
||||
|
||||
buildSubscription(
|
||||
params,
|
||||
{
|
||||
onSuccess: () => {
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('subscription.createSuccess', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
onClose()
|
||||
refetch?.()
|
||||
},
|
||||
onError: async (error: unknown) => {
|
||||
const errorMessage = await parsePluginErrorMessage(error) || t('subscription.createFailed', { ns: 'pluginTrigger' })
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: errorMessage,
|
||||
})
|
||||
},
|
||||
},
|
||||
)
|
||||
}, [
|
||||
subscriptionBuilder,
|
||||
detail?.provider,
|
||||
createType,
|
||||
autoCommonParametersSchema.length,
|
||||
manualPropertiesSchema.length,
|
||||
buildSubscription,
|
||||
onClose,
|
||||
refetch,
|
||||
t,
|
||||
])
|
||||
|
||||
// Handle confirm (dispatch based on step)
|
||||
const handleConfirm = useCallback(() => {
|
||||
if (currentStep === ApiKeyStep.Verify)
|
||||
handleVerify()
|
||||
else
|
||||
handleCreate()
|
||||
}, [currentStep, handleVerify, handleCreate])
|
||||
|
||||
// Confirm button text
|
||||
const confirmButtonText = useMemo(() => {
|
||||
if (currentStep === ApiKeyStep.Verify) {
|
||||
return isVerifyingCredentials
|
||||
? t('modal.common.verifying', { ns: 'pluginTrigger' })
|
||||
: t('modal.common.verify', { ns: 'pluginTrigger' })
|
||||
}
|
||||
return isBuilding
|
||||
? t('modal.common.creating', { ns: 'pluginTrigger' })
|
||||
: t('modal.common.create', { ns: 'pluginTrigger' })
|
||||
}, [currentStep, isVerifyingCredentials, isBuilding, t])
|
||||
|
||||
return {
|
||||
currentStep,
|
||||
subscriptionBuilder,
|
||||
isVerifyingCredentials,
|
||||
isBuilding,
|
||||
formRefs: {
|
||||
manualPropertiesFormRef,
|
||||
subscriptionFormRef,
|
||||
autoCommonParametersFormRef,
|
||||
apiKeyCredentialsFormRef,
|
||||
},
|
||||
detail,
|
||||
manualPropertiesSchema,
|
||||
autoCommonParametersSchema,
|
||||
apiKeyCredentialsSchema,
|
||||
logData,
|
||||
confirmButtonText,
|
||||
handleVerify,
|
||||
handleCreate,
|
||||
handleConfirm,
|
||||
handleManualPropertiesChange,
|
||||
handleApiKeyCredentialsChange,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,719 @@
|
||||
import type { TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
|
||||
import {
|
||||
AuthorizationStatusEnum,
|
||||
ClientTypeEnum,
|
||||
getErrorMessage,
|
||||
useOAuthClientState,
|
||||
} from './use-oauth-client-state'
|
||||
|
||||
// ============================================================================
|
||||
// Mock Factory Functions
|
||||
// ============================================================================
|
||||
|
||||
function createMockOAuthConfig(overrides: Partial<TriggerOAuthConfig> = {}): TriggerOAuthConfig {
|
||||
return {
|
||||
configured: true,
|
||||
custom_configured: false,
|
||||
custom_enabled: false,
|
||||
system_configured: true,
|
||||
redirect_uri: 'https://example.com/oauth/callback',
|
||||
params: {
|
||||
client_id: 'default-client-id',
|
||||
client_secret: 'default-client-secret',
|
||||
},
|
||||
oauth_client_schema: [
|
||||
{ name: 'client_id', type: 'text-input' as unknown, required: true, label: { 'en-US': 'Client ID' } as unknown },
|
||||
{ name: 'client_secret', type: 'secret-input' as unknown, required: true, label: { 'en-US': 'Client Secret' } as unknown },
|
||||
] as TriggerOAuthConfig['oauth_client_schema'],
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
function createMockSubscriptionBuilder(overrides: Partial<TriggerSubscriptionBuilder> = {}): TriggerSubscriptionBuilder {
|
||||
return {
|
||||
id: 'builder-123',
|
||||
name: 'Test Builder',
|
||||
provider: 'test-provider',
|
||||
credential_type: TriggerCredentialTypeEnum.Oauth2,
|
||||
credentials: {},
|
||||
endpoint: 'https://example.com/callback',
|
||||
parameters: {},
|
||||
properties: {},
|
||||
workflows_in_use: 0,
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Setup
|
||||
// ============================================================================
|
||||
|
||||
const mockInitiateOAuth = vi.fn()
|
||||
const mockVerifyBuilder = vi.fn()
|
||||
const mockConfigureOAuth = vi.fn()
|
||||
const mockDeleteOAuth = vi.fn()
|
||||
|
||||
vi.mock('@/service/use-triggers', () => ({
|
||||
useInitiateTriggerOAuth: () => ({
|
||||
mutate: mockInitiateOAuth,
|
||||
}),
|
||||
useVerifyAndUpdateTriggerSubscriptionBuilder: () => ({
|
||||
mutate: mockVerifyBuilder,
|
||||
}),
|
||||
useConfigureTriggerOAuth: () => ({
|
||||
mutate: mockConfigureOAuth,
|
||||
}),
|
||||
useDeleteTriggerOAuth: () => ({
|
||||
mutate: mockDeleteOAuth,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockOpenOAuthPopup = vi.fn()
|
||||
vi.mock('@/hooks/use-oauth', () => ({
|
||||
openOAuthPopup: (url: string, callback: (data: unknown) => void) => mockOpenOAuthPopup(url, callback),
|
||||
}))
|
||||
|
||||
const mockToastNotify = vi.fn()
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
notify: (params: unknown) => mockToastNotify(params),
|
||||
},
|
||||
}))
|
||||
|
||||
// ============================================================================
|
||||
// Test Suites
|
||||
// ============================================================================
|
||||
|
||||
describe('getErrorMessage', () => {
|
||||
it('should extract message from Error instance', () => {
|
||||
const error = new Error('Test error message')
|
||||
expect(getErrorMessage(error, 'fallback')).toBe('Test error message')
|
||||
})
|
||||
|
||||
it('should extract message from object with message property', () => {
|
||||
const error = { message: 'Object error message' }
|
||||
expect(getErrorMessage(error, 'fallback')).toBe('Object error message')
|
||||
})
|
||||
|
||||
it('should return fallback when error is empty object', () => {
|
||||
expect(getErrorMessage({}, 'fallback')).toBe('fallback')
|
||||
})
|
||||
|
||||
it('should return fallback when error.message is not a string', () => {
|
||||
expect(getErrorMessage({ message: 123 }, 'fallback')).toBe('fallback')
|
||||
})
|
||||
|
||||
it('should return fallback when error.message is empty string', () => {
|
||||
expect(getErrorMessage({ message: '' }, 'fallback')).toBe('fallback')
|
||||
})
|
||||
|
||||
it('should return fallback when error is null', () => {
|
||||
expect(getErrorMessage(null, 'fallback')).toBe('fallback')
|
||||
})
|
||||
|
||||
it('should return fallback when error is undefined', () => {
|
||||
expect(getErrorMessage(undefined, 'fallback')).toBe('fallback')
|
||||
})
|
||||
|
||||
it('should return fallback when error is a primitive', () => {
|
||||
expect(getErrorMessage('string error', 'fallback')).toBe('fallback')
|
||||
expect(getErrorMessage(123, 'fallback')).toBe('fallback')
|
||||
})
|
||||
})
|
||||
|
||||
describe('useOAuthClientState', () => {
|
||||
const defaultParams = {
|
||||
oauthConfig: createMockOAuthConfig(),
|
||||
providerName: 'test-provider',
|
||||
onClose: vi.fn(),
|
||||
showOAuthCreateModal: vi.fn(),
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Initial State', () => {
|
||||
it('should default to Default client type when system_configured is true', () => {
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
expect(result.current.clientType).toBe(ClientTypeEnum.Default)
|
||||
})
|
||||
|
||||
it('should default to Custom client type when system_configured is false', () => {
|
||||
const config = createMockOAuthConfig({ system_configured: false })
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
oauthConfig: config,
|
||||
}))
|
||||
|
||||
expect(result.current.clientType).toBe(ClientTypeEnum.Custom)
|
||||
})
|
||||
|
||||
it('should have undefined authorizationStatus initially', () => {
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
expect(result.current.authorizationStatus).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should provide clientFormRef', () => {
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
expect(result.current.clientFormRef).toBeDefined()
|
||||
expect(result.current.clientFormRef.current).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('OAuth Client Schema', () => {
|
||||
it('should compute schema with default values from params', () => {
|
||||
const config = createMockOAuthConfig({
|
||||
params: {
|
||||
client_id: 'my-client-id',
|
||||
client_secret: 'my-secret',
|
||||
},
|
||||
})
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
oauthConfig: config,
|
||||
}))
|
||||
|
||||
expect(result.current.oauthClientSchema).toHaveLength(2)
|
||||
expect(result.current.oauthClientSchema[0].default).toBe('my-client-id')
|
||||
expect(result.current.oauthClientSchema[1].default).toBe('my-secret')
|
||||
})
|
||||
|
||||
it('should return empty array when oauth_client_schema is empty', () => {
|
||||
const config = createMockOAuthConfig({
|
||||
oauth_client_schema: [],
|
||||
})
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
oauthConfig: config,
|
||||
}))
|
||||
|
||||
expect(result.current.oauthClientSchema).toEqual([])
|
||||
})
|
||||
|
||||
it('should return empty array when params is undefined', () => {
|
||||
const config = createMockOAuthConfig({
|
||||
params: undefined as unknown as TriggerOAuthConfig['params'],
|
||||
})
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
oauthConfig: config,
|
||||
}))
|
||||
|
||||
expect(result.current.oauthClientSchema).toEqual([])
|
||||
})
|
||||
|
||||
it('should preserve original schema default when param key not found', () => {
|
||||
const config = createMockOAuthConfig({
|
||||
params: {
|
||||
client_id: 'only-client-id',
|
||||
client_secret: '', // empty
|
||||
},
|
||||
oauth_client_schema: [
|
||||
{ name: 'client_id', type: 'text-input' as unknown, required: true, label: {} as unknown, default: 'original-default' },
|
||||
{ name: 'extra_field', type: 'text-input' as unknown, required: false, label: {} as unknown, default: 'extra-default' },
|
||||
] as TriggerOAuthConfig['oauth_client_schema'],
|
||||
})
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
oauthConfig: config,
|
||||
}))
|
||||
|
||||
// client_id should be overridden
|
||||
expect(result.current.oauthClientSchema[0].default).toBe('only-client-id')
|
||||
// extra_field should keep original default since key not in params
|
||||
expect(result.current.oauthClientSchema[1].default).toBe('extra-default')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Confirm Button Text', () => {
|
||||
it('should show saveAndAuth text by default', () => {
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
expect(result.current.confirmButtonText).toBe('plugin.auth.saveAndAuth')
|
||||
})
|
||||
|
||||
it('should show authorizing text when status is Pending', async () => {
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
mockInitiateOAuth.mockImplementation(() => {
|
||||
// Don't resolve - stays pending
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.confirmButtonText).toBe('pluginTrigger.modal.common.authorizing')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('setClientType', () => {
|
||||
it('should update client type when called', () => {
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.setClientType(ClientTypeEnum.Custom)
|
||||
})
|
||||
|
||||
expect(result.current.clientType).toBe(ClientTypeEnum.Custom)
|
||||
})
|
||||
|
||||
it('should toggle between client types', () => {
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.setClientType(ClientTypeEnum.Custom)
|
||||
})
|
||||
expect(result.current.clientType).toBe(ClientTypeEnum.Custom)
|
||||
|
||||
act(() => {
|
||||
result.current.setClientType(ClientTypeEnum.Default)
|
||||
})
|
||||
expect(result.current.clientType).toBe(ClientTypeEnum.Default)
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleRemove', () => {
|
||||
it('should call deleteOAuth with provider name', () => {
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleRemove()
|
||||
})
|
||||
|
||||
expect(mockDeleteOAuth).toHaveBeenCalledWith(
|
||||
'test-provider',
|
||||
expect.any(Object),
|
||||
)
|
||||
})
|
||||
|
||||
it('should call onClose and show success toast on success', () => {
|
||||
mockDeleteOAuth.mockImplementation((provider, { onSuccess }) => onSuccess())
|
||||
|
||||
const onClose = vi.fn()
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
onClose,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.handleRemove()
|
||||
})
|
||||
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
message: 'pluginTrigger.modal.oauth.remove.success',
|
||||
})
|
||||
})
|
||||
|
||||
it('should show error toast with error message on failure', () => {
|
||||
mockDeleteOAuth.mockImplementation((provider, { onError }) => {
|
||||
onError(new Error('Delete failed'))
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleRemove()
|
||||
})
|
||||
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'Delete failed',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleSave', () => {
|
||||
it('should call configureOAuth with enabled: false for Default type', () => {
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(false)
|
||||
})
|
||||
|
||||
expect(mockConfigureOAuth).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: 'test-provider',
|
||||
enabled: false,
|
||||
}),
|
||||
expect.any(Object),
|
||||
)
|
||||
})
|
||||
|
||||
it('should call configureOAuth with enabled: true for Custom type', () => {
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
|
||||
const config = createMockOAuthConfig({ system_configured: false })
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
oauthConfig: config,
|
||||
}))
|
||||
|
||||
// Mock the form ref
|
||||
const mockFormRef = {
|
||||
getFormValues: () => ({
|
||||
values: { client_id: 'new-id', client_secret: 'new-secret' },
|
||||
isCheckValidated: true,
|
||||
}),
|
||||
}
|
||||
// @ts-expect-error - mocking ref
|
||||
result.current.clientFormRef.current = mockFormRef
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(false)
|
||||
})
|
||||
|
||||
expect(mockConfigureOAuth).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
enabled: true,
|
||||
}),
|
||||
expect.any(Object),
|
||||
)
|
||||
})
|
||||
|
||||
it('should show success toast and call onClose when needAuth is false', () => {
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
const onClose = vi.fn()
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
onClose,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(false)
|
||||
})
|
||||
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
message: 'pluginTrigger.modal.oauth.save.success',
|
||||
})
|
||||
})
|
||||
|
||||
it('should trigger authorization when needAuth is true', () => {
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
|
||||
onSuccess({
|
||||
authorization_url: 'https://oauth.example.com/authorize',
|
||||
subscription_builder: createMockSubscriptionBuilder(),
|
||||
})
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(true)
|
||||
})
|
||||
|
||||
expect(mockInitiateOAuth).toHaveBeenCalledWith(
|
||||
'test-provider',
|
||||
expect.any(Object),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleAuthorization', () => {
|
||||
it('should set status to Pending and call initiateOAuth', () => {
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
mockInitiateOAuth.mockImplementation(() => {})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(true)
|
||||
})
|
||||
|
||||
expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Pending)
|
||||
expect(mockInitiateOAuth).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should open OAuth popup on success', () => {
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
|
||||
onSuccess({
|
||||
authorization_url: 'https://oauth.example.com/authorize',
|
||||
subscription_builder: createMockSubscriptionBuilder(),
|
||||
})
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(true)
|
||||
})
|
||||
|
||||
expect(mockOpenOAuthPopup).toHaveBeenCalledWith(
|
||||
'https://oauth.example.com/authorize',
|
||||
expect.any(Function),
|
||||
)
|
||||
})
|
||||
|
||||
it('should set status to Failed and show error toast on error', () => {
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
mockInitiateOAuth.mockImplementation((provider, { onError }) => {
|
||||
onError(new Error('OAuth failed'))
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(true)
|
||||
})
|
||||
|
||||
expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Failed)
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'pluginTrigger.modal.oauth.authorization.authFailed',
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onClose and showOAuthCreateModal on callback success', () => {
|
||||
const onClose = vi.fn()
|
||||
const showOAuthCreateModal = vi.fn()
|
||||
const builder = createMockSubscriptionBuilder()
|
||||
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
|
||||
onSuccess({
|
||||
authorization_url: 'https://oauth.example.com/authorize',
|
||||
subscription_builder: builder,
|
||||
})
|
||||
})
|
||||
mockOpenOAuthPopup.mockImplementation((url, callback) => {
|
||||
callback({ success: true })
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
onClose,
|
||||
showOAuthCreateModal,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(true)
|
||||
})
|
||||
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
expect(showOAuthCreateModal).toHaveBeenCalledWith(builder)
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
message: 'pluginTrigger.modal.oauth.authorization.authSuccess',
|
||||
})
|
||||
})
|
||||
|
||||
it('should not call callbacks when OAuth callback returns falsy', () => {
|
||||
const onClose = vi.fn()
|
||||
const showOAuthCreateModal = vi.fn()
|
||||
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
|
||||
onSuccess({
|
||||
authorization_url: 'https://oauth.example.com/authorize',
|
||||
subscription_builder: createMockSubscriptionBuilder(),
|
||||
})
|
||||
})
|
||||
mockOpenOAuthPopup.mockImplementation((url, callback) => {
|
||||
callback(null)
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
onClose,
|
||||
showOAuthCreateModal,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(true)
|
||||
})
|
||||
|
||||
expect(onClose).not.toHaveBeenCalled()
|
||||
expect(showOAuthCreateModal).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Polling Effect', () => {
|
||||
it('should start polling after authorization starts', async () => {
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true })
|
||||
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
|
||||
onSuccess({
|
||||
authorization_url: 'https://oauth.example.com/authorize',
|
||||
subscription_builder: createMockSubscriptionBuilder(),
|
||||
})
|
||||
})
|
||||
mockVerifyBuilder.mockImplementation((params, { onSuccess }) => {
|
||||
onSuccess({ verified: false })
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(true)
|
||||
})
|
||||
|
||||
// Advance timer to trigger first poll
|
||||
await act(async () => {
|
||||
vi.advanceTimersByTime(3000)
|
||||
})
|
||||
|
||||
expect(mockVerifyBuilder).toHaveBeenCalled()
|
||||
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should set status to Success when verified', async () => {
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true })
|
||||
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
|
||||
onSuccess({
|
||||
authorization_url: 'https://oauth.example.com/authorize',
|
||||
subscription_builder: createMockSubscriptionBuilder(),
|
||||
})
|
||||
})
|
||||
mockVerifyBuilder.mockImplementation((params, { onSuccess }) => {
|
||||
onSuccess({ verified: true })
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(true)
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
vi.advanceTimersByTime(3000)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Success)
|
||||
})
|
||||
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should continue polling on error', async () => {
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true })
|
||||
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
|
||||
onSuccess({
|
||||
authorization_url: 'https://oauth.example.com/authorize',
|
||||
subscription_builder: createMockSubscriptionBuilder(),
|
||||
})
|
||||
})
|
||||
mockVerifyBuilder.mockImplementation((params, { onError }) => {
|
||||
onError(new Error('Verify failed'))
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(true)
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
vi.advanceTimersByTime(3000)
|
||||
})
|
||||
|
||||
expect(mockVerifyBuilder).toHaveBeenCalled()
|
||||
// Status should still be Pending
|
||||
expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Pending)
|
||||
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should stop polling when verified', async () => {
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true })
|
||||
|
||||
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
|
||||
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
|
||||
onSuccess({
|
||||
authorization_url: 'https://oauth.example.com/authorize',
|
||||
subscription_builder: createMockSubscriptionBuilder(),
|
||||
})
|
||||
})
|
||||
mockVerifyBuilder.mockImplementation((params, { onSuccess }) => {
|
||||
onSuccess({ verified: true })
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useOAuthClientState(defaultParams))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSave(true)
|
||||
})
|
||||
|
||||
// First poll - should verify
|
||||
await act(async () => {
|
||||
vi.advanceTimersByTime(3000)
|
||||
})
|
||||
|
||||
expect(mockVerifyBuilder).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Second poll - should not happen as interval is cleared
|
||||
await act(async () => {
|
||||
vi.advanceTimersByTime(3000)
|
||||
})
|
||||
|
||||
// Still only 1 call because polling stopped
|
||||
expect(mockVerifyBuilder).toHaveBeenCalledTimes(1)
|
||||
|
||||
vi.useRealTimers()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle undefined oauthConfig', () => {
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
oauthConfig: undefined,
|
||||
}))
|
||||
|
||||
expect(result.current.clientType).toBe(ClientTypeEnum.Custom)
|
||||
expect(result.current.oauthClientSchema).toEqual([])
|
||||
})
|
||||
|
||||
it('should handle empty providerName', () => {
|
||||
const { result } = renderHook(() => useOAuthClientState({
|
||||
...defaultParams,
|
||||
providerName: '',
|
||||
}))
|
||||
|
||||
// Should not throw
|
||||
expect(result.current.clientType).toBe(ClientTypeEnum.Default)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Enum Exports', () => {
|
||||
it('should export AuthorizationStatusEnum', () => {
|
||||
expect(AuthorizationStatusEnum.Pending).toBe('pending')
|
||||
expect(AuthorizationStatusEnum.Success).toBe('success')
|
||||
expect(AuthorizationStatusEnum.Failed).toBe('failed')
|
||||
})
|
||||
|
||||
it('should export ClientTypeEnum', () => {
|
||||
expect(ClientTypeEnum.Default).toBe('default')
|
||||
expect(ClientTypeEnum.Custom).toBe('custom')
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,241 @@
|
||||
'use client'
|
||||
import type { FormRefObject } from '@/app/components/base/form/types'
|
||||
import type { TriggerOAuthClientParams, TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
|
||||
import type { ConfigureTriggerOAuthPayload } from '@/service/use-triggers'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { openOAuthPopup } from '@/hooks/use-oauth'
|
||||
import {
|
||||
useConfigureTriggerOAuth,
|
||||
useDeleteTriggerOAuth,
|
||||
useInitiateTriggerOAuth,
|
||||
useVerifyAndUpdateTriggerSubscriptionBuilder,
|
||||
} from '@/service/use-triggers'
|
||||
|
||||
export enum AuthorizationStatusEnum {
|
||||
Pending = 'pending',
|
||||
Success = 'success',
|
||||
Failed = 'failed',
|
||||
}
|
||||
|
||||
export enum ClientTypeEnum {
|
||||
Default = 'default',
|
||||
Custom = 'custom',
|
||||
}
|
||||
|
||||
const POLL_INTERVAL_MS = 3000
|
||||
|
||||
// Extract error message from various error formats
|
||||
export const getErrorMessage = (error: unknown, fallback: string): string => {
|
||||
if (error instanceof Error && error.message)
|
||||
return error.message
|
||||
if (typeof error === 'object' && error && 'message' in error) {
|
||||
const message = (error as { message?: string }).message
|
||||
if (typeof message === 'string' && message)
|
||||
return message
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
type UseOAuthClientStateParams = {
|
||||
oauthConfig?: TriggerOAuthConfig
|
||||
providerName: string
|
||||
onClose: () => void
|
||||
showOAuthCreateModal: (builder: TriggerSubscriptionBuilder) => void
|
||||
}
|
||||
|
||||
type UseOAuthClientStateReturn = {
|
||||
// State
|
||||
clientType: ClientTypeEnum
|
||||
setClientType: (type: ClientTypeEnum) => void
|
||||
authorizationStatus: AuthorizationStatusEnum | undefined
|
||||
|
||||
// Refs
|
||||
clientFormRef: React.RefObject<FormRefObject | null>
|
||||
|
||||
// Computed values
|
||||
oauthClientSchema: TriggerOAuthConfig['oauth_client_schema']
|
||||
confirmButtonText: string
|
||||
|
||||
// Handlers
|
||||
handleAuthorization: () => void
|
||||
handleRemove: () => void
|
||||
handleSave: (needAuth: boolean) => void
|
||||
}
|
||||
|
||||
export const useOAuthClientState = ({
|
||||
oauthConfig,
|
||||
providerName,
|
||||
onClose,
|
||||
showOAuthCreateModal,
|
||||
}: UseOAuthClientStateParams): UseOAuthClientStateReturn => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
// State management
|
||||
const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>()
|
||||
const [authorizationStatus, setAuthorizationStatus] = useState<AuthorizationStatusEnum>()
|
||||
const [clientType, setClientType] = useState<ClientTypeEnum>(
|
||||
oauthConfig?.system_configured ? ClientTypeEnum.Default : ClientTypeEnum.Custom,
|
||||
)
|
||||
|
||||
const clientFormRef = useRef<FormRefObject>(null)
|
||||
|
||||
// Mutations
|
||||
const { mutate: initiateOAuth } = useInitiateTriggerOAuth()
|
||||
const { mutate: verifyBuilder } = useVerifyAndUpdateTriggerSubscriptionBuilder()
|
||||
const { mutate: configureOAuth } = useConfigureTriggerOAuth()
|
||||
const { mutate: deleteOAuth } = useDeleteTriggerOAuth()
|
||||
|
||||
// Compute OAuth client schema with default values
|
||||
const oauthClientSchema = useMemo(() => {
|
||||
const { oauth_client_schema, params } = oauthConfig || {}
|
||||
if (!oauth_client_schema?.length || !params)
|
||||
return []
|
||||
|
||||
const paramKeys = Object.keys(params)
|
||||
return oauth_client_schema.map(schema => ({
|
||||
...schema,
|
||||
default: paramKeys.includes(schema.name) ? params[schema.name] : schema.default,
|
||||
}))
|
||||
}, [oauthConfig])
|
||||
|
||||
// Compute confirm button text based on authorization status
|
||||
const confirmButtonText = useMemo(() => {
|
||||
if (authorizationStatus === AuthorizationStatusEnum.Pending)
|
||||
return t('modal.common.authorizing', { ns: 'pluginTrigger' })
|
||||
if (authorizationStatus === AuthorizationStatusEnum.Success)
|
||||
return t('modal.oauth.authorization.waitingJump', { ns: 'pluginTrigger' })
|
||||
return t('auth.saveAndAuth', { ns: 'plugin' })
|
||||
}, [authorizationStatus, t])
|
||||
|
||||
// Authorization handler
|
||||
const handleAuthorization = useCallback(() => {
|
||||
setAuthorizationStatus(AuthorizationStatusEnum.Pending)
|
||||
initiateOAuth(providerName, {
|
||||
onSuccess: (response) => {
|
||||
setSubscriptionBuilder(response.subscription_builder)
|
||||
openOAuthPopup(response.authorization_url, (callbackData) => {
|
||||
if (!callbackData)
|
||||
return
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
onClose()
|
||||
showOAuthCreateModal(response.subscription_builder)
|
||||
})
|
||||
},
|
||||
onError: () => {
|
||||
setAuthorizationStatus(AuthorizationStatusEnum.Failed)
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
},
|
||||
})
|
||||
}, [providerName, initiateOAuth, onClose, showOAuthCreateModal, t])
|
||||
|
||||
// Remove handler
|
||||
const handleRemove = useCallback(() => {
|
||||
deleteOAuth(providerName, {
|
||||
onSuccess: () => {
|
||||
onClose()
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('modal.oauth.remove.success', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
},
|
||||
onError: (error: unknown) => {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: getErrorMessage(error, t('modal.oauth.remove.failed', { ns: 'pluginTrigger' })),
|
||||
})
|
||||
},
|
||||
})
|
||||
}, [providerName, deleteOAuth, onClose, t])
|
||||
|
||||
// Save handler
|
||||
const handleSave = useCallback((needAuth: boolean) => {
|
||||
const isCustom = clientType === ClientTypeEnum.Custom
|
||||
const params: ConfigureTriggerOAuthPayload = {
|
||||
provider: providerName,
|
||||
enabled: isCustom,
|
||||
}
|
||||
|
||||
if (isCustom && oauthClientSchema?.length) {
|
||||
const clientFormValues = clientFormRef.current?.getFormValues({}) as {
|
||||
values: TriggerOAuthClientParams
|
||||
isCheckValidated: boolean
|
||||
} | undefined
|
||||
// Handle missing ref or form values
|
||||
if (!clientFormValues || !clientFormValues.isCheckValidated)
|
||||
return
|
||||
const clientParams = { ...clientFormValues.values }
|
||||
// Preserve hidden values if unchanged
|
||||
if (clientParams.client_id === oauthConfig?.params.client_id)
|
||||
clientParams.client_id = '[__HIDDEN__]'
|
||||
if (clientParams.client_secret === oauthConfig?.params.client_secret)
|
||||
clientParams.client_secret = '[__HIDDEN__]'
|
||||
params.client_params = clientParams
|
||||
}
|
||||
|
||||
configureOAuth(params, {
|
||||
onSuccess: () => {
|
||||
if (needAuth) {
|
||||
handleAuthorization()
|
||||
return
|
||||
}
|
||||
onClose()
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('modal.oauth.save.success', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
},
|
||||
})
|
||||
}, [clientType, providerName, oauthClientSchema, oauthConfig?.params, configureOAuth, handleAuthorization, onClose, t])
|
||||
|
||||
// Polling effect for authorization verification
|
||||
useEffect(() => {
|
||||
const shouldPoll = providerName
|
||||
&& subscriptionBuilder
|
||||
&& authorizationStatus === AuthorizationStatusEnum.Pending
|
||||
|
||||
if (!shouldPoll)
|
||||
return
|
||||
|
||||
const pollInterval = setInterval(() => {
|
||||
verifyBuilder(
|
||||
{
|
||||
provider: providerName,
|
||||
subscriptionBuilderId: subscriptionBuilder.id,
|
||||
},
|
||||
{
|
||||
onSuccess: (response) => {
|
||||
if (response.verified) {
|
||||
setAuthorizationStatus(AuthorizationStatusEnum.Success)
|
||||
clearInterval(pollInterval)
|
||||
}
|
||||
},
|
||||
onError: () => {
|
||||
// Continue polling on error - auth might still be in progress
|
||||
},
|
||||
},
|
||||
)
|
||||
}, POLL_INTERVAL_MS)
|
||||
|
||||
return () => clearInterval(pollInterval)
|
||||
}, [subscriptionBuilder, authorizationStatus, verifyBuilder, providerName])
|
||||
|
||||
return {
|
||||
clientType,
|
||||
setClientType,
|
||||
authorizationStatus,
|
||||
clientFormRef,
|
||||
oauthClientSchema,
|
||||
confirmButtonText,
|
||||
handleAuthorization,
|
||||
handleRemove,
|
||||
handleSave,
|
||||
}
|
||||
}
|
||||
@@ -6,9 +6,6 @@ import { SupportedCreationMethods } from '@/app/components/plugins/types'
|
||||
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
|
||||
import { CreateButtonType, CreateSubscriptionButton, DEFAULT_METHOD } from './index'
|
||||
|
||||
// ==================== Mock Setup ====================
|
||||
|
||||
// Mock shared state for portal
|
||||
let mockPortalOpenState = false
|
||||
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
@@ -36,21 +33,18 @@ vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock Toast
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
notify: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock zustand store
|
||||
let mockStoreDetail: SimpleDetail | undefined
|
||||
vi.mock('../../store', () => ({
|
||||
usePluginStore: (selector: (state: { detail: SimpleDetail | undefined }) => SimpleDetail | undefined) =>
|
||||
selector({ detail: mockStoreDetail }),
|
||||
}))
|
||||
|
||||
// Mock subscription list hook
|
||||
const mockSubscriptions: TriggerSubscription[] = []
|
||||
const mockRefetch = vi.fn()
|
||||
vi.mock('../use-subscription-list', () => ({
|
||||
@@ -60,7 +54,6 @@ vi.mock('../use-subscription-list', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock trigger service hooks
|
||||
let mockProviderInfo: { data: TriggerProviderApiEntity | undefined } = { data: undefined }
|
||||
let mockOAuthConfig: { data: TriggerOAuthConfig | undefined, refetch: () => void } = { data: undefined, refetch: vi.fn() }
|
||||
const mockInitiateOAuth = vi.fn()
|
||||
@@ -73,14 +66,12 @@ vi.mock('@/service/use-triggers', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock OAuth popup
|
||||
vi.mock('@/hooks/use-oauth', () => ({
|
||||
openOAuthPopup: vi.fn((url: string, callback: (data?: unknown) => void) => {
|
||||
callback({ success: true, subscriptionId: 'test-subscription' })
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock child modals
|
||||
vi.mock('./common-modal', () => ({
|
||||
CommonCreateModal: ({ createType, onClose, builder }: {
|
||||
createType: SupportedCreationMethods
|
||||
@@ -128,7 +119,6 @@ vi.mock('./oauth-client', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
// Mock CustomSelect
|
||||
vi.mock('@/app/components/base/select/custom', () => ({
|
||||
default: ({ options, value, onChange, CustomTrigger, CustomOption, containerProps }: {
|
||||
options: Array<{ value: string, label: string, show: boolean, extra?: React.ReactNode, tag?: React.ReactNode }>
|
||||
@@ -160,11 +150,6 @@ vi.mock('@/app/components/base/select/custom', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
// ==================== Test Utilities ====================
|
||||
|
||||
/**
|
||||
* Factory function to create a TriggerProviderApiEntity with defaults
|
||||
*/
|
||||
const createProviderInfo = (overrides: Partial<TriggerProviderApiEntity> = {}): TriggerProviderApiEntity => ({
|
||||
author: 'test-author',
|
||||
name: 'test-provider',
|
||||
@@ -179,9 +164,6 @@ const createProviderInfo = (overrides: Partial<TriggerProviderApiEntity> = {}):
|
||||
...overrides,
|
||||
})
|
||||
|
||||
/**
|
||||
* Factory function to create a TriggerOAuthConfig with defaults
|
||||
*/
|
||||
const createOAuthConfig = (overrides: Partial<TriggerOAuthConfig> = {}): TriggerOAuthConfig => ({
|
||||
configured: false,
|
||||
custom_configured: false,
|
||||
@@ -196,9 +178,6 @@ const createOAuthConfig = (overrides: Partial<TriggerOAuthConfig> = {}): Trigger
|
||||
...overrides,
|
||||
})
|
||||
|
||||
/**
|
||||
* Factory function to create a SimpleDetail with defaults
|
||||
*/
|
||||
const createStoreDetail = (overrides: Partial<SimpleDetail> = {}): SimpleDetail => ({
|
||||
plugin_id: 'test-plugin',
|
||||
name: 'Test Plugin',
|
||||
@@ -209,9 +188,6 @@ const createStoreDetail = (overrides: Partial<SimpleDetail> = {}): SimpleDetail
|
||||
...overrides,
|
||||
})
|
||||
|
||||
/**
|
||||
* Factory function to create a TriggerSubscription with defaults
|
||||
*/
|
||||
const createSubscription = (overrides: Partial<TriggerSubscription> = {}): TriggerSubscription => ({
|
||||
id: 'test-subscription',
|
||||
name: 'Test Subscription',
|
||||
@@ -225,16 +201,10 @@ const createSubscription = (overrides: Partial<TriggerSubscription> = {}): Trigg
|
||||
...overrides,
|
||||
})
|
||||
|
||||
/**
|
||||
* Factory function to create default props
|
||||
*/
|
||||
const createDefaultProps = (overrides: Partial<Parameters<typeof CreateSubscriptionButton>[0]> = {}) => ({
|
||||
...overrides,
|
||||
})
|
||||
|
||||
/**
|
||||
* Helper to set up mock data for testing
|
||||
*/
|
||||
const setupMocks = (config: {
|
||||
providerInfo?: TriggerProviderApiEntity
|
||||
oauthConfig?: TriggerOAuthConfig
|
||||
@@ -249,8 +219,6 @@ const setupMocks = (config: {
|
||||
mockSubscriptions.push(...config.subscriptions)
|
||||
}
|
||||
|
||||
// ==================== Tests ====================
|
||||
|
||||
describe('CreateSubscriptionButton', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -258,7 +226,6 @@ describe('CreateSubscriptionButton', () => {
|
||||
setupMocks()
|
||||
})
|
||||
|
||||
// ==================== Rendering Tests ====================
|
||||
describe('Rendering', () => {
|
||||
it('should render null when supportedMethods is empty', () => {
|
||||
// Arrange
|
||||
@@ -322,7 +289,6 @@ describe('CreateSubscriptionButton', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== Props Testing ====================
|
||||
describe('Props', () => {
|
||||
it('should apply default buttonType as FULL_BUTTON', () => {
|
||||
// Arrange
|
||||
@@ -355,7 +321,6 @@ describe('CreateSubscriptionButton', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== State Management ====================
|
||||
describe('State Management', () => {
|
||||
it('should show CommonCreateModal when selectedCreateInfo is set', async () => {
|
||||
// Arrange
|
||||
@@ -474,7 +439,6 @@ describe('CreateSubscriptionButton', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== Memoization Logic ====================
|
||||
describe('Memoization - buttonTextMap', () => {
|
||||
it('should display correct button text for OAUTH method', () => {
|
||||
// Arrange
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { Option } from '@/app/components/base/select/custom'
|
||||
import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
|
||||
import { RiAddLine, RiEqualizer2Line } from '@remixicon/react'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import { useMemo, useState } from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { ActionButton, ActionButtonState } from '@/app/components/base/action-button'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
@@ -18,11 +18,7 @@ import { usePluginStore } from '../../store'
|
||||
import { useSubscriptionList } from '../use-subscription-list'
|
||||
import { CommonCreateModal } from './common-modal'
|
||||
import { OAuthClientSettingsModal } from './oauth-client'
|
||||
|
||||
export enum CreateButtonType {
|
||||
FULL_BUTTON = 'full-button',
|
||||
ICON_BUTTON = 'icon-button',
|
||||
}
|
||||
import { CreateButtonType, DEFAULT_METHOD } from './types'
|
||||
|
||||
type Props = {
|
||||
className?: string
|
||||
@@ -32,8 +28,6 @@ type Props = {
|
||||
|
||||
const MAX_COUNT = 10
|
||||
|
||||
export const DEFAULT_METHOD = 'default'
|
||||
|
||||
export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BUTTON, shape = 'square' }: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { subscriptions } = useSubscriptionList()
|
||||
@@ -43,7 +37,7 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU
|
||||
const detail = usePluginStore(state => state.detail)
|
||||
|
||||
const { data: providerInfo } = useTriggerProviderInfo(detail?.provider || '')
|
||||
const supportedMethods = providerInfo?.supported_creation_methods || []
|
||||
const supportedMethods = useMemo(() => providerInfo?.supported_creation_methods || [], [providerInfo?.supported_creation_methods])
|
||||
const { data: oauthConfig, refetch: refetchOAuthConfig } = useTriggerOAuthConfig(detail?.provider || '', supportedMethods.includes(SupportedCreationMethods.OAUTH))
|
||||
const { mutate: initiateOAuth } = useInitiateTriggerOAuth()
|
||||
|
||||
@@ -63,11 +57,11 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU
|
||||
}
|
||||
}, [t])
|
||||
|
||||
const onClickClientSettings = (e: React.MouseEvent<HTMLDivElement | HTMLButtonElement>) => {
|
||||
const onClickClientSettings = useCallback((e: React.MouseEvent<HTMLDivElement | HTMLButtonElement>) => {
|
||||
e.stopPropagation()
|
||||
e.preventDefault()
|
||||
showClientSettingsModal()
|
||||
}
|
||||
}, [showClientSettingsModal])
|
||||
|
||||
const allOptions = useMemo(() => {
|
||||
const showCustomBadge = oauthConfig?.custom_enabled && oauthConfig?.custom_configured
|
||||
@@ -104,7 +98,7 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU
|
||||
show: supportedMethods.includes(SupportedCreationMethods.MANUAL),
|
||||
},
|
||||
]
|
||||
}, [t, oauthConfig, supportedMethods, methodType])
|
||||
}, [t, oauthConfig, supportedMethods, methodType, onClickClientSettings])
|
||||
|
||||
const onChooseCreateType = async (type: SupportedCreationMethods) => {
|
||||
if (type === SupportedCreationMethods.OAUTH) {
|
||||
@@ -160,7 +154,7 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU
|
||||
<CustomSelect<Option & { show: boolean, extra?: React.ReactNode, tag?: React.ReactNode }>
|
||||
options={allOptions.filter(option => option.show)}
|
||||
value={methodType}
|
||||
onChange={value => onChooseCreateType(value as any)}
|
||||
onChange={value => onChooseCreateType(value as SupportedCreationMethods)}
|
||||
containerProps={{
|
||||
open: (methodType === DEFAULT_METHOD || (methodType === SupportedCreationMethods.OAUTH && supportedMethods.length === 1)) ? undefined : false,
|
||||
placement: 'bottom-start',
|
||||
@@ -254,3 +248,5 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export { CreateButtonType, DEFAULT_METHOD } from './types'
|
||||
|
||||
@@ -3,24 +3,14 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
|
||||
|
||||
// Import after mocks
|
||||
import { OAuthClientSettingsModal } from './oauth-client'
|
||||
|
||||
// ============================================================================
|
||||
// Type Definitions
|
||||
// ============================================================================
|
||||
|
||||
type PluginDetail = {
|
||||
plugin_id: string
|
||||
provider: string
|
||||
name: string
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Factory Functions
|
||||
// ============================================================================
|
||||
|
||||
function createMockOAuthConfig(overrides: Partial<TriggerOAuthConfig> = {}): TriggerOAuthConfig {
|
||||
return {
|
||||
configured: true,
|
||||
@@ -64,18 +54,12 @@ function createMockSubscriptionBuilder(overrides: Partial<TriggerSubscriptionBui
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Setup
|
||||
// ============================================================================
|
||||
|
||||
// Mock plugin store
|
||||
const mockPluginDetail = createMockPluginDetail()
|
||||
const mockUsePluginStore = vi.fn(() => mockPluginDetail)
|
||||
vi.mock('../../store', () => ({
|
||||
usePluginStore: () => mockUsePluginStore(),
|
||||
}))
|
||||
|
||||
// Mock service hooks
|
||||
const mockInitiateOAuth = vi.fn()
|
||||
const mockVerifyBuilder = vi.fn()
|
||||
const mockConfigureOAuth = vi.fn()
|
||||
@@ -96,13 +80,11 @@ vi.mock('@/service/use-triggers', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock OAuth popup
|
||||
const mockOpenOAuthPopup = vi.fn()
|
||||
vi.mock('@/hooks/use-oauth', () => ({
|
||||
openOAuthPopup: (url: string, callback: (data: unknown) => void) => mockOpenOAuthPopup(url, callback),
|
||||
}))
|
||||
|
||||
// Mock toast
|
||||
const mockToastNotify = vi.fn()
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
@@ -110,7 +92,6 @@ vi.mock('@/app/components/base/toast', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock clipboard API
|
||||
const mockClipboardWriteText = vi.fn()
|
||||
Object.assign(navigator, {
|
||||
clipboard: {
|
||||
@@ -118,7 +99,6 @@ Object.assign(navigator, {
|
||||
},
|
||||
})
|
||||
|
||||
// Mock Modal component
|
||||
vi.mock('@/app/components/base/modal/modal', () => ({
|
||||
default: ({
|
||||
children,
|
||||
@@ -161,24 +141,6 @@ vi.mock('@/app/components/base/modal/modal', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
// Mock Button component
|
||||
vi.mock('@/app/components/base/button', () => ({
|
||||
default: ({ children, onClick, variant, className }: {
|
||||
children: React.ReactNode
|
||||
onClick?: () => void
|
||||
variant?: string
|
||||
className?: string
|
||||
}) => (
|
||||
<button
|
||||
data-testid={`button-${variant || 'default'}`}
|
||||
onClick={onClick}
|
||||
className={className}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
),
|
||||
}))
|
||||
// Configurable form mock values
|
||||
let mockFormValues: { values: Record<string, string>, isCheckValidated: boolean } = {
|
||||
values: { client_id: 'test-client-id', client_secret: 'test-client-secret' },
|
||||
isCheckValidated: true,
|
||||
@@ -210,29 +172,6 @@ vi.mock('@/app/components/base/form/components/base', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock OptionCard component
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/option-card', () => ({
|
||||
default: ({ title, onSelect, selected, className }: {
|
||||
title: string
|
||||
onSelect: () => void
|
||||
selected: boolean
|
||||
className?: string
|
||||
}) => (
|
||||
<div
|
||||
data-testid={`option-card-${title}`}
|
||||
onClick={onSelect}
|
||||
className={`${className} ${selected ? 'selected' : ''}`}
|
||||
data-selected={selected}
|
||||
>
|
||||
{title}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
// ============================================================================
|
||||
// Test Suites
|
||||
// ============================================================================
|
||||
|
||||
describe('OAuthClientSettingsModal', () => {
|
||||
const defaultProps = {
|
||||
oauthConfig: createMockOAuthConfig(),
|
||||
@@ -244,7 +183,6 @@ describe('OAuthClientSettingsModal', () => {
|
||||
vi.clearAllMocks()
|
||||
mockUsePluginStore.mockReturnValue(mockPluginDetail)
|
||||
mockClipboardWriteText.mockResolvedValue(undefined)
|
||||
// Reset form values to default
|
||||
setMockFormValues({
|
||||
values: { client_id: 'test-client-id', client_secret: 'test-client-secret' },
|
||||
isCheckValidated: true,
|
||||
@@ -265,8 +203,8 @@ describe('OAuthClientSettingsModal', () => {
|
||||
it('should render client type selector when system_configured is true', () => {
|
||||
render(<OAuthClientSettingsModal {...defaultProps} />)
|
||||
|
||||
expect(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')).toBeInTheDocument()
|
||||
expect(screen.getByText('pluginTrigger.subscription.addType.options.oauth.default')).toBeInTheDocument()
|
||||
expect(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render client type selector when system_configured is false', () => {
|
||||
@@ -276,7 +214,7 @@ describe('OAuthClientSettingsModal', () => {
|
||||
|
||||
render(<OAuthClientSettingsModal {...defaultProps} oauthConfig={configWithoutSystemConfigured} />)
|
||||
|
||||
expect(screen.queryByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('pluginTrigger.subscription.addType.options.oauth.default')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render redirect URI info when custom client type is selected', () => {
|
||||
@@ -319,29 +257,29 @@ describe('OAuthClientSettingsModal', () => {
|
||||
it('should default to Default client type when system_configured is true', () => {
|
||||
render(<OAuthClientSettingsModal {...defaultProps} />)
|
||||
|
||||
const defaultCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default')
|
||||
expect(defaultCard).toHaveAttribute('data-selected', 'true')
|
||||
const defaultCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.default').closest('div')
|
||||
expect(defaultCard).toHaveClass('border-[1.5px]')
|
||||
})
|
||||
|
||||
it('should switch to Custom client type when Custom card is clicked', () => {
|
||||
render(<OAuthClientSettingsModal {...defaultProps} />)
|
||||
|
||||
const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')
|
||||
fireEvent.click(customCard)
|
||||
const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')
|
||||
fireEvent.click(customCard!)
|
||||
|
||||
expect(customCard).toHaveAttribute('data-selected', 'true')
|
||||
expect(customCard).toHaveClass('border-[1.5px]')
|
||||
})
|
||||
|
||||
it('should switch back to Default client type when Default card is clicked', () => {
|
||||
render(<OAuthClientSettingsModal {...defaultProps} />)
|
||||
|
||||
const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')
|
||||
fireEvent.click(customCard)
|
||||
const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')
|
||||
fireEvent.click(customCard!)
|
||||
|
||||
const defaultCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default')
|
||||
fireEvent.click(defaultCard)
|
||||
const defaultCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.default').closest('div')
|
||||
fireEvent.click(defaultCard!)
|
||||
|
||||
expect(defaultCard).toHaveAttribute('data-selected', 'true')
|
||||
expect(defaultCard).toHaveClass('border-[1.5px]')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -852,8 +790,8 @@ describe('OAuthClientSettingsModal', () => {
|
||||
render(<OAuthClientSettingsModal {...defaultProps} />)
|
||||
|
||||
// Switch to custom
|
||||
const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')
|
||||
fireEvent.click(customCard)
|
||||
const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')
|
||||
fireEvent.click(customCard!)
|
||||
|
||||
fireEvent.click(screen.getByTestId('modal-cancel'))
|
||||
|
||||
@@ -1054,7 +992,7 @@ describe('OAuthClientSettingsModal', () => {
|
||||
render(<OAuthClientSettingsModal {...defaultProps} />)
|
||||
|
||||
// Switch to custom type
|
||||
const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')
|
||||
const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!
|
||||
fireEvent.click(customCard)
|
||||
|
||||
fireEvent.click(screen.getByTestId('modal-cancel'))
|
||||
@@ -1077,7 +1015,7 @@ describe('OAuthClientSettingsModal', () => {
|
||||
render(<OAuthClientSettingsModal {...defaultProps} />)
|
||||
|
||||
// Switch to custom type
|
||||
fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom'))
|
||||
fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!)
|
||||
|
||||
fireEvent.click(screen.getByTestId('modal-cancel'))
|
||||
|
||||
@@ -1104,7 +1042,7 @@ describe('OAuthClientSettingsModal', () => {
|
||||
render(<OAuthClientSettingsModal {...defaultProps} />)
|
||||
|
||||
// Switch to custom type
|
||||
fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom'))
|
||||
fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!)
|
||||
|
||||
fireEvent.click(screen.getByTestId('modal-cancel'))
|
||||
|
||||
@@ -1131,7 +1069,7 @@ describe('OAuthClientSettingsModal', () => {
|
||||
render(<OAuthClientSettingsModal {...defaultProps} />)
|
||||
|
||||
// Switch to custom type
|
||||
fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom'))
|
||||
fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!)
|
||||
|
||||
fireEvent.click(screen.getByTestId('modal-cancel'))
|
||||
|
||||
@@ -1158,7 +1096,7 @@ describe('OAuthClientSettingsModal', () => {
|
||||
render(<OAuthClientSettingsModal {...defaultProps} />)
|
||||
|
||||
// Switch to custom type
|
||||
fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom'))
|
||||
fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!)
|
||||
|
||||
fireEvent.click(screen.getByTestId('modal-cancel'))
|
||||
|
||||
|
||||
@@ -1,27 +1,17 @@
|
||||
'use client'
|
||||
import type { FormRefObject } from '@/app/components/base/form/types'
|
||||
import type { TriggerOAuthClientParams, TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
|
||||
import type { ConfigureTriggerOAuthPayload } from '@/service/use-triggers'
|
||||
import type { TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
|
||||
import {
|
||||
RiClipboardLine,
|
||||
RiInformation2Fill,
|
||||
} from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { BaseForm } from '@/app/components/base/form/components/base'
|
||||
import Modal from '@/app/components/base/modal/modal'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import OptionCard from '@/app/components/workflow/nodes/_base/components/option-card'
|
||||
import { openOAuthPopup } from '@/hooks/use-oauth'
|
||||
import {
|
||||
useConfigureTriggerOAuth,
|
||||
useDeleteTriggerOAuth,
|
||||
useInitiateTriggerOAuth,
|
||||
useVerifyAndUpdateTriggerSubscriptionBuilder,
|
||||
} from '@/service/use-triggers'
|
||||
import { usePluginStore } from '../../store'
|
||||
import { ClientTypeEnum, useOAuthClientState } from './hooks/use-oauth-client-state'
|
||||
|
||||
type Props = {
|
||||
oauthConfig?: TriggerOAuthConfig
|
||||
@@ -29,169 +19,38 @@ type Props = {
|
||||
showOAuthCreateModal: (builder: TriggerSubscriptionBuilder) => void
|
||||
}
|
||||
|
||||
enum AuthorizationStatusEnum {
|
||||
Pending = 'pending',
|
||||
Success = 'success',
|
||||
Failed = 'failed',
|
||||
}
|
||||
|
||||
enum ClientTypeEnum {
|
||||
Default = 'default',
|
||||
Custom = 'custom',
|
||||
}
|
||||
const CLIENT_TYPE_OPTIONS = [ClientTypeEnum.Default, ClientTypeEnum.Custom] as const
|
||||
|
||||
export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreateModal }: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const detail = usePluginStore(state => state.detail)
|
||||
const { system_configured, params, oauth_client_schema } = oauthConfig || {}
|
||||
const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>()
|
||||
const [authorizationStatus, setAuthorizationStatus] = useState<AuthorizationStatusEnum>()
|
||||
|
||||
const [clientType, setClientType] = useState<ClientTypeEnum>(system_configured ? ClientTypeEnum.Default : ClientTypeEnum.Custom)
|
||||
|
||||
const clientFormRef = React.useRef<FormRefObject>(null)
|
||||
|
||||
const oauthClientSchema = useMemo(() => {
|
||||
if (oauth_client_schema && oauth_client_schema.length > 0 && params) {
|
||||
const oauthConfigPramaKeys = Object.keys(params || {})
|
||||
for (const schema of oauth_client_schema) {
|
||||
if (oauthConfigPramaKeys.includes(schema.name))
|
||||
schema.default = params?.[schema.name]
|
||||
}
|
||||
return oauth_client_schema
|
||||
}
|
||||
return []
|
||||
}, [oauth_client_schema, params])
|
||||
|
||||
const providerName = detail?.provider || ''
|
||||
const { mutate: initiateOAuth } = useInitiateTriggerOAuth()
|
||||
const { mutate: verifyBuilder } = useVerifyAndUpdateTriggerSubscriptionBuilder()
|
||||
const { mutate: configureOAuth } = useConfigureTriggerOAuth()
|
||||
const { mutate: deleteOAuth } = useDeleteTriggerOAuth()
|
||||
|
||||
const confirmButtonText = useMemo(() => {
|
||||
if (authorizationStatus === AuthorizationStatusEnum.Pending)
|
||||
return t('modal.common.authorizing', { ns: 'pluginTrigger' })
|
||||
if (authorizationStatus === AuthorizationStatusEnum.Success)
|
||||
return t('modal.oauth.authorization.waitingJump', { ns: 'pluginTrigger' })
|
||||
return t('auth.saveAndAuth', { ns: 'plugin' })
|
||||
}, [authorizationStatus, t])
|
||||
const {
|
||||
clientType,
|
||||
setClientType,
|
||||
clientFormRef,
|
||||
oauthClientSchema,
|
||||
confirmButtonText,
|
||||
handleRemove,
|
||||
handleSave,
|
||||
} = useOAuthClientState({
|
||||
oauthConfig,
|
||||
providerName,
|
||||
onClose,
|
||||
showOAuthCreateModal,
|
||||
})
|
||||
|
||||
const getErrorMessage = (error: unknown, fallback: string) => {
|
||||
if (error instanceof Error && error.message)
|
||||
return error.message
|
||||
if (typeof error === 'object' && error && 'message' in error) {
|
||||
const message = (error as { message?: string }).message
|
||||
if (typeof message === 'string' && message)
|
||||
return message
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
const isCustomClient = clientType === ClientTypeEnum.Custom
|
||||
const showRemoveButton = oauthConfig?.custom_enabled && oauthConfig?.params && isCustomClient
|
||||
const showRedirectInfo = isCustomClient && oauthConfig?.redirect_uri
|
||||
const showClientForm = isCustomClient && oauthClientSchema.length > 0
|
||||
|
||||
const handleAuthorization = () => {
|
||||
setAuthorizationStatus(AuthorizationStatusEnum.Pending)
|
||||
initiateOAuth(providerName, {
|
||||
onSuccess: (response) => {
|
||||
setSubscriptionBuilder(response.subscription_builder)
|
||||
openOAuthPopup(response.authorization_url, (callbackData) => {
|
||||
if (callbackData) {
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
onClose()
|
||||
showOAuthCreateModal(response.subscription_builder)
|
||||
}
|
||||
})
|
||||
},
|
||||
onError: () => {
|
||||
setAuthorizationStatus(AuthorizationStatusEnum.Failed)
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (providerName && subscriptionBuilder && authorizationStatus === AuthorizationStatusEnum.Pending) {
|
||||
const pollInterval = setInterval(() => {
|
||||
verifyBuilder(
|
||||
{
|
||||
provider: providerName,
|
||||
subscriptionBuilderId: subscriptionBuilder.id,
|
||||
},
|
||||
{
|
||||
onSuccess: (response) => {
|
||||
if (response.verified) {
|
||||
setAuthorizationStatus(AuthorizationStatusEnum.Success)
|
||||
clearInterval(pollInterval)
|
||||
}
|
||||
},
|
||||
onError: () => {
|
||||
// Continue polling - auth might still be in progress
|
||||
},
|
||||
},
|
||||
)
|
||||
}, 3000)
|
||||
|
||||
return () => clearInterval(pollInterval)
|
||||
}
|
||||
}, [subscriptionBuilder, authorizationStatus, verifyBuilder, providerName, t])
|
||||
|
||||
const handleRemove = () => {
|
||||
deleteOAuth(providerName, {
|
||||
onSuccess: () => {
|
||||
onClose()
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('modal.oauth.remove.success', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
},
|
||||
onError: (error: unknown) => {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: getErrorMessage(error, t('modal.oauth.remove.failed', { ns: 'pluginTrigger' })),
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const handleSave = (needAuth: boolean) => {
|
||||
const isCustom = clientType === ClientTypeEnum.Custom
|
||||
const params: ConfigureTriggerOAuthPayload = {
|
||||
provider: providerName,
|
||||
enabled: isCustom,
|
||||
}
|
||||
|
||||
if (isCustom) {
|
||||
const clientFormValues = clientFormRef.current?.getFormValues({}) as { values: TriggerOAuthClientParams, isCheckValidated: boolean }
|
||||
if (!clientFormValues.isCheckValidated)
|
||||
return
|
||||
const clientParams = clientFormValues.values
|
||||
if (clientParams.client_id === oauthConfig?.params.client_id)
|
||||
clientParams.client_id = '[__HIDDEN__]'
|
||||
|
||||
if (clientParams.client_secret === oauthConfig?.params.client_secret)
|
||||
clientParams.client_secret = '[__HIDDEN__]'
|
||||
|
||||
params.client_params = clientParams
|
||||
}
|
||||
|
||||
configureOAuth(params, {
|
||||
onSuccess: () => {
|
||||
if (needAuth) {
|
||||
handleAuthorization()
|
||||
}
|
||||
else {
|
||||
onClose()
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('modal.oauth.save.success', { ns: 'pluginTrigger' }),
|
||||
})
|
||||
}
|
||||
},
|
||||
const handleCopyRedirectUri = () => {
|
||||
navigator.clipboard.writeText(oauthConfig?.redirect_uri || '')
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('actionMsg.copySuccessfully', { ns: 'common' }),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -208,25 +67,25 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate
|
||||
onClose={onClose}
|
||||
onCancel={() => handleSave(false)}
|
||||
onConfirm={() => handleSave(true)}
|
||||
footerSlot={
|
||||
oauthConfig?.custom_enabled && oauthConfig?.params && clientType === ClientTypeEnum.Custom && (
|
||||
<div className="grow">
|
||||
<Button
|
||||
variant="secondary"
|
||||
className="text-components-button-destructive-secondary-text"
|
||||
// disabled={disabled || doingAction || !editValues}
|
||||
onClick={handleRemove}
|
||||
>
|
||||
{t('operation.remove', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
footerSlot={showRemoveButton && (
|
||||
<div className="grow">
|
||||
<Button
|
||||
variant="secondary"
|
||||
className="text-components-button-destructive-secondary-text"
|
||||
onClick={handleRemove}
|
||||
>
|
||||
{t('operation.remove', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
>
|
||||
<div className="system-sm-medium mb-2 text-text-secondary">{t('subscription.addType.options.oauth.clientTitle', { ns: 'pluginTrigger' })}</div>
|
||||
<div className="system-sm-medium mb-2 text-text-secondary">
|
||||
{t('subscription.addType.options.oauth.clientTitle', { ns: 'pluginTrigger' })}
|
||||
</div>
|
||||
|
||||
{oauthConfig?.system_configured && (
|
||||
<div className="mb-4 flex w-full items-start justify-between gap-2">
|
||||
{[ClientTypeEnum.Default, ClientTypeEnum.Custom].map(option => (
|
||||
{CLIENT_TYPE_OPTIONS.map(option => (
|
||||
<OptionCard
|
||||
key={option}
|
||||
title={t(`subscription.addType.options.oauth.${option}`, { ns: 'pluginTrigger' })}
|
||||
@@ -237,7 +96,8 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{clientType === ClientTypeEnum.Custom && oauthConfig?.redirect_uri && (
|
||||
|
||||
{showRedirectInfo && (
|
||||
<div className="mb-4 flex items-start gap-3 rounded-xl bg-background-section-burn p-4">
|
||||
<div className="rounded-lg border-[0.5px] border-components-card-border bg-components-card-bg p-2 shadow-xs shadow-shadow-shadow-3">
|
||||
<RiInformation2Fill className="h-5 w-5 shrink-0 text-text-accent" />
|
||||
@@ -247,18 +107,12 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate
|
||||
{t('modal.oauthRedirectInfo', { ns: 'pluginTrigger' })}
|
||||
</div>
|
||||
<div className="system-sm-medium my-1.5 break-all leading-4">
|
||||
{oauthConfig.redirect_uri}
|
||||
{oauthConfig?.redirect_uri}
|
||||
</div>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
navigator.clipboard.writeText(oauthConfig.redirect_uri)
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('actionMsg.copySuccessfully', { ns: 'common' }),
|
||||
})
|
||||
}}
|
||||
onClick={handleCopyRedirectUri}
|
||||
>
|
||||
<RiClipboardLine className="mr-1 h-[14px] w-[14px]" />
|
||||
{t('operation.copy', { ns: 'common' })}
|
||||
@@ -266,7 +120,8 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{clientType === ClientTypeEnum.Custom && oauthClientSchema.length > 0 && (
|
||||
|
||||
{showClientForm && (
|
||||
<BaseForm
|
||||
formSchemas={oauthClientSchema}
|
||||
ref={clientFormRef}
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
export enum CreateButtonType {
|
||||
FULL_BUTTON = 'full-button',
|
||||
ICON_BUTTON = 'icon-button',
|
||||
}
|
||||
|
||||
export const DEFAULT_METHOD = 'default'
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
Group,
|
||||
} from '@/app/components/workflow/nodes/_base/components/layout'
|
||||
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import Split from '../_base/components/split'
|
||||
import ChunkStructure from './components/chunk-structure'
|
||||
import EmbeddingModel from './components/embedding-model'
|
||||
@@ -172,7 +173,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
{
|
||||
data.indexing_technique === IndexMethodEnum.QUALIFIED
|
||||
&& [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure)
|
||||
&& (
|
||||
&& IS_CE_EDITION && (
|
||||
<>
|
||||
<SummaryIndexSetting
|
||||
summaryIndexSetting={data.summary_index_setting}
|
||||
|
||||
@@ -196,19 +196,19 @@ describe('useDocLink', () => {
|
||||
|
||||
const { result } = renderHook(() => useDocLink())
|
||||
const url = result.current('/api-reference/annotations/create-annotation')
|
||||
expect(url).toBe(`${defaultDocBaseUrl}/en/api-reference/annotations/create-annotation`)
|
||||
expect(url).toBe(`${defaultDocBaseUrl}/api-reference/annotations/create-annotation`)
|
||||
})
|
||||
|
||||
it('should keep original path when no translation exists for non-English locale', () => {
|
||||
vi.mocked(useTranslation).mockReturnValue({
|
||||
i18n: { language: 'ja-JP' },
|
||||
i18n: { language: 'zh-Hans' },
|
||||
} as ReturnType<typeof useTranslation>)
|
||||
vi.mocked(getDocLanguage).mockReturnValue('ja')
|
||||
vi.mocked(getDocLanguage).mockReturnValue('zh')
|
||||
|
||||
const { result } = renderHook(() => useDocLink())
|
||||
// This path has no Japanese translation
|
||||
const url = result.current('/api-reference/annotations/create-annotation')
|
||||
expect(url).toBe(`${defaultDocBaseUrl}/ja/api-reference/annotations/create-annotation`)
|
||||
expect(url).toBe(`${defaultDocBaseUrl}/api-reference/标注管理/创建标注`)
|
||||
})
|
||||
|
||||
it('should remove language prefix when translation is applied', () => {
|
||||
|
||||
@@ -35,12 +35,13 @@ export const useDocLink = (baseUrl?: string): ((path?: DocPathWithoutLang, pathM
|
||||
let targetPath = (pathMap) ? pathMap[locale] || pathUrl : pathUrl
|
||||
let languagePrefix = `/${docLanguage}`
|
||||
|
||||
// Translate API reference paths for non-English locales
|
||||
if (targetPath.startsWith('/api-reference/') && docLanguage !== 'en') {
|
||||
const translatedPath = apiReferencePathTranslations[targetPath]?.[docLanguage as 'zh' | 'ja']
|
||||
if (translatedPath) {
|
||||
targetPath = translatedPath
|
||||
languagePrefix = ''
|
||||
if (targetPath.startsWith('/api-reference/')) {
|
||||
languagePrefix = ''
|
||||
if (docLanguage !== 'en') {
|
||||
const translatedPath = apiReferencePathTranslations[targetPath]?.[docLanguage]
|
||||
if (translatedPath) {
|
||||
targetPath = translatedPath
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2445,11 +2445,6 @@
|
||||
"count": 8
|
||||
}
|
||||
},
|
||||
"app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 8
|
||||
}
|
||||
},
|
||||
"app/components/plugins/plugin-detail-panel/datasource-action-list.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
@@ -2503,14 +2498,6 @@
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx": {
|
||||
"react-refresh/only-export-components": {
|
||||
"count": 1
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/plugins/plugin-detail-panel/subscription-list/delete-confirm.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { App, AppCategory } from '@/models/explore'
|
||||
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { fetchAppList, fetchBanners, fetchInstalledAppList, getAppAccessModeByAppId, uninstallApp, updatePinStatus } from './explore'
|
||||
import { AppSourceType, fetchAppMeta, fetchAppParams } from './share'
|
||||
@@ -13,8 +14,9 @@ type ExploreAppListData = {
|
||||
}
|
||||
|
||||
export const useExploreAppList = () => {
|
||||
const locale = useLocale()
|
||||
return useQuery<ExploreAppListData>({
|
||||
queryKey: [NAME_SPACE, 'appList'],
|
||||
queryKey: [NAME_SPACE, 'appList', locale],
|
||||
queryFn: async () => {
|
||||
const { categories, recommended_apps } = await fetchAppList()
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user