Compare commits

..

1 Commits

Author SHA1 Message Date
-LAN-
7d5572027c Move prompt module under model runtime 2026-03-01 19:57:00 +08:00
537 changed files with 1286 additions and 7537 deletions

View File

@@ -29,26 +29,20 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Prepare diagnostics extractor
run: |
git show ${{ github.event.pull_request.head.sha }}:api/libs/pyrefly_diagnostics.py > /tmp/pyrefly_diagnostics.py
- name: Run pyrefly on PR branch
run: |
uv run --directory api --dev pyrefly check 2>&1 \
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_pr.txt || true
uv run --directory api pyrefly check > /tmp/pyrefly_pr.txt 2>&1 || true
- name: Checkout base branch
run: git checkout ${{ github.base_ref }}
- name: Run pyrefly on base branch
run: |
uv run --directory api --dev pyrefly check 2>&1 \
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_base.txt || true
uv run --directory api pyrefly check > /tmp/pyrefly_base.txt 2>&1 || true
- name: Compute diff
run: |
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
diff /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
- name: Save PR number
run: |

View File

@@ -29,8 +29,6 @@ ignore_imports =
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
@@ -54,6 +52,7 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
# TODO(QuantumGhost): use DI to avoid depending on global DB.
@@ -92,7 +91,7 @@ forbidden_modules =
core.moderation
core.ops
core.plugin
core.prompt
core.model_runtime.prompt
core.provider_manager
core.rag
core.repositories
@@ -108,11 +107,14 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> models.model
core.workflow.nodes.llm.llm_utils -> models.provider
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
core.workflow.nodes.llm.node -> core.tools.signature
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
@@ -125,14 +127,16 @@ ignore_imports =
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.advanced_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.simple_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.prompt.simple_prompt_transform
core.workflow.nodes.start.entities -> core.app.app_config.entities
core.workflow.nodes.start.start_node -> core.app.app_config.entities
core.workflow.workflow_entry -> core.app.apps.exc
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
@@ -144,15 +148,16 @@ ignore_imports =
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
core.workflow.nodes.llm.node -> core.model_manager
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.agent.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.model_runtime.prompt.utils.prompt_message_util
core.workflow.nodes.parameter_extractor.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.utils.prompt_message_util
core.workflow.nodes.question_classifier.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.prompt.utils.prompt_message_util
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
@@ -167,6 +172,7 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
@@ -174,7 +180,7 @@ ignore_imports =
core.workflow.workflow_entry -> extensions.otel.runtime
core.workflow.nodes.agent.agent_node -> models
core.workflow.nodes.base.node -> models.enums
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.llm.llm_utils -> models.provider_ids
core.workflow.nodes.llm.node -> models.model
core.workflow.workflow_entry -> models.enums
core.workflow.nodes.agent.agent_node -> services
@@ -184,7 +190,12 @@ ignore_imports =
name = Model Runtime Internal Imports
type = forbidden
source_modules =
core.model_runtime
core.model_runtime.callbacks
core.model_runtime.entities
core.model_runtime.errors
core.model_runtime.model_providers
core.model_runtime.schema_validators
core.model_runtime.utils
forbidden_modules =
configs
controllers
@@ -214,7 +225,7 @@ forbidden_modules =
core.moderation
core.ops
core.plugin
core.prompt
core.model_runtime.prompt
core.provider_manager
core.rag
core.repositories

View File

@@ -1366,32 +1366,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
)
class EvaluationConfig(BaseSettings):
"""
Configuration for evaluation runtime
"""
EVALUATION_FRAMEWORK: str = Field(
description="Evaluation framework to use (ragas/deepeval/none)",
default="none",
)
EVALUATION_MAX_CONCURRENT_RUNS: PositiveInt = Field(
description="Maximum number of concurrent evaluation runs per tenant",
default=3,
)
EVALUATION_MAX_DATASET_ROWS: PositiveInt = Field(
description="Maximum number of rows allowed in an evaluation dataset",
default=1000,
)
EVALUATION_TASK_TIMEOUT: PositiveInt = Field(
description="Timeout in seconds for a single evaluation task",
default=3600,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@@ -1404,7 +1378,6 @@ class FeatureConfig(
MarketplaceConfig,
DataSetConfig,
EndpointConfig,
EvaluationConfig,
FileAccessConfig,
FileUploadConfig,
HttpConfig,

View File

@@ -116,12 +116,6 @@ from .explore import (
trial,
)
# Import evaluation controllers
from .evaluation import evaluation
# Import snippet controllers
from .snippets import snippet_workflow
# Import tag controllers
from .tag import tags
@@ -135,7 +129,6 @@ from .workspace import (
model_providers,
models,
plugin,
snippets,
tool_providers,
trigger_providers,
workspace,
@@ -173,7 +166,6 @@ __all__ = [
"datasource_content_preview",
"email_register",
"endpoint",
"evaluation",
"extension",
"external",
"feature",
@@ -207,8 +199,6 @@ __all__ = [
"saved_message",
"setup",
"site",
"snippet_workflow",
"snippets",
"spec",
"statistic",
"tags",

View File

@@ -1 +0,0 @@
# Evaluation controller module

View File

@@ -1,605 +0,0 @@
from __future__ import annotations
import logging
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, ParamSpec, TypeVar, Union
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, fields
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
from core.workflow.file import helpers as file_helpers
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models import App
from models.model import UploadFile
from models.snippet import CustomizedSnippet
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.evaluation_service import EvaluationService
if TYPE_CHECKING:
from models.evaluation import EvaluationRun, EvaluationRunItem
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
# Valid evaluation target types
EVALUATE_TARGET_TYPES = {"app", "snippets"}
class VersionQuery(BaseModel):
"""Query parameters for version endpoint."""
version: str
register_schema_models(
console_ns,
VersionQuery,
)
# Response field definitions
file_info_fields = {
"id": fields.String,
"name": fields.String,
}
evaluation_log_fields = {
"created_at": TimestampField,
"created_by": fields.String,
"test_file": fields.Nested(
console_ns.model(
"EvaluationTestFile",
file_info_fields,
)
),
"result_file": fields.Nested(
console_ns.model(
"EvaluationResultFile",
file_info_fields,
),
allow_null=True,
),
"version": fields.String,
}
evaluation_log_list_model = console_ns.model(
"EvaluationLogList",
{
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
},
)
customized_matrix_fields = {
"evaluation_workflow_id": fields.String,
"input_fields": fields.Raw,
"output_fields": fields.Raw,
}
condition_fields = {
"name": fields.List(fields.String),
"comparison_operator": fields.String,
"value": fields.String,
}
judgement_conditions_fields = {
"logical_operator": fields.String,
"conditions": fields.List(fields.Nested(console_ns.model("EvaluationCondition", condition_fields))),
}
evaluation_detail_fields = {
"evaluation_model": fields.String,
"evaluation_model_provider": fields.String,
"customized_matrix": fields.Nested(
console_ns.model("EvaluationCustomizedMatrix", customized_matrix_fields),
allow_null=True,
),
"judgement_conditions": fields.Nested(
console_ns.model("EvaluationJudgementConditions", judgement_conditions_fields),
allow_null=True,
),
}
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
def get_evaluation_target(view_func: Callable[P, R]):
"""
Decorator to resolve polymorphic evaluation target (app or snippet).
Validates the target_type parameter and fetches the corresponding
model (App or CustomizedSnippet) with tenant isolation.
"""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
target_type = kwargs.get("evaluate_target_type")
target_id = kwargs.get("evaluate_target_id")
if target_type not in EVALUATE_TARGET_TYPES:
raise NotFound(f"Invalid evaluation target type: {target_type}")
_, current_tenant_id = current_account_with_tenant()
target_id = str(target_id)
# Remove path parameters
del kwargs["evaluate_target_type"]
del kwargs["evaluate_target_id"]
target: Union[App, CustomizedSnippet] | None = None
if target_type == "app":
target = (
db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
)
elif target_type == "snippets":
target = (
db.session.query(CustomizedSnippet)
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
.first()
)
if not target:
raise NotFound(f"{str(target_type)} not found")
kwargs["target"] = target
kwargs["target_type"] = target_type
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
class EvaluationDatasetTemplateDownloadApi(Resource):
@console_ns.doc("download_evaluation_dataset_template")
@console_ns.response(200, "Template file streamed as XLSX attachment")
@console_ns.response(400, "Invalid target type or excluded app mode")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Download evaluation dataset template.
Generates an XLSX template based on the target's input parameters
and streams it directly as a file attachment.
"""
try:
xlsx_content, filename = EvaluationService.generate_dataset_template(
target=target,
target_type=target_type,
)
except ValueError as e:
return {"message": str(e)}, 400
encoded_filename = quote(filename)
response = Response(
xlsx_content,
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Length"] = str(len(xlsx_content))
return response
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
class EvaluationDetailApi(Resource):
@console_ns.doc("get_evaluation_detail")
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation configuration for the target.
Returns evaluation configuration including model settings,
metrics config, and judgement conditions.
"""
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.get_evaluation_config(
session, current_tenant_id, target_type, str(target.id)
)
if config is None:
return {
"evaluation_model": None,
"evaluation_model_provider": None,
"metrics_config": None,
"judgement_conditions": None,
}
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"metrics_config": config.metrics_config_dict,
"judgement_conditions": config.judgement_conditions_dict,
}
@console_ns.doc("save_evaluation_detail")
@console_ns.response(200, "Evaluation configuration saved successfully")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def put(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Save evaluation configuration for the target.
"""
current_account, current_tenant_id = current_account_with_tenant()
body = request.get_json(force=True)
try:
config_data = EvaluationConfigData.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.save_evaluation_config(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
data=config_data,
)
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"metrics_config": config.metrics_config_dict,
"judgement_conditions": config.judgement_conditions_dict,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
class EvaluationLogsApi(Resource):
@console_ns.doc("get_evaluation_logs")
@console_ns.response(200, "Evaluation logs retrieved successfully")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation run history for the target.
Returns a paginated list of evaluation runs.
"""
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
with Session(db.engine, expire_on_commit=False) as session:
runs, total = EvaluationService.get_evaluation_runs(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
page=page,
page_size=page_size,
)
return {
"data": [_serialize_evaluation_run(run) for run in runs],
"total": total,
"page": page,
"page_size": page_size,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run")
class EvaluationRunApi(Resource):
@console_ns.doc("start_evaluation_run")
@console_ns.response(200, "Evaluation run started")
@console_ns.response(400, "Invalid request")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Start an evaluation run.
Expects JSON body with:
- file_id: uploaded dataset file ID
- evaluation_model: evaluation model name
- evaluation_model_provider: evaluation model provider
- default_metrics: list of default metric objects
- customized_metrics: customized metrics object (optional)
- judgment_config: judgment conditions config (optional)
"""
current_account, current_tenant_id = current_account_with_tenant()
body = request.get_json(force=True)
if not body:
raise BadRequest("Request body is required.")
# Validate and parse request body
try:
run_request = EvaluationRunRequest.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
# Load dataset file
upload_file = (
db.session.query(UploadFile)
.filter_by(id=run_request.file_id, tenant_id=current_tenant_id)
.first()
)
if not upload_file:
raise NotFound("Dataset file not found.")
try:
dataset_content = storage.load_once(upload_file.key)
except Exception:
raise BadRequest("Failed to read dataset file.")
if not dataset_content:
raise BadRequest("Dataset file is empty.")
try:
with Session(db.engine, expire_on_commit=False) as session:
evaluation_run = EvaluationService.start_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
dataset_file_content=dataset_content,
run_request=run_request,
)
return _serialize_evaluation_run(evaluation_run), 200
except EvaluationFrameworkNotConfiguredError as e:
return {"message": str(e.description)}, 400
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except EvaluationMaxConcurrentRunsError as e:
return {"message": str(e.description)}, 429
except EvaluationDatasetInvalidError as e:
return {"message": str(e.description)}, 400
@console_ns.route(
"/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>"
)
class EvaluationRunDetailApi(Resource):
@console_ns.doc("get_evaluation_run_detail")
@console_ns.response(200, "Evaluation run detail retrieved")
@console_ns.response(404, "Run not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
"""
Get evaluation run detail including items.
"""
_, current_tenant_id = current_account_with_tenant()
run_id = str(run_id)
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 50, type=int)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.get_evaluation_run_detail(
session=session,
tenant_id=current_tenant_id,
run_id=run_id,
)
items, total_items = EvaluationService.get_evaluation_run_items(
session=session,
run_id=run_id,
page=page,
page_size=page_size,
)
return {
"run": _serialize_evaluation_run(run),
"items": {
"data": [_serialize_evaluation_run_item(item) for item in items],
"total": total_items,
"page": page,
"page_size": page_size,
},
}
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
@console_ns.route(
"/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>/cancel"
)
class EvaluationRunCancelApi(Resource):
@console_ns.doc("cancel_evaluation_run")
@console_ns.response(200, "Evaluation run cancelled")
@console_ns.response(404, "Run not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
"""Cancel a running evaluation."""
_, current_tenant_id = current_account_with_tenant()
run_id = str(run_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.cancel_evaluation_run(
session=session,
tenant_id=current_tenant_id,
run_id=run_id,
)
return _serialize_evaluation_run(run)
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/metrics")
class EvaluationMetricsApi(Resource):
@console_ns.doc("get_evaluation_metrics")
@console_ns.response(200, "Available metrics retrieved")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get available evaluation metrics for the current framework.
"""
result = {}
for category in EvaluationCategory:
result[category.value] = EvaluationService.get_supported_metrics(category)
return {"metrics": result}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
class EvaluationFileDownloadApi(Resource):
@console_ns.doc("download_evaluation_file")
@console_ns.response(200, "File download URL generated successfully")
@console_ns.response(404, "Target or file not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
"""
Download evaluation test file or result file.
Looks up the specified file, verifies it belongs to the same tenant,
and returns file info and download URL.
"""
file_id = str(file_id)
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(UploadFile).where(
UploadFile.id == file_id,
UploadFile.tenant_id == current_tenant_id,
)
upload_file = session.execute(stmt).scalar_one_or_none()
if not upload_file:
raise NotFound("File not found")
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
"download_url": download_url,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
class EvaluationVersionApi(Resource):
@console_ns.doc("get_evaluation_version_detail")
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
@console_ns.response(200, "Version details retrieved successfully")
@console_ns.response(404, "Target or version not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation target version details.
Returns the workflow graph for the specified version.
"""
version = request.args.get("version")
if not version:
return {"message": "version parameter is required"}, 400
graph = {}
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
graph = target.graph_dict
return {
"graph": graph,
}
# ---- Serialization Helpers ----
def _serialize_evaluation_run(run: EvaluationRun) -> dict[str, object]:
return {
"id": run.id,
"tenant_id": run.tenant_id,
"target_type": run.target_type,
"target_id": run.target_id,
"evaluation_config_id": run.evaluation_config_id,
"status": run.status,
"dataset_file_id": run.dataset_file_id,
"result_file_id": run.result_file_id,
"total_items": run.total_items,
"completed_items": run.completed_items,
"failed_items": run.failed_items,
"progress": run.progress,
"metrics_summary": run.metrics_summary_dict,
"error": run.error,
"created_by": run.created_by,
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
}
def _serialize_evaluation_run_item(item: EvaluationRunItem) -> dict[str, object]:
return {
"id": item.id,
"item_index": item.item_index,
"inputs": item.inputs_dict,
"expected_output": item.expected_output,
"actual_output": item.actual_output,
"metrics": item.metrics_list,
"metadata": item.metadata_dict,
"error": item.error,
"overall_score": item.overall_score,
}

View File

@@ -1,102 +0,0 @@
from typing import Any, Literal
from pydantic import BaseModel, Field
class SnippetListQuery(BaseModel):
"""Query parameters for listing snippets."""
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
keyword: str | None = None
class IconInfo(BaseModel):
"""Icon information model."""
icon: str | None = None
icon_type: Literal["emoji", "image"] | None = None
icon_background: str | None = None
icon_url: str | None = None
class InputFieldDefinition(BaseModel):
"""Input field definition for snippet parameters."""
default: str | None = None
hint: bool | None = None
label: str | None = None
max_length: int | None = None
options: list[str] | None = None
placeholder: str | None = None
required: bool | None = None
type: str | None = None # e.g., "text-input"
class CreateSnippetPayload(BaseModel):
"""Payload for creating a new snippet."""
name: str = Field(..., min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
type: Literal["node", "group"] = "node"
icon_info: IconInfo | None = None
graph: dict[str, Any] | None = None
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
class UpdateSnippetPayload(BaseModel):
"""Payload for updating a snippet."""
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
icon_info: IconInfo | None = None
class SnippetDraftSyncPayload(BaseModel):
"""Payload for syncing snippet draft workflow."""
graph: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] | None = None
conversation_variables: list[dict[str, Any]] | None = None
input_variables: list[dict[str, Any]] | None = None
class WorkflowRunQuery(BaseModel):
"""Query parameters for workflow runs."""
last_id: str | None = None
limit: int = Field(default=20, ge=1, le=100)
class SnippetDraftRunPayload(BaseModel):
"""Payload for running snippet draft workflow."""
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class SnippetDraftNodeRunPayload(BaseModel):
"""Payload for running a single node in snippet draft workflow."""
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
class SnippetIterationNodeRunPayload(BaseModel):
"""Payload for running an iteration node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class SnippetLoopNodeRunPayload(BaseModel):
"""Payload for running a loop node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class PublishWorkflowPayload(BaseModel):
"""Payload for publishing snippet workflow."""
knowledge_base_setting: dict[str, Any] | None = None

View File

@@ -1,540 +0,0 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, marshal_with
from sqlalchemy.orm import Session
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow import workflow_model
from controllers.console.app.workflow_run import (
workflow_run_detail_model,
workflow_run_node_execution_list_model,
workflow_run_node_execution_model,
workflow_run_pagination_model,
)
from controllers.console.snippets.payloads import (
PublishWorkflowPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetDraftSyncPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
WorkflowRunQuery,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from factories import variable_factory
from libs import helper
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.snippet import CustomizedSnippet
from services.errors.app import WorkflowHashNotEqualError
from services.snippet_generate_service import SnippetGenerateService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetDraftSyncPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
WorkflowRunQuery,
PublishWorkflowPayload,
)
class SnippetNotFoundError(Exception):
"""Snippet not found error."""
pass
def get_snippet(view_func: Callable[P, R]):
"""Decorator to fetch and validate snippet access."""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("snippet_id"):
raise ValueError("missing snippet_id in path parameters")
_, current_tenant_id = current_account_with_tenant()
snippet_id = str(kwargs.get("snippet_id"))
del kwargs["snippet_id"]
snippet = SnippetService.get_snippet_by_id(
snippet_id=snippet_id,
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
kwargs["snippet"] = snippet
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
class SnippetDraftWorkflowApi(Resource):
@console_ns.doc("get_snippet_draft_workflow")
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get draft workflow for snippet."""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise DraftWorkflowNotExist()
return workflow
@console_ns.doc("sync_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
@console_ns.response(200, "Draft workflow synced successfully")
@console_ns.response(400, "Hash mismatch")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Sync draft workflow for snippet."""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
try:
environment_variables_list = payload.environment_variables or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = payload.conversation_variables or []
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
snippet_service = SnippetService()
workflow = snippet_service.sync_draft_workflow(
snippet=snippet,
graph=payload.graph,
unique_hash=payload.hash,
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
input_variables=payload.input_variables,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
class SnippetDraftConfigApi(Resource):
@console_ns.doc("get_snippet_draft_config")
@console_ns.response(200, "Draft config retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get snippet draft workflow configuration limits."""
return {
"parallel_depth_limit": 3,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
class SnippetPublishedWorkflowApi(Resource):
@console_ns.doc("get_snippet_published_workflow")
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get published workflow for snippet."""
if not snippet.is_published:
return None
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=snippet)
return workflow
@console_ns.doc("publish_snippet_workflow")
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
@console_ns.response(200, "Workflow published successfully")
@console_ns.response(400, "No draft workflow found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Publish snippet workflow."""
current_user, _ = current_account_with_tenant()
snippet_service = SnippetService()
with Session(db.engine) as session:
snippet = session.merge(snippet)
try:
workflow = snippet_service.publish_workflow(
session=session,
snippet=snippet,
account=current_user,
)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return {
"result": "success",
"created_at": workflow_created_at,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
class SnippetDefaultBlockConfigsApi(Resource):
@console_ns.doc("get_snippet_default_block_configs")
@console_ns.response(200, "Default block configs retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get default block configurations for snippet workflow."""
snippet_service = SnippetService()
return snippet_service.get_default_block_configs()
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
class SnippetWorkflowRunsApi(Resource):
@console_ns.doc("list_snippet_workflow_runs")
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_pagination_model)
def get(self, snippet: CustomizedSnippet):
"""List workflow runs for snippet."""
query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": query.last_id,
"limit": query.limit,
}
snippet_service = SnippetService()
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
return result
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
class SnippetWorkflowRunDetailApi(Resource):
@console_ns.doc("get_snippet_workflow_run_detail")
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_detail_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""Get workflow run detail for snippet."""
run_id = str(run_id)
snippet_service = SnippetService()
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
if not workflow_run:
raise NotFound("Workflow run not found")
return workflow_run
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
class SnippetWorkflowRunNodeExecutionsApi(Resource):
@console_ns.doc("list_snippet_workflow_run_node_executions")
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_list_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""List node executions for a workflow run."""
run_id = str(run_id)
snippet_service = SnippetService()
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
snippet=snippet,
run_id=run_id,
)
return {"data": node_executions}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
class SnippetDraftNodeRunApi(Resource):
@console_ns.doc("run_snippet_draft_node")
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
@console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_model)
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a single node in snippet draft workflow.
Executes a specific node with provided inputs for single-step debugging.
Returns the node execution result including status, outputs, and timing.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
user_inputs = payload.inputs
# Get draft workflow for file parsing
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
workflow_node_execution = SnippetGenerateService.run_draft_node(
snippet=snippet,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query=payload.query,
files=files,
)
return workflow_node_execution
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
class SnippetDraftNodeLastRunApi(Resource):
@console_ns.doc("get_snippet_draft_node_last_run")
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_model)
def get(self, snippet: CustomizedSnippet, node_id: str):
"""
Get the last run result for a specific node in snippet draft workflow.
Returns the most recent execution record for the given node,
including status, inputs, outputs, and timing information.
"""
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
node_exec = snippet_service.get_snippet_node_last_run(
snippet=snippet,
workflow=draft_workflow,
node_id=node_id,
)
if node_exec is None:
raise NotFound("Node last run not found")
return node_exec
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class SnippetDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_snippet_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow iteration node for snippet.
Iteration nodes execute their internal sub-graph multiple times over an input list.
Returns an SSE event stream with iteration progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate_single_iteration(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class SnippetDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_snippet_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow loop node for snippet.
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
Returns an SSE event stream with loop progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = SnippetGenerateService.generate_single_loop(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
class SnippetDraftWorkflowRunApi(Resource):
@console_ns.doc("run_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""
Run draft workflow for snippet.
Executes the snippet's draft workflow with the provided inputs
and returns an SSE event stream with execution progress and results.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate(
snippet=snippet,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
class SnippetWorkflowTaskStopApi(Resource):
@console_ns.doc("stop_snippet_workflow_task")
@console_ns.response(200, "Task stopped successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, task_id: str):
"""
Stop a running snippet workflow task.
Uses both the legacy stop flag mechanism and the graph engine
command channel for backward compatibility.
"""
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}

View File

@@ -1,202 +0,0 @@
import logging
from flask import request
from flask_restx import Resource, marshal, marshal_with
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.snippets.payloads import (
CreateSnippetPayload,
SnippetListQuery,
UpdateSnippetPayload,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
from libs.login import current_account_with_tenant, login_required
from models.snippet import SnippetType
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetListQuery,
CreateSnippetPayload,
UpdateSnippetPayload,
)
# Create namespace models for marshaling
snippet_model = console_ns.model("Snippet", snippet_fields)
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
@console_ns.route("/workspaces/current/customized-snippets")
class CustomizedSnippetsApi(Resource):
@console_ns.doc("list_customized_snippets")
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List customized snippets with pagination and search."""
_, current_tenant_id = current_account_with_tenant()
query_params = request.args.to_dict()
query = SnippetListQuery.model_validate(query_params)
snippets, total, has_more = SnippetService.get_snippets(
tenant_id=current_tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
)
return {
"data": marshal(snippets, snippet_list_fields),
"page": query.page,
"limit": query.limit,
"total": total,
"has_more": has_more,
}, 200
@console_ns.doc("create_customized_snippet")
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
@console_ns.response(201, "Snippet created successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Create a new customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
try:
snippet_type = SnippetType(payload.type)
except ValueError:
snippet_type = SnippetType.NODE
try:
snippet = SnippetService.create_snippet(
tenant_id=current_tenant_id,
name=payload.name,
description=payload.description,
snippet_type=snippet_type,
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
graph=payload.graph,
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
account=current_user,
)
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 201
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
class CustomizedSnippetDetailApi(Resource):
@console_ns.doc("get_customized_snippet")
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
def get(self, snippet_id: str):
"""Get customized snippet details."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
return marshal(snippet, snippet_fields), 200
@console_ns.doc("update_customized_snippet")
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
@console_ns.response(200, "Snippet updated successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, snippet_id: str):
"""Update customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
if "icon_info" in update_data and update_data["icon_info"] is not None:
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
if not update_data:
return {"message": "No valid fields to update"}, 400
try:
with Session(db.engine, expire_on_commit=False) as session:
snippet = session.merge(snippet)
snippet = SnippetService.update_snippet(
session=session,
snippet=snippet,
account_id=current_user.id,
data=update_data,
)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 200
@console_ns.doc("delete_customized_snippet")
@console_ns.response(204, "Snippet deleted successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, snippet_id: str):
"""Delete customized snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.delete_snippet(
session=session,
snippet=snippet,
)
session.commit()
return "", 204

View File

@@ -8,9 +8,9 @@ from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
from core.mcp import types as mcp_types
from core.mcp.server.streamable_http import handle_mcp_request
from core.workflow.variables.input_entities import VariableEntity
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode, EndUser

View File

@@ -32,7 +32,7 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,

View File

@@ -17,8 +17,8 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine

View File

@@ -21,7 +21,7 @@ from core.model_runtime.entities import (
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.model_runtime.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.file import file_manager

View File

@@ -5,7 +5,7 @@ from core.app.app_config.entities import (
PromptTemplateEntity,
)
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.prompt.simple_prompt_transform import ModelMode
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from models.model import AppMode

View File

@@ -1,8 +1,7 @@
import re
from core.app.app_config.entities import ExternalDataVariableEntity
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
[

View File

@@ -2,12 +2,12 @@ from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
from pydantic import BaseModel, Field
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.file import FileUploadConfig
from core.workflow.variables.input_entities import VariableEntity as WorkflowVariableEntity
from core.workflow.file import FileTransferMethod, FileType, FileUploadConfig
from models.model import AppMode
@@ -90,7 +90,61 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
class RagPipelineVariableEntity(WorkflowVariableEntity):
class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel):
"""
Variable Entity.
"""
# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""
type: VariableEntityType
required: bool = False
hide: bool = False
default: Any = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
def convert_none_description(cls, v: Any) -> str:
return v or ""
@field_validator("options", mode="before")
@classmethod
def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
class RagPipelineVariableEntity(VariableEntity):
"""
Rag Pipeline Variable Entity.
"""
@@ -260,7 +314,7 @@ class AppConfig(BaseModel):
app_id: str
app_mode: AppMode
additional_features: AppAdditionalFeatures | None = None
variables: list[WorkflowVariableEntity] = []
variables: list[VariableEntity] = []
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None

View File

@@ -1,7 +1,6 @@
import re
from core.app.app_config.entities import RagPipelineVariableEntity
from core.workflow.variables.input_entities import VariableEntity
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from models.workflow import Workflow

View File

@@ -32,8 +32,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.model_runtime.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import (

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Union, final
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import NodeType
from core.workflow.file import File, FileUploadConfig
@@ -11,14 +12,13 @@ from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
NoopDraftVariableSaver,
)
from core.workflow.variables.input_entities import VariableEntityType
from factories import file_factory
from libs.orjson import orjson_dumps
from models import Account, EndUser
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
from core.workflow.variables.input_entities import VariableEntity
from core.app.app_config.entities import VariableEntity
class BaseAppGenerator:

View File

@@ -33,10 +33,14 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
ChatModelMessage,
CompletionModelPromptTemplate,
MemoryConfig,
)
from core.model_runtime.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.file.enums import FileTransferMethod, FileType
from extensions.ext_database import db

View File

@@ -27,7 +27,7 @@ from core.app.entities.task_entities import (
CompletionAppStreamResponse,
)
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic

View File

@@ -1,5 +1 @@
"""LLM-related application services."""
from .quota import deduct_llm_quota, ensure_llm_quota_available
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]

View File

@@ -1,93 +0,0 @@
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@@ -52,10 +52,10 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.signature import sign_tool_file
from core.workflow.file import helpers as file_helpers
from core.workflow.file.enums import FileTransferMethod
@@ -157,7 +157,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
id=self._message_id,
mode=self._conversation_mode,
message_id=self._message_id,
answer=self._task_state.llm_result.message.get_text_content(),
answer=cast(str, self._task_state.llm_result.message.content),
created_at=self._message_created_at,
**extras,
),
@@ -170,7 +170,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
answer=self._task_state.llm_result.message.get_text_content(),
answer=cast(str, self._task_state.llm_result.message.content),
created_at=self._message_created_at,
**extras,
),
@@ -283,7 +283,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
# handle output moderation
output_moderation_answer = self.handle_output_moderation_when_task_finished(
self._task_state.llm_result.message.get_text_content()
cast(str, self._task_state.llm_result.message.content)
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
@@ -397,7 +397,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.message_unit_price = usage.prompt_unit_price
message.message_price_unit = usage.prompt_price_unit
message.answer = (
PromptTemplateParser.remove_template_variables(llm_result.message.get_text_content().strip())
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
if llm_result.message.content
else ""
)

View File

@@ -1,11 +1,9 @@
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
from .llm_quota import LLMQuotaLayer
from .observability import ObservabilityLayer
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
__all__ = [
"LLMQuotaLayer",
"ObservabilityLayer",
"PersistenceWorkflowInfo",
"WorkflowPersistenceLayer",

View File

@@ -1,128 +0,0 @@
"""
LLM quota deduction layer for GraphEngine.
This layer centralizes model-quota deduction outside node implementations.
"""
import logging
from typing import TYPE_CHECKING, cast, final
from typing_extensions import override
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.workflow.enums import NodeType
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.nodes.base.node import Node
if TYPE_CHECKING:
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
logger = logging.getLogger(__name__)
@final
class LLMQuotaLayer(GraphEngineLayer):
"""Graph layer that applies LLM quota deduction after node execution."""
def __init__(self) -> None:
super().__init__()
self._abort_sent = False
@override
def on_graph_start(self) -> None:
self._abort_sent = False
@override
def on_event(self, event: GraphEngineEvent) -> None:
_ = event
@override
def on_graph_end(self, error: Exception | None) -> None:
_ = error
@override
def on_node_run_start(self, node: Node) -> None:
if self._abort_sent:
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
ensure_llm_quota_available(model_instance=model_instance)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
@override
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
deduct_llm_quota(
tenant_id=node.tenant_id,
model_instance=model_instance,
usage=result_event.node_run_result.llm_usage,
)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
except Exception:
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
@staticmethod
def _set_stop_event(node: Node) -> None:
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
if stop_event is not None:
stop_event.set()
def _send_abort_command(self, *, reason: str) -> None:
if not self.command_channel or self._abort_sent:
return
try:
self.command_channel.send_command(
AbortCommand(
command_type=CommandType.ABORT,
reason=reason,
)
)
self._abort_sent = True
except Exception:
logger.exception("Failed to send quota abort command")
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case NodeType.LLM:
return cast("LLMNode", node).model_instance
case NodeType.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
case NodeType.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
node.id,
)
return None

View File

@@ -1,8 +1,6 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, final
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
@@ -13,16 +11,14 @@ from core.helper.code_executor.code_executor import (
CodeExecutor,
)
from core.helper.ssrf_proxy import ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType, SystemVariableKey
from core.workflow.enums import NodeType
from core.workflow.file.file_manager import file_manager
from core.workflow.graph.graph import NodeFactory
from core.workflow.nodes.base.node import Node
@@ -33,9 +29,11 @@ from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import PromptMessageMemory
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
@@ -43,34 +41,12 @@ from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
def fetch_memory(
*,
conversation_id: str | None,
app_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
) -> TokenBufferMemory | None:
if not node_data_memory or not conversation_id:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
class DefaultWorkflowCodeExecutor:
def execute(
self,
@@ -245,7 +221,6 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return QuestionClassifierNode(
id=node_id,
config=node_config,
@@ -254,12 +229,10 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return ParameterExtractorNode(
id=node_id,
config=node_config,
@@ -268,7 +241,6 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
return node_class(
@@ -323,14 +295,8 @@ class DifyNodeFactory(NodeFactory):
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
conversation_id = (
conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None
)
return fetch_memory(
conversation_id=conversation_id,
return llm_utils.fetch_memory(
variable_pool=self.graph_runtime_state.variable_pool,
app_id=self.graph_init_params.app_id,
node_data_memory=node_memory,
model_instance=model_instance,

View File

@@ -1,64 +0,0 @@
from abc import ABC, abstractmethod
from core.evaluation.entities.evaluation_entity import (
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
)
class BaseEvaluationInstance(ABC):
"""Abstract base class for evaluation framework adapters."""
@abstractmethod
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate LLM outputs using the configured framework."""
...
@abstractmethod
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate retrieval quality using the configured framework."""
...
@abstractmethod
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate agent outputs using the configured framework."""
...
@abstractmethod
def evaluate_workflow(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate workflow outputs using the configured framework."""
...
@abstractmethod
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
"""Return the list of supported metric names for a given evaluation category."""
...

View File

@@ -1,25 +0,0 @@
from enum import StrEnum
from pydantic import BaseModel
class EvaluationFrameworkEnum(StrEnum):
RAGAS = "ragas"
DEEPEVAL = "deepeval"
CUSTOMIZED = "customized"
NONE = "none"
class BaseEvaluationConfig(BaseModel):
"""Base configuration for evaluation frameworks."""
pass
class RagasConfig(BaseEvaluationConfig):
"""RAGAS-specific configuration."""
pass
class CustomizedEvaluatorConfig(BaseEvaluationConfig):
"""Configuration for the customized workflow-based evaluator."""
pass

View File

@@ -1,94 +0,0 @@
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from core.evaluation.entities.judgment_entity import JudgmentConfig, JudgmentResult
class EvaluationCategory(StrEnum):
LLM = "llm"
RETRIEVAL = "knowledge_retrieval"
AGENT = "agent"
WORKFLOW = "workflow"
RETRIEVAL_TEST = "retrieval_test"
class EvaluationMetric(BaseModel):
name: str
score: float
details: dict[str, Any] = Field(default_factory=dict)
class EvaluationItemInput(BaseModel):
index: int
inputs: dict[str, Any]
expected_output: str | None = None
context: list[str] | None = None
class EvaluationItemResult(BaseModel):
index: int
actual_output: str | None = None
metrics: list[EvaluationMetric] = Field(default_factory=list)
judgment: JudgmentResult | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
error: str | None = None
@property
def overall_score(self) -> float | None:
if not self.metrics:
return None
scores = [m.score for m in self.metrics]
return sum(scores) / len(scores)
class NodeInfo(BaseModel):
node_id: str
type: str
title: str
class DefaultMetric(BaseModel):
metric: str
node_info_list: list[NodeInfo]
class CustomizedMetricOutputField(BaseModel):
variable: str
value_type: str
class CustomizedMetrics(BaseModel):
evaluation_workflow_id: str
input_fields: dict[str, str]
output_fields: list[CustomizedMetricOutputField]
class EvaluationConfigData(BaseModel):
"""Structured data for saving evaluation configuration."""
evaluation_model: str = ""
evaluation_model_provider: str = ""
default_metrics: list[DefaultMetric] = Field(default_factory=list)
customized_metrics: CustomizedMetrics | None = None
judgment_config: JudgmentConfig | None = None
class EvaluationRunRequest(EvaluationConfigData):
"""Request body for starting an evaluation run."""
file_id: str
class EvaluationRunData(BaseModel):
"""Serializable data for Celery task."""
evaluation_run_id: str
tenant_id: str
target_type: str
target_id: str
evaluation_category: EvaluationCategory
evaluation_model_provider: str
evaluation_model: str
default_metrics: list[dict[str, Any]] = Field(default_factory=list)
customized_metrics: dict[str, Any] | None = None
judgment_config: JudgmentConfig | None = None
items: list[EvaluationItemInput]

View File

@@ -1,84 +0,0 @@
"""Judgment condition entities for evaluation metric assessment.
Typical usage:
judgment_config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(metric_name="faithfulness", comparison_operator=">", value="0.8"),
JudgmentCondition(metric_name="answer_relevancy", comparison_operator="", value="0.7"),
],
)
"""
from collections.abc import Sequence
from typing import Any, Literal
from pydantic import BaseModel, Field
from core.workflow.utils.condition.entities import SupportedComparisonOperator
class JudgmentCondition(BaseModel):
"""A single judgment condition that checks one metric value.
Attributes:
metric_name: The name of the evaluation metric to check
(must match an EvaluationMetric.name in the results).
comparison_operator: The comparison operator to apply
(reuses the same operator set as workflow condition branches).
value: The expected/threshold value to compare against.
For numeric operators (>, <, =, etc.), this should be a numeric string.
For string operators (contains, is, etc.), this should be a string.
For unary operators (empty, null, etc.), this can be None.
"""
metric_name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None = None
class JudgmentConfig(BaseModel):
"""A group of judgment conditions combined with a logical operator.
Attributes:
logical_operator: How to combine condition results — "and" requires
all conditions to pass, "or" requires at least one.
conditions: The list of individual conditions to evaluate.
"""
logical_operator: Literal["and", "or"] = "and"
conditions: list[JudgmentCondition] = Field(default_factory=list)
class JudgmentConditionResult(BaseModel):
"""Result of evaluating a single judgment condition.
Attributes:
metric_name: Which metric was checked.
comparison_operator: The operator that was applied.
expected_value: The threshold/expected value from the condition config.
actual_value: The actual metric value that was evaluated.
passed: Whether this individual condition passed.
error: Error message if the condition evaluation failed.
"""
metric_name: str
comparison_operator: str
expected_value: Any = None
actual_value: Any = None
passed: bool = False
error: str | None = None
class JudgmentResult(BaseModel):
"""Overall result of evaluating all judgment conditions for one item.
Attributes:
passed: Whether the overall judgment passed (based on logical_operator).
logical_operator: The logical operator used to combine conditions.
condition_results: Detailed result for each individual condition.
"""
passed: bool = False
logical_operator: Literal["and", "or"] = "and"
condition_results: list[JudgmentConditionResult] = Field(default_factory=list)

View File

@@ -1,69 +0,0 @@
import collections
import logging
from typing import Any
from configs import dify_config
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import EvaluationFrameworkEnum
from core.evaluation.entities.evaluation_entity import EvaluationCategory
logger = logging.getLogger(__name__)
class EvaluationFrameworkConfigMap(collections.UserDict[str, dict[str, Any]]):
"""Registry mapping framework enum -> {config_class, evaluator_class}."""
def __getitem__(self, framework: str) -> dict[str, Any]:
match framework:
case EvaluationFrameworkEnum.RAGAS:
from core.evaluation.entities.config_entity import RagasConfig
from core.evaluation.frameworks.ragas.ragas_evaluator import RagasEvaluator
return {
"config_class": RagasConfig,
"evaluator_class": RagasEvaluator,
}
case EvaluationFrameworkEnum.DEEPEVAL:
raise NotImplementedError("DeepEval adapter is not yet implemented.")
case EvaluationFrameworkEnum.CUSTOMIZED:
from core.evaluation.entities.config_entity import CustomizedEvaluatorConfig
from core.evaluation.frameworks.customized.customized_evaluator import CustomizedEvaluator
return {
"config_class": CustomizedEvaluatorConfig,
"evaluator_class": CustomizedEvaluator,
}
case _:
raise ValueError(f"Unknown evaluation framework: {framework}")
evaluation_framework_config_map = EvaluationFrameworkConfigMap()
class EvaluationManager:
"""Factory for evaluation instances based on global configuration."""
@staticmethod
def get_evaluation_instance() -> BaseEvaluationInstance | None:
"""Create and return an evaluation instance based on EVALUATION_FRAMEWORK env var."""
framework = dify_config.EVALUATION_FRAMEWORK
if not framework or framework == EvaluationFrameworkEnum.NONE:
return None
try:
config_map = evaluation_framework_config_map[framework]
evaluator_class = config_map["evaluator_class"]
config_class = config_map["config_class"]
config = config_class()
return evaluator_class(config)
except Exception:
logger.exception("Failed to create evaluation instance for framework: %s", framework)
return None
@staticmethod
def get_supported_metrics(category: EvaluationCategory) -> list[str]:
"""Return supported metrics for the current framework and given category."""
instance = EvaluationManager.get_evaluation_instance()
if instance is None:
return []
return instance.get_supported_metrics(category)

View File

@@ -1,267 +0,0 @@
"""Customized workflow-based evaluator.
Uses a published workflow as the evaluation strategy. The target's actual output,
expected output, original inputs, and context are passed as workflow inputs.
The workflow's output variables are treated as evaluation metrics.
The evaluation workflow_id is provided per evaluation run via
metrics_config["workflow_id"].
"""
import json
import logging
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import CustomizedEvaluatorConfig
from core.evaluation.entities.evaluation_entity import (
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
)
logger = logging.getLogger(__name__)
class CustomizedEvaluator(BaseEvaluationInstance):
"""Evaluate using a published workflow."""
def __init__(self, config: CustomizedEvaluatorConfig):
self.config = config
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate_with_workflow(items, metrics_config, tenant_id)
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate_with_workflow(items, metrics_config, tenant_id)
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate_with_workflow(items, metrics_config, tenant_id)
def evaluate_workflow(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate_with_workflow(items, metrics_config, tenant_id)
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
"""Metrics are dynamic and defined by the evaluation workflow outputs.
Return an empty list since available metrics depend on the specific
workflow chosen at runtime.
"""
return []
def _evaluate_with_workflow(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Run the evaluation workflow for each item and extract metric scores.
Args:
items: Evaluation items with inputs, expected_output, and context
(context typically contains the target's actual_output, merged
by the Runner's evaluate_metrics method).
metrics_config: Must contain "workflow_id" pointing to a published
WORKFLOW-type App.
tenant_id: Tenant scope for database and workflow execution.
Returns:
List of EvaluationItemResult with metrics extracted from workflow outputs.
Raises:
ValueError: If workflow_id is missing from metrics_config or the
workflow/app cannot be found.
"""
workflow_id = metrics_config.get("workflow_id")
if not workflow_id:
raise ValueError(
"metrics_config must contain 'workflow_id' for customized evaluator"
)
app, workflow, service_account = self._load_workflow_resources(workflow_id, tenant_id)
results: list[EvaluationItemResult] = []
for item in items:
try:
result = self._evaluate_single_item(app, workflow, service_account, item)
results.append(result)
except Exception:
logger.exception(
"Customized evaluator failed for item %d with workflow %s",
item.index,
workflow_id,
)
results.append(EvaluationItemResult(index=item.index))
return results
def _evaluate_single_item(
self,
app: Any,
workflow: Any,
service_account: Any,
item: EvaluationItemInput,
) -> EvaluationItemResult:
"""Run the evaluation workflow for a single item.
Builds workflow inputs from the item data and executes the workflow
in non-streaming mode. Extracts metrics from the workflow's output
variables.
"""
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
workflow_inputs = self._build_workflow_inputs(item)
generator = WorkflowAppGenerator()
response: Mapping[str, Any] = generator.generate(
app_model=app,
workflow=workflow,
user=service_account,
args={"inputs": workflow_inputs},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
metrics = self._extract_metrics(response)
return EvaluationItemResult(
index=item.index,
metrics=metrics,
metadata={"workflow_response": self._safe_serialize(response)},
)
def _load_workflow_resources(
self, workflow_id: str, tenant_id: str
) -> tuple[Any, Any, Any]:
"""Load the evaluation workflow App, its published workflow, and a service account.
Args:
workflow_id: The App ID of the evaluation workflow.
tenant_id: Tenant scope.
Returns:
Tuple of (app, workflow, service_account).
Raises:
ValueError: If the app or published workflow cannot be found.
"""
from sqlalchemy.orm import Session
from core.evaluation.runners import get_service_account_for_app
from models.engine import db
from models.model import App
from services.workflow_service import WorkflowService
with Session(db.engine, expire_on_commit=False) as session, session.begin():
app = session.query(App).filter_by(id=workflow_id, tenant_id=tenant_id).first()
if not app:
raise ValueError(
f"Evaluation workflow app {workflow_id} not found in tenant {tenant_id}"
)
service_account = get_service_account_for_app(session, workflow_id)
workflow_service = WorkflowService()
published_workflow = workflow_service.get_published_workflow(app_model=app)
if not published_workflow:
raise ValueError(
f"No published workflow found for evaluation app {workflow_id}"
)
return app, published_workflow, service_account
@staticmethod
def _build_workflow_inputs(item: EvaluationItemInput) -> dict[str, Any]:
"""Build workflow input dict from an evaluation item.
Maps evaluation data to conventional workflow input variable names:
- actual_output: The target's actual output (from context[0] if available)
- expected_output: The expected/reference output
- inputs: The original evaluation inputs as JSON string
- context: All context strings joined by newlines
"""
workflow_inputs: dict[str, Any] = {}
# The actual_output is typically the first element in context
# (merged by the Runner's evaluate_metrics method)
if item.context:
workflow_inputs["actual_output"] = item.context[0] if len(item.context) == 1 else "\n\n".join(item.context)
if item.expected_output:
workflow_inputs["expected_output"] = item.expected_output
if item.inputs:
workflow_inputs["inputs"] = json.dumps(item.inputs, ensure_ascii=False)
if item.context and len(item.context) > 1:
workflow_inputs["context"] = "\n\n".join(item.context)
return workflow_inputs
@staticmethod
def _extract_metrics(response: Mapping[str, Any]) -> list[EvaluationMetric]:
"""Extract evaluation metrics from workflow output variables.
Each output variable is treated as a metric.
"""
metrics: list[EvaluationMetric] = []
data = response.get("data", {})
if not isinstance(data, Mapping):
logger.warning("Unexpected workflow response format: missing 'data' dict")
return metrics
outputs = data.get("outputs", {})
if not isinstance(outputs, Mapping):
logger.warning("Unexpected workflow response format: 'outputs' is not a dict")
return metrics
for key, value in outputs.items():
try:
score = float(value)
metrics.append(EvaluationMetric(name=key, score=score))
except (TypeError, ValueError):
metrics.append(
EvaluationMetric(name=key, score=0.0, details={"raw_value": value})
)
return metrics
@staticmethod
def _safe_serialize(response: Mapping[str, Any]) -> dict[str, Any]:
"""Safely serialize workflow response for metadata storage."""
try:
return dict(response)
except Exception:
return {"raw": str(response)}

View File

@@ -1,279 +0,0 @@
import logging
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import RagasConfig
from core.evaluation.entities.evaluation_entity import (
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
)
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
logger = logging.getLogger(__name__)
# Metric name mappings per category
LLM_METRICS = ["faithfulness", "answer_relevancy", "answer_correctness", "answer_similarity"]
RETRIEVAL_METRICS = ["context_precision", "context_recall", "context_relevancy"]
AGENT_METRICS = ["tool_call_accuracy", "answer_correctness"]
WORKFLOW_METRICS = ["faithfulness", "answer_correctness"]
class RagasEvaluator(BaseEvaluationInstance):
"""RAGAS framework adapter for evaluation."""
def __init__(self, config: RagasConfig):
self.config = config
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
match category:
case EvaluationCategory.LLM:
return LLM_METRICS
case EvaluationCategory.RETRIEVAL:
return RETRIEVAL_METRICS
case EvaluationCategory.AGENT:
return AGENT_METRICS
case EvaluationCategory.WORKFLOW:
return WORKFLOW_METRICS
case _:
return []
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metrics_config, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(
items, metrics_config, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL
)
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metrics_config, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
def evaluate_workflow(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(
items, metrics_config, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW
)
def _evaluate(
self,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Core evaluation logic using RAGAS.
Uses the Dify model wrapper as judge LLM. Falls back to simple
string similarity if RAGAS import fails.
"""
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
requested_metrics = metrics_config.get("metrics", self.get_supported_metrics(category))
try:
return self._evaluate_with_ragas(items, requested_metrics, model_wrapper, category)
except ImportError:
logger.warning("RAGAS not installed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
def _evaluate_with_ragas(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Evaluate using RAGAS library."""
from ragas import evaluate as ragas_evaluate
from ragas.dataset_schema import EvaluationDataset, SingleTurnSample
from ragas.llms import LangchainLLMWrapper
from ragas.metrics import (
Faithfulness,
ResponseRelevancy,
)
# Build RAGAS dataset
samples = []
for item in items:
sample = SingleTurnSample(
user_input=self._inputs_to_query(item.inputs),
response=item.expected_output or "",
retrieved_contexts=item.context or [],
)
if item.expected_output:
sample.reference = item.expected_output
samples.append(sample)
dataset = EvaluationDataset(samples=samples)
# Build metric instances
ragas_metrics = self._build_ragas_metrics(requested_metrics)
if not ragas_metrics:
logger.warning("No valid RAGAS metrics found for: %s", requested_metrics)
return [EvaluationItemResult(index=item.index) for item in items]
# Run RAGAS evaluation
try:
result = ragas_evaluate(
dataset=dataset,
metrics=ragas_metrics,
)
# Convert RAGAS results to our format
results = []
result_df = result.to_pandas()
for i, item in enumerate(items):
metrics = []
for metric_name in requested_metrics:
if metric_name in result_df.columns:
score = result_df.iloc[i][metric_name]
if score is not None and not (isinstance(score, float) and score != score): # NaN check
metrics.append(EvaluationMetric(name=metric_name, score=float(score)))
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
except Exception:
logger.exception("RAGAS evaluation failed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
def _evaluate_simple(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
) -> list[EvaluationItemResult]:
"""Simple LLM-as-judge fallback when RAGAS is not available."""
results = []
for item in items:
metrics = []
query = self._inputs_to_query(item.inputs)
for metric_name in requested_metrics:
try:
score = self._judge_with_llm(model_wrapper, metric_name, query, item)
metrics.append(EvaluationMetric(name=metric_name, score=score))
except Exception:
logger.exception("Failed to compute metric %s for item %d", metric_name, item.index)
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
def _judge_with_llm(
self,
model_wrapper: DifyModelWrapper,
metric_name: str,
query: str,
item: EvaluationItemInput,
) -> float:
"""Use the LLM to judge a single metric for a single item."""
prompt = self._build_judge_prompt(metric_name, query, item)
response = model_wrapper.invoke(prompt)
return self._parse_score(response)
def _build_judge_prompt(self, metric_name: str, query: str, item: EvaluationItemInput) -> str:
"""Build a scoring prompt for the LLM judge."""
parts = [
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
f"\nQuery: {query}",
]
if item.expected_output:
parts.append(f"\nExpected Output: {item.expected_output}")
if item.context:
parts.append(f"\nContext: {'; '.join(item.context)}")
parts.append(
"\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else."
)
return "\n".join(parts)
@staticmethod
def _parse_score(response: str) -> float:
"""Parse a float score from LLM response."""
cleaned = response.strip()
try:
score = float(cleaned)
return max(0.0, min(1.0, score))
except ValueError:
# Try to extract first number from response
import re
match = re.search(r"(\d+\.?\d*)", cleaned)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
return 0.0
@staticmethod
def _inputs_to_query(inputs: dict[str, Any]) -> str:
"""Convert input dict to a query string."""
if "query" in inputs:
return str(inputs["query"])
if "question" in inputs:
return str(inputs["question"])
# Fallback: concatenate all input values
return " ".join(str(v) for v in inputs.values())
@staticmethod
def _build_ragas_metrics(requested_metrics: list[str]) -> list[Any]:
"""Build RAGAS metric instances from metric names."""
try:
from ragas.metrics import (
AnswerCorrectness,
AnswerRelevancy,
AnswerSimilarity,
ContextPrecision,
ContextRecall,
ContextRelevancy,
Faithfulness,
)
metric_map: dict[str, Any] = {
"faithfulness": Faithfulness,
"answer_relevancy": AnswerRelevancy,
"answer_correctness": AnswerCorrectness,
"answer_similarity": AnswerSimilarity,
"context_precision": ContextPrecision,
"context_recall": ContextRecall,
"context_relevancy": ContextRelevancy,
}
metrics = []
for name in requested_metrics:
metric_class = metric_map.get(name)
if metric_class:
metrics.append(metric_class())
else:
logger.warning("Unknown RAGAS metric: %s", name)
return metrics
except ImportError:
logger.warning("RAGAS metrics not available")
return []

View File

@@ -1,48 +0,0 @@
import logging
from typing import Any
logger = logging.getLogger(__name__)
class DifyModelWrapper:
"""Wraps Dify's model invocation interface for use by RAGAS as an LLM judge.
RAGAS requires an LLM to compute certain metrics (faithfulness, answer_relevancy, etc.).
This wrapper bridges Dify's ModelInstance to a callable that RAGAS can use.
"""
def __init__(self, model_provider: str, model_name: str, tenant_id: str):
self.model_provider = model_provider
self.model_name = model_name
self.tenant_id = tenant_id
def _get_model_instance(self) -> Any:
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.model_provider,
model_type=ModelType.LLM,
model=self.model_name,
)
return model_instance
def invoke(self, prompt: str) -> str:
"""Invoke the model with a text prompt and return the text response."""
from core.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
model_instance = self._get_model_instance()
result = model_instance.invoke_llm(
prompt_messages=[
SystemPromptMessage(content="You are an evaluation judge. Answer precisely and concisely."),
UserPromptMessage(content=prompt),
],
model_parameters={"temperature": 0.0, "max_tokens": 2048},
stream=False,
)
return result.message.content

View File

@@ -1,144 +0,0 @@
"""Judgment condition processor for evaluation metrics.
Evaluates pass/fail judgment conditions against evaluation metric values.
Reuses the core comparison engine from the workflow condition system
(core.workflow.utils.condition.processor._evaluate_condition) to ensure
consistent operator semantics across the platform.
"""
import logging
from typing import Any
from core.evaluation.entities.judgment_entity import (
JudgmentCondition,
JudgmentConditionResult,
JudgmentConfig,
JudgmentResult,
)
from core.workflow.utils.condition.processor import _evaluate_condition
logger = logging.getLogger(__name__)
class JudgmentProcessor:
@staticmethod
def evaluate(
metric_values: dict[str, Any],
config: JudgmentConfig,
) -> JudgmentResult:
"""Evaluate all judgment conditions against the given metric values.
Args:
metric_values: Mapping of metric name to its value
(e.g. {"faithfulness": 0.85, "status": "success"}).
config: The judgment configuration with logical_operator and conditions.
Returns:
JudgmentResult with overall pass/fail and per-condition details.
"""
if not config.conditions:
return JudgmentResult(
passed=True,
logical_operator=config.logical_operator,
condition_results=[],
)
condition_results: list[JudgmentConditionResult] = []
for condition in config.conditions:
result = JudgmentProcessor._evaluate_single_condition(
metric_values, condition
)
condition_results.append(result)
if config.logical_operator == "and" and not result.passed:
return JudgmentResult(
passed=False,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
if config.logical_operator == "or" and result.passed:
return JudgmentResult(
passed=True,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
if config.logical_operator == "and":
final_passed = all(r.passed for r in condition_results)
else:
final_passed = any(r.passed for r in condition_results)
return JudgmentResult(
passed=final_passed,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
@staticmethod
def _evaluate_single_condition(
metric_values: dict[str, Any],
condition: JudgmentCondition,
) -> JudgmentConditionResult:
"""Evaluate a single judgment condition against the metric values.
Looks up the metric by name, then delegates to the workflow condition
engine for the actual comparison.
Args:
metric_values: Mapping of metric name to its value.
condition: The condition to evaluate.
Returns:
JudgmentConditionResult with pass/fail and details.
"""
metric_name = condition.metric_name
actual_value = metric_values.get(metric_name)
# Handle metric not found
if actual_value is None and condition.comparison_operator not in (
"null",
"not null",
"empty",
"not empty",
"exists",
"not exists",
):
return JudgmentConditionResult(
metric_name=metric_name,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=None,
passed=False,
error=f"Metric '{metric_name}' not found in evaluation results",
)
try:
passed = _evaluate_condition(
operator=condition.comparison_operator,
value=actual_value,
expected=condition.value,
)
return JudgmentConditionResult(
metric_name=metric_name,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=actual_value,
passed=passed,
)
except Exception as e:
logger.warning(
"Judgment condition evaluation failed for metric '%s': %s",
metric_name,
str(e),
)
return JudgmentConditionResult(
metric_name=metric_name,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=actual_value,
passed=False,
error=str(e),
)

View File

@@ -1,32 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import Account, App, TenantAccountJoin
def get_service_account_for_app(session: Session, app_id: str) -> Account:
"""Get the creator account for an app with tenant context set up.
This follows the same pattern as BaseTraceInstance.get_service_account_with_tenant().
"""
app = session.scalar(select(App).where(App.id == app_id))
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator")
account = session.scalar(select(Account).where(Account.id == app.created_by))
if not account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin)
.filter_by(account_id=account.id, current=True)
.first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {account.id}")
account.set_tenant_id(current_tenant.tenant_id)
return account

View File

@@ -1,152 +0,0 @@
import logging
from typing import Any, Mapping, Union
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from models.model import App, AppMode
logger = logging.getLogger(__name__)
class AgentEvaluationRunner(BaseEvaluationRunner):
"""Runner for agent evaluation: executes agent-type App, collects tool calls and final output."""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
super().__init__(evaluation_instance, session)
def execute_target(
self,
tenant_id: str,
target_id: str,
target_type: str,
item: EvaluationItemInput,
) -> EvaluationItemResult:
"""Execute agent app and collect response with tool call information."""
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.evaluation.runners import get_service_account_for_app
from core.app.entities.app_invoke_entities import InvokeFrom
app = self.session.query(App).filter_by(id=target_id).first()
if not app:
raise ValueError(f"App {target_id} not found")
service_account = get_service_account_for_app(self.session, target_id)
query = self._extract_query(item.inputs)
args: dict[str, Any] = {
"inputs": item.inputs,
"query": query,
}
generator = AgentChatAppGenerator()
# Agent chat requires streaming - collect full response
response_generator = generator.generate(
app_model=app,
user=service_account,
args=args,
invoke_from=InvokeFrom.SERVICE_API,
streaming=True,
)
# Consume the stream to get the full response
actual_output, tool_calls = self._consume_agent_stream(response_generator)
return EvaluationItemResult(
index=item.index,
actual_output=actual_output,
metadata={"tool_calls": tool_calls},
)
def evaluate_metrics(
self,
items: list[EvaluationItemInput],
results: list[EvaluationItemResult],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute agent evaluation metrics."""
result_by_index = {r.index: r for r in results}
merged_items = []
for item in items:
result = result_by_index.get(item.index)
context = []
if result and result.actual_output:
context.append(result.actual_output)
merged_items.append(
EvaluationItemInput(
index=item.index,
inputs=item.inputs,
expected_output=item.expected_output,
context=context + (item.context or []),
)
)
evaluated = self.evaluation_instance.evaluate_agent(
merged_items, metrics_config, model_provider, model_name, tenant_id
)
# Merge metrics back preserving metadata
eval_by_index = {r.index: r for r in evaluated}
final_results = []
for result in results:
if result.index in eval_by_index:
eval_result = eval_by_index[result.index]
final_results.append(
EvaluationItemResult(
index=result.index,
actual_output=result.actual_output,
metrics=eval_result.metrics,
metadata=result.metadata,
error=result.error,
)
)
else:
final_results.append(result)
return final_results
@staticmethod
def _extract_query(inputs: dict[str, Any]) -> str:
for key in ("query", "question", "input", "text"):
if key in inputs:
return str(inputs[key])
values = list(inputs.values())
return str(values[0]) if values else ""
@staticmethod
def _consume_agent_stream(response_generator: Any) -> tuple[str, list[dict]]:
"""Consume agent streaming response and extract final answer + tool calls."""
answer_parts: list[str] = []
tool_calls: list[dict] = []
try:
for chunk in response_generator:
if isinstance(chunk, Mapping):
event = chunk.get("event")
if event == "agent_thought":
thought = chunk.get("thought", "")
if thought:
answer_parts.append(thought)
tool = chunk.get("tool")
if tool:
tool_calls.append({
"tool": tool,
"tool_input": chunk.get("tool_input", ""),
})
elif event == "message":
answer = chunk.get("answer", "")
if answer:
answer_parts.append(answer)
elif isinstance(chunk, str):
answer_parts.append(chunk)
except Exception:
logger.exception("Error consuming agent stream")
return "".join(answer_parts), tool_calls

View File

@@ -1,171 +0,0 @@
import json
import logging
from abc import ABC, abstractmethod
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.entities.judgment_entity import JudgmentConfig
from core.evaluation.judgment.processor import JudgmentProcessor
from libs.datetime_utils import naive_utc_now
from models.evaluation import EvaluationRun, EvaluationRunItem, EvaluationRunStatus
logger = logging.getLogger(__name__)
class BaseEvaluationRunner(ABC):
"""Abstract base class for evaluation runners.
Runners are responsible for executing the target (App/Snippet/Retrieval)
to collect actual outputs, then delegating to the evaluation instance
for metric computation, and optionally applying judgment conditions.
Execution phases:
1. execute_target — run the target and collect actual outputs
2. evaluate_metrics — compute evaluation metrics via the framework
3. apply_judgment — evaluate pass/fail judgment conditions on metrics
4. persist — save results to the database
"""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
self.evaluation_instance = evaluation_instance
self.session = session
@abstractmethod
def execute_target(
self,
tenant_id: str,
target_id: str,
target_type: str,
item: EvaluationItemInput,
) -> EvaluationItemResult:
"""Execute the evaluation target for a single item and return the result with actual_output populated."""
...
@abstractmethod
def evaluate_metrics(
self,
items: list[EvaluationItemInput],
results: list[EvaluationItemResult],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute evaluation metrics on the collected results."""
...
def run(
self,
evaluation_run_id: str,
tenant_id: str,
target_id: str,
target_type: str,
items: list[EvaluationItemInput],
metrics_config: dict,
model_provider: str,
model_name: str,
judgment_config: JudgmentConfig | None = None,
) -> list[EvaluationItemResult]:
"""Orchestrate target execution + metric evaluation + judgment for all items.
"""
evaluation_run = self.session.query(EvaluationRun).filter_by(id=evaluation_run_id).first()
if not evaluation_run:
raise ValueError(f"EvaluationRun {evaluation_run_id} not found")
# Update status to running
evaluation_run.status = EvaluationRunStatus.RUNNING
evaluation_run.started_at = naive_utc_now()
self.session.commit()
results: list[EvaluationItemResult] = []
# Phase 1: Execute target for each item
for item in items:
try:
result = self.execute_target(tenant_id, target_id, target_type, item)
results.append(result)
evaluation_run.completed_items += 1
except Exception as e:
logger.exception("Failed to execute target for item %d", item.index)
results.append(
EvaluationItemResult(
index=item.index,
error=str(e),
)
)
evaluation_run.failed_items += 1
self.session.commit()
# Phase 2: Compute metrics on successful results
successful_items = [item for item, result in zip(items, results) if result.error is None]
successful_results = [r for r in results if r.error is None]
if successful_items and successful_results:
try:
evaluated_results = self.evaluate_metrics(
successful_items, successful_results, metrics_config, model_provider, model_name, tenant_id
)
# Merge evaluated metrics back into results
evaluated_by_index = {r.index: r for r in evaluated_results}
for i, result in enumerate(results):
if result.index in evaluated_by_index:
results[i] = evaluated_by_index[result.index]
except Exception:
logger.exception("Failed to compute metrics for evaluation run %s", evaluation_run_id)
# Phase 3: Apply judgment conditions on metrics
if judgment_config and judgment_config.conditions:
results = self._apply_judgment(results, judgment_config)
# Phase 4: Persist individual items
for result in results:
item_input = next((item for item in items if item.index == result.index), None)
run_item = EvaluationRunItem(
evaluation_run_id=evaluation_run_id,
item_index=result.index,
inputs=json.dumps(item_input.inputs) if item_input else None,
expected_output=item_input.expected_output if item_input else None,
context=json.dumps(item_input.context) if item_input and item_input.context else None,
actual_output=result.actual_output,
metrics=json.dumps([m.model_dump() for m in result.metrics]) if result.metrics else None,
judgment=json.dumps(result.judgment.model_dump()) if result.judgment else None,
metadata_json=json.dumps(result.metadata) if result.metadata else None,
error=result.error,
overall_score=result.overall_score,
)
self.session.add(run_item)
self.session.commit()
return results
@staticmethod
def _apply_judgment(
results: list[EvaluationItemResult],
judgment_config: JudgmentConfig,
) -> list[EvaluationItemResult]:
"""Apply judgment conditions to each result's metrics.
Builds a metric_name → score mapping from each result's metrics,
then delegates to JudgmentProcessor for condition evaluation.
Results with errors are skipped.
"""
judged_results: list[EvaluationItemResult] = []
for result in results:
if result.error is not None or not result.metrics:
judged_results.append(result)
continue
metric_values = {m.name: m.score for m in result.metrics}
judgment_result = JudgmentProcessor.evaluate(metric_values, judgment_config)
judged_results.append(
result.model_copy(update={"judgment": judgment_result})
)
return judged_results

View File

@@ -1,152 +0,0 @@
import logging
from typing import Any, Mapping, Union
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from models.model import App, AppMode
logger = logging.getLogger(__name__)
class LLMEvaluationRunner(BaseEvaluationRunner):
"""Runner for LLM evaluation: executes App to get responses, then evaluates."""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
super().__init__(evaluation_instance, session)
def execute_target(
self,
tenant_id: str,
target_id: str,
target_type: str,
item: EvaluationItemInput,
) -> EvaluationItemResult:
"""Execute the App/Snippet with the given inputs and collect the response."""
from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.evaluation.runners import get_service_account_for_app
from core.app.entities.app_invoke_entities import InvokeFrom
from services.workflow_service import WorkflowService
app = self.session.query(App).filter_by(id=target_id).first()
if not app:
raise ValueError(f"App {target_id} not found")
# Get a service account for invocation
service_account = get_service_account_for_app(self.session, target_id)
app_mode = AppMode.value_of(app.mode)
# Build args from evaluation item inputs
args: dict[str, Any] = {
"inputs": item.inputs,
}
# For completion/chat modes, first text input becomes query
if app_mode in (AppMode.COMPLETION, AppMode.CHAT):
query = self._extract_query(item.inputs)
args["query"] = query
if app_mode in (AppMode.WORKFLOW, AppMode.ADVANCED_CHAT):
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app)
if not workflow:
raise ValueError(f"No published workflow found for app {target_id}")
generator = WorkflowAppGenerator()
response: Mapping[str, Any] = generator.generate(
app_model=app,
workflow=workflow,
user=service_account,
args=args,
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
elif app_mode == AppMode.COMPLETION:
generator = CompletionAppGenerator()
response = generator.generate(
app_model=app,
user=service_account,
args=args,
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
else:
raise ValueError(f"Unsupported app mode for LLM evaluation: {app_mode}")
actual_output = self._extract_output(response)
return EvaluationItemResult(
index=item.index,
actual_output=actual_output,
)
def evaluate_metrics(
self,
items: list[EvaluationItemInput],
results: list[EvaluationItemResult],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Use the evaluation instance to compute LLM metrics."""
# Merge actual_output into items for evaluation
merged_items = self._merge_results_into_items(items, results)
return self.evaluation_instance.evaluate_llm(
merged_items, metrics_config, model_provider, model_name, tenant_id
)
@staticmethod
def _extract_query(inputs: dict[str, Any]) -> str:
"""Extract query from inputs."""
for key in ("query", "question", "input", "text"):
if key in inputs:
return str(inputs[key])
values = list(inputs.values())
return str(values[0]) if values else ""
@staticmethod
def _extract_output(response: Union[Mapping[str, Any], Any]) -> str:
"""Extract text output from app response."""
if isinstance(response, Mapping):
# Workflow response
if "data" in response and isinstance(response["data"], Mapping):
outputs = response["data"].get("outputs", {})
if isinstance(outputs, Mapping):
values = list(outputs.values())
return str(values[0]) if values else ""
return str(outputs)
# Completion response
if "answer" in response:
return str(response["answer"])
if "text" in response:
return str(response["text"])
return str(response)
@staticmethod
def _merge_results_into_items(
items: list[EvaluationItemInput],
results: list[EvaluationItemResult],
) -> list[EvaluationItemInput]:
"""Create new items with actual_output set as expected_output context for metrics."""
result_by_index = {r.index: r for r in results}
merged = []
for item in items:
result = result_by_index.get(item.index)
if result and result.actual_output:
merged.append(
EvaluationItemInput(
index=item.index,
inputs=item.inputs,
expected_output=item.expected_output,
context=[result.actual_output] + (item.context or []),
)
)
else:
merged.append(item)
return merged

View File

@@ -1,111 +0,0 @@
import logging
from typing import Any
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
logger = logging.getLogger(__name__)
class RetrievalEvaluationRunner(BaseEvaluationRunner):
"""Runner for retrieval evaluation: performs knowledge base retrieval, then evaluates."""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
super().__init__(evaluation_instance, session)
def execute_target(
self,
tenant_id: str,
target_id: str,
target_type: str,
item: EvaluationItemInput,
) -> EvaluationItemResult:
"""Execute retrieval using DatasetRetrieval and collect context documents."""
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
query = self._extract_query(item.inputs)
dataset_retrieval = DatasetRetrieval()
# Use knowledge_retrieval for structured results
try:
from core.rag.retrieval.dataset_retrieval import KnowledgeRetrievalRequest
request = KnowledgeRetrievalRequest(
query=query,
app_id=target_id,
tenant_id=tenant_id,
)
sources = dataset_retrieval.knowledge_retrieval(request)
retrieved_contexts = [source.content for source in sources if source.content]
except (ImportError, AttributeError):
logger.warning("KnowledgeRetrievalRequest not available, using simple retrieval")
retrieved_contexts = []
return EvaluationItemResult(
index=item.index,
actual_output="\n\n".join(retrieved_contexts) if retrieved_contexts else "",
metadata={"retrieved_contexts": retrieved_contexts},
)
def evaluate_metrics(
self,
items: list[EvaluationItemInput],
results: list[EvaluationItemResult],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute retrieval evaluation metrics."""
# Merge retrieved contexts into items
result_by_index = {r.index: r for r in results}
merged_items = []
for item in items:
result = result_by_index.get(item.index)
contexts = result.metadata.get("retrieved_contexts", []) if result else []
merged_items.append(
EvaluationItemInput(
index=item.index,
inputs=item.inputs,
expected_output=item.expected_output,
context=contexts,
)
)
evaluated = self.evaluation_instance.evaluate_retrieval(
merged_items, metrics_config, model_provider, model_name, tenant_id
)
# Merge metrics back into original results (preserve actual_output and metadata)
eval_by_index = {r.index: r for r in evaluated}
final_results = []
for result in results:
if result.index in eval_by_index:
eval_result = eval_by_index[result.index]
final_results.append(
EvaluationItemResult(
index=result.index,
actual_output=result.actual_output,
metrics=eval_result.metrics,
metadata=result.metadata,
error=result.error,
)
)
else:
final_results.append(result)
return final_results
@staticmethod
def _extract_query(inputs: dict[str, Any]) -> str:
for key in ("query", "question", "input", "text"):
if key in inputs:
return str(inputs[key])
values = list(inputs.values())
return str(values[0]) if values else ""

View File

@@ -1,133 +0,0 @@
import logging
from typing import Any, Mapping
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from models.model import App
logger = logging.getLogger(__name__)
class WorkflowEvaluationRunner(BaseEvaluationRunner):
"""Runner for workflow evaluation: executes workflow App in non-streaming mode."""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
super().__init__(evaluation_instance, session)
def execute_target(
self,
tenant_id: str,
target_id: str,
target_type: str,
item: EvaluationItemInput,
) -> EvaluationItemResult:
"""Execute workflow and collect outputs."""
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.evaluation.runners import get_service_account_for_app
from core.app.entities.app_invoke_entities import InvokeFrom
from services.workflow_service import WorkflowService
app = self.session.query(App).filter_by(id=target_id).first()
if not app:
raise ValueError(f"App {target_id} not found")
service_account = get_service_account_for_app(self.session, target_id)
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app)
if not workflow:
raise ValueError(f"No published workflow found for app {target_id}")
args: dict[str, Any] = {"inputs": item.inputs}
generator = WorkflowAppGenerator()
response: Mapping[str, Any] = generator.generate(
app_model=app,
workflow=workflow,
user=service_account,
args=args,
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
actual_output = self._extract_output(response)
node_executions = self._extract_node_executions(response)
return EvaluationItemResult(
index=item.index,
actual_output=actual_output,
metadata={"node_executions": node_executions},
)
def evaluate_metrics(
self,
items: list[EvaluationItemInput],
results: list[EvaluationItemResult],
metrics_config: dict,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute workflow evaluation metrics (end-to-end)."""
result_by_index = {r.index: r for r in results}
merged_items = []
for item in items:
result = result_by_index.get(item.index)
context = []
if result and result.actual_output:
context.append(result.actual_output)
merged_items.append(
EvaluationItemInput(
index=item.index,
inputs=item.inputs,
expected_output=item.expected_output,
context=context + (item.context or []),
)
)
evaluated = self.evaluation_instance.evaluate_workflow(
merged_items, metrics_config, model_provider, model_name, tenant_id
)
# Merge metrics back preserving metadata
eval_by_index = {r.index: r for r in evaluated}
final_results = []
for result in results:
if result.index in eval_by_index:
eval_result = eval_by_index[result.index]
final_results.append(
EvaluationItemResult(
index=result.index,
actual_output=result.actual_output,
metrics=eval_result.metrics,
metadata=result.metadata,
error=result.error,
)
)
else:
final_results.append(result)
return final_results
@staticmethod
def _extract_output(response: Mapping[str, Any]) -> str:
"""Extract text output from workflow response."""
if "data" in response and isinstance(response["data"], Mapping):
outputs = response["data"].get("outputs", {})
if isinstance(outputs, Mapping):
values = list(outputs.values())
return str(values[0]) if values else ""
return str(outputs)
return str(response)
@staticmethod
def _extract_node_executions(response: Mapping[str, Any]) -> list[dict]:
"""Extract node execution trace from workflow response."""
data = response.get("data", {})
if isinstance(data, Mapping):
return data.get("node_executions", [])
return []

View File

@@ -27,10 +27,10 @@ from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from extensions.ext_storage import storage

View File

@@ -4,10 +4,10 @@ from collections.abc import Mapping
from typing import Any, cast
from configs import dify_config
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types as mcp_types
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService

View File

@@ -14,7 +14,7 @@ from core.model_runtime.entities import (
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
from core.workflow.file import file_manager
from extensions.ext_database import db
from factories import file_factory

View File

@@ -1,3 +0,0 @@
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]

View File

@@ -1,18 +0,0 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Protocol
from core.model_runtime.entities import PromptMessage
DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages."""
def get_history_prompt_messages(
self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@@ -14,9 +14,13 @@ from core.model_runtime.entities import (
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
ChatModelMessage,
CompletionModelPromptTemplate,
MemoryConfig,
)
from core.model_runtime.prompt.prompt_transform import PromptTransform
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file import file_manager
from core.workflow.file.models import File
from core.workflow.runtime import VariablePool

View File

@@ -10,7 +10,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_transform import PromptTransform
from core.model_runtime.prompt.prompt_transform import PromptTransform
class AgentHistoryPromptTransform(PromptTransform):

View File

@@ -5,7 +5,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
class PromptTransform:

View File

@@ -15,9 +15,9 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.model_runtime.prompt.prompt_transform import PromptTransform
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file import file_manager
from models.model import AppMode

View File

@@ -1,6 +1,6 @@
from sqlalchemy import select
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Message

View File

@@ -10,7 +10,7 @@ from core.model_runtime.entities import (
PromptMessageRole,
TextPromptMessageContent,
)
from core.prompt.simple_prompt_transform import ModelMode
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
class PromptMessageUtil:

View File

@@ -2,7 +2,6 @@ import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator
from core.app.llm import deduct_llm_quota
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import (
@@ -30,6 +29,7 @@ from core.plugin.entities.request import (
)
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm import llm_utils
from models.account import Tenant
@@ -63,14 +63,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response:
if chunk.delta.usage:
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
@@ -124,14 +126,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
for chunk in response:
if chunk.delta.usage:
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(
response: LLMResultWithStructuredOutput,

View File

@@ -8,7 +8,6 @@ from typing import Any, cast
logger = logging.getLogger(__name__)
from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
@@ -36,6 +35,7 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
from core.workflow.nodes.llm import llm_utils
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs import helper
@@ -474,7 +474,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# Deduct quota for summary generation (same as workflow nodes)
try:
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
except Exception as e:
# Log but don't fail summary generation if quota deduction fails
logger.warning("Failed to deduct quota for summary generation: %s", str(e))

View File

@@ -29,12 +29,12 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService

View File

@@ -2,14 +2,14 @@ from collections.abc import Generator, Sequence
from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm import llm_utils
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
return text, usage

View File

@@ -6,9 +6,9 @@ identity:
zh_Hans: 网页抓取
pt_BR: WebScraper
description:
en_US: Web Scraper tool kit is used to scrape web
en_US: Web Scrapper tool kit is used to scrape web
zh_Hans: 一个用于抓取网页的工具。
pt_BR: Web Scraper tool kit is used to scrape web
pt_BR: Web Scrapper tool kit is used to scrape web
icon: icon.svg
tags:
- productivity

View File

@@ -1,11 +1,11 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import OutputVariableEntity
from core.workflow.variables.input_entities import VariableEntity
class WorkflowToolConfigurationUtils:

View File

@@ -5,6 +5,7 @@ from collections.abc import Mapping
from pydantic import Field
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.db.session_factory import session_factory
from core.plugin.entities.parameters import PluginParameterOption
@@ -22,7 +23,6 @@ from core.tools.entities.tool_entities import (
)
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode

View File

@@ -3,7 +3,7 @@ from typing import Any, Literal, Union
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from core.workflow.nodes.base.entities import BaseNodeData

View File

@@ -588,7 +588,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@@ -643,6 +642,5 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@@ -4,7 +4,11 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
ChatModelMessage,
CompletionModelPromptTemplate,
MemoryConfig,
)
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector

View File

@@ -1,19 +1,26 @@
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import PromptMessageRole
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.enums import SystemVariableKey
from core.workflow.file.models import File
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from .exc import InvalidVariableTypeError
@@ -41,51 +48,88 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
continue
used_quota = 1
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def fetch_memory_text(
*,
memory: PromptMessageMemory,
max_token_limit: int,
message_limit: int | None = None,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
) -> str:
history_messages = memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
return convert_history_messages_to_text(
history_messages=history_messages,
human_prefix=human_prefix,
ai_prefix=ai_prefix,
)
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@@ -37,10 +37,9 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
@@ -63,7 +62,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory
from core.workflow.runtime import VariablePool
from core.workflow.variables import (
ArrayFileSegment,
@@ -279,6 +278,8 @@ class LLMNode(Node[LLMNodeData]):
else None
)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
elif isinstance(event, LLMStructuredOutput):
structured_output = event
@@ -1233,10 +1234,6 @@ class LLMNode(Node[LLMNodeData]):
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
@@ -1339,16 +1336,48 @@ def _handle_memory_completion_mode(
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = llm_utils.fetch_memory_text(
memory=memory,
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
memory_text = _convert_history_messages_to_text(
history_messages=memory_messages,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Protocol
from core.model_manager import ModelInstance
from core.model_runtime.entities import PromptMessage
class CredentialsProvider(Protocol):
@@ -19,3 +21,13 @@ class ModelFactory(Protocol):
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
...
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages for LLM nodes."""
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@@ -413,7 +413,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@@ -455,6 +454,5 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@@ -7,7 +7,7 @@ from pydantic import (
field_validator,
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
from core.workflow.variables.types import SegmentType

View File

@@ -5,6 +5,7 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage
@@ -17,18 +18,13 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.file import File
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
@@ -101,7 +97,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
_model_instance: ModelInstance
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_memory: PromptMessageMemory | None
def __init__(
self,
@@ -113,7 +108,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
) -> None:
super().__init__(
id=id,
@@ -124,7 +118,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -170,7 +163,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
memory = self._memory
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
if (
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
@@ -309,6 +308,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, tool_call
def _generate_function_call_prompt(
@@ -317,7 +319,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: PromptMessageMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
@@ -405,7 +407,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: PromptMessageMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -443,7 +445,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: PromptMessageMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -468,8 +470,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
files=files,
context="",
memory_config=node_data.memory,
# AdvancedPromptTransform is still typed against TokenBufferMemory.
memory=cast(Any, memory),
memory=memory,
model_instance=model_instance,
image_detail_config=vision_detail,
)
@@ -482,7 +483,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: PromptMessageMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -714,7 +715,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: PromptMessageMemory | None,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
@@ -723,8 +724,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
@@ -741,7 +742,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: PromptMessageMemory | None,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@@ -750,8 +751,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
@@ -827,10 +828,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return rest_tokens
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -1,6 +1,6 @@
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig

View File

@@ -3,12 +3,12 @@ import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
NodeExecutionType,
@@ -56,7 +56,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
def __init__(
self,
@@ -68,7 +67,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@@ -83,7 +81,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@@ -106,7 +103,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variables = {"query": query}
# fetch model instance
model_instance = self._model_instance
memory = self._memory
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
# fetch instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
@@ -237,10 +240,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
llm_usage=usage,
)
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@@ -324,7 +323,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: PromptMessageMemory | None,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@@ -337,8 +336,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
input_text = query
memory_str = ""
if memory:
memory_str = llm_utils.fetch_memory_text(
memory=memory,
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)

View File

@@ -2,8 +2,8 @@ from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity
from core.workflow.nodes.base import BaseNodeData
from core.workflow.variables.input_entities import VariableEntity
class StartNodeData(BaseNodeData):

View File

@@ -2,12 +2,12 @@ from typing import Any
from jsonschema import Draft7Validator, ValidationError
from core.app.app_config.entities import VariableEntityType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.variables.input_entities import VariableEntityType
class StartNode(Node[StartNodeData]):

View File

@@ -1,4 +1,3 @@
from .input_entities import VariableEntity, VariableEntityType
from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
@@ -65,6 +64,4 @@ __all__ = [
"StringVariable",
"Variable",
"VariableBase",
"VariableEntity",
"VariableEntityType",
]

View File

@@ -1,62 +0,0 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Any
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from core.workflow.file import FileTransferMethod, FileType
class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel):
"""
Shared variable entity used by workflow runtime and app configuration.
"""
# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""
type: VariableEntityType
required: bool = False
hide: bool = False
default: Any = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict[str, Any] | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
def convert_none_description(cls, value: Any) -> str:
return value or ""
@field_validator("options", mode="before")
@classmethod
def convert_none_options(cls, value: Any) -> Sequence[str]:
return value or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as error:
raise ValueError(f"Invalid JSON schema: {error.message}")
return schema

View File

@@ -6,7 +6,6 @@ from typing import Any, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
@@ -107,7 +106,6 @@ class WorkflowEntry:
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
self.graph_engine.layer(limits_layer)
self.graph_engine.layer(LLMQuotaLayer())
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():

View File

@@ -1,45 +0,0 @@
from flask_restx import fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
# Snippet list item fields (lightweight for list display)
snippet_list_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"created_at": TimestampField,
"updated_at": TimestampField,
}
# Full snippet fields (includes creator info and graph data)
snippet_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"graph": fields.Raw(attribute="graph_dict"),
"input_fields": fields.Raw(attribute="input_fields_list"),
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
"updated_at": TimestampField,
}
# Pagination response fields
snippet_pagination_fields = {
"data": fields.List(fields.Nested(snippet_list_fields)),
"page": fields.Integer,
"limit": fields.Integer,
"total": fields.Integer,
"has_more": fields.Boolean,
}

View File

@@ -1,48 +0,0 @@
"""Helpers for producing concise pyrefly diagnostics for CI diff output."""
from __future__ import annotations
import sys
_DIAGNOSTIC_PREFIXES = ("ERROR ", "WARNING ")
_LOCATION_PREFIX = "-->"
def extract_diagnostics(raw_output: str) -> str:
"""Extract stable diagnostic lines from pyrefly output.
The full pyrefly output includes code excerpts and carets, which create noisy
diffs. This helper keeps only:
- diagnostic headline lines (``ERROR ...`` / ``WARNING ...``)
- the following location line (``--> path:line:column``), when present
"""
lines = raw_output.splitlines()
diagnostics: list[str] = []
for index, line in enumerate(lines):
if line.startswith(_DIAGNOSTIC_PREFIXES):
diagnostics.append(line.rstrip())
next_index = index + 1
if next_index < len(lines):
next_line = lines[next_index]
if next_line.lstrip().startswith(_LOCATION_PREFIX):
diagnostics.append(next_line.rstrip())
if not diagnostics:
return ""
return "\n".join(diagnostics) + "\n"
def main() -> int:
"""Read pyrefly output from stdin and print normalized diagnostics."""
raw_output = sys.stdin.read()
sys.stdout.write(extract_diagnostics(raw_output))
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -1,83 +0,0 @@
"""add_customized_snippets_table
Revision ID: 1c05e80d2380
Revises: 788d3099ae3a
Create Date: 2026-01-29 12:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models as models
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = "1c05e80d2380"
down_revision = "788d3099ae3a"
branch_labels = None
depends_on = None
def upgrade():
conn = op.get_bind()
if _is_pg(conn):
op.create_table(
"customized_snippets",
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("icon_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("graph", sa.Text(), nullable=True),
sa.Column("input_fields", sa.Text(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
else:
op.create_table(
"customized_snippets",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", models.types.LongText(), nullable=True),
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("icon_info", models.types.AdjustedJSON(astext_type=sa.Text()), nullable=True),
sa.Column("graph", models.types.LongText(), nullable=True),
sa.Column("input_fields", models.types.LongText(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
batch_op.create_index("customized_snippet_tenant_idx", ["tenant_id"], unique=False)
def downgrade():
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
batch_op.drop_index("customized_snippet_tenant_idx")
op.drop_table("customized_snippets")

View File

@@ -1,113 +0,0 @@
"""add_evaluation_tables
Revision ID: a1b2c3d4e5f6
Revises: 1c05e80d2380
Create Date: 2026-03-03 00:01:00.000000
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "a1b2c3d4e5f6"
down_revision = "1c05e80d2380"
branch_labels = None
depends_on = None
def upgrade():
# evaluation_configurations
op.create_table(
"evaluation_configurations",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("target_type", sa.String(length=20), nullable=False),
sa.Column("target_id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_model_provider", sa.String(length=255), nullable=True),
sa.Column("evaluation_model", sa.String(length=255), nullable=True),
sa.Column("metrics_config", models.types.LongText(), nullable=True),
sa.Column("judgement_conditions", models.types.LongText(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"),
sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"),
)
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
batch_op.create_index(
"evaluation_configuration_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
)
# evaluation_runs
op.create_table(
"evaluation_runs",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("target_type", sa.String(length=20), nullable=False),
sa.Column("target_id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_config_id", models.types.StringUUID(), nullable=False),
sa.Column("status", sa.String(length=20), nullable=False, server_default=sa.text("'pending'")),
sa.Column("dataset_file_id", models.types.StringUUID(), nullable=True),
sa.Column("result_file_id", models.types.StringUUID(), nullable=True),
sa.Column("total_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("completed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("failed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("metrics_summary", models.types.LongText(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("celery_task_id", sa.String(length=255), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=False),
sa.Column("started_at", sa.DateTime(), nullable=True),
sa.Column("completed_at", sa.DateTime(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"),
)
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
batch_op.create_index(
"evaluation_run_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
)
batch_op.create_index("evaluation_run_status_idx", ["tenant_id", "status"], unique=False)
# evaluation_run_items
op.create_table(
"evaluation_run_items",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_run_id", models.types.StringUUID(), nullable=False),
sa.Column("item_index", sa.Integer(), nullable=False),
sa.Column("inputs", models.types.LongText(), nullable=True),
sa.Column("expected_output", models.types.LongText(), nullable=True),
sa.Column("context", models.types.LongText(), nullable=True),
sa.Column("actual_output", models.types.LongText(), nullable=True),
sa.Column("metrics", models.types.LongText(), nullable=True),
sa.Column("metadata_json", models.types.LongText(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("overall_score", sa.Float(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"),
)
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
batch_op.create_index("evaluation_run_item_run_idx", ["evaluation_run_id"], unique=False)
batch_op.create_index(
"evaluation_run_item_index_idx", ["evaluation_run_id", "item_index"], unique=False
)
def downgrade():
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
batch_op.drop_index("evaluation_run_item_index_idx")
batch_op.drop_index("evaluation_run_item_run_idx")
op.drop_table("evaluation_run_items")
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
batch_op.drop_index("evaluation_run_status_idx")
batch_op.drop_index("evaluation_run_target_idx")
op.drop_table("evaluation_runs")
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
batch_op.drop_index("evaluation_configuration_target_idx")
op.drop_table("evaluation_configurations")

View File

@@ -26,13 +26,6 @@ from .dataset import (
TidbAuthBinding,
Whitelist,
)
from .evaluation import (
EvaluationConfiguration,
EvaluationRun,
EvaluationRunItem,
EvaluationRunStatus,
EvaluationTargetType,
)
from .enums import (
AppTriggerStatus,
AppTriggerType,
@@ -88,7 +81,6 @@ from .provider import (
TenantDefaultModel,
TenantPreferredModelProvider,
)
from .snippet import CustomizedSnippet, SnippetType
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from .task import CeleryTask, CeleryTaskSet
from .tools import (
@@ -148,7 +140,6 @@ __all__ = [
"Conversation",
"ConversationVariable",
"CreatorUserRole",
"CustomizedSnippet",
"DataSourceApiKeyAuthBinding",
"DataSourceOauthBinding",
"Dataset",
@@ -165,11 +156,6 @@ __all__ = [
"Document",
"DocumentSegment",
"Embedding",
"EvaluationConfiguration",
"EvaluationRun",
"EvaluationRunItem",
"EvaluationRunStatus",
"EvaluationTargetType",
"EndUser",
"ExecutionExtraContent",
"ExporleBanner",
@@ -198,7 +184,6 @@ __all__ = [
"RecommendedApp",
"SavedMessage",
"Site",
"SnippetType",
"Tag",
"TagBinding",
"Tenant",

View File

@@ -1,194 +0,0 @@
from __future__ import annotations
import json
from datetime import datetime
from enum import StrEnum
from typing import Any
import sqlalchemy as sa
from sqlalchemy import DateTime, Float, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .base import Base
from .types import LongText, StringUUID
class EvaluationRunStatus(StrEnum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class EvaluationTargetType(StrEnum):
APP = "app"
SNIPPETS = "snippets"
class EvaluationConfiguration(Base):
"""Stores evaluation configuration for each target (App or Snippet)."""
__tablename__ = "evaluation_configurations"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"),
sa.Index("evaluation_configuration_target_idx", "tenant_id", "target_type", "target_id"),
sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
target_type: Mapped[str] = mapped_column(String(20), nullable=False)
target_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
evaluation_model_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
evaluation_model: Mapped[str | None] = mapped_column(String(255), nullable=True)
metrics_config: Mapped[str | None] = mapped_column(LongText, nullable=True)
judgement_conditions: Mapped[str | None] = mapped_column(LongText, nullable=True)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
updated_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def metrics_config_dict(self) -> dict[str, Any]:
if self.metrics_config:
return json.loads(self.metrics_config)
return {}
@metrics_config_dict.setter
def metrics_config_dict(self, value: dict[str, Any]) -> None:
self.metrics_config = json.dumps(value)
@property
def judgement_conditions_dict(self) -> dict[str, Any]:
if self.judgement_conditions:
return json.loads(self.judgement_conditions)
return {}
@judgement_conditions_dict.setter
def judgement_conditions_dict(self, value: dict[str, Any]) -> None:
self.judgement_conditions = json.dumps(value)
def __repr__(self) -> str:
return f"<EvaluationConfiguration(id={self.id}, target={self.target_type}:{self.target_id})>"
class EvaluationRun(Base):
"""Stores each evaluation run record."""
__tablename__ = "evaluation_runs"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"),
sa.Index("evaluation_run_target_idx", "tenant_id", "target_type", "target_id"),
sa.Index("evaluation_run_status_idx", "tenant_id", "status"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
target_type: Mapped[str] = mapped_column(String(20), nullable=False)
target_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
evaluation_config_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
status: Mapped[str] = mapped_column(
String(20), nullable=False, default=EvaluationRunStatus.PENDING
)
dataset_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
result_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
total_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
completed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
failed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
metrics_summary: Mapped[str | None] = mapped_column(LongText, nullable=True)
error: Mapped[str | None] = mapped_column(Text, nullable=True)
celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def metrics_summary_dict(self) -> dict[str, Any]:
if self.metrics_summary:
return json.loads(self.metrics_summary)
return {}
@metrics_summary_dict.setter
def metrics_summary_dict(self, value: dict[str, Any]) -> None:
self.metrics_summary = json.dumps(value)
@property
def progress(self) -> float:
if self.total_items == 0:
return 0.0
return (self.completed_items + self.failed_items) / self.total_items
def __repr__(self) -> str:
return f"<EvaluationRun(id={self.id}, status={self.status})>"
class EvaluationRunItem(Base):
"""Stores per-row evaluation results."""
__tablename__ = "evaluation_run_items"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"),
sa.Index("evaluation_run_item_run_idx", "evaluation_run_id"),
sa.Index("evaluation_run_item_index_idx", "evaluation_run_id", "item_index"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
evaluation_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
item_index: Mapped[int] = mapped_column(Integer, nullable=False)
inputs: Mapped[str | None] = mapped_column(LongText, nullable=True)
expected_output: Mapped[str | None] = mapped_column(LongText, nullable=True)
context: Mapped[str | None] = mapped_column(LongText, nullable=True)
actual_output: Mapped[str | None] = mapped_column(LongText, nullable=True)
metrics: Mapped[str | None] = mapped_column(LongText, nullable=True)
judgment: Mapped[str | None] = mapped_column(LongText, nullable=True)
metadata_json: Mapped[str | None] = mapped_column(LongText, nullable=True)
error: Mapped[str | None] = mapped_column(Text, nullable=True)
overall_score: Mapped[float | None] = mapped_column(Float, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp()
)
@property
def inputs_dict(self) -> dict[str, Any]:
if self.inputs:
return json.loads(self.inputs)
return {}
@property
def metrics_list(self) -> list[dict[str, Any]]:
if self.metrics:
return json.loads(self.metrics)
return []
@property
def metadata_dict(self) -> dict[str, Any]:
if self.metadata_json:
return json.loads(self.metadata_json)
return {}
def __repr__(self) -> str:
return f"<EvaluationRunItem(id={self.id}, run={self.evaluation_run_id}, index={self.item_index})>"

View File

@@ -1,101 +0,0 @@
import json
from datetime import datetime
from enum import StrEnum
from typing import Any
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .account import Account
from .base import Base
from .engine import db
from .types import AdjustedJSON, LongText, StringUUID
class SnippetType(StrEnum):
"""Snippet Type Enum"""
NODE = "node"
GROUP = "group"
class CustomizedSnippet(Base):
"""
Customized Snippet Model
Stores reusable workflow components (nodes or node groups) that can be
shared across applications within a workspace.
"""
__tablename__ = "customized_snippets"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.Index("customized_snippet_tenant_idx", "tenant_id"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(LongText, nullable=True)
type: Mapped[str] = mapped_column(String(50), nullable=False, server_default=sa.text("'node'"))
# Workflow reference for published version
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
# State flags
is_published: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("1"))
use_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
# Visual customization
icon_info: Mapped[dict | None] = mapped_column(AdjustedJSON, nullable=True)
# Snippet configuration (stored as JSON text)
input_fields: Mapped[str | None] = mapped_column(LongText, nullable=True)
# Audit fields
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def graph_dict(self) -> dict[str, Any]:
"""Get graph from associated workflow."""
if self.workflow_id:
from .workflow import Workflow
workflow = db.session.get(Workflow, self.workflow_id)
if workflow:
return json.loads(workflow.graph) if workflow.graph else {}
return {}
@property
def input_fields_list(self) -> list[dict[str, Any]]:
"""Parse input_fields JSON to list."""
return json.loads(self.input_fields) if self.input_fields else []
@property
def created_by_account(self) -> Account | None:
"""Get the account that created this snippet."""
if self.created_by:
return db.session.get(Account, self.created_by)
return None
@property
def updated_by_account(self) -> Account | None:
"""Get the account that last updated this snippet."""
if self.updated_by:
return db.session.get(Account, self.updated_by)
return None
@property
def version_str(self) -> str:
"""Get version as string for API response."""
return str(self.version)

Some files were not shown because too many files have changed in this diff Show More