Compare commits

...

1 Commits

Author SHA1 Message Date
Asuka Minato
7828508b30 refactor: remove all reqparser (#29289)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-02-01 13:43:14 +09:00
10 changed files with 434 additions and 412 deletions

View File

@@ -53,6 +53,7 @@ select = [
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage,
"TID", # flake8-tidy-imports
]
@@ -88,6 +89,7 @@ ignore = [
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"TID252", # allow relative imports from parent modules
]
[lint.per-file-ignores]
@@ -109,10 +111,20 @@ ignore = [
"S110", # allow ignoring exceptions in tests code (currently)
]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.pyflakes]
allowed-unused-imports = [
"_pytest.monkeypatch",
"tests.integration_tests",
"tests.unit_tests",
]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
msg = "Use Pydantic payload/query models instead of reqparse."

View File

@@ -1,10 +1,9 @@
import json
import logging
from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -38,7 +37,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import variable_factory
from libs import helper
from libs.helper import TimestampField
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
@@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel):
class WorkflowRunQuery(BaseModel):
last_id: UUID | None = None
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
@@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel):
start_node_title: str
class RagPipelineRecommendedPluginQuery(BaseModel):
type: str = "all"
register_schema_models(
console_ns,
DraftWorkflowSyncPayload,
@@ -135,6 +138,7 @@ register_schema_models(
NodeIdQuery,
WorkflowRunQuery,
DatasourceVariablesPayload,
RagPipelineRecommendedPluginQuery,
)
@@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
return recommended_plugins

File diff suppressed because it is too large Load Diff

View File

@@ -30,6 +30,7 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
@@ -52,7 +53,7 @@ class ChatRequestPayload(BaseModel):
query: str
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID")
conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID")
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")

View File

@@ -1,5 +1,4 @@
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@@ -23,12 +22,13 @@ from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model,
)
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
class ConversationListQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort order for conversations"
@@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel):
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
variable_name: str | None = Field(
default=None, description="Filter variables by name", min_length=1, max_length=255

View File

@@ -1,6 +1,5 @@
import logging
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@@ -15,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")

View File

@@ -1,7 +1,10 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.common.schema import register_schema_model
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
register_schema_model(service_api_ns, HitTestingPayload)
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
@@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
404: "Dataset not found",
}
)
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Perform hit testing on a dataset.

View File

@@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""

View File

@@ -1,8 +1,6 @@
import json
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
@@ -10,8 +8,8 @@ from sqlalchemy.orm import Session
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
@@ -38,12 +36,10 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",
labels: list[str] | None = None,
):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
@@ -75,7 +71,7 @@ class WorkflowToolManageService:
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
parameter_configuration=json.dumps([p.model_dump() for p in parameters]),
privacy_policy=privacy_policy,
version=workflow.version,
)
@@ -104,7 +100,7 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",
labels: list[str] | None = None,
):
@@ -122,8 +118,6 @@ class WorkflowToolManageService:
:param labels: labels
:return: the updated tool
"""
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
@@ -162,7 +156,7 @@ class WorkflowToolManageService:
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now()

View File

@@ -3,7 +3,9 @@ from unittest.mock import patch
import pytest
from faker import Faker
from pydantic import ValidationError
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from models.tools import WorkflowToolProvider
from models.workflow import Workflow as WorkflowModel
from services.account_service import AccountService, TenantService
@@ -130,20 +132,24 @@ class TestWorkflowToolManageService:
def _create_test_workflow_tool_parameters(self):
"""Helper method to create valid workflow tool parameters."""
return [
{
"name": "input_text",
"description": "Input text for processing",
"form": "form",
"type": "string",
"required": True,
},
{
"name": "output_format",
"description": "Output format specification",
"form": "form",
"type": "select",
"required": False,
},
WorkflowToolParameterConfiguration.model_validate(
{
"name": "input_text",
"description": "Input text for processing",
"form": "form",
"type": "string",
"required": True,
}
),
WorkflowToolParameterConfiguration.model_validate(
{
"name": "output_format",
"description": "Output format specification",
"form": "form",
"type": "select",
"required": False,
}
),
]
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
@@ -208,7 +214,7 @@ class TestWorkflowToolManageService:
assert created_tool_provider.label == tool_label
assert created_tool_provider.icon == json.dumps(tool_icon)
assert created_tool_provider.description == tool_description
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters)
assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters])
assert created_tool_provider.privacy_policy == tool_privacy_policy
assert created_tool_provider.version == workflow.version
assert created_tool_provider.user_id == account.id
@@ -353,18 +359,9 @@ class TestWorkflowToolManageService:
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Setup invalid workflow tool parameters (missing required fields)
invalid_parameters = [
{
"name": "input_text",
# Missing description and form fields
"type": "string",
"required": True,
}
]
# Attempt to create workflow tool with invalid parameters
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValidationError) as exc_info:
# Setup invalid workflow tool parameters (missing required fields)
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
@@ -373,7 +370,16 @@ class TestWorkflowToolManageService:
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=invalid_parameters,
parameters=[
WorkflowToolParameterConfiguration.model_validate(
{
"name": "input_text",
# Missing description and form fields
"type": "string",
"required": True,
}
)
],
)
# Verify error message contains validation error
@@ -579,11 +585,12 @@ class TestWorkflowToolManageService:
# Verify database state was updated
db.session.refresh(created_tool)
assert created_tool is not None
assert created_tool.name == updated_tool_name
assert created_tool.label == updated_tool_label
assert created_tool.icon == json.dumps(updated_tool_icon)
assert created_tool.description == updated_tool_description
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters)
assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters])
assert created_tool.privacy_policy == updated_tool_privacy_policy
assert created_tool.version == workflow.version
assert created_tool.updated_at is not None
@@ -750,13 +757,15 @@ class TestWorkflowToolManageService:
# Setup workflow tool parameters with FILE type
file_parameters = [
{
"name": "document",
"description": "Upload a document",
"form": "form",
"type": "file",
"required": False,
}
WorkflowToolParameterConfiguration.model_validate(
{
"name": "document",
"description": "Upload a document",
"form": "form",
"type": "file",
"required": False,
}
)
]
# Execute the method under test
@@ -823,13 +832,15 @@ class TestWorkflowToolManageService:
# Setup workflow tool parameters with FILES type
files_parameters = [
{
"name": "documents",
"description": "Upload multiple documents",
"form": "form",
"type": "files",
"required": False,
}
WorkflowToolParameterConfiguration.model_validate(
{
"name": "documents",
"description": "Upload multiple documents",
"form": "form",
"type": "files",
"required": False,
}
)
]
# Execute the method under test