mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 22:55:13 +00:00
Compare commits
1 Commits
feat/evalu
...
move-core-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d5572027c |
12
.github/workflows/pyrefly-diff.yml
vendored
12
.github/workflows/pyrefly-diff.yml
vendored
@@ -29,26 +29,20 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Prepare diagnostics extractor
|
||||
run: |
|
||||
git show ${{ github.event.pull_request.head.sha }}:api/libs/pyrefly_diagnostics.py > /tmp/pyrefly_diagnostics.py
|
||||
|
||||
- name: Run pyrefly on PR branch
|
||||
run: |
|
||||
uv run --directory api --dev pyrefly check 2>&1 \
|
||||
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_pr.txt || true
|
||||
uv run --directory api pyrefly check > /tmp/pyrefly_pr.txt 2>&1 || true
|
||||
|
||||
- name: Checkout base branch
|
||||
run: git checkout ${{ github.base_ref }}
|
||||
|
||||
- name: Run pyrefly on base branch
|
||||
run: |
|
||||
uv run --directory api --dev pyrefly check 2>&1 \
|
||||
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_base.txt || true
|
||||
uv run --directory api pyrefly check > /tmp/pyrefly_base.txt 2>&1 || true
|
||||
|
||||
- name: Compute diff
|
||||
run: |
|
||||
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
diff /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
|
||||
@@ -29,8 +29,6 @@ ignore_imports =
|
||||
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
@@ -54,6 +52,7 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
||||
@@ -92,7 +91,7 @@ forbidden_modules =
|
||||
core.moderation
|
||||
core.ops
|
||||
core.plugin
|
||||
core.prompt
|
||||
core.model_runtime.prompt
|
||||
core.provider_manager
|
||||
core.rag
|
||||
core.repositories
|
||||
@@ -108,11 +107,14 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||
core.workflow.nodes.llm.protocols -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> models.model
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
|
||||
core.workflow.nodes.llm.node -> core.tools.signature
|
||||
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
@@ -125,14 +127,16 @@ ignore_imports =
|
||||
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.start.entities -> core.app.app_config.entities
|
||||
core.workflow.nodes.start.start_node -> core.app.app_config.entities
|
||||
core.workflow.workflow_entry -> core.app.apps.exc
|
||||
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
|
||||
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
@@ -144,15 +148,16 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
core.workflow.nodes.llm.node -> core.model_manager
|
||||
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.agent.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.model_runtime.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.question_classifier.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
|
||||
@@ -167,6 +172,7 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
@@ -174,7 +180,7 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> extensions.otel.runtime
|
||||
core.workflow.nodes.agent.agent_node -> models
|
||||
core.workflow.nodes.base.node -> models.enums
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider_ids
|
||||
core.workflow.nodes.llm.node -> models.model
|
||||
core.workflow.workflow_entry -> models.enums
|
||||
core.workflow.nodes.agent.agent_node -> services
|
||||
@@ -184,7 +190,12 @@ ignore_imports =
|
||||
name = Model Runtime Internal Imports
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.model_runtime
|
||||
core.model_runtime.callbacks
|
||||
core.model_runtime.entities
|
||||
core.model_runtime.errors
|
||||
core.model_runtime.model_providers
|
||||
core.model_runtime.schema_validators
|
||||
core.model_runtime.utils
|
||||
forbidden_modules =
|
||||
configs
|
||||
controllers
|
||||
@@ -214,7 +225,7 @@ forbidden_modules =
|
||||
core.moderation
|
||||
core.ops
|
||||
core.plugin
|
||||
core.prompt
|
||||
core.model_runtime.prompt
|
||||
core.provider_manager
|
||||
core.rag
|
||||
core.repositories
|
||||
|
||||
@@ -1366,32 +1366,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class EvaluationConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for evaluation runtime
|
||||
"""
|
||||
|
||||
EVALUATION_FRAMEWORK: str = Field(
|
||||
description="Evaluation framework to use (ragas/deepeval/none)",
|
||||
default="none",
|
||||
)
|
||||
|
||||
EVALUATION_MAX_CONCURRENT_RUNS: PositiveInt = Field(
|
||||
description="Maximum number of concurrent evaluation runs per tenant",
|
||||
default=3,
|
||||
)
|
||||
|
||||
EVALUATION_MAX_DATASET_ROWS: PositiveInt = Field(
|
||||
description="Maximum number of rows allowed in an evaluation dataset",
|
||||
default=1000,
|
||||
)
|
||||
|
||||
EVALUATION_TASK_TIMEOUT: PositiveInt = Field(
|
||||
description="Timeout in seconds for a single evaluation task",
|
||||
default=3600,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@@ -1404,7 +1378,6 @@ class FeatureConfig(
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
EvaluationConfig,
|
||||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HttpConfig,
|
||||
|
||||
@@ -116,12 +116,6 @@ from .explore import (
|
||||
trial,
|
||||
)
|
||||
|
||||
# Import evaluation controllers
|
||||
from .evaluation import evaluation
|
||||
|
||||
# Import snippet controllers
|
||||
from .snippets import snippet_workflow
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
@@ -135,7 +129,6 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
snippets,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@@ -173,7 +166,6 @@ __all__ = [
|
||||
"datasource_content_preview",
|
||||
"email_register",
|
||||
"endpoint",
|
||||
"evaluation",
|
||||
"extension",
|
||||
"external",
|
||||
"feature",
|
||||
@@ -207,8 +199,6 @@ __all__ = [
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"snippet_workflow",
|
||||
"snippets",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# Evaluation controller module
|
||||
@@ -1,605 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, ParamSpec, TypeVar, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import UploadFile
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.evaluation import (
|
||||
EvaluationDatasetInvalidError,
|
||||
EvaluationFrameworkNotConfiguredError,
|
||||
EvaluationMaxConcurrentRunsError,
|
||||
EvaluationNotFoundError,
|
||||
)
|
||||
from services.evaluation_service import EvaluationService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.evaluation import EvaluationRun, EvaluationRunItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Valid evaluation target types
|
||||
EVALUATE_TARGET_TYPES = {"app", "snippets"}
|
||||
|
||||
|
||||
class VersionQuery(BaseModel):
|
||||
"""Query parameters for version endpoint."""
|
||||
|
||||
version: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
VersionQuery,
|
||||
)
|
||||
|
||||
|
||||
# Response field definitions
|
||||
file_info_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_fields = {
|
||||
"created_at": TimestampField,
|
||||
"created_by": fields.String,
|
||||
"test_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationTestFile",
|
||||
file_info_fields,
|
||||
)
|
||||
),
|
||||
"result_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationResultFile",
|
||||
file_info_fields,
|
||||
),
|
||||
allow_null=True,
|
||||
),
|
||||
"version": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_list_model = console_ns.model(
|
||||
"EvaluationLogList",
|
||||
{
|
||||
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
|
||||
},
|
||||
)
|
||||
|
||||
customized_matrix_fields = {
|
||||
"evaluation_workflow_id": fields.String,
|
||||
"input_fields": fields.Raw,
|
||||
"output_fields": fields.Raw,
|
||||
}
|
||||
|
||||
condition_fields = {
|
||||
"name": fields.List(fields.String),
|
||||
"comparison_operator": fields.String,
|
||||
"value": fields.String,
|
||||
}
|
||||
|
||||
judgement_conditions_fields = {
|
||||
"logical_operator": fields.String,
|
||||
"conditions": fields.List(fields.Nested(console_ns.model("EvaluationCondition", condition_fields))),
|
||||
}
|
||||
|
||||
evaluation_detail_fields = {
|
||||
"evaluation_model": fields.String,
|
||||
"evaluation_model_provider": fields.String,
|
||||
"customized_matrix": fields.Nested(
|
||||
console_ns.model("EvaluationCustomizedMatrix", customized_matrix_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
"judgement_conditions": fields.Nested(
|
||||
console_ns.model("EvaluationJudgementConditions", judgement_conditions_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
}
|
||||
|
||||
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
|
||||
|
||||
|
||||
def get_evaluation_target(view_func: Callable[P, R]):
|
||||
"""
|
||||
Decorator to resolve polymorphic evaluation target (app or snippet).
|
||||
|
||||
Validates the target_type parameter and fetches the corresponding
|
||||
model (App or CustomizedSnippet) with tenant isolation.
|
||||
"""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
target_type = kwargs.get("evaluate_target_type")
|
||||
target_id = kwargs.get("evaluate_target_id")
|
||||
|
||||
if target_type not in EVALUATE_TARGET_TYPES:
|
||||
raise NotFound(f"Invalid evaluation target type: {target_type}")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
target_id = str(target_id)
|
||||
|
||||
# Remove path parameters
|
||||
del kwargs["evaluate_target_type"]
|
||||
del kwargs["evaluate_target_id"]
|
||||
|
||||
target: Union[App, CustomizedSnippet] | None = None
|
||||
|
||||
if target_type == "app":
|
||||
target = (
|
||||
db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
|
||||
)
|
||||
elif target_type == "snippets":
|
||||
target = (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not target:
|
||||
raise NotFound(f"{str(target_type)} not found")
|
||||
|
||||
kwargs["target"] = target
|
||||
kwargs["target_type"] = target_type
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
|
||||
class EvaluationDatasetTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_dataset_template")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(400, "Invalid target type or excluded app mode")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Download evaluation dataset template.
|
||||
|
||||
Generates an XLSX template based on the target's input parameters
|
||||
and streams it directly as a file attachment.
|
||||
"""
|
||||
try:
|
||||
xlsx_content, filename = EvaluationService.generate_dataset_template(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
|
||||
class EvaluationDetailApi(Resource):
|
||||
@console_ns.doc("get_evaluation_detail")
|
||||
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation configuration for the target.
|
||||
|
||||
Returns evaluation configuration including model settings,
|
||||
metrics config, and judgement conditions.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.get_evaluation_config(
|
||||
session, current_tenant_id, target_type, str(target.id)
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"evaluation_model": None,
|
||||
"evaluation_model_provider": None,
|
||||
"metrics_config": None,
|
||||
"judgement_conditions": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"metrics_config": config.metrics_config_dict,
|
||||
"judgement_conditions": config.judgement_conditions_dict,
|
||||
}
|
||||
|
||||
@console_ns.doc("save_evaluation_detail")
|
||||
@console_ns.response(200, "Evaluation configuration saved successfully")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def put(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Save evaluation configuration for the target.
|
||||
"""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
body = request.get_json(force=True)
|
||||
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.save_evaluation_config(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
data=config_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"metrics_config": config.metrics_config_dict,
|
||||
"judgement_conditions": config.judgement_conditions_dict,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
|
||||
class EvaluationLogsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_logs")
|
||||
@console_ns.response(200, "Evaluation logs retrieved successfully")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation run history for the target.
|
||||
|
||||
Returns a paginated list of evaluation runs.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
runs, total = EvaluationService.get_evaluation_runs(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": [_serialize_evaluation_run(run) for run in runs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run")
|
||||
class EvaluationRunApi(Resource):
|
||||
@console_ns.doc("start_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
Expects JSON body with:
|
||||
- file_id: uploaded dataset file ID
|
||||
- evaluation_model: evaluation model name
|
||||
- evaluation_model_provider: evaluation model provider
|
||||
- default_metrics: list of default metric objects
|
||||
- customized_metrics: customized metrics object (optional)
|
||||
- judgment_config: judgment conditions config (optional)
|
||||
"""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
body = request.get_json(force=True)
|
||||
if not body:
|
||||
raise BadRequest("Request body is required.")
|
||||
|
||||
# Validate and parse request body
|
||||
try:
|
||||
run_request = EvaluationRunRequest.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
# Load dataset file
|
||||
upload_file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter_by(id=run_request.file_id, tenant_id=current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("Dataset file not found.")
|
||||
|
||||
try:
|
||||
dataset_content = storage.load_once(upload_file.key)
|
||||
except Exception:
|
||||
raise BadRequest("Failed to read dataset file.")
|
||||
|
||||
if not dataset_content:
|
||||
raise BadRequest("Dataset file is empty.")
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
dataset_file_content=dataset_content,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>"
|
||||
)
|
||||
class EvaluationRunDetailApi(Resource):
|
||||
@console_ns.doc("get_evaluation_run_detail")
|
||||
@console_ns.response(200, "Evaluation run detail retrieved")
|
||||
@console_ns.response(404, "Run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
|
||||
"""
|
||||
Get evaluation run detail including items.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
run_id = str(run_id)
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 50, type=int)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.get_evaluation_run_detail(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
items, total_items = EvaluationService.get_evaluation_run_items(
|
||||
session=session,
|
||||
run_id=run_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"run": _serialize_evaluation_run(run),
|
||||
"items": {
|
||||
"data": [_serialize_evaluation_run_item(item) for item in items],
|
||||
"total": total_items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
}
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>/cancel"
|
||||
)
|
||||
class EvaluationRunCancelApi(Resource):
|
||||
@console_ns.doc("cancel_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run cancelled")
|
||||
@console_ns.response(404, "Run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
|
||||
"""Cancel a running evaluation."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
run_id = str(run_id)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.cancel_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
return _serialize_evaluation_run(run)
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/metrics")
|
||||
class EvaluationMetricsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_metrics")
|
||||
@console_ns.response(200, "Available metrics retrieved")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get available evaluation metrics for the current framework.
|
||||
"""
|
||||
result = {}
|
||||
for category in EvaluationCategory:
|
||||
result[category.value] = EvaluationService.get_supported_metrics(category)
|
||||
return {"metrics": result}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
|
||||
class EvaluationFileDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_file")
|
||||
@console_ns.response(200, "File download URL generated successfully")
|
||||
@console_ns.response(404, "Target or file not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
|
||||
"""
|
||||
Download evaluation test file or result file.
|
||||
|
||||
Looks up the specified file, verifies it belongs to the same tenant,
|
||||
and returns file info and download URL.
|
||||
"""
|
||||
file_id = str(file_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
|
||||
class EvaluationVersionApi(Resource):
|
||||
@console_ns.doc("get_evaluation_version_detail")
|
||||
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
|
||||
@console_ns.response(200, "Version details retrieved successfully")
|
||||
@console_ns.response(404, "Target or version not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation target version details.
|
||||
|
||||
Returns the workflow graph for the specified version.
|
||||
"""
|
||||
version = request.args.get("version")
|
||||
|
||||
if not version:
|
||||
return {"message": "version parameter is required"}, 400
|
||||
|
||||
graph = {}
|
||||
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
|
||||
graph = target.graph_dict
|
||||
|
||||
return {
|
||||
"graph": graph,
|
||||
}
|
||||
|
||||
|
||||
# ---- Serialization Helpers ----
|
||||
|
||||
|
||||
def _serialize_evaluation_run(run: EvaluationRun) -> dict[str, object]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"target_type": run.target_type,
|
||||
"target_id": run.target_id,
|
||||
"evaluation_config_id": run.evaluation_config_id,
|
||||
"status": run.status,
|
||||
"dataset_file_id": run.dataset_file_id,
|
||||
"result_file_id": run.result_file_id,
|
||||
"total_items": run.total_items,
|
||||
"completed_items": run.completed_items,
|
||||
"failed_items": run.failed_items,
|
||||
"progress": run.progress,
|
||||
"metrics_summary": run.metrics_summary_dict,
|
||||
"error": run.error,
|
||||
"created_by": run.created_by,
|
||||
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
|
||||
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
|
||||
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_evaluation_run_item(item: EvaluationRunItem) -> dict[str, object]:
|
||||
return {
|
||||
"id": item.id,
|
||||
"item_index": item.item_index,
|
||||
"inputs": item.inputs_dict,
|
||||
"expected_output": item.expected_output,
|
||||
"actual_output": item.actual_output,
|
||||
"metrics": item.metrics_list,
|
||||
"metadata": item.metadata_dict,
|
||||
"error": item.error,
|
||||
"overall_score": item.overall_score,
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SnippetListQuery(BaseModel):
|
||||
"""Query parameters for listing snippets."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
"""Icon information model."""
|
||||
|
||||
icon: str | None = None
|
||||
icon_type: Literal["emoji", "image"] | None = None
|
||||
icon_background: str | None = None
|
||||
icon_url: str | None = None
|
||||
|
||||
|
||||
class InputFieldDefinition(BaseModel):
|
||||
"""Input field definition for snippet parameters."""
|
||||
|
||||
default: str | None = None
|
||||
hint: bool | None = None
|
||||
label: str | None = None
|
||||
max_length: int | None = None
|
||||
options: list[str] | None = None
|
||||
placeholder: str | None = None
|
||||
required: bool | None = None
|
||||
type: str | None = None # e.g., "text-input"
|
||||
|
||||
|
||||
class CreateSnippetPayload(BaseModel):
|
||||
"""Payload for creating a new snippet."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
type: Literal["node", "group"] = "node"
|
||||
icon_info: IconInfo | None = None
|
||||
graph: dict[str, Any] | None = None
|
||||
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateSnippetPayload(BaseModel):
|
||||
"""Payload for updating a snippet."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
icon_info: IconInfo | None = None
|
||||
|
||||
|
||||
class SnippetDraftSyncPayload(BaseModel):
|
||||
"""Payload for syncing snippet draft workflow."""
|
||||
|
||||
graph: dict[str, Any]
|
||||
hash: str | None = None
|
||||
environment_variables: list[dict[str, Any]] | None = None
|
||||
conversation_variables: list[dict[str, Any]] | None = None
|
||||
input_variables: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
"""Query parameters for workflow runs."""
|
||||
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SnippetDraftRunPayload(BaseModel):
|
||||
"""Payload for running snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetDraftNodeRunPayload(BaseModel):
|
||||
"""Payload for running a single node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetIterationNodeRunPayload(BaseModel):
|
||||
"""Payload for running an iteration node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetLoopNodeRunPayload(BaseModel):
|
||||
"""Payload for running a loop node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
"""Payload for publishing snippet workflow."""
|
||||
|
||||
knowledge_base_setting: dict[str, Any] | None = None
|
||||
@@ -1,540 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow import workflow_model
|
||||
from controllers.console.app.workflow_run import (
|
||||
workflow_run_detail_model,
|
||||
workflow_run_node_execution_list_model,
|
||||
workflow_run_node_execution_model,
|
||||
workflow_run_pagination_model,
|
||||
)
|
||||
from controllers.console.snippets.payloads import (
|
||||
PublishWorkflowPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
WorkflowRunQuery,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
WorkflowRunQuery,
|
||||
PublishWorkflowPayload,
|
||||
)
|
||||
|
||||
|
||||
class SnippetNotFoundError(Exception):
|
||||
"""Snippet not found error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_snippet(view_func: Callable[P, R]):
|
||||
"""Decorator to fetch and validate snippet access."""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("snippet_id"):
|
||||
raise ValueError("missing snippet_id in path parameters")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_id = str(kwargs.get("snippet_id"))
|
||||
del kwargs["snippet_id"]
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=snippet_id,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
kwargs["snippet"] = snippet
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
|
||||
class SnippetDraftWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_workflow")
|
||||
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get draft workflow for snippet."""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
return workflow
|
||||
|
||||
@console_ns.doc("sync_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow synced successfully")
|
||||
@console_ns.response(400, "Hash mismatch")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Sync draft workflow for snippet."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
environment_variables_list = payload.environment_variables or []
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = payload.conversation_variables or []
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.sync_draft_workflow(
|
||||
snippet=snippet,
|
||||
graph=payload.graph,
|
||||
unique_hash=payload.hash,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
input_variables=payload.input_variables,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
|
||||
class SnippetDraftConfigApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_config")
|
||||
@console_ns.response(200, "Draft config retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get snippet draft workflow configuration limits."""
|
||||
return {
|
||||
"parallel_depth_limit": 3,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
|
||||
class SnippetPublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_published_workflow")
|
||||
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get published workflow for snippet."""
|
||||
if not snippet.is_published:
|
||||
return None
|
||||
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
|
||||
return workflow
|
||||
|
||||
@console_ns.doc("publish_snippet_workflow")
|
||||
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
|
||||
@console_ns.response(200, "Workflow published successfully")
|
||||
@console_ns.response(400, "No draft workflow found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Publish snippet workflow."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
snippet_service = SnippetService()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
try:
|
||||
workflow = snippet_service.publish_workflow(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account=current_user,
|
||||
)
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
|
||||
class SnippetDefaultBlockConfigsApi(Resource):
|
||||
@console_ns.doc("get_snippet_default_block_configs")
|
||||
@console_ns.response(200, "Default block configs retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get default block configurations for snippet workflow."""
|
||||
snippet_service = SnippetService()
|
||||
return snippet_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
|
||||
class SnippetWorkflowRunsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_runs")
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""List workflow runs for snippet."""
|
||||
query = WorkflowRunQuery.model_validate(
|
||||
{
|
||||
"last_id": request.args.get("last_id"),
|
||||
"limit": request.args.get("limit", type=int, default=20),
|
||||
}
|
||||
)
|
||||
args = {
|
||||
"last_id": query.last_id,
|
||||
"limit": query.limit,
|
||||
}
|
||||
|
||||
snippet_service = SnippetService()
|
||||
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
|
||||
class SnippetWorkflowRunDetailApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_run_detail")
|
||||
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""Get workflow run detail for snippet."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = SnippetService()
|
||||
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class SnippetWorkflowRunNodeExecutionsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_run_node_executions")
|
||||
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""List node executions for a workflow run."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = SnippetService()
|
||||
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
|
||||
snippet=snippet,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
return {"data": node_executions}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class SnippetDraftNodeRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_node")
|
||||
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a single node in snippet draft workflow.
|
||||
|
||||
Executes a specific node with provided inputs for single-step debugging.
|
||||
Returns the node execution result including status, outputs, and timing.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
user_inputs = payload.inputs
|
||||
|
||||
# Get draft workflow for file parsing
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
|
||||
|
||||
workflow_node_execution = SnippetGenerateService.run_draft_node(
|
||||
snippet=snippet,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query=payload.query,
|
||||
files=files,
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
class SnippetDraftNodeLastRunApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_node_last_run")
|
||||
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def get(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Get the last run result for a specific node in snippet draft workflow.
|
||||
|
||||
Returns the most recent execution record for the given node,
|
||||
including status, inputs, outputs, and timing information.
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
node_exec = snippet_service.get_snippet_node_last_run(
|
||||
snippet=snippet,
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
)
|
||||
if node_exec is None:
|
||||
raise NotFound("Node last run not found")
|
||||
|
||||
return node_exec
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunIterationNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow iteration node for snippet.
|
||||
|
||||
Iteration nodes execute their internal sub-graph multiple times over an input list.
|
||||
Returns an SSE event stream with iteration progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_iteration(
|
||||
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunLoopNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow loop node for snippet.
|
||||
|
||||
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
|
||||
Returns an SSE event stream with loop progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_loop(
|
||||
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
|
||||
class SnippetDraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""
|
||||
Run draft workflow for snippet.
|
||||
|
||||
Executes the snippet's draft workflow with the provided inputs
|
||||
and returns an SSE event stream with execution progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate(
|
||||
snippet=snippet,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class SnippetWorkflowTaskStopApi(Resource):
|
||||
@console_ns.doc("stop_snippet_workflow_task")
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, task_id: str):
|
||||
"""
|
||||
Stop a running snippet workflow task.
|
||||
|
||||
Uses both the legacy stop flag mechanism and the graph engine
|
||||
command channel for backward compatibility.
|
||||
"""
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
@@ -1,202 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.snippets.payloads import (
|
||||
CreateSnippetPayload,
|
||||
SnippetListQuery,
|
||||
UpdateSnippetPayload,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import SnippetType
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetListQuery,
|
||||
CreateSnippetPayload,
|
||||
UpdateSnippetPayload,
|
||||
)
|
||||
|
||||
# Create namespace models for marshaling
|
||||
snippet_model = console_ns.model("Snippet", snippet_fields)
|
||||
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
|
||||
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets")
|
||||
class CustomizedSnippetsApi(Resource):
|
||||
@console_ns.doc("list_customized_snippets")
|
||||
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
|
||||
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List customized snippets with pagination and search."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query_params = request.args.to_dict()
|
||||
query = SnippetListQuery.model_validate(query_params)
|
||||
|
||||
snippets, total, has_more = SnippetService.get_snippets(
|
||||
tenant_id=current_tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": marshal(snippets, snippet_list_fields),
|
||||
"page": query.page,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"has_more": has_more,
|
||||
}, 200
|
||||
|
||||
@console_ns.doc("create_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
|
||||
@console_ns.response(201, "Snippet created successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Create a new customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_type = SnippetType(payload.type)
|
||||
except ValueError:
|
||||
snippet_type = SnippetType.NODE
|
||||
|
||||
try:
|
||||
snippet = SnippetService.create_snippet(
|
||||
tenant_id=current_tenant_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
snippet_type=snippet_type,
|
||||
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
|
||||
graph=payload.graph,
|
||||
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
|
||||
account=current_user,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
|
||||
class CustomizedSnippetDetailApi(Resource):
|
||||
@console_ns.doc("get_customized_snippet")
|
||||
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Get customized snippet details."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("update_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
|
||||
@console_ns.response(200, "Snippet updated successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, snippet_id: str):
|
||||
"""Update customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "icon_info" in update_data and update_data["icon_info"] is not None:
|
||||
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
snippet = session.merge(snippet)
|
||||
snippet = SnippetService.update_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("delete_customized_snippet")
|
||||
@console_ns.response(204, "Snippet deleted successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, snippet_id: str):
|
||||
"""Delete customized snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.delete_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return "", 204
|
||||
@@ -8,9 +8,9 @@ from sqlalchemy.orm import Session
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types as mcp_types
|
||||
from core.mcp.server.streamable_http import handle_mcp_request
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
|
||||
@@ -32,7 +32,7 @@ from core.model_runtime.entities import (
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
|
||||
@@ -17,8 +17,8 @@ from core.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
|
||||
@@ -21,7 +21,7 @@ from core.model_runtime.entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.model_runtime.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.file import file_manager
|
||||
|
||||
@@ -5,7 +5,7 @@ from core.app.app_config.entities import (
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
[
|
||||
|
||||
@@ -2,12 +2,12 @@ from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from jsonschema import Draft7Validator, SchemaError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.file import FileUploadConfig
|
||||
from core.workflow.variables.input_entities import VariableEntity as WorkflowVariableEntity
|
||||
from core.workflow.file import FileTransferMethod, FileType, FileUploadConfig
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
@@ -90,7 +90,61 @@ class PromptTemplateEntity(BaseModel):
|
||||
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
|
||||
|
||||
|
||||
class RagPipelineVariableEntity(WorkflowVariableEntity):
|
||||
class VariableEntityType(StrEnum):
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
FILE = "file"
|
||||
FILE_LIST = "file-list"
|
||||
CHECKBOX = "checkbox"
|
||||
JSON_OBJECT = "json_object"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
"""
|
||||
Variable Entity.
|
||||
"""
|
||||
|
||||
# `variable` records the name of the variable in user inputs.
|
||||
variable: str
|
||||
label: str
|
||||
description: str = ""
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
hide: bool = False
|
||||
default: Any = None
|
||||
max_length: int | None = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: dict | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
def convert_none_description(cls, v: Any) -> str:
|
||||
return v or ""
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
||||
return v or []
|
||||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: dict | None) -> dict | None:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
Draft7Validator.check_schema(schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
||||
|
||||
class RagPipelineVariableEntity(VariableEntity):
|
||||
"""
|
||||
Rag Pipeline Variable Entity.
|
||||
"""
|
||||
@@ -260,7 +314,7 @@ class AppConfig(BaseModel):
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
additional_features: AppAdditionalFeatures | None = None
|
||||
variables: list[WorkflowVariableEntity] = []
|
||||
variables: list[VariableEntity] = []
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
|
||||
@@ -32,8 +32,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.model_runtime.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.file import File, FileUploadConfig
|
||||
@@ -11,14 +12,13 @@ from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
NoopDraftVariableSaver,
|
||||
)
|
||||
from core.workflow.variables.input_entities import VariableEntityType
|
||||
from factories import file_factory
|
||||
from libs.orjson import orjson_dumps
|
||||
from models import Account, EndUser
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
|
||||
@@ -33,10 +33,14 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
|
||||
ChatModelMessage,
|
||||
CompletionModelPromptTemplate,
|
||||
MemoryConfig,
|
||||
)
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -27,7 +27,7 @@ from core.app.entities.task_entities import (
|
||||
CompletionAppStreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
|
||||
@@ -1,5 +1 @@
|
||||
"""LLM-related application services."""
|
||||
|
||||
from .quota import deduct_llm_quota, ensure_llm_quota_available
|
||||
|
||||
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
provider_model = provider_configuration.get_provider_model(
|
||||
model_type=model_instance.model_type_instance.model_type,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
|
||||
|
||||
|
||||
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
@@ -52,10 +52,10 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from core.workflow.file.enums import FileTransferMethod
|
||||
@@ -157,7 +157,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
message_id=self._message_id,
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
@@ -170,7 +170,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
@@ -283,7 +283,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self.handle_output_moderation_when_task_finished(
|
||||
self._task_state.llm_result.message.get_text_content()
|
||||
cast(str, self._task_state.llm_result.message.content)
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
@@ -397,7 +397,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(llm_result.message.get_text_content().strip())
|
||||
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
|
||||
|
||||
from .llm_quota import LLMQuotaLayer
|
||||
from .observability import ObservabilityLayer
|
||||
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
|
||||
__all__ = [
|
||||
"LLMQuotaLayer",
|
||||
"ObservabilityLayer",
|
||||
"PersistenceWorkflowInfo",
|
||||
"WorkflowPersistenceLayer",
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
"""
|
||||
LLM quota deduction layer for GraphEngine.
|
||||
|
||||
This layer centralizes model-quota deduction outside node implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class LLMQuotaLayer(GraphEngineLayer):
|
||||
"""Graph layer that applies LLM quota deduction after node execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._abort_sent = False
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
self._abort_sent = False
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
_ = event
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
_ = error
|
||||
|
||||
@override
|
||||
def on_node_run_start(self, node: Node) -> None:
|
||||
if self._abort_sent:
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
return
|
||||
|
||||
try:
|
||||
ensure_llm_quota_available(model_instance=model_instance)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
|
||||
|
||||
@override
|
||||
def on_node_run_end(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
return
|
||||
|
||||
try:
|
||||
deduct_llm_quota(
|
||||
tenant_id=node.tenant_id,
|
||||
model_instance=model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
|
||||
except Exception:
|
||||
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
|
||||
|
||||
@staticmethod
|
||||
def _set_stop_event(node: Node) -> None:
|
||||
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
|
||||
if stop_event is not None:
|
||||
stop_event.set()
|
||||
|
||||
def _send_abort_command(self, *, reason: str) -> None:
|
||||
if not self.command_channel or self._abort_sent:
|
||||
return
|
||||
|
||||
try:
|
||||
self.command_channel.send_command(
|
||||
AbortCommand(
|
||||
command_type=CommandType.ABORT,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
self._abort_sent = True
|
||||
except Exception:
|
||||
logger.exception("Failed to send quota abort command")
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_instance(node: Node) -> ModelInstance | None:
|
||||
try:
|
||||
match node.node_type:
|
||||
case NodeType.LLM:
|
||||
return cast("LLMNode", node).model_instance
|
||||
case NodeType.PARAMETER_EXTRACTOR:
|
||||
return cast("ParameterExtractorNode", node).model_instance
|
||||
case NodeType.QUESTION_CLASSIFIER:
|
||||
return cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
@@ -1,8 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, cast, final
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
@@ -13,16 +11,14 @@ from core.helper.code_executor.code_executor import (
|
||||
CodeExecutor,
|
||||
)
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType, SystemVariableKey
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.file.file_manager import file_manager
|
||||
from core.workflow.graph.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@@ -33,9 +29,11 @@ from core.workflow.nodes.datasource import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm.protocols import PromptMessageMemory
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
@@ -43,34 +41,12 @@ from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
*,
|
||||
conversation_id: str | None,
|
||||
app_id: str,
|
||||
node_data_memory: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory or not conversation_id:
|
||||
return None
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
|
||||
class DefaultWorkflowCodeExecutor:
|
||||
def execute(
|
||||
self,
|
||||
@@ -245,7 +221,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@@ -254,12 +229,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@@ -268,7 +241,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
@@ -323,14 +295,8 @@ class DifyNodeFactory(NodeFactory):
|
||||
return None
|
||||
|
||||
node_memory = MemoryConfig.model_validate(raw_memory_config)
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
conversation_id = (
|
||||
conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None
|
||||
)
|
||||
return fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
return llm_utils.fetch_memory(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
app_id=self.graph_init_params.app_id,
|
||||
node_data_memory=node_memory,
|
||||
model_instance=model_instance,
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
|
||||
|
||||
class BaseEvaluationInstance(ABC):
|
||||
"""Abstract base class for evaluation framework adapters."""
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate LLM outputs using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate retrieval quality using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate agent outputs using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate workflow outputs using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
"""Return the list of supported metric names for a given evaluation category."""
|
||||
...
|
||||
@@ -1,25 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EvaluationFrameworkEnum(StrEnum):
|
||||
RAGAS = "ragas"
|
||||
DEEPEVAL = "deepeval"
|
||||
CUSTOMIZED = "customized"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class BaseEvaluationConfig(BaseModel):
|
||||
"""Base configuration for evaluation frameworks."""
|
||||
pass
|
||||
|
||||
|
||||
class RagasConfig(BaseEvaluationConfig):
|
||||
"""RAGAS-specific configuration."""
|
||||
pass
|
||||
|
||||
|
||||
class CustomizedEvaluatorConfig(BaseEvaluationConfig):
|
||||
"""Configuration for the customized workflow-based evaluator."""
|
||||
pass
|
||||
@@ -1,94 +0,0 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.evaluation.entities.judgment_entity import JudgmentConfig, JudgmentResult
|
||||
|
||||
|
||||
class EvaluationCategory(StrEnum):
|
||||
LLM = "llm"
|
||||
RETRIEVAL = "knowledge_retrieval"
|
||||
AGENT = "agent"
|
||||
WORKFLOW = "workflow"
|
||||
RETRIEVAL_TEST = "retrieval_test"
|
||||
|
||||
|
||||
class EvaluationMetric(BaseModel):
|
||||
name: str
|
||||
score: float
|
||||
details: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EvaluationItemInput(BaseModel):
|
||||
index: int
|
||||
inputs: dict[str, Any]
|
||||
expected_output: str | None = None
|
||||
context: list[str] | None = None
|
||||
|
||||
|
||||
class EvaluationItemResult(BaseModel):
|
||||
index: int
|
||||
actual_output: str | None = None
|
||||
metrics: list[EvaluationMetric] = Field(default_factory=list)
|
||||
judgment: JudgmentResult | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
|
||||
@property
|
||||
def overall_score(self) -> float | None:
|
||||
if not self.metrics:
|
||||
return None
|
||||
scores = [m.score for m in self.metrics]
|
||||
return sum(scores) / len(scores)
|
||||
|
||||
|
||||
class NodeInfo(BaseModel):
|
||||
node_id: str
|
||||
type: str
|
||||
title: str
|
||||
|
||||
|
||||
class DefaultMetric(BaseModel):
|
||||
metric: str
|
||||
node_info_list: list[NodeInfo]
|
||||
|
||||
|
||||
class CustomizedMetricOutputField(BaseModel):
|
||||
variable: str
|
||||
value_type: str
|
||||
|
||||
|
||||
class CustomizedMetrics(BaseModel):
|
||||
evaluation_workflow_id: str
|
||||
input_fields: dict[str, str]
|
||||
output_fields: list[CustomizedMetricOutputField]
|
||||
|
||||
|
||||
class EvaluationConfigData(BaseModel):
|
||||
"""Structured data for saving evaluation configuration."""
|
||||
evaluation_model: str = ""
|
||||
evaluation_model_provider: str = ""
|
||||
default_metrics: list[DefaultMetric] = Field(default_factory=list)
|
||||
customized_metrics: CustomizedMetrics | None = None
|
||||
judgment_config: JudgmentConfig | None = None
|
||||
|
||||
|
||||
class EvaluationRunRequest(EvaluationConfigData):
|
||||
"""Request body for starting an evaluation run."""
|
||||
file_id: str
|
||||
|
||||
|
||||
class EvaluationRunData(BaseModel):
|
||||
"""Serializable data for Celery task."""
|
||||
evaluation_run_id: str
|
||||
tenant_id: str
|
||||
target_type: str
|
||||
target_id: str
|
||||
evaluation_category: EvaluationCategory
|
||||
evaluation_model_provider: str
|
||||
evaluation_model: str
|
||||
default_metrics: list[dict[str, Any]] = Field(default_factory=list)
|
||||
customized_metrics: dict[str, Any] | None = None
|
||||
judgment_config: JudgmentConfig | None = None
|
||||
items: list[EvaluationItemInput]
|
||||
@@ -1,84 +0,0 @@
|
||||
"""Judgment condition entities for evaluation metric assessment.
|
||||
|
||||
Typical usage:
|
||||
judgment_config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(metric_name="faithfulness", comparison_operator=">", value="0.8"),
|
||||
JudgmentCondition(metric_name="answer_relevancy", comparison_operator="≥", value="0.7"),
|
||||
],
|
||||
)
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.utils.condition.entities import SupportedComparisonOperator
|
||||
|
||||
|
||||
class JudgmentCondition(BaseModel):
|
||||
"""A single judgment condition that checks one metric value.
|
||||
|
||||
Attributes:
|
||||
metric_name: The name of the evaluation metric to check
|
||||
(must match an EvaluationMetric.name in the results).
|
||||
comparison_operator: The comparison operator to apply
|
||||
(reuses the same operator set as workflow condition branches).
|
||||
value: The expected/threshold value to compare against.
|
||||
For numeric operators (>, <, =, etc.), this should be a numeric string.
|
||||
For string operators (contains, is, etc.), this should be a string.
|
||||
For unary operators (empty, null, etc.), this can be None.
|
||||
"""
|
||||
|
||||
metric_name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
class JudgmentConfig(BaseModel):
|
||||
"""A group of judgment conditions combined with a logical operator.
|
||||
|
||||
Attributes:
|
||||
logical_operator: How to combine condition results — "and" requires
|
||||
all conditions to pass, "or" requires at least one.
|
||||
conditions: The list of individual conditions to evaluate.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
conditions: list[JudgmentCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class JudgmentConditionResult(BaseModel):
|
||||
"""Result of evaluating a single judgment condition.
|
||||
|
||||
Attributes:
|
||||
metric_name: Which metric was checked.
|
||||
comparison_operator: The operator that was applied.
|
||||
expected_value: The threshold/expected value from the condition config.
|
||||
actual_value: The actual metric value that was evaluated.
|
||||
passed: Whether this individual condition passed.
|
||||
error: Error message if the condition evaluation failed.
|
||||
"""
|
||||
|
||||
metric_name: str
|
||||
comparison_operator: str
|
||||
expected_value: Any = None
|
||||
actual_value: Any = None
|
||||
passed: bool = False
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class JudgmentResult(BaseModel):
|
||||
"""Overall result of evaluating all judgment conditions for one item.
|
||||
|
||||
Attributes:
|
||||
passed: Whether the overall judgment passed (based on logical_operator).
|
||||
logical_operator: The logical operator used to combine conditions.
|
||||
condition_results: Detailed result for each individual condition.
|
||||
"""
|
||||
|
||||
passed: bool = False
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
condition_results: list[JudgmentConditionResult] = Field(default_factory=list)
|
||||
@@ -1,69 +0,0 @@
|
||||
import collections
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import EvaluationFrameworkEnum
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvaluationFrameworkConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
"""Registry mapping framework enum -> {config_class, evaluator_class}."""
|
||||
|
||||
def __getitem__(self, framework: str) -> dict[str, Any]:
|
||||
match framework:
|
||||
case EvaluationFrameworkEnum.RAGAS:
|
||||
from core.evaluation.entities.config_entity import RagasConfig
|
||||
from core.evaluation.frameworks.ragas.ragas_evaluator import RagasEvaluator
|
||||
|
||||
return {
|
||||
"config_class": RagasConfig,
|
||||
"evaluator_class": RagasEvaluator,
|
||||
}
|
||||
case EvaluationFrameworkEnum.DEEPEVAL:
|
||||
raise NotImplementedError("DeepEval adapter is not yet implemented.")
|
||||
case EvaluationFrameworkEnum.CUSTOMIZED:
|
||||
from core.evaluation.entities.config_entity import CustomizedEvaluatorConfig
|
||||
from core.evaluation.frameworks.customized.customized_evaluator import CustomizedEvaluator
|
||||
|
||||
return {
|
||||
"config_class": CustomizedEvaluatorConfig,
|
||||
"evaluator_class": CustomizedEvaluator,
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unknown evaluation framework: {framework}")
|
||||
|
||||
|
||||
evaluation_framework_config_map = EvaluationFrameworkConfigMap()
|
||||
|
||||
|
||||
class EvaluationManager:
|
||||
"""Factory for evaluation instances based on global configuration."""
|
||||
|
||||
@staticmethod
|
||||
def get_evaluation_instance() -> BaseEvaluationInstance | None:
|
||||
"""Create and return an evaluation instance based on EVALUATION_FRAMEWORK env var."""
|
||||
framework = dify_config.EVALUATION_FRAMEWORK
|
||||
if not framework or framework == EvaluationFrameworkEnum.NONE:
|
||||
return None
|
||||
|
||||
try:
|
||||
config_map = evaluation_framework_config_map[framework]
|
||||
evaluator_class = config_map["evaluator_class"]
|
||||
config_class = config_map["config_class"]
|
||||
config = config_class()
|
||||
return evaluator_class(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to create evaluation instance for framework: %s", framework)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_supported_metrics(category: EvaluationCategory) -> list[str]:
|
||||
"""Return supported metrics for the current framework and given category."""
|
||||
instance = EvaluationManager.get_evaluation_instance()
|
||||
if instance is None:
|
||||
return []
|
||||
return instance.get_supported_metrics(category)
|
||||
@@ -1,267 +0,0 @@
|
||||
"""Customized workflow-based evaluator.
|
||||
|
||||
Uses a published workflow as the evaluation strategy. The target's actual output,
|
||||
expected output, original inputs, and context are passed as workflow inputs.
|
||||
The workflow's output variables are treated as evaluation metrics.
|
||||
|
||||
The evaluation workflow_id is provided per evaluation run via
|
||||
metrics_config["workflow_id"].
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import CustomizedEvaluatorConfig
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationMetric,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomizedEvaluator(BaseEvaluationInstance):
|
||||
"""Evaluate using a published workflow."""
|
||||
|
||||
def __init__(self, config: CustomizedEvaluatorConfig):
|
||||
self.config = config
|
||||
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate_with_workflow(items, metrics_config, tenant_id)
|
||||
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate_with_workflow(items, metrics_config, tenant_id)
|
||||
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate_with_workflow(items, metrics_config, tenant_id)
|
||||
|
||||
def evaluate_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate_with_workflow(items, metrics_config, tenant_id)
|
||||
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
"""Metrics are dynamic and defined by the evaluation workflow outputs.
|
||||
|
||||
Return an empty list since available metrics depend on the specific
|
||||
workflow chosen at runtime.
|
||||
"""
|
||||
return []
|
||||
|
||||
def _evaluate_with_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Run the evaluation workflow for each item and extract metric scores.
|
||||
|
||||
Args:
|
||||
items: Evaluation items with inputs, expected_output, and context
|
||||
(context typically contains the target's actual_output, merged
|
||||
by the Runner's evaluate_metrics method).
|
||||
metrics_config: Must contain "workflow_id" pointing to a published
|
||||
WORKFLOW-type App.
|
||||
tenant_id: Tenant scope for database and workflow execution.
|
||||
|
||||
Returns:
|
||||
List of EvaluationItemResult with metrics extracted from workflow outputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_id is missing from metrics_config or the
|
||||
workflow/app cannot be found.
|
||||
"""
|
||||
workflow_id = metrics_config.get("workflow_id")
|
||||
if not workflow_id:
|
||||
raise ValueError(
|
||||
"metrics_config must contain 'workflow_id' for customized evaluator"
|
||||
)
|
||||
|
||||
app, workflow, service_account = self._load_workflow_resources(workflow_id, tenant_id)
|
||||
|
||||
results: list[EvaluationItemResult] = []
|
||||
for item in items:
|
||||
try:
|
||||
result = self._evaluate_single_item(app, workflow, service_account, item)
|
||||
results.append(result)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Customized evaluator failed for item %d with workflow %s",
|
||||
item.index,
|
||||
workflow_id,
|
||||
)
|
||||
results.append(EvaluationItemResult(index=item.index))
|
||||
return results
|
||||
|
||||
def _evaluate_single_item(
|
||||
self,
|
||||
app: Any,
|
||||
workflow: Any,
|
||||
service_account: Any,
|
||||
item: EvaluationItemInput,
|
||||
) -> EvaluationItemResult:
|
||||
"""Run the evaluation workflow for a single item.
|
||||
|
||||
Builds workflow inputs from the item data and executes the workflow
|
||||
in non-streaming mode. Extracts metrics from the workflow's output
|
||||
variables.
|
||||
"""
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
workflow_inputs = self._build_workflow_inputs(item)
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
response: Mapping[str, Any] = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=service_account,
|
||||
args={"inputs": workflow_inputs},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
metrics = self._extract_metrics(response)
|
||||
return EvaluationItemResult(
|
||||
index=item.index,
|
||||
metrics=metrics,
|
||||
metadata={"workflow_response": self._safe_serialize(response)},
|
||||
)
|
||||
|
||||
def _load_workflow_resources(
|
||||
self, workflow_id: str, tenant_id: str
|
||||
) -> tuple[Any, Any, Any]:
|
||||
"""Load the evaluation workflow App, its published workflow, and a service account.
|
||||
|
||||
Args:
|
||||
workflow_id: The App ID of the evaluation workflow.
|
||||
tenant_id: Tenant scope.
|
||||
|
||||
Returns:
|
||||
Tuple of (app, workflow, service_account).
|
||||
|
||||
Raises:
|
||||
ValueError: If the app or published workflow cannot be found.
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.evaluation.runners import get_service_account_for_app
|
||||
from models.engine import db
|
||||
from models.model import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
app = session.query(App).filter_by(id=workflow_id, tenant_id=tenant_id).first()
|
||||
if not app:
|
||||
raise ValueError(
|
||||
f"Evaluation workflow app {workflow_id} not found in tenant {tenant_id}"
|
||||
)
|
||||
|
||||
service_account = get_service_account_for_app(session, workflow_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
published_workflow = workflow_service.get_published_workflow(app_model=app)
|
||||
if not published_workflow:
|
||||
raise ValueError(
|
||||
f"No published workflow found for evaluation app {workflow_id}"
|
||||
)
|
||||
|
||||
return app, published_workflow, service_account
|
||||
|
||||
@staticmethod
|
||||
def _build_workflow_inputs(item: EvaluationItemInput) -> dict[str, Any]:
|
||||
"""Build workflow input dict from an evaluation item.
|
||||
|
||||
Maps evaluation data to conventional workflow input variable names:
|
||||
- actual_output: The target's actual output (from context[0] if available)
|
||||
- expected_output: The expected/reference output
|
||||
- inputs: The original evaluation inputs as JSON string
|
||||
- context: All context strings joined by newlines
|
||||
|
||||
"""
|
||||
workflow_inputs: dict[str, Any] = {}
|
||||
|
||||
# The actual_output is typically the first element in context
|
||||
# (merged by the Runner's evaluate_metrics method)
|
||||
if item.context:
|
||||
workflow_inputs["actual_output"] = item.context[0] if len(item.context) == 1 else "\n\n".join(item.context)
|
||||
|
||||
if item.expected_output:
|
||||
workflow_inputs["expected_output"] = item.expected_output
|
||||
|
||||
if item.inputs:
|
||||
workflow_inputs["inputs"] = json.dumps(item.inputs, ensure_ascii=False)
|
||||
|
||||
if item.context and len(item.context) > 1:
|
||||
workflow_inputs["context"] = "\n\n".join(item.context)
|
||||
|
||||
return workflow_inputs
|
||||
|
||||
@staticmethod
|
||||
def _extract_metrics(response: Mapping[str, Any]) -> list[EvaluationMetric]:
|
||||
"""Extract evaluation metrics from workflow output variables.
|
||||
|
||||
Each output variable is treated as a metric.
|
||||
"""
|
||||
metrics: list[EvaluationMetric] = []
|
||||
|
||||
data = response.get("data", {})
|
||||
if not isinstance(data, Mapping):
|
||||
logger.warning("Unexpected workflow response format: missing 'data' dict")
|
||||
return metrics
|
||||
|
||||
outputs = data.get("outputs", {})
|
||||
if not isinstance(outputs, Mapping):
|
||||
logger.warning("Unexpected workflow response format: 'outputs' is not a dict")
|
||||
return metrics
|
||||
|
||||
for key, value in outputs.items():
|
||||
try:
|
||||
score = float(value)
|
||||
metrics.append(EvaluationMetric(name=key, score=score))
|
||||
except (TypeError, ValueError):
|
||||
metrics.append(
|
||||
EvaluationMetric(name=key, score=0.0, details={"raw_value": value})
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def _safe_serialize(response: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Safely serialize workflow response for metadata storage."""
|
||||
try:
|
||||
return dict(response)
|
||||
except Exception:
|
||||
return {"raw": str(response)}
|
||||
@@ -1,279 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import RagasConfig
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationMetric,
|
||||
)
|
||||
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Metric name mappings per category
|
||||
LLM_METRICS = ["faithfulness", "answer_relevancy", "answer_correctness", "answer_similarity"]
|
||||
RETRIEVAL_METRICS = ["context_precision", "context_recall", "context_relevancy"]
|
||||
AGENT_METRICS = ["tool_call_accuracy", "answer_correctness"]
|
||||
WORKFLOW_METRICS = ["faithfulness", "answer_correctness"]
|
||||
|
||||
|
||||
class RagasEvaluator(BaseEvaluationInstance):
|
||||
"""RAGAS framework adapter for evaluation."""
|
||||
|
||||
def __init__(self, config: RagasConfig):
|
||||
self.config = config
|
||||
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
return LLM_METRICS
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
return RETRIEVAL_METRICS
|
||||
case EvaluationCategory.AGENT:
|
||||
return AGENT_METRICS
|
||||
case EvaluationCategory.WORKFLOW:
|
||||
return WORKFLOW_METRICS
|
||||
case _:
|
||||
return []
|
||||
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metrics_config, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
|
||||
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(
|
||||
items, metrics_config, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL
|
||||
)
|
||||
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metrics_config, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
|
||||
|
||||
def evaluate_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(
|
||||
items, metrics_config, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW
|
||||
)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Core evaluation logic using RAGAS.
|
||||
|
||||
Uses the Dify model wrapper as judge LLM. Falls back to simple
|
||||
string similarity if RAGAS import fails.
|
||||
"""
|
||||
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
|
||||
requested_metrics = metrics_config.get("metrics", self.get_supported_metrics(category))
|
||||
|
||||
try:
|
||||
return self._evaluate_with_ragas(items, requested_metrics, model_wrapper, category)
|
||||
except ImportError:
|
||||
logger.warning("RAGAS not installed, falling back to simple evaluation")
|
||||
return self._evaluate_simple(items, requested_metrics, model_wrapper)
|
||||
|
||||
def _evaluate_with_ragas(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
model_wrapper: DifyModelWrapper,
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate using RAGAS library."""
|
||||
from ragas import evaluate as ragas_evaluate
|
||||
from ragas.dataset_schema import EvaluationDataset, SingleTurnSample
|
||||
from ragas.llms import LangchainLLMWrapper
|
||||
from ragas.metrics import (
|
||||
Faithfulness,
|
||||
ResponseRelevancy,
|
||||
)
|
||||
|
||||
# Build RAGAS dataset
|
||||
samples = []
|
||||
for item in items:
|
||||
sample = SingleTurnSample(
|
||||
user_input=self._inputs_to_query(item.inputs),
|
||||
response=item.expected_output or "",
|
||||
retrieved_contexts=item.context or [],
|
||||
)
|
||||
if item.expected_output:
|
||||
sample.reference = item.expected_output
|
||||
samples.append(sample)
|
||||
|
||||
dataset = EvaluationDataset(samples=samples)
|
||||
|
||||
# Build metric instances
|
||||
ragas_metrics = self._build_ragas_metrics(requested_metrics)
|
||||
|
||||
if not ragas_metrics:
|
||||
logger.warning("No valid RAGAS metrics found for: %s", requested_metrics)
|
||||
return [EvaluationItemResult(index=item.index) for item in items]
|
||||
|
||||
# Run RAGAS evaluation
|
||||
try:
|
||||
result = ragas_evaluate(
|
||||
dataset=dataset,
|
||||
metrics=ragas_metrics,
|
||||
)
|
||||
|
||||
# Convert RAGAS results to our format
|
||||
results = []
|
||||
result_df = result.to_pandas()
|
||||
for i, item in enumerate(items):
|
||||
metrics = []
|
||||
for metric_name in requested_metrics:
|
||||
if metric_name in result_df.columns:
|
||||
score = result_df.iloc[i][metric_name]
|
||||
if score is not None and not (isinstance(score, float) and score != score): # NaN check
|
||||
metrics.append(EvaluationMetric(name=metric_name, score=float(score)))
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
except Exception:
|
||||
logger.exception("RAGAS evaluation failed, falling back to simple evaluation")
|
||||
return self._evaluate_simple(items, requested_metrics, model_wrapper)
|
||||
|
||||
def _evaluate_simple(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
model_wrapper: DifyModelWrapper,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Simple LLM-as-judge fallback when RAGAS is not available."""
|
||||
results = []
|
||||
for item in items:
|
||||
metrics = []
|
||||
query = self._inputs_to_query(item.inputs)
|
||||
|
||||
for metric_name in requested_metrics:
|
||||
try:
|
||||
score = self._judge_with_llm(model_wrapper, metric_name, query, item)
|
||||
metrics.append(EvaluationMetric(name=metric_name, score=score))
|
||||
except Exception:
|
||||
logger.exception("Failed to compute metric %s for item %d", metric_name, item.index)
|
||||
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
|
||||
def _judge_with_llm(
|
||||
self,
|
||||
model_wrapper: DifyModelWrapper,
|
||||
metric_name: str,
|
||||
query: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> float:
|
||||
"""Use the LLM to judge a single metric for a single item."""
|
||||
prompt = self._build_judge_prompt(metric_name, query, item)
|
||||
response = model_wrapper.invoke(prompt)
|
||||
return self._parse_score(response)
|
||||
|
||||
def _build_judge_prompt(self, metric_name: str, query: str, item: EvaluationItemInput) -> str:
|
||||
"""Build a scoring prompt for the LLM judge."""
|
||||
parts = [
|
||||
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
|
||||
f"\nQuery: {query}",
|
||||
]
|
||||
if item.expected_output:
|
||||
parts.append(f"\nExpected Output: {item.expected_output}")
|
||||
if item.context:
|
||||
parts.append(f"\nContext: {'; '.join(item.context)}")
|
||||
parts.append(
|
||||
"\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else."
|
||||
)
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _parse_score(response: str) -> float:
|
||||
"""Parse a float score from LLM response."""
|
||||
cleaned = response.strip()
|
||||
try:
|
||||
score = float(cleaned)
|
||||
return max(0.0, min(1.0, score))
|
||||
except ValueError:
|
||||
# Try to extract first number from response
|
||||
import re
|
||||
|
||||
match = re.search(r"(\d+\.?\d*)", cleaned)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
return max(0.0, min(1.0, score))
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _inputs_to_query(inputs: dict[str, Any]) -> str:
|
||||
"""Convert input dict to a query string."""
|
||||
if "query" in inputs:
|
||||
return str(inputs["query"])
|
||||
if "question" in inputs:
|
||||
return str(inputs["question"])
|
||||
# Fallback: concatenate all input values
|
||||
return " ".join(str(v) for v in inputs.values())
|
||||
|
||||
@staticmethod
|
||||
def _build_ragas_metrics(requested_metrics: list[str]) -> list[Any]:
|
||||
"""Build RAGAS metric instances from metric names."""
|
||||
try:
|
||||
from ragas.metrics import (
|
||||
AnswerCorrectness,
|
||||
AnswerRelevancy,
|
||||
AnswerSimilarity,
|
||||
ContextPrecision,
|
||||
ContextRecall,
|
||||
ContextRelevancy,
|
||||
Faithfulness,
|
||||
)
|
||||
|
||||
metric_map: dict[str, Any] = {
|
||||
"faithfulness": Faithfulness,
|
||||
"answer_relevancy": AnswerRelevancy,
|
||||
"answer_correctness": AnswerCorrectness,
|
||||
"answer_similarity": AnswerSimilarity,
|
||||
"context_precision": ContextPrecision,
|
||||
"context_recall": ContextRecall,
|
||||
"context_relevancy": ContextRelevancy,
|
||||
}
|
||||
|
||||
metrics = []
|
||||
for name in requested_metrics:
|
||||
metric_class = metric_map.get(name)
|
||||
if metric_class:
|
||||
metrics.append(metric_class())
|
||||
else:
|
||||
logger.warning("Unknown RAGAS metric: %s", name)
|
||||
return metrics
|
||||
except ImportError:
|
||||
logger.warning("RAGAS metrics not available")
|
||||
return []
|
||||
@@ -1,48 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyModelWrapper:
|
||||
"""Wraps Dify's model invocation interface for use by RAGAS as an LLM judge.
|
||||
|
||||
RAGAS requires an LLM to compute certain metrics (faithfulness, answer_relevancy, etc.).
|
||||
This wrapper bridges Dify's ModelInstance to a callable that RAGAS can use.
|
||||
"""
|
||||
|
||||
def __init__(self, model_provider: str, model_name: str, tenant_id: str):
|
||||
self.model_provider = model_provider
|
||||
self.model_name = model_name
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def _get_model_instance(self) -> Any:
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=self.model_provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=self.model_name,
|
||||
)
|
||||
return model_instance
|
||||
|
||||
def invoke(self, prompt: str) -> str:
|
||||
"""Invoke the model with a text prompt and return the text response."""
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
model_instance = self._get_model_instance()
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content="You are an evaluation judge. Answer precisely and concisely."),
|
||||
UserPromptMessage(content=prompt),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 2048},
|
||||
stream=False,
|
||||
)
|
||||
return result.message.content
|
||||
@@ -1,144 +0,0 @@
|
||||
"""Judgment condition processor for evaluation metrics.
|
||||
|
||||
Evaluates pass/fail judgment conditions against evaluation metric values.
|
||||
Reuses the core comparison engine from the workflow condition system
|
||||
(core.workflow.utils.condition.processor._evaluate_condition) to ensure
|
||||
consistent operator semantics across the platform.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.entities.judgment_entity import (
|
||||
JudgmentCondition,
|
||||
JudgmentConditionResult,
|
||||
JudgmentConfig,
|
||||
JudgmentResult,
|
||||
)
|
||||
from core.workflow.utils.condition.processor import _evaluate_condition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JudgmentProcessor:
|
||||
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
metric_values: dict[str, Any],
|
||||
config: JudgmentConfig,
|
||||
) -> JudgmentResult:
|
||||
"""Evaluate all judgment conditions against the given metric values.
|
||||
|
||||
Args:
|
||||
metric_values: Mapping of metric name to its value
|
||||
(e.g. {"faithfulness": 0.85, "status": "success"}).
|
||||
config: The judgment configuration with logical_operator and conditions.
|
||||
|
||||
Returns:
|
||||
JudgmentResult with overall pass/fail and per-condition details.
|
||||
"""
|
||||
if not config.conditions:
|
||||
return JudgmentResult(
|
||||
passed=True,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=[],
|
||||
)
|
||||
|
||||
condition_results: list[JudgmentConditionResult] = []
|
||||
|
||||
for condition in config.conditions:
|
||||
result = JudgmentProcessor._evaluate_single_condition(
|
||||
metric_values, condition
|
||||
)
|
||||
condition_results.append(result)
|
||||
|
||||
if config.logical_operator == "and" and not result.passed:
|
||||
return JudgmentResult(
|
||||
passed=False,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
if config.logical_operator == "or" and result.passed:
|
||||
return JudgmentResult(
|
||||
passed=True,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
|
||||
if config.logical_operator == "and":
|
||||
final_passed = all(r.passed for r in condition_results)
|
||||
else:
|
||||
final_passed = any(r.passed for r in condition_results)
|
||||
|
||||
return JudgmentResult(
|
||||
passed=final_passed,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_single_condition(
|
||||
metric_values: dict[str, Any],
|
||||
condition: JudgmentCondition,
|
||||
) -> JudgmentConditionResult:
|
||||
"""Evaluate a single judgment condition against the metric values.
|
||||
|
||||
Looks up the metric by name, then delegates to the workflow condition
|
||||
engine for the actual comparison.
|
||||
|
||||
Args:
|
||||
metric_values: Mapping of metric name to its value.
|
||||
condition: The condition to evaluate.
|
||||
|
||||
Returns:
|
||||
JudgmentConditionResult with pass/fail and details.
|
||||
"""
|
||||
metric_name = condition.metric_name
|
||||
actual_value = metric_values.get(metric_name)
|
||||
|
||||
# Handle metric not found
|
||||
if actual_value is None and condition.comparison_operator not in (
|
||||
"null",
|
||||
"not null",
|
||||
"empty",
|
||||
"not empty",
|
||||
"exists",
|
||||
"not exists",
|
||||
):
|
||||
return JudgmentConditionResult(
|
||||
metric_name=metric_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=None,
|
||||
passed=False,
|
||||
error=f"Metric '{metric_name}' not found in evaluation results",
|
||||
)
|
||||
|
||||
try:
|
||||
passed = _evaluate_condition(
|
||||
operator=condition.comparison_operator,
|
||||
value=actual_value,
|
||||
expected=condition.value,
|
||||
)
|
||||
return JudgmentConditionResult(
|
||||
metric_name=metric_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=actual_value,
|
||||
passed=passed,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Judgment condition evaluation failed for metric '%s': %s",
|
||||
metric_name,
|
||||
str(e),
|
||||
)
|
||||
return JudgmentConditionResult(
|
||||
metric_name=metric_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=actual_value,
|
||||
passed=False,
|
||||
error=str(e),
|
||||
)
|
||||
@@ -1,32 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, App, TenantAccountJoin
|
||||
|
||||
|
||||
def get_service_account_for_app(session: Session, app_id: str) -> Account:
|
||||
"""Get the creator account for an app with tenant context set up.
|
||||
|
||||
This follows the same pattern as BaseTraceInstance.get_service_account_with_tenant().
|
||||
"""
|
||||
app = session.scalar(select(App).where(App.id == app_id))
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator")
|
||||
|
||||
account = session.scalar(select(Account).where(Account.id == app.created_by))
|
||||
if not account:
|
||||
raise ValueError(f"Creator account not found for app {app_id}")
|
||||
|
||||
current_tenant = (
|
||||
session.query(TenantAccountJoin)
|
||||
.filter_by(account_id=account.id, current=True)
|
||||
.first()
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {account.id}")
|
||||
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
return account
|
||||
@@ -1,152 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Mapping, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from models.model import App, AppMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for agent evaluation: executes agent-type App, collects tool calls and final output."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
|
||||
super().__init__(evaluation_instance, session)
|
||||
|
||||
def execute_target(
|
||||
self,
|
||||
tenant_id: str,
|
||||
target_id: str,
|
||||
target_type: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> EvaluationItemResult:
|
||||
"""Execute agent app and collect response with tool call information."""
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
from core.evaluation.runners import get_service_account_for_app
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
app = self.session.query(App).filter_by(id=target_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App {target_id} not found")
|
||||
|
||||
service_account = get_service_account_for_app(self.session, target_id)
|
||||
|
||||
query = self._extract_query(item.inputs)
|
||||
args: dict[str, Any] = {
|
||||
"inputs": item.inputs,
|
||||
"query": query,
|
||||
}
|
||||
|
||||
generator = AgentChatAppGenerator()
|
||||
# Agent chat requires streaming - collect full response
|
||||
response_generator = generator.generate(
|
||||
app_model=app,
|
||||
user=service_account,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Consume the stream to get the full response
|
||||
actual_output, tool_calls = self._consume_agent_stream(response_generator)
|
||||
|
||||
return EvaluationItemResult(
|
||||
index=item.index,
|
||||
actual_output=actual_output,
|
||||
metadata={"tool_calls": tool_calls},
|
||||
)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
results: list[EvaluationItemResult],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute agent evaluation metrics."""
|
||||
result_by_index = {r.index: r for r in results}
|
||||
merged_items = []
|
||||
for item in items:
|
||||
result = result_by_index.get(item.index)
|
||||
context = []
|
||||
if result and result.actual_output:
|
||||
context.append(result.actual_output)
|
||||
merged_items.append(
|
||||
EvaluationItemInput(
|
||||
index=item.index,
|
||||
inputs=item.inputs,
|
||||
expected_output=item.expected_output,
|
||||
context=context + (item.context or []),
|
||||
)
|
||||
)
|
||||
|
||||
evaluated = self.evaluation_instance.evaluate_agent(
|
||||
merged_items, metrics_config, model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
# Merge metrics back preserving metadata
|
||||
eval_by_index = {r.index: r for r in evaluated}
|
||||
final_results = []
|
||||
for result in results:
|
||||
if result.index in eval_by_index:
|
||||
eval_result = eval_by_index[result.index]
|
||||
final_results.append(
|
||||
EvaluationItemResult(
|
||||
index=result.index,
|
||||
actual_output=result.actual_output,
|
||||
metrics=eval_result.metrics,
|
||||
metadata=result.metadata,
|
||||
error=result.error,
|
||||
)
|
||||
)
|
||||
else:
|
||||
final_results.append(result)
|
||||
return final_results
|
||||
|
||||
@staticmethod
|
||||
def _extract_query(inputs: dict[str, Any]) -> str:
|
||||
for key in ("query", "question", "input", "text"):
|
||||
if key in inputs:
|
||||
return str(inputs[key])
|
||||
values = list(inputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
|
||||
@staticmethod
|
||||
def _consume_agent_stream(response_generator: Any) -> tuple[str, list[dict]]:
|
||||
"""Consume agent streaming response and extract final answer + tool calls."""
|
||||
answer_parts: list[str] = []
|
||||
tool_calls: list[dict] = []
|
||||
|
||||
try:
|
||||
for chunk in response_generator:
|
||||
if isinstance(chunk, Mapping):
|
||||
event = chunk.get("event")
|
||||
if event == "agent_thought":
|
||||
thought = chunk.get("thought", "")
|
||||
if thought:
|
||||
answer_parts.append(thought)
|
||||
tool = chunk.get("tool")
|
||||
if tool:
|
||||
tool_calls.append({
|
||||
"tool": tool,
|
||||
"tool_input": chunk.get("tool_input", ""),
|
||||
})
|
||||
elif event == "message":
|
||||
answer = chunk.get("answer", "")
|
||||
if answer:
|
||||
answer_parts.append(answer)
|
||||
elif isinstance(chunk, str):
|
||||
answer_parts.append(chunk)
|
||||
except Exception:
|
||||
logger.exception("Error consuming agent stream")
|
||||
|
||||
return "".join(answer_parts), tool_calls
|
||||
@@ -1,171 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.entities.judgment_entity import JudgmentConfig
|
||||
from core.evaluation.judgment.processor import JudgmentProcessor
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.evaluation import EvaluationRun, EvaluationRunItem, EvaluationRunStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseEvaluationRunner(ABC):
|
||||
"""Abstract base class for evaluation runners.
|
||||
|
||||
Runners are responsible for executing the target (App/Snippet/Retrieval)
|
||||
to collect actual outputs, then delegating to the evaluation instance
|
||||
for metric computation, and optionally applying judgment conditions.
|
||||
|
||||
Execution phases:
|
||||
1. execute_target — run the target and collect actual outputs
|
||||
2. evaluate_metrics — compute evaluation metrics via the framework
|
||||
3. apply_judgment — evaluate pass/fail judgment conditions on metrics
|
||||
4. persist — save results to the database
|
||||
"""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
|
||||
self.evaluation_instance = evaluation_instance
|
||||
self.session = session
|
||||
|
||||
@abstractmethod
|
||||
def execute_target(
|
||||
self,
|
||||
tenant_id: str,
|
||||
target_id: str,
|
||||
target_type: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> EvaluationItemResult:
|
||||
"""Execute the evaluation target for a single item and return the result with actual_output populated."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
results: list[EvaluationItemResult],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute evaluation metrics on the collected results."""
|
||||
...
|
||||
|
||||
def run(
|
||||
self,
|
||||
evaluation_run_id: str,
|
||||
tenant_id: str,
|
||||
target_id: str,
|
||||
target_type: str,
|
||||
items: list[EvaluationItemInput],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
judgment_config: JudgmentConfig | None = None,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Orchestrate target execution + metric evaluation + judgment for all items.
|
||||
|
||||
"""
|
||||
evaluation_run = self.session.query(EvaluationRun).filter_by(id=evaluation_run_id).first()
|
||||
if not evaluation_run:
|
||||
raise ValueError(f"EvaluationRun {evaluation_run_id} not found")
|
||||
|
||||
# Update status to running
|
||||
evaluation_run.status = EvaluationRunStatus.RUNNING
|
||||
evaluation_run.started_at = naive_utc_now()
|
||||
self.session.commit()
|
||||
|
||||
results: list[EvaluationItemResult] = []
|
||||
|
||||
# Phase 1: Execute target for each item
|
||||
for item in items:
|
||||
try:
|
||||
result = self.execute_target(tenant_id, target_id, target_type, item)
|
||||
results.append(result)
|
||||
evaluation_run.completed_items += 1
|
||||
except Exception as e:
|
||||
logger.exception("Failed to execute target for item %d", item.index)
|
||||
results.append(
|
||||
EvaluationItemResult(
|
||||
index=item.index,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
evaluation_run.failed_items += 1
|
||||
self.session.commit()
|
||||
|
||||
# Phase 2: Compute metrics on successful results
|
||||
successful_items = [item for item, result in zip(items, results) if result.error is None]
|
||||
successful_results = [r for r in results if r.error is None]
|
||||
|
||||
if successful_items and successful_results:
|
||||
try:
|
||||
evaluated_results = self.evaluate_metrics(
|
||||
successful_items, successful_results, metrics_config, model_provider, model_name, tenant_id
|
||||
)
|
||||
# Merge evaluated metrics back into results
|
||||
evaluated_by_index = {r.index: r for r in evaluated_results}
|
||||
for i, result in enumerate(results):
|
||||
if result.index in evaluated_by_index:
|
||||
results[i] = evaluated_by_index[result.index]
|
||||
except Exception:
|
||||
logger.exception("Failed to compute metrics for evaluation run %s", evaluation_run_id)
|
||||
|
||||
# Phase 3: Apply judgment conditions on metrics
|
||||
if judgment_config and judgment_config.conditions:
|
||||
results = self._apply_judgment(results, judgment_config)
|
||||
|
||||
# Phase 4: Persist individual items
|
||||
for result in results:
|
||||
item_input = next((item for item in items if item.index == result.index), None)
|
||||
run_item = EvaluationRunItem(
|
||||
evaluation_run_id=evaluation_run_id,
|
||||
item_index=result.index,
|
||||
inputs=json.dumps(item_input.inputs) if item_input else None,
|
||||
expected_output=item_input.expected_output if item_input else None,
|
||||
context=json.dumps(item_input.context) if item_input and item_input.context else None,
|
||||
actual_output=result.actual_output,
|
||||
metrics=json.dumps([m.model_dump() for m in result.metrics]) if result.metrics else None,
|
||||
judgment=json.dumps(result.judgment.model_dump()) if result.judgment else None,
|
||||
metadata_json=json.dumps(result.metadata) if result.metadata else None,
|
||||
error=result.error,
|
||||
overall_score=result.overall_score,
|
||||
)
|
||||
self.session.add(run_item)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _apply_judgment(
|
||||
results: list[EvaluationItemResult],
|
||||
judgment_config: JudgmentConfig,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Apply judgment conditions to each result's metrics.
|
||||
|
||||
Builds a metric_name → score mapping from each result's metrics,
|
||||
then delegates to JudgmentProcessor for condition evaluation.
|
||||
Results with errors are skipped.
|
||||
"""
|
||||
judged_results: list[EvaluationItemResult] = []
|
||||
for result in results:
|
||||
if result.error is not None or not result.metrics:
|
||||
judged_results.append(result)
|
||||
continue
|
||||
|
||||
metric_values = {m.name: m.score for m in result.metrics}
|
||||
judgment_result = JudgmentProcessor.evaluate(metric_values, judgment_config)
|
||||
|
||||
judged_results.append(
|
||||
result.model_copy(update={"judgment": judgment_result})
|
||||
)
|
||||
return judged_results
|
||||
@@ -1,152 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Mapping, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from models.model import App, AppMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for LLM evaluation: executes App to get responses, then evaluates."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
|
||||
super().__init__(evaluation_instance, session)
|
||||
|
||||
def execute_target(
|
||||
self,
|
||||
tenant_id: str,
|
||||
target_id: str,
|
||||
target_type: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> EvaluationItemResult:
|
||||
"""Execute the App/Snippet with the given inputs and collect the response."""
|
||||
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.evaluation.runners import get_service_account_for_app
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
app = self.session.query(App).filter_by(id=target_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App {target_id} not found")
|
||||
|
||||
# Get a service account for invocation
|
||||
service_account = get_service_account_for_app(self.session, target_id)
|
||||
|
||||
app_mode = AppMode.value_of(app.mode)
|
||||
|
||||
# Build args from evaluation item inputs
|
||||
args: dict[str, Any] = {
|
||||
"inputs": item.inputs,
|
||||
}
|
||||
# For completion/chat modes, first text input becomes query
|
||||
if app_mode in (AppMode.COMPLETION, AppMode.CHAT):
|
||||
query = self._extract_query(item.inputs)
|
||||
args["query"] = query
|
||||
|
||||
if app_mode in (AppMode.WORKFLOW, AppMode.ADVANCED_CHAT):
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=app)
|
||||
if not workflow:
|
||||
raise ValueError(f"No published workflow found for app {target_id}")
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
response: Mapping[str, Any] = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=service_account,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
generator = CompletionAppGenerator()
|
||||
response = generator.generate(
|
||||
app_model=app,
|
||||
user=service_account,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported app mode for LLM evaluation: {app_mode}")
|
||||
|
||||
actual_output = self._extract_output(response)
|
||||
return EvaluationItemResult(
|
||||
index=item.index,
|
||||
actual_output=actual_output,
|
||||
)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
results: list[EvaluationItemResult],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Use the evaluation instance to compute LLM metrics."""
|
||||
# Merge actual_output into items for evaluation
|
||||
merged_items = self._merge_results_into_items(items, results)
|
||||
return self.evaluation_instance.evaluate_llm(
|
||||
merged_items, metrics_config, model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_query(inputs: dict[str, Any]) -> str:
|
||||
"""Extract query from inputs."""
|
||||
for key in ("query", "question", "input", "text"):
|
||||
if key in inputs:
|
||||
return str(inputs[key])
|
||||
values = list(inputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_output(response: Union[Mapping[str, Any], Any]) -> str:
|
||||
"""Extract text output from app response."""
|
||||
if isinstance(response, Mapping):
|
||||
# Workflow response
|
||||
if "data" in response and isinstance(response["data"], Mapping):
|
||||
outputs = response["data"].get("outputs", {})
|
||||
if isinstance(outputs, Mapping):
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
return str(outputs)
|
||||
# Completion response
|
||||
if "answer" in response:
|
||||
return str(response["answer"])
|
||||
if "text" in response:
|
||||
return str(response["text"])
|
||||
return str(response)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(
|
||||
items: list[EvaluationItemInput],
|
||||
results: list[EvaluationItemResult],
|
||||
) -> list[EvaluationItemInput]:
|
||||
"""Create new items with actual_output set as expected_output context for metrics."""
|
||||
result_by_index = {r.index: r for r in results}
|
||||
merged = []
|
||||
for item in items:
|
||||
result = result_by_index.get(item.index)
|
||||
if result and result.actual_output:
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=item.index,
|
||||
inputs=item.inputs,
|
||||
expected_output=item.expected_output,
|
||||
context=[result.actual_output] + (item.context or []),
|
||||
)
|
||||
)
|
||||
else:
|
||||
merged.append(item)
|
||||
return merged
|
||||
@@ -1,111 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for retrieval evaluation: performs knowledge base retrieval, then evaluates."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
|
||||
super().__init__(evaluation_instance, session)
|
||||
|
||||
def execute_target(
|
||||
self,
|
||||
tenant_id: str,
|
||||
target_id: str,
|
||||
target_type: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> EvaluationItemResult:
|
||||
"""Execute retrieval using DatasetRetrieval and collect context documents."""
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
|
||||
query = self._extract_query(item.inputs)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Use knowledge_retrieval for structured results
|
||||
try:
|
||||
from core.rag.retrieval.dataset_retrieval import KnowledgeRetrievalRequest
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
query=query,
|
||||
app_id=target_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
sources = dataset_retrieval.knowledge_retrieval(request)
|
||||
retrieved_contexts = [source.content for source in sources if source.content]
|
||||
except (ImportError, AttributeError):
|
||||
logger.warning("KnowledgeRetrievalRequest not available, using simple retrieval")
|
||||
retrieved_contexts = []
|
||||
|
||||
return EvaluationItemResult(
|
||||
index=item.index,
|
||||
actual_output="\n\n".join(retrieved_contexts) if retrieved_contexts else "",
|
||||
metadata={"retrieved_contexts": retrieved_contexts},
|
||||
)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
results: list[EvaluationItemResult],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute retrieval evaluation metrics."""
|
||||
# Merge retrieved contexts into items
|
||||
result_by_index = {r.index: r for r in results}
|
||||
merged_items = []
|
||||
for item in items:
|
||||
result = result_by_index.get(item.index)
|
||||
contexts = result.metadata.get("retrieved_contexts", []) if result else []
|
||||
merged_items.append(
|
||||
EvaluationItemInput(
|
||||
index=item.index,
|
||||
inputs=item.inputs,
|
||||
expected_output=item.expected_output,
|
||||
context=contexts,
|
||||
)
|
||||
)
|
||||
|
||||
evaluated = self.evaluation_instance.evaluate_retrieval(
|
||||
merged_items, metrics_config, model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
# Merge metrics back into original results (preserve actual_output and metadata)
|
||||
eval_by_index = {r.index: r for r in evaluated}
|
||||
final_results = []
|
||||
for result in results:
|
||||
if result.index in eval_by_index:
|
||||
eval_result = eval_by_index[result.index]
|
||||
final_results.append(
|
||||
EvaluationItemResult(
|
||||
index=result.index,
|
||||
actual_output=result.actual_output,
|
||||
metrics=eval_result.metrics,
|
||||
metadata=result.metadata,
|
||||
error=result.error,
|
||||
)
|
||||
)
|
||||
else:
|
||||
final_results.append(result)
|
||||
return final_results
|
||||
|
||||
@staticmethod
|
||||
def _extract_query(inputs: dict[str, Any]) -> str:
|
||||
for key in ("query", "question", "input", "text"):
|
||||
if key in inputs:
|
||||
return str(inputs[key])
|
||||
values = list(inputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
@@ -1,133 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Mapping
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from models.model import App
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for workflow evaluation: executes workflow App in non-streaming mode."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
|
||||
super().__init__(evaluation_instance, session)
|
||||
|
||||
def execute_target(
|
||||
self,
|
||||
tenant_id: str,
|
||||
target_id: str,
|
||||
target_type: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> EvaluationItemResult:
|
||||
"""Execute workflow and collect outputs."""
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.evaluation.runners import get_service_account_for_app
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
app = self.session.query(App).filter_by(id=target_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App {target_id} not found")
|
||||
|
||||
service_account = get_service_account_for_app(self.session, target_id)
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=app)
|
||||
if not workflow:
|
||||
raise ValueError(f"No published workflow found for app {target_id}")
|
||||
|
||||
args: dict[str, Any] = {"inputs": item.inputs}
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
response: Mapping[str, Any] = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=service_account,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
actual_output = self._extract_output(response)
|
||||
node_executions = self._extract_node_executions(response)
|
||||
|
||||
return EvaluationItemResult(
|
||||
index=item.index,
|
||||
actual_output=actual_output,
|
||||
metadata={"node_executions": node_executions},
|
||||
)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
results: list[EvaluationItemResult],
|
||||
metrics_config: dict,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute workflow evaluation metrics (end-to-end)."""
|
||||
result_by_index = {r.index: r for r in results}
|
||||
merged_items = []
|
||||
for item in items:
|
||||
result = result_by_index.get(item.index)
|
||||
context = []
|
||||
if result and result.actual_output:
|
||||
context.append(result.actual_output)
|
||||
merged_items.append(
|
||||
EvaluationItemInput(
|
||||
index=item.index,
|
||||
inputs=item.inputs,
|
||||
expected_output=item.expected_output,
|
||||
context=context + (item.context or []),
|
||||
)
|
||||
)
|
||||
|
||||
evaluated = self.evaluation_instance.evaluate_workflow(
|
||||
merged_items, metrics_config, model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
# Merge metrics back preserving metadata
|
||||
eval_by_index = {r.index: r for r in evaluated}
|
||||
final_results = []
|
||||
for result in results:
|
||||
if result.index in eval_by_index:
|
||||
eval_result = eval_by_index[result.index]
|
||||
final_results.append(
|
||||
EvaluationItemResult(
|
||||
index=result.index,
|
||||
actual_output=result.actual_output,
|
||||
metrics=eval_result.metrics,
|
||||
metadata=result.metadata,
|
||||
error=result.error,
|
||||
)
|
||||
)
|
||||
else:
|
||||
final_results.append(result)
|
||||
return final_results
|
||||
|
||||
@staticmethod
|
||||
def _extract_output(response: Mapping[str, Any]) -> str:
|
||||
"""Extract text output from workflow response."""
|
||||
if "data" in response and isinstance(response["data"], Mapping):
|
||||
outputs = response["data"].get("outputs", {})
|
||||
if isinstance(outputs, Mapping):
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
return str(outputs)
|
||||
return str(response)
|
||||
|
||||
@staticmethod
|
||||
def _extract_node_executions(response: Mapping[str, Any]) -> list[dict]:
|
||||
"""Extract node execution trace from workflow response."""
|
||||
data = response.get("data", {})
|
||||
if isinstance(data, Mapping):
|
||||
return data.get("node_executions", [])
|
||||
return []
|
||||
@@ -27,10 +27,10 @@ from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@@ -4,10 +4,10 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types as mcp_types
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.model_runtime.entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.workflow.file import file_manager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
|
||||
|
||||
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]
|
||||
@@ -1,18 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000
|
||||
|
||||
|
||||
class PromptMessageMemory(Protocol):
|
||||
"""Port for loading memory as prompt messages."""
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Return historical prompt messages constrained by token/message limits."""
|
||||
...
|
||||
@@ -14,9 +14,13 @@ from core.model_runtime.entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
|
||||
ChatModelMessage,
|
||||
CompletionModelPromptTemplate,
|
||||
MemoryConfig,
|
||||
)
|
||||
from core.model_runtime.prompt.prompt_transform import PromptTransform
|
||||
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.file import file_manager
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.runtime import VariablePool
|
||||
@@ -10,7 +10,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.model_runtime.prompt.prompt_transform import PromptTransform
|
||||
|
||||
|
||||
class AgentHistoryPromptTransform(PromptTransform):
|
||||
@@ -5,7 +5,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
|
||||
|
||||
class PromptTransform:
|
||||
@@ -15,9 +15,9 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.model_runtime.prompt.prompt_transform import PromptTransform
|
||||
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.file import file_manager
|
||||
from models.model import AppMode
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.model_runtime.entities import (
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
|
||||
|
||||
|
||||
class PromptMessageUtil:
|
||||
@@ -2,7 +2,6 @@ import tempfile
|
||||
from binascii import hexlify, unhexlify
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
@@ -30,6 +29,7 @@ from core.plugin.entities.request import (
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
@@ -63,14 +63,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def handle() -> Generator[LLMResultChunk, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
@@ -124,14 +126,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -36,6 +35,7 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs import helper
|
||||
@@ -474,7 +474,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
# Deduct quota for summary generation (same as workflow nodes)
|
||||
try:
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
except Exception as e:
|
||||
# Log but don't fail summary generation if quota deduction fails
|
||||
logger.warning("Failed to deduct quota for summary generation: %s", str(e))
|
||||
|
||||
@@ -29,12 +29,12 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
|
||||
@@ -2,14 +2,14 @@ from collections.abc import Generator, Sequence
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
@@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
|
||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
# deduct quota
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@ identity:
|
||||
zh_Hans: 网页抓取
|
||||
pt_BR: WebScraper
|
||||
description:
|
||||
en_US: Web Scraper tool kit is used to scrape web
|
||||
en_US: Web Scrapper tool kit is used to scrape web
|
||||
zh_Hans: 一个用于抓取网页的工具。
|
||||
pt_BR: Web Scraper tool kit is used to scrape web
|
||||
pt_BR: Web Scrapper tool kit is used to scrape web
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- productivity
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Mapping
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
@@ -22,7 +23,6 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
@@ -588,7 +588,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
@@ -643,6 +642,5 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
|
||||
@@ -4,7 +4,11 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
|
||||
ChatModelMessage,
|
||||
CompletionModelPromptTemplate,
|
||||
MemoryConfig,
|
||||
)
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
@@ -1,19 +1,26 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import PromptMessageRole
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
from .exc import InvalidVariableTypeError
|
||||
|
||||
@@ -41,51 +48,88 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
|
||||
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
|
||||
def convert_history_messages_to_text(
|
||||
*,
|
||||
history_messages: Sequence[PromptMessage],
|
||||
human_prefix: str,
|
||||
ai_prefix: str,
|
||||
) -> str:
|
||||
string_messages: list[str] = []
|
||||
for message in history_messages:
|
||||
if message.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif message.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
else:
|
||||
continue
|
||||
used_quota = 1
|
||||
|
||||
if isinstance(message.content, list):
|
||||
content_parts = []
|
||||
for content in message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
content_parts.append(content.data)
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
content_parts.append("[image]")
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
inner_msg = "\n".join(content_parts)
|
||||
string_messages.append(f"{role}: {inner_msg}")
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
string_messages.append(f"{role}: {message.content}")
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
def fetch_memory_text(
|
||||
*,
|
||||
memory: PromptMessageMemory,
|
||||
max_token_limit: int,
|
||||
message_limit: int | None = None,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
) -> str:
|
||||
history_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
return convert_history_messages_to_text(
|
||||
history_messages=history_messages,
|
||||
human_prefix=human_prefix,
|
||||
ai_prefix=ai_prefix,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@@ -37,10 +37,9 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
@@ -63,7 +62,7 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables import (
|
||||
ArrayFileSegment,
|
||||
@@ -279,6 +278,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
else None
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
elif isinstance(event, LLMStructuredOutput):
|
||||
structured_output = event
|
||||
@@ -1233,10 +1234,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def retry(self) -> bool:
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
return self._model_instance
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
@@ -1339,16 +1336,48 @@ def _handle_memory_completion_mode(
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = llm_utils.fetch_memory_text(
|
||||
memory=memory,
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
memory_text = _convert_history_messages_to_text(
|
||||
history_messages=memory_messages,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
|
||||
def _convert_history_messages_to_text(
|
||||
*,
|
||||
history_messages: Sequence[PromptMessage],
|
||||
human_prefix: str,
|
||||
ai_prefix: str,
|
||||
) -> str:
|
||||
string_messages: list[str] = []
|
||||
for message in history_messages:
|
||||
if message.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif message.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(message.content, list):
|
||||
content_parts = []
|
||||
for content in message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
content_parts.append(content.data)
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
content_parts.append("[image]")
|
||||
|
||||
inner_msg = "\n".join(content_parts)
|
||||
string_messages.append(f"{role}: {inner_msg}")
|
||||
else:
|
||||
string_messages.append(f"{role}: {message.content}")
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
|
||||
class CredentialsProvider(Protocol):
|
||||
@@ -19,3 +21,13 @@ class ModelFactory(Protocol):
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
"""Create a model instance that is ready for schema lookup and invocation."""
|
||||
...
|
||||
|
||||
|
||||
class PromptMessageMemory(Protocol):
|
||||
"""Port for loading memory as prompt messages for LLM nodes."""
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Return historical prompt messages constrained by token/message limits."""
|
||||
...
|
||||
|
||||
@@ -413,7 +413,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
@@ -455,6 +454,5 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import (
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@@ -17,18 +18,13 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
|
||||
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.file import File
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
@@ -101,7 +97,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
_model_instance: ModelInstance
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
_memory: PromptMessageMemory | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -113,7 +108,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -124,7 +118,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@@ -170,7 +163,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
except ValueError as exc:
|
||||
raise ModelSchemaNotFoundError("Model schema not found") from exc
|
||||
memory = self._memory
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
if (
|
||||
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
|
||||
@@ -309,6 +308,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
usage = invoke_result.usage
|
||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage, tool_call
|
||||
|
||||
def _generate_function_call_prompt(
|
||||
@@ -317,7 +319,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
@@ -405,7 +407,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -443,7 +445,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -468,8 +470,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
# AdvancedPromptTransform is still typed against TokenBufferMemory.
|
||||
memory=cast(Any, memory),
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
@@ -482,7 +483,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -714,7 +715,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@@ -723,8 +724,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
||||
if memory and node_data.memory and node_data.memory.window:
|
||||
memory_str = llm_utils.fetch_memory_text(
|
||||
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
@@ -741,7 +742,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@@ -750,8 +751,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
||||
if memory and node_data.memory and node_data.memory.window:
|
||||
memory_str = llm_utils.fetch_memory_text(
|
||||
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
@@ -827,10 +828,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
return rest_tokens
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
return self._model_instance
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
@@ -3,12 +3,12 @@ import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
|
||||
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
@@ -56,7 +56,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
_model_instance: ModelInstance
|
||||
_memory: PromptMessageMemory | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -68,7 +67,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -83,7 +81,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@@ -106,7 +103,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
variables = {"query": query}
|
||||
# fetch model instance
|
||||
model_instance = self._model_instance
|
||||
memory = self._memory
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
# fetch instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
|
||||
@@ -237,10 +240,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
return self._model_instance
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
@@ -324,7 +323,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@@ -337,8 +336,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
if memory:
|
||||
memory_str = llm_utils.fetch_memory_text(
|
||||
memory=memory,
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@ from collections.abc import Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
class StartNodeData(BaseNodeData):
|
||||
|
||||
@@ -2,12 +2,12 @@ from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.variables.input_entities import VariableEntityType
|
||||
|
||||
|
||||
class StartNode(Node[StartNodeData]):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from .input_entities import VariableEntity, VariableEntityType
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
@@ -65,6 +64,4 @@ __all__ = [
|
||||
"StringVariable",
|
||||
"Variable",
|
||||
"VariableBase",
|
||||
"VariableEntity",
|
||||
"VariableEntityType",
|
||||
]
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, SchemaError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.workflow.file import FileTransferMethod, FileType
|
||||
|
||||
|
||||
class VariableEntityType(StrEnum):
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
FILE = "file"
|
||||
FILE_LIST = "file-list"
|
||||
CHECKBOX = "checkbox"
|
||||
JSON_OBJECT = "json_object"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
"""
|
||||
Shared variable entity used by workflow runtime and app configuration.
|
||||
"""
|
||||
|
||||
# `variable` records the name of the variable in user inputs.
|
||||
variable: str
|
||||
label: str
|
||||
description: str = ""
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
hide: bool = False
|
||||
default: Any = None
|
||||
max_length: int | None = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
def convert_none_description(cls, value: Any) -> str:
|
||||
return value or ""
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def convert_none_options(cls, value: Any) -> Sequence[str]:
|
||||
return value or []
|
||||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
Draft7Validator.check_schema(schema)
|
||||
except SchemaError as error:
|
||||
raise ValueError(f"Invalid JSON schema: {error.message}")
|
||||
return schema
|
||||
@@ -6,7 +6,6 @@ from typing import Any, cast
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
@@ -107,7 +106,6 @@ class WorkflowEntry:
|
||||
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
self.graph_engine.layer(limits_layer)
|
||||
self.graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
# Add observability layer when OTel is enabled
|
||||
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from flask_restx import fields
|
||||
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
# Snippet list item fields (lightweight for list display)
|
||||
snippet_list_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"version": fields.Integer,
|
||||
"use_count": fields.Integer,
|
||||
"is_published": fields.Boolean,
|
||||
"icon_info": fields.Raw,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
# Full snippet fields (includes creator info and graph data)
|
||||
snippet_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"version": fields.Integer,
|
||||
"use_count": fields.Integer,
|
||||
"is_published": fields.Boolean,
|
||||
"icon_info": fields.Raw,
|
||||
"graph": fields.Raw(attribute="graph_dict"),
|
||||
"input_fields": fields.Raw(attribute="input_fields_list"),
|
||||
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
# Pagination response fields
|
||||
snippet_pagination_fields = {
|
||||
"data": fields.List(fields.Nested(snippet_list_fields)),
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer,
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
"""Helpers for producing concise pyrefly diagnostics for CI diff output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
_DIAGNOSTIC_PREFIXES = ("ERROR ", "WARNING ")
|
||||
_LOCATION_PREFIX = "-->"
|
||||
|
||||
|
||||
def extract_diagnostics(raw_output: str) -> str:
|
||||
"""Extract stable diagnostic lines from pyrefly output.
|
||||
|
||||
The full pyrefly output includes code excerpts and carets, which create noisy
|
||||
diffs. This helper keeps only:
|
||||
- diagnostic headline lines (``ERROR ...`` / ``WARNING ...``)
|
||||
- the following location line (``--> path:line:column``), when present
|
||||
"""
|
||||
|
||||
lines = raw_output.splitlines()
|
||||
diagnostics: list[str] = []
|
||||
|
||||
for index, line in enumerate(lines):
|
||||
if line.startswith(_DIAGNOSTIC_PREFIXES):
|
||||
diagnostics.append(line.rstrip())
|
||||
|
||||
next_index = index + 1
|
||||
if next_index < len(lines):
|
||||
next_line = lines[next_index]
|
||||
if next_line.lstrip().startswith(_LOCATION_PREFIX):
|
||||
diagnostics.append(next_line.rstrip())
|
||||
|
||||
if not diagnostics:
|
||||
return ""
|
||||
|
||||
return "\n".join(diagnostics) + "\n"
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Read pyrefly output from stdin and print normalized diagnostics."""
|
||||
|
||||
raw_output = sys.stdin.read()
|
||||
sys.stdout.write(extract_diagnostics(raw_output))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,83 +0,0 @@
|
||||
"""add_customized_snippets_table
|
||||
|
||||
Revision ID: 1c05e80d2380
|
||||
Revises: 788d3099ae3a
|
||||
Create Date: 2026-01-29 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1c05e80d2380"
|
||||
down_revision = "788d3099ae3a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table(
|
||||
"customized_snippets",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
|
||||
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
|
||||
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||
sa.Column("icon_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("graph", sa.Text(), nullable=True),
|
||||
sa.Column("input_fields", sa.Text(), nullable=True),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||
)
|
||||
else:
|
||||
op.create_table(
|
||||
"customized_snippets",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", models.types.LongText(), nullable=True),
|
||||
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
|
||||
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
|
||||
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||
sa.Column("icon_info", models.types.AdjustedJSON(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("graph", models.types.LongText(), nullable=True),
|
||||
sa.Column("input_fields", models.types.LongText(), nullable=True),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||
)
|
||||
|
||||
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
|
||||
batch_op.create_index("customized_snippet_tenant_idx", ["tenant_id"], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
|
||||
batch_op.drop_index("customized_snippet_tenant_idx")
|
||||
|
||||
op.drop_table("customized_snippets")
|
||||
@@ -1,113 +0,0 @@
|
||||
"""add_evaluation_tables
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 1c05e80d2380
|
||||
Create Date: 2026-03-03 00:01:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f6"
|
||||
down_revision = "1c05e80d2380"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# evaluation_configurations
|
||||
op.create_table(
|
||||
"evaluation_configurations",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("target_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("target_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("evaluation_model_provider", sa.String(length=255), nullable=True),
|
||||
sa.Column("evaluation_model", sa.String(length=255), nullable=True),
|
||||
sa.Column("metrics_config", models.types.LongText(), nullable=True),
|
||||
sa.Column("judgement_conditions", models.types.LongText(), nullable=True),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("updated_by", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"),
|
||||
)
|
||||
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
"evaluation_configuration_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
|
||||
)
|
||||
|
||||
# evaluation_runs
|
||||
op.create_table(
|
||||
"evaluation_runs",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("target_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("target_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("evaluation_config_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("status", sa.String(length=20), nullable=False, server_default=sa.text("'pending'")),
|
||||
sa.Column("dataset_file_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("result_file_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("total_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("completed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("failed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("metrics_summary", models.types.LongText(), nullable=True),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column("celery_task_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("started_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("completed_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"),
|
||||
)
|
||||
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
"evaluation_run_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
|
||||
)
|
||||
batch_op.create_index("evaluation_run_status_idx", ["tenant_id", "status"], unique=False)
|
||||
|
||||
# evaluation_run_items
|
||||
op.create_table(
|
||||
"evaluation_run_items",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("evaluation_run_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("item_index", sa.Integer(), nullable=False),
|
||||
sa.Column("inputs", models.types.LongText(), nullable=True),
|
||||
sa.Column("expected_output", models.types.LongText(), nullable=True),
|
||||
sa.Column("context", models.types.LongText(), nullable=True),
|
||||
sa.Column("actual_output", models.types.LongText(), nullable=True),
|
||||
sa.Column("metrics", models.types.LongText(), nullable=True),
|
||||
sa.Column("metadata_json", models.types.LongText(), nullable=True),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column("overall_score", sa.Float(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"),
|
||||
)
|
||||
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
|
||||
batch_op.create_index("evaluation_run_item_run_idx", ["evaluation_run_id"], unique=False)
|
||||
batch_op.create_index(
|
||||
"evaluation_run_item_index_idx", ["evaluation_run_id", "item_index"], unique=False
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
|
||||
batch_op.drop_index("evaluation_run_item_index_idx")
|
||||
batch_op.drop_index("evaluation_run_item_run_idx")
|
||||
op.drop_table("evaluation_run_items")
|
||||
|
||||
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
|
||||
batch_op.drop_index("evaluation_run_status_idx")
|
||||
batch_op.drop_index("evaluation_run_target_idx")
|
||||
op.drop_table("evaluation_runs")
|
||||
|
||||
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
|
||||
batch_op.drop_index("evaluation_configuration_target_idx")
|
||||
op.drop_table("evaluation_configurations")
|
||||
@@ -26,13 +26,6 @@ from .dataset import (
|
||||
TidbAuthBinding,
|
||||
Whitelist,
|
||||
)
|
||||
from .evaluation import (
|
||||
EvaluationConfiguration,
|
||||
EvaluationRun,
|
||||
EvaluationRunItem,
|
||||
EvaluationRunStatus,
|
||||
EvaluationTargetType,
|
||||
)
|
||||
from .enums import (
|
||||
AppTriggerStatus,
|
||||
AppTriggerType,
|
||||
@@ -88,7 +81,6 @@ from .provider import (
|
||||
TenantDefaultModel,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from .snippet import CustomizedSnippet, SnippetType
|
||||
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from .task import CeleryTask, CeleryTaskSet
|
||||
from .tools import (
|
||||
@@ -148,7 +140,6 @@ __all__ = [
|
||||
"Conversation",
|
||||
"ConversationVariable",
|
||||
"CreatorUserRole",
|
||||
"CustomizedSnippet",
|
||||
"DataSourceApiKeyAuthBinding",
|
||||
"DataSourceOauthBinding",
|
||||
"Dataset",
|
||||
@@ -165,11 +156,6 @@ __all__ = [
|
||||
"Document",
|
||||
"DocumentSegment",
|
||||
"Embedding",
|
||||
"EvaluationConfiguration",
|
||||
"EvaluationRun",
|
||||
"EvaluationRunItem",
|
||||
"EvaluationRunStatus",
|
||||
"EvaluationTargetType",
|
||||
"EndUser",
|
||||
"ExecutionExtraContent",
|
||||
"ExporleBanner",
|
||||
@@ -198,7 +184,6 @@ __all__ = [
|
||||
"RecommendedApp",
|
||||
"SavedMessage",
|
||||
"Site",
|
||||
"SnippetType",
|
||||
"Tag",
|
||||
"TagBinding",
|
||||
"Tenant",
|
||||
|
||||
@@ -1,194 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, Float, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import Base
|
||||
from .types import LongText, StringUUID
|
||||
|
||||
|
||||
class EvaluationRunStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class EvaluationTargetType(StrEnum):
|
||||
APP = "app"
|
||||
SNIPPETS = "snippets"
|
||||
|
||||
|
||||
class EvaluationConfiguration(Base):
|
||||
"""Stores evaluation configuration for each target (App or Snippet)."""
|
||||
|
||||
__tablename__ = "evaluation_configurations"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"),
|
||||
sa.Index("evaluation_configuration_target_idx", "tenant_id", "target_type", "target_id"),
|
||||
sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
target_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
target_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
evaluation_model_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
evaluation_model: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
metrics_config: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
judgement_conditions: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
updated_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def metrics_config_dict(self) -> dict[str, Any]:
|
||||
if self.metrics_config:
|
||||
return json.loads(self.metrics_config)
|
||||
return {}
|
||||
|
||||
@metrics_config_dict.setter
|
||||
def metrics_config_dict(self, value: dict[str, Any]) -> None:
|
||||
self.metrics_config = json.dumps(value)
|
||||
|
||||
@property
|
||||
def judgement_conditions_dict(self) -> dict[str, Any]:
|
||||
if self.judgement_conditions:
|
||||
return json.loads(self.judgement_conditions)
|
||||
return {}
|
||||
|
||||
@judgement_conditions_dict.setter
|
||||
def judgement_conditions_dict(self, value: dict[str, Any]) -> None:
|
||||
self.judgement_conditions = json.dumps(value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EvaluationConfiguration(id={self.id}, target={self.target_type}:{self.target_id})>"
|
||||
|
||||
|
||||
class EvaluationRun(Base):
|
||||
"""Stores each evaluation run record."""
|
||||
|
||||
__tablename__ = "evaluation_runs"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"),
|
||||
sa.Index("evaluation_run_target_idx", "tenant_id", "target_type", "target_id"),
|
||||
sa.Index("evaluation_run_status_idx", "tenant_id", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
target_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
target_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
evaluation_config_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, default=EvaluationRunStatus.PENDING
|
||||
)
|
||||
dataset_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
result_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
total_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
completed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
failed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
metrics_summary: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def metrics_summary_dict(self) -> dict[str, Any]:
|
||||
if self.metrics_summary:
|
||||
return json.loads(self.metrics_summary)
|
||||
return {}
|
||||
|
||||
@metrics_summary_dict.setter
|
||||
def metrics_summary_dict(self, value: dict[str, Any]) -> None:
|
||||
self.metrics_summary = json.dumps(value)
|
||||
|
||||
@property
|
||||
def progress(self) -> float:
|
||||
if self.total_items == 0:
|
||||
return 0.0
|
||||
return (self.completed_items + self.failed_items) / self.total_items
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EvaluationRun(id={self.id}, status={self.status})>"
|
||||
|
||||
|
||||
class EvaluationRunItem(Base):
|
||||
"""Stores per-row evaluation results."""
|
||||
|
||||
__tablename__ = "evaluation_run_items"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"),
|
||||
sa.Index("evaluation_run_item_run_idx", "evaluation_run_id"),
|
||||
sa.Index("evaluation_run_item_index_idx", "evaluation_run_id", "item_index"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
evaluation_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
item_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
inputs: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
expected_output: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
context: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
actual_output: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
|
||||
metrics: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
judgment: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
metadata_json: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
overall_score: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs_dict(self) -> dict[str, Any]:
|
||||
if self.inputs:
|
||||
return json.loads(self.inputs)
|
||||
return {}
|
||||
|
||||
@property
|
||||
def metrics_list(self) -> list[dict[str, Any]]:
|
||||
if self.metrics:
|
||||
return json.loads(self.metrics)
|
||||
return []
|
||||
|
||||
@property
|
||||
def metadata_dict(self) -> dict[str, Any]:
|
||||
if self.metadata_json:
|
||||
return json.loads(self.metadata_json)
|
||||
return {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EvaluationRunItem(id={self.id}, run={self.evaluation_run_id}, index={self.item_index})>"
|
||||
@@ -1,101 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import AdjustedJSON, LongText, StringUUID
|
||||
|
||||
|
||||
class SnippetType(StrEnum):
|
||||
"""Snippet Type Enum"""
|
||||
|
||||
NODE = "node"
|
||||
GROUP = "group"
|
||||
|
||||
|
||||
class CustomizedSnippet(Base):
|
||||
"""
|
||||
Customized Snippet Model
|
||||
|
||||
Stores reusable workflow components (nodes or node groups) that can be
|
||||
shared across applications within a workspace.
|
||||
"""
|
||||
|
||||
__tablename__ = "customized_snippets"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||
sa.Index("customized_snippet_tenant_idx", "tenant_id"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(50), nullable=False, server_default=sa.text("'node'"))
|
||||
|
||||
# Workflow reference for published version
|
||||
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
# State flags
|
||||
is_published: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("1"))
|
||||
use_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
|
||||
# Visual customization
|
||||
icon_info: Mapped[dict | None] = mapped_column(AdjustedJSON, nullable=True)
|
||||
|
||||
# Snippet configuration (stored as JSON text)
|
||||
input_fields: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
|
||||
# Audit fields
|
||||
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def graph_dict(self) -> dict[str, Any]:
|
||||
"""Get graph from associated workflow."""
|
||||
if self.workflow_id:
|
||||
from .workflow import Workflow
|
||||
|
||||
workflow = db.session.get(Workflow, self.workflow_id)
|
||||
if workflow:
|
||||
return json.loads(workflow.graph) if workflow.graph else {}
|
||||
return {}
|
||||
|
||||
@property
|
||||
def input_fields_list(self) -> list[dict[str, Any]]:
|
||||
"""Parse input_fields JSON to list."""
|
||||
return json.loads(self.input_fields) if self.input_fields else []
|
||||
|
||||
@property
|
||||
def created_by_account(self) -> Account | None:
|
||||
"""Get the account that created this snippet."""
|
||||
if self.created_by:
|
||||
return db.session.get(Account, self.created_by)
|
||||
return None
|
||||
|
||||
@property
|
||||
def updated_by_account(self) -> Account | None:
|
||||
"""Get the account that last updated this snippet."""
|
||||
if self.updated_by:
|
||||
return db.session.get(Account, self.updated_by)
|
||||
return None
|
||||
|
||||
@property
|
||||
def version_str(self) -> str:
|
||||
"""Get version as string for API response."""
|
||||
return str(self.version)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user