refactor: port reqparse to BaseModel (#28993)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-12-08 15:31:19 +09:00
committed by GitHub
parent 2f96374837
commit 05fe92a541
44 changed files with 1531 additions and 1894 deletions

View File

@@ -1,10 +1,12 @@
from typing import Any, Literal, cast
from flask import request
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
@@ -18,173 +20,83 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user
from libs.validators import validate_description_length
from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", description="Dataset description (max 400 chars)", max_length=400)
indexing_technique: Literal["high_quality", "economy"] | None = None
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
external_knowledge_api_id: str | None = None
provider: str = "vendor"
external_knowledge_id: str | None = None
retrieval_model: RetrievalModel | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
# Define parsers for dataset operations
dataset_create_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
required=False,
nullable=False,
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
default="_validate_name",
)
.add_argument(
"provider",
type=str,
nullable=True,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
class DatasetUpdatePayload(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=40)
description: str | None = Field(default=None, description="Dataset description (max 400 chars)", max_length=400)
indexing_technique: Literal["high_quality", "economy"] | None = None
permission: DatasetPermissionEnum | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: RetrievalModel | None = None
partial_member_list: list[str] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
dataset_update_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.")
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
)
tag_create_parser = reqparse.RequestParser().add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=lambda x: x
if x and 1 <= len(x) <= 50
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
)
class TagNamePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=50)
tag_update_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=lambda x: x
if x and 1 <= len(x) <= 50
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
)
.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
)
tag_delete_parser = reqparse.RequestParser().add_argument(
"tag_id", nullable=False, required=True, help="Id of a tag.", type=str
)
class TagCreatePayload(TagNamePayload):
pass
tag_binding_parser = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
)
)
tag_unbinding_parser = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
class TagUpdatePayload(TagNamePayload):
tag_id: str
class TagDeletePayload(BaseModel):
tag_id: str
class TagBindingPayload(BaseModel):
tag_ids: list[str]
target_id: str
@field_validator("tag_ids")
@classmethod
def validate_tag_ids(cls, value: list[str]) -> list[str]:
if not value:
raise ValueError("Tag IDs is required.")
return value
class TagUnbindingPayload(BaseModel):
tag_id: str
target_id: str
register_schema_models(
service_api_ns,
DatasetCreatePayload,
DatasetUpdatePayload,
TagCreatePayload,
TagUpdatePayload,
TagDeletePayload,
TagBindingPayload,
TagUnbindingPayload,
)
@@ -239,7 +151,7 @@ class DatasetListApi(DatasetApiResource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200
@service_api_ns.expect(dataset_create_parser)
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@service_api_ns.doc("create_dataset")
@service_api_ns.doc(description="Create a new dataset")
@service_api_ns.doc(
@@ -252,42 +164,41 @@ class DatasetListApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id):
"""Resource for creating datasets."""
args = dataset_create_parser.parse_args()
payload = DatasetCreatePayload.model_validate(service_api_ns.payload or {})
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
try:
assert isinstance(current_user, Account)
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
name=payload.name,
description=payload.description,
indexing_technique=payload.indexing_technique,
account=current_user,
permission=args["permission"],
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
permission=str(payload.permission) if payload.permission else None,
provider=payload.provider,
external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=payload.external_knowledge_id,
embedding_model_provider=payload.embedding_model_provider,
embedding_model_name=payload.embedding_model,
retrieval_model=payload.retrieval_model,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@@ -353,7 +264,7 @@ class DatasetApi(DatasetApiResource):
return data, 200
@service_api_ns.expect(dataset_update_parser)
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
@service_api_ns.doc("update_dataset")
@service_api_ns.doc(description="Update an existing dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@@ -372,36 +283,45 @@ class DatasetApi(DatasetApiResource):
if dataset is None:
raise NotFound("Dataset not found.")
args = dataset_update_parser.parse_args()
data = request.get_json()
payload_dict = service_api_ns.payload or {}
payload = DatasetUpdatePayload.model_validate(payload_dict)
update_data = payload.model_dump(exclude_unset=True)
if payload.permission is not None:
update_data["permission"] = str(payload.permission)
if payload.retrieval_model is not None:
update_data["retrieval_model"] = payload.retrieval_model.model_dump()
# check embedding model setting
embedding_model_provider = data.get("embedding_model_provider")
embedding_model = data.get("embedding_model")
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if payload.indexing_technique == "high_quality" or embedding_model_provider:
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model
)
retrieval_model = data.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
dataset.tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
current_user,
dataset,
str(payload.permission) if payload.permission else None,
payload.partial_member_list,
)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
@@ -410,15 +330,10 @@ class DatasetApi(DatasetApiResource):
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
if payload.partial_member_list and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@@ -556,7 +471,7 @@ class DatasetTagsApi(DatasetApiResource):
return tags, 200
@service_api_ns.expect(tag_create_parser)
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag")
@service_api_ns.doc(description="Add a knowledge type tag")
@service_api_ns.doc(
@@ -574,14 +489,13 @@ class DatasetTagsApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_create_parser.parse_args()
args["type"] = "knowledge"
tag = TagService.save_tags(args)
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
@service_api_ns.expect(tag_update_parser)
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_tag")
@service_api_ns.doc(description="Update a knowledge type tag")
@service_api_ns.doc(
@@ -598,10 +512,10 @@ class DatasetTagsApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_update_parser.parse_args()
args["type"] = "knowledge"
tag_id = args["tag_id"]
tag = TagService.update_tags(args, tag_id)
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
params = {"name": payload.name, "type": "knowledge"}
tag_id = payload.tag_id
tag = TagService.update_tags(params, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@@ -609,7 +523,7 @@ class DatasetTagsApi(DatasetApiResource):
return response, 200
@service_api_ns.expect(tag_delete_parser)
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
@service_api_ns.doc("delete_dataset_tag")
@service_api_ns.doc(description="Delete a knowledge type tag")
@service_api_ns.doc(
@@ -623,15 +537,15 @@ class DatasetTagsApi(DatasetApiResource):
@edit_permission_required
def delete(self, _, dataset_id):
"""Delete a knowledge type tag."""
args = tag_delete_parser.parse_args()
TagService.delete_tag(args["tag_id"])
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
return 204
@service_api_ns.route("/datasets/tags/binding")
class DatasetTagBindingApi(DatasetApiResource):
@service_api_ns.expect(tag_binding_parser)
@service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__])
@service_api_ns.doc("bind_dataset_tags")
@service_api_ns.doc(description="Bind tags to a dataset")
@service_api_ns.doc(
@@ -648,16 +562,15 @@ class DatasetTagBindingApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_binding_parser.parse_args()
args["type"] = "knowledge"
TagService.save_tag_binding(args)
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
return 204
@service_api_ns.route("/datasets/tags/unbinding")
class DatasetTagUnbindingApi(DatasetApiResource):
@service_api_ns.expect(tag_unbinding_parser)
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
@service_api_ns.doc("unbind_dataset_tag")
@service_api_ns.doc(description="Unbind a tag from a dataset")
@service_api_ns.doc(
@@ -674,9 +587,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_unbinding_parser.parse_args()
args["type"] = "knowledge"
TagService.delete_tag_binding(args)
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
return 204

View File

@@ -3,8 +3,8 @@ from typing import Self
from uuid import UUID
from flask import request
from flask_restx import marshal, reqparse
from pydantic import BaseModel, model_validator
from flask_restx import marshal
from pydantic import BaseModel, Field, model_validator
from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound
@@ -37,22 +37,19 @@ from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
# Define parsers for document operations
document_text_create_parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("text", type=str, required=True, nullable=False, location="json")
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
.add_argument("original_document_id", type=str, required=False, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
class DocumentTextCreatePayload(BaseModel):
name: str
text: str
process_rule: ProcessRule | None = None
original_document_id: str | None = None
doc_form: str = Field(default="text_model")
doc_language: str = Field(default="English")
indexing_technique: str | None = None
retrieval_model: RetrievalModel | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -72,7 +69,7 @@ class DocumentTextUpdate(BaseModel):
return self
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
@@ -83,7 +80,7 @@ for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents."""
@service_api_ns.expect(document_text_create_parser)
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
@service_api_ns.doc("create_document_by_text")
@service_api_ns.doc(description="Create a new document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@@ -99,7 +96,8 @@ class DocumentAddByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by text."""
args = document_text_create_parser.parse_args()
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
@@ -111,33 +109,29 @@ class DocumentAddByTextApi(DatasetApiResource):
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")
text = args.get("text")
name = args.get("name")
if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.")
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
@@ -174,7 +168,7 @@ class DocumentAddByTextApi(DatasetApiResource):
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents."""
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True)
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
@service_api_ns.doc("update_document_by_text")
@service_api_ns.doc(description="Update an existing document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@@ -189,22 +183,23 @@ class DocumentUpdateByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text."""
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True)
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
args = payload.model_dump(exclude_none=True)
if not dataset:
raise ValueError("Dataset does not exist.")
retrieval_model = args.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
# indexing_technique is already set in dataset since this is an update

View File

@@ -1,9 +1,11 @@
from typing import Literal
from flask_login import current_user
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields
@@ -14,25 +16,18 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.metadata_service import MetadataService
# Define parsers for metadata APIs
metadata_create_parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type")
.add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name")
)
metadata_update_parser = reqparse.RequestParser().add_argument(
"name", type=str, required=True, nullable=False, location="json", help="New metadata name"
)
class MetadataUpdatePayload(BaseModel):
name: str
document_metadata_parser = reqparse.RequestParser().add_argument(
"operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data"
)
register_schema_model(service_api_ns, MetadataUpdatePayload)
register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData)
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.expect(metadata_create_parser)
@service_api_ns.expect(service_api_ns.models[MetadataArgs.__name__])
@service_api_ns.doc("create_dataset_metadata")
@service_api_ns.doc(description="Create metadata for a dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@@ -46,8 +41,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset."""
args = metadata_create_parser.parse_args()
metadata_args = MetadataArgs.model_validate(args)
metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {})
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@@ -79,7 +73,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
class DatasetMetadataServiceApi(DatasetApiResource):
@service_api_ns.expect(metadata_update_parser)
@service_api_ns.expect(service_api_ns.models[MetadataUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_metadata")
@service_api_ns.doc(description="Update metadata name")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
@@ -93,7 +87,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, metadata_id):
"""Update metadata name."""
args = metadata_update_parser.parse_args()
payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {})
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@@ -102,7 +96,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
return marshal(metadata, dataset_metadata_fields), 200
@service_api_ns.doc("delete_dataset_metadata")
@@ -175,7 +169,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
class DocumentMetadataEditServiceApi(DatasetApiResource):
@service_api_ns.expect(document_metadata_parser)
@service_api_ns.expect(service_api_ns.models[MetadataOperationData.__name__])
@service_api_ns.doc("update_documents_metadata")
@service_api_ns.doc(description="Update metadata for multiple documents")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@@ -195,8 +189,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
args = document_metadata_parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args)
metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@@ -4,12 +4,12 @@ from collections.abc import Generator
from typing import Any
from flask import request
from flask_restx import reqparse
from flask_restx.reqparse import ParseResult, RequestParser
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.wraps import DatasetApiResource
@@ -22,11 +22,25 @@ from models.dataset import Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
from services.rag_pipeline.entity.pipeline_service_api_entities import (
DatasourceNodeRunApiEntity,
PipelineRunApiEntity,
)
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService
class DatasourceNodeRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
is_published: bool
register_schema_model(service_api_ns, DatasourceNodeRunPayload)
register_schema_model(service_api_ns, PipelineRunApiEntity)
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins."""
@@ -88,22 +102,20 @@ class DatasourceNodeRunApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins."""
# Get query parameter to determine published or draft
parser: RequestParser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
.add_argument("is_published", type=bool, required=True, location="json")
)
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
{
**payload.model_dump(exclude_none=True),
"pipeline_id": str(pipeline.id),
"node_id": node_id,
}
)
return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node(
@@ -147,25 +159,10 @@ class PipelineRunApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline."""
parser: RequestParser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_published", type=bool, required=True, default=True, location="json")
.add_argument(
"response_mode",
type=str,
required=True,
choices=["streaming", "blocking"],
default="blocking",
location="json",
)
)
args: ParseResult = parser.parse_args()
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
if not isinstance(current_user, Account):
raise Forbidden()
@@ -176,9 +173,9 @@ class PipelineRunApi(DatasetApiResource):
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
streaming=args.get("response_mode") == "streaming",
args=payload.model_dump(),
invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
streaming=payload.response_mode == "streaming",
)
return helper.compact_generate_response(response)

View File

@@ -1,8 +1,12 @@
from typing import Any
from flask import request
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
@@ -24,34 +28,42 @@ from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexing
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
# Define parsers for segment operations
segment_create_parser = reqparse.RequestParser().add_argument(
"segments", type=list, required=False, nullable=True, location="json"
)
segment_list_parser = (
reqparse.RequestParser()
.add_argument("status", type=str, action="append", default=[], location="args")
.add_argument("keyword", type=str, default=None, location="args")
)
class SegmentCreatePayload(BaseModel):
segments: list[dict[str, Any]] | None = None
segment_update_parser = reqparse.RequestParser().add_argument(
"segment", type=dict, required=False, nullable=True, location="json"
)
child_chunk_create_parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
class SegmentListQuery(BaseModel):
status: list[str] = Field(default_factory=list)
keyword: str | None = None
child_chunk_list_parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
child_chunk_update_parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
class SegmentUpdatePayload(BaseModel):
segment: SegmentUpdateArgs
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkListQuery(BaseModel):
limit: int = Field(default=20, ge=1)
keyword: str | None = None
page: int = Field(default=1, ge=1)
class ChildChunkUpdatePayload(BaseModel):
content: str
register_schema_models(
service_api_ns,
SegmentCreatePayload,
SegmentListQuery,
SegmentUpdatePayload,
ChildChunkCreatePayload,
ChildChunkListQuery,
ChildChunkUpdatePayload,
)
@@ -59,7 +71,7 @@ child_chunk_update_parser = reqparse.RequestParser().add_argument(
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
@service_api_ns.expect(segment_create_parser)
@service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
@service_api_ns.doc("create_segments")
@service_api_ns.doc(description="Create segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@@ -106,20 +118,20 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
args = segment_create_parser.parse_args()
if args["segments"] is not None:
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
if payload.segments is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(args["segments"]) > segments_limit:
if segments_limit > 0 and len(payload.segments) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
for args_item in args["segments"]:
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
@service_api_ns.expect(segment_list_parser)
@service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__])
@service_api_ns.doc("list_segments")
@service_api_ns.doc(description="List segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@@ -160,13 +172,18 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
args = segment_list_parser.parse_args()
args = SegmentListQuery.model_validate(
{
"status": request.args.getlist("status"),
"keyword": request.args.get("keyword"),
}
)
segments, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=current_tenant_id,
status_list=args["status"],
keyword=args["keyword"],
status_list=args.status,
keyword=args.keyword,
page=page,
limit=limit,
)
@@ -217,7 +234,7 @@ class DatasetSegmentApi(DatasetApiResource):
SegmentService.delete_segment(segment, document, dataset)
return 204
@service_api_ns.expect(segment_update_parser)
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment")
@service_api_ns.doc(description="Update a specific segment")
@service_api_ns.doc(
@@ -265,12 +282,9 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
# validate args
args = segment_update_parser.parse_args()
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
)
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment")
@@ -308,7 +322,7 @@ class DatasetSegmentApi(DatasetApiResource):
class ChildChunkApi(DatasetApiResource):
"""Resource for child chunks."""
@service_api_ns.expect(child_chunk_create_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
@service_api_ns.doc("create_child_chunk")
@service_api_ns.doc(description="Create a new child chunk for a segment")
@service_api_ns.doc(
@@ -360,16 +374,16 @@ class ChildChunkApi(DatasetApiResource):
raise ProviderNotInitializeError(ex.description)
# validate args
args = child_chunk_create_parser.parse_args()
payload = ChildChunkCreatePayload.model_validate(service_api_ns.payload or {})
try:
child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset)
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@service_api_ns.expect(child_chunk_list_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__])
@service_api_ns.doc("list_child_chunks")
@service_api_ns.doc(description="List child chunks for a segment")
@service_api_ns.doc(
@@ -400,11 +414,17 @@ class ChildChunkApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
args = child_chunk_list_parser.parse_args()
args = ChildChunkListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
page = args.page
limit = min(args.limit, 100)
keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
@@ -480,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
return 204
@service_api_ns.expect(child_chunk_update_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk")
@service_api_ns.doc(description="Update a specific child chunk")
@service_api_ns.doc(
@@ -533,10 +553,10 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.")
# validate args
args = child_chunk_update_parser.parse_args()
payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {})
try:
child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset)
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))