Compare commits

..

1 Commits

Author SHA1 Message Date
jyong
cf08295dde add pipeline template endpoint 2025-10-13 17:33:40 +08:00
3 changed files with 438 additions and 160 deletions

View File

@@ -1,4 +1,6 @@
import json
import logging
from typing import Any
from flask import request
from flask_restx import Resource, reqparse
@@ -16,6 +18,8 @@ from libs.login import login_required
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import (
PipelineBuiltInTemplateEntity,
PipelineBuiltInTemplateInstallEntity,
PipelineBuiltInTemplateUpdateEntity,
PipelineTemplateInfoEntity,
)
from services.rag_pipeline.rag_pipeline import RagPipelineService
@@ -23,12 +27,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
class PipelineTemplateListApi(Resource):
@setup_required
@login_required
@@ -156,82 +154,8 @@ class PipelineTemplateInstallApi(Resource):
Returns:
Success response or error with appropriate HTTP status
"""
try:
# Extract and validate Bearer token
auth_token = self._extract_bearer_token()
# Parse and validate request parameters
template_args = self._parse_template_args()
# Process uploaded template file
file_content = self._process_template_file()
# Create template entity
pipeline_built_in_template_entity = PipelineBuiltInTemplateEntity(**template_args)
# Install the template
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.install_built_in_pipeline_template(
pipeline_built_in_template_entity, file_content, auth_token
)
return {"result": "success", "message": "Template installed successfully"}, 200
except ValueError as e:
logger.exception("Validation error in template installation")
return {"error": str(e)}, 400
except Exception as e:
logger.exception("Unexpected error in template installation")
return {"error": "An unexpected error occurred during template installation"}, 500
def _extract_bearer_token(self) -> str:
"""
Extract and validate Bearer token from Authorization header
Returns:
The extracted token string
Raises:
ValueError: If token is missing or invalid
"""
auth_header = request.headers.get("Authorization", "").strip()
if not auth_header:
raise ValueError("Authorization header is required")
if not auth_header.startswith("Bearer "):
raise ValueError("Authorization header must start with 'Bearer '")
token_parts = auth_header.split(" ", 1)
if len(token_parts) != 2:
raise ValueError("Invalid Authorization header format")
auth_token = token_parts[1].strip()
if not auth_token:
raise ValueError("Bearer token cannot be empty")
return auth_token
def _parse_template_args(self) -> dict:
"""
Parse and validate template arguments from form data
Args:
template_id: The template ID from URL
Returns:
Dictionary of validated template arguments
"""
# Use reqparse for consistent parameter parsing
parser = reqparse.RequestParser()
parser.add_argument(
"template_id",
type=str,
location="form",
required=False,
help="Template ID for updating existing template"
)
parser.add_argument(
"language",
type=str,
@@ -257,70 +181,260 @@ class PipelineTemplateInstallApi(Resource):
default="",
help="Template description (max 1000 characters)"
)
parser.add_argument(
"position",
type=int,
location="form",
required=False,
default=1,
help="Template position"
)
parser.add_argument(
"icon",
type=str,
location="form",
required=False,
help="Template icon"
)
args = parser.parse_args()
template_args = parser.parse_args()
# Additional validation
if args.get("name"):
args["name"] = self._validate_name(args["name"])
if template_args.get("name"):
template_args["name"] = _validate_name(template_args["name"])
if template_args.get("icon"):
template_args["icon"] = json.loads(template_args["icon"])
if args.get("description") and len(args["description"]) > 1000:
if template_args.get("description") and len(template_args["description"]) > 1000:
raise ValueError("Description must not exceed 1000 characters")
# Filter out None values
return {k: v for k, v in args.items() if v is not None}
def _validate_name(self, name: str) -> str:
try:
# Extract and validate Bearer token
auth_token = _extract_bearer_token()
# Process uploaded template file
file_content = _process_template_file()
template_args["yaml_content"] = file_content
# Create template entity
pipeline_built_in_template_entity = PipelineBuiltInTemplateInstallEntity(**template_args)
# Install the template
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.insert_built_in_pipeline_template(
pipeline_built_in_template_entity, auth_token
)
return {"result": "success", "message": "Template installed successfully"}, 200
except ValueError as e:
logger.exception("Validation error in template installation")
return {"error": str(e)}, 400
except Exception as e:
logger.exception("Unexpected error in template installation")
return {"error": "An unexpected error occurred during template installation"}, 500
class PipelineTemplateUnpdateApi(Resource):
"""API endpoint for updating built-in pipeline templates"""
def patch(self, template_id: str):
"""
Validate template name
Update a built-in pipeline template
Args:
name: Template name to validate
Returns:
Validated and trimmed name
Raises:
ValueError: If name is invalid
template_id: The template ID to update (from URL parameter)
"""
name = name.strip()
if not name or len(name) < 1 or len(name) > 200:
raise ValueError("Template name must be between 1 and 200 characters")
return name
def _process_template_file(self) -> str:
"""
Process and validate uploaded template file
parser = reqparse.RequestParser()
Returns:
File content as string
Raises:
ValueError: If file is missing or invalid
"""
if "file" not in request.files:
raise ValueError("Template file is required")
file = request.files["file"]
# Validate file
if not file or not file.filename:
raise ValueError("No file selected")
filename = file.filename.strip()
if not filename:
raise ValueError("File name cannot be empty")
# Check file extension
if not filename.lower().endswith(".pipeline"):
raise ValueError("Template file must be a pipeline file (.pipeline)")
parser.add_argument(
"language",
type=str,
location="form",
required=True,
default="en-US",
choices=["en-US", "zh-CN", "ja-JP"],
help="Template language code"
)
parser.add_argument(
"name",
type=str,
location="form",
required=True,
default="New Pipeline Template",
help="Template name (1-200 characters)"
)
parser.add_argument(
"description",
type=str,
location="form",
required=False,
default="",
help="Template description (max 1000 characters)"
)
parser.add_argument(
"position",
type=int,
location="form",
required=False,
default=1,
help="Template position"
)
parser.add_argument(
"icon",
type=str,
location="form",
required=False,
help="Template icon"
)
template_args = parser.parse_args()
try:
file_content = file.read().decode("utf-8")
except UnicodeDecodeError:
raise ValueError("Template file must be valid UTF-8 text")
# Extract and validate Bearer token
auth_token = _extract_bearer_token()
if template_args.get("icon"):
template_args["icon"] = json.loads(template_args["icon"])
if "file" in request.files:
file_content = request.files["file"].read().decode("utf-8")
template_args["yaml_content"] = file_content
else:
template_args["yaml_content"] = None
# Validate template_id
if not template_id or not template_id.strip():
raise ValueError("Template ID is required")
template_args["template_id"] = template_id
pipeline_built_in_template_entity = PipelineBuiltInTemplateUpdateEntity(**template_args)
# Update the template
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.update_built_in_pipeline_template(pipeline_built_in_template_entity, auth_token)
return {"result": "success", "message": "Template updated successfully"}, 200
except ValueError as e:
logger.exception("Validation error in template update")
return {"error": str(e)}, 400
except Exception as e:
logger.exception("Unexpected error in template update")
return {"error": "An unexpected error occurred during template update"}, 500
def delete(self, template_id: str):
"""
Uninstall a built-in pipeline template
return file_content
Args:
template_id: The template ID to uninstall (from URL parameter)
Returns:
Success response or error with appropriate HTTP status
"""
try:
# Extract and validate Bearer token
auth_token = _extract_bearer_token()
# Validate template_id
if not template_id or not template_id.strip():
raise ValueError("Template ID is required")
# Uninstall the template
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.uninstall_built_in_pipeline_template(
template_id.strip(), auth_token
)
return {"result": "success", "message": "Template uninstalled successfully"}, 200
except ValueError as e:
logger.exception("Validation error in template uninstallation")
return {"error": str(e)}, 400
except Exception as e:
logger.exception("Unexpected error in template uninstallation")
return {"error": "An unexpected error occurred during template uninstallation"}, 500
def _extract_bearer_token() -> str:
"""
Extract and validate Bearer token from Authorization header
Returns:
The extracted token string
Raises:
ValueError: If token is missing or invalid
"""
auth_header = request.headers.get("Authorization", "").strip()
if not auth_header:
raise ValueError("Authorization header is required")
if not auth_header.startswith("Bearer "):
raise ValueError("Authorization header must start with 'Bearer '")
token_parts = auth_header.split(" ", 1)
if len(token_parts) != 2:
raise ValueError("Invalid Authorization header format")
auth_token = token_parts[1].strip()
if not auth_token:
raise ValueError("Bearer token cannot be empty")
return auth_token
def _validate_name(name: str) -> str:
"""
Validate template name
Args:
name: Template name to validate
Returns:
Validated and trimmed name
Raises:
ValueError: If name is invalid
"""
name = name.strip()
if not name or len(name) < 1 or len(name) > 200:
raise ValueError("Template name must be between 1 and 200 characters")
return name
def _process_template_file() -> str:
"""
Process and validate uploaded template file
Returns:
File content as string
Raises:
ValueError: If file is missing or invalid
"""
if "file" not in request.files:
raise ValueError("Template file is required")
file = request.files["file"]
# Validate file
if not file or not file.filename:
raise ValueError("No file selected")
filename = file.filename.strip()
if not filename:
raise ValueError("File name cannot be empty")
# Check file extension
if not filename.lower().endswith(".pipeline"):
raise ValueError("Template file must be a pipeline file (.pipeline)")
try:
file_content = file.read().decode("utf-8")
except UnicodeDecodeError:
raise ValueError("Template file must be valid UTF-8 text")
return file_content
api.add_resource(
@@ -342,4 +456,8 @@ api.add_resource(
api.add_resource(
PipelineTemplateInstallApi,
"/rag/pipeline/built-in/templates/install",
)
api.add_resource(
PipelineTemplateUninstallApi,
"/rag/pipeline/built-in/templates/<string:template_id>/uninstall",
)

View File

@@ -130,8 +130,20 @@ class KnowledgeConfiguration(BaseModel):
return v
class PipelineBuiltInTemplateEntity(BaseModel):
template_id: str | None = None
class PipelineBuiltInTemplateInstallEntity(BaseModel):
icon: IconInfo | None = None
name: str
description: str
language: str
yaml_content: str
position: int = 1
class PipelineBuiltInTemplateUpdateEntity(BaseModel):
template_id: str
icon: IconInfo
name: str
description: str
language: str
yaml_content: str | None = None
position: int = 1

View File

@@ -10,11 +10,12 @@ from uuid import uuid4
import yaml
from flask_login import current_user
from sqlalchemy import func, or_, select
from sqlalchemy import func, or_, select, update
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from controllers.console.datasets.rag_pipeline.rag_pipeline import PipelineBuiltInTemplateInstallEntity, PipelineBuiltInTemplateUpdateEntity
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import (
@@ -1458,8 +1459,8 @@ class RagPipelineService:
raise ValueError("Pipeline not found")
return pipeline
def install_built_in_pipeline_template(
self, args: PipelineBuiltInTemplateEntity, file_content: str, auth_token: str
def insert_built_in_pipeline_template(
self, args: PipelineBuiltInTemplateInstallEntity, auth_token: str
) -> None:
"""
Install built-in pipeline template
@@ -1476,10 +1477,13 @@ class RagPipelineService:
self._validate_auth_token(auth_token)
# Parse and validate template content
pipeline_template_dsl = self._parse_template_content(file_content)
pipeline_template_dsl = self._parse_template_content(args.yaml_content)
# Extract template metadata
icon = self._extract_icon_metadata(pipeline_template_dsl)
if not args.icon:
icon = self._extract_icon_metadata(pipeline_template_dsl)
else:
icon = args.icon.model_dump()
chunk_structure = self._extract_chunk_structure(pipeline_template_dsl)
# Prepare template data
@@ -1489,19 +1493,39 @@ class RagPipelineService:
"chunk_structure": chunk_structure,
"icon": icon,
"language": args.language,
"yaml_content": file_content,
"yaml_content": args.yaml_content,
"position": args.position,
}
# Use transaction for database operations
try:
if args.template_id:
self._update_existing_template(args.template_id, template_data)
else:
self._create_new_template(template_data)
self._create_new_template(template_data)
db.session.commit()
except Exception as e:
db.session.rollback()
raise ValueError(f"Failed to install pipeline template: {str(e)}")
def update_built_in_pipeline_template(self, args: PipelineBuiltInTemplateUpdateEntity, auth_token: str) -> None:
"""
Update built-in pipeline template
"""
self._validate_auth_token(auth_token)
template_data = {
"name": args.name,
"description": args.description,
"icon": args.icon.model_dump(),
"language": args.language,
"position": args.position,
}
if args.yaml_content:
template_data["yaml_content"] = args.yaml_content
chunk_structure = self._extract_chunk_structure(self._parse_template_content(args.yaml_content))
template_data["chunk_structure"] = chunk_structure
self._update_existing_template(args.template_id, template_data)
db.session.commit()
logger.info("Updated template %s with new data", args.template_id)
def _validate_auth_token(self, auth_token: str) -> None:
"""Validate the authentication token"""
@@ -1554,43 +1578,167 @@ class RagPipelineService:
return chunk_structure
def _update_existing_template(self, template_id: str, template_data: dict) -> None:
"""Update an existing pipeline template"""
"""
Update an existing pipeline template with optimistic locking
Args:
template_id: ID of the template to update
template_data: Dictionary containing updated template fields
Raises:
ValueError: If template not found
"""
# Use with_for_update() for row-level locking to prevent concurrent updates
pipeline_built_in_template = (
db.session.query(PipelineBuiltInTemplate)
.filter(PipelineBuiltInTemplate.id == template_id)
.with_for_update()
.first()
)
if not pipeline_built_in_template:
raise ValueError(f"Pipeline built-in template not found: {template_id}")
update_position = False
if template_data.get("position") != pipeline_built_in_template.position:
update_position = True
if update_position:
db.session.execute(
update(PipelineBuiltInTemplate)
.where(PipelineBuiltInTemplate.language == pipeline_built_in_template.language, PipelineBuiltInTemplate.position > pipeline_built_in_template.position)
.values(position=PipelineBuiltInTemplate.position - 1)
)
db.session.flush()
db.session.execute(
update(PipelineBuiltInTemplate)
.where(PipelineBuiltInTemplate.language == pipeline_built_in_template.language, PipelineBuiltInTemplate.position >= template_data.get("position"))
.values(position=PipelineBuiltInTemplate.position + 1)
)
db.session.flush()
# Update template fields
for key, value in template_data.items():
setattr(pipeline_built_in_template, key, value)
# Update timestamp if exists
if hasattr(pipeline_built_in_template, 'updated_at'):
pipeline_built_in_template.updated_at = datetime.now(UTC)
db.session.add(pipeline_built_in_template)
db.session.flush()
def _create_new_template(self, template_data: dict) -> None:
"""Create a new pipeline template"""
# Get the next available position
position = self._get_next_position(template_data["language"])
"""
Create a new pipeline template with atomic position assignment
# Add additional fields for new template
template_data.update({
"position": position,
"install_count": 0,
"copyright": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_COPYRIGHT,
"privacy_policy": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_PRIVACY_POLICY,
})
new_template = PipelineBuiltInTemplate(**template_data)
db.session.add(new_template)
Args:
template_data: Dictionary containing template fields
"""
# Use a single query with locking to get and increment position atomically
with db.session.begin_nested():
# Lock all templates of the same language to prevent position conflicts
db.session.query(PipelineBuiltInTemplate).filter(
PipelineBuiltInTemplate.language == template_data["language"]
).with_for_update().all()
# Get the next available position
position = self._get_next_position(template_data["language"])
# Add additional fields for new template
template_data.update({
"position": position,
"install_count": 0,
"copyright": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_COPYRIGHT or "",
"privacy_policy": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_PRIVACY_POLICY or "",
})
# Add timestamp if model supports it
if hasattr(PipelineBuiltInTemplate, 'created_at'):
template_data['created_at'] = datetime.now(UTC)
new_template = PipelineBuiltInTemplate(**template_data)
db.session.add(new_template)
logger.info(
"Created new template '%s' at position %d for language %s",
template_data.get("name"), position, template_data["language"]
)
def _get_next_position(self, language: str) -> int:
"""Get the next available position for a template in the specified language"""
"""
Get the next available position for a template in the specified language
Args:
language: Language code for the template
Returns:
Next available position number (1-based)
"""
# Use COALESCE for database compatibility
max_position = (
db.session.query(func.max(PipelineBuiltInTemplate.position))
db.session.query(func.coalesce(func.max(PipelineBuiltInTemplate.position), 0))
.filter(PipelineBuiltInTemplate.language == language)
.scalar()
)
return (max_position or 0) + 1
return max_position + 1
def uninstall_built_in_pipeline_template(self, template_id: str, auth_token: str) -> None:
"""
Uninstall a built-in pipeline template and reorder remaining templates
Args:
template_id: ID of the template to uninstall
auth_token: Authentication token for authorization
Raises:
ValueError: If template not found or authentication fails
"""
# Validate authentication
self._validate_auth_token(auth_token)
# Use transaction for atomic operations
try:
# Get the template to delete with lock to prevent concurrent modifications
pipeline_built_in_template = (
db.session.query(PipelineBuiltInTemplate)
.filter(PipelineBuiltInTemplate.id == template_id)
.with_for_update()
.first()
)
if not pipeline_built_in_template:
raise ValueError(f"Pipeline built-in template not found: {template_id}")
# Store position and language for reordering
deleted_position = pipeline_built_in_template.position
template_language = pipeline_built_in_template.language
# Delete the template first
db.session.delete(pipeline_built_in_template)
db.session.flush() # Execute delete but don't commit yet
# Batch update positions for all templates after the deleted one
# Using bulk update for better performance
db.session.execute(
update(PipelineBuiltInTemplate)
.where(
PipelineBuiltInTemplate.language == template_language,
PipelineBuiltInTemplate.position > deleted_position
)
.values(position=PipelineBuiltInTemplate.position - 1)
)
# Commit all changes together
db.session.commit()
logger.info(
"Successfully uninstalled template %s at position %d for language %s",
template_id, deleted_position, template_language
)
except Exception as e:
db.session.rollback()
logger.exception("Failed to uninstall template %s", template_id)
raise ValueError(f"Failed to uninstall pipeline template: {str(e)}")