diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 8c5f91cb7f..471ecbf070 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -1,5 +1,6 @@ import logging +import yaml from flask import request from flask_restful import Resource, reqparse from sqlalchemy.orm import Session @@ -12,10 +13,9 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from libs.login import login_required -from models.dataset import Pipeline, PipelineCustomizedTemplate +from models.dataset import PipelineCustomizedTemplate from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.rag_pipeline.rag_pipeline import RagPipelineService -from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService logger = logging.getLogger(__name__) @@ -84,8 +84,8 @@ class CustomizedPipelineTemplateApi(Resource): ) args = parser.parse_args() pipeline_template_info = PipelineTemplateInfoEntity(**args) - pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) - return pipeline_template, 200 + RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) + return 200 @setup_required @login_required @@ -106,13 +106,41 @@ class CustomizedPipelineTemplateApi(Resource): ) if not template: raise ValueError("Customized pipeline template not found.") - pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() - if not pipeline: - raise ValueError("Pipeline not found.") - dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) + dsl = yaml.safe_load(template.yaml_content) return {"data": dsl}, 200 +class CustomizedPipelineTemplateApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def post(self, pipeline_id: str): + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + parser.add_argument( + "icon_info", + type=dict, + location="json", + nullable=True, + ) + args = parser.parse_args() + rag_pipeline_service = RagPipelineService() + RagPipelineService.publish_customized_pipeline_template(pipeline_id, args) + return 200 api.add_resource( PipelineTemplateListApi, diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index e4c96775c8..b7e20cfd10 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -20,11 +20,11 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager from core.app.apps.pipeline.pipeline_runner import PipelineRunner -from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline -from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity +from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.entities.knowledge_entities import PipelineDataset, PipelineDocument from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.index_processor.constant.built_in_field import BuiltInField from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository @@ -32,6 +32,7 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db +from fields.document_fields import dataset_and_document_fields from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, Pipeline from models.enums import WorkflowRunTriggeredFrom @@ -54,7 +55,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[True], call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Generator[Mapping | str, None, None] | None: ... + ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... @overload def generate( @@ -101,23 +102,18 @@ class PipelineGenerator(BaseAppGenerator): pipeline=pipeline, workflow=workflow, ) - + # Add null check for dataset + dataset = pipeline.dataset + if not dataset: + raise ValueError("Pipeline dataset is required") inputs: Mapping[str, Any] = args["inputs"] start_node_id: str = args["start_node_id"] datasource_type: str = args["datasource_type"] datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) - - for datasource_info in datasource_info_list: - workflow_run_id = str(uuid.uuid4()) - document_id = None - - # Add null check for dataset - dataset = pipeline.dataset - if not dataset: - raise ValueError("Pipeline dataset is required") - - if invoke_from == InvokeFrom.PUBLISHED: + documents = [] + if invoke_from == InvokeFrom.PUBLISHED: + for datasource_info in datasource_info_list: position = DocumentService.get_documents_position(dataset.id) document = self._build_document( tenant_id=pipeline.tenant_id, @@ -132,9 +128,15 @@ class PipelineGenerator(BaseAppGenerator): document_form=dataset.chunk_structure, ) db.session.add(document) - db.session.commit() - document_id = document.id - # init application generate entity + documents.append(document) + db.session.commit() + + # run in child thread + for i, datasource_info in enumerate(datasource_info_list): + workflow_run_id = str(uuid.uuid4()) + document_id = None + if invoke_from == InvokeFrom.PUBLISHED: + document_id = documents[i].id application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), app_config=pipeline_config, @@ -159,7 +161,6 @@ class PipelineGenerator(BaseAppGenerator): workflow_run_id=workflow_run_id, ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) if invoke_from == InvokeFrom.DEBUGGER: @@ -183,6 +184,7 @@ class PipelineGenerator(BaseAppGenerator): ) if invoke_from == InvokeFrom.DEBUGGER: return self._generate( + flask_app=current_app._get_current_object(),# type: ignore pipeline=pipeline, workflow=workflow, user=user, @@ -194,21 +196,47 @@ class PipelineGenerator(BaseAppGenerator): workflow_thread_pool_id=workflow_thread_pool_id, ) else: - self._generate( - pipeline=pipeline, - workflow=workflow, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=invoke_from, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=streaming, - workflow_thread_pool_id=workflow_thread_pool_id, + # run in child thread + thread = threading.Thread( + target=self._generate, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "pipeline": pipeline, + "workflow": workflow, + "user": user, + "application_generate_entity": application_generate_entity, + "invoke_from": invoke_from, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "streaming": streaming, + "workflow_thread_pool_id": workflow_thread_pool_id, + }, ) - + thread.start() + # return batch, dataset, documents + return { + "batch": batch, + "dataset": PipelineDataset( + id=dataset.id, + name=dataset.name, + description=dataset.description, + chunk_structure=dataset.chunk_structure, + ).model_dump(), + "documents": [PipelineDocument( + id=document.id, + position=document.position, + data_source_info=document.data_source_info, + name=document.name, + indexing_status=document.indexing_status, + error=document.error, + enabled=document.enabled, + ).model_dump() for document in documents + ] + } def _generate( self, *, + flask_app: Flask, pipeline: Pipeline, workflow: Workflow, user: Union[Account, EndUser], @@ -232,40 +260,42 @@ class PipelineGenerator(BaseAppGenerator): :param streaming: is stream :param workflow_thread_pool_id: workflow thread pool id """ - # init queue manager - queue_manager = PipelineQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - app_mode=AppMode.RAG_PIPELINE, - ) + print(user.id) + with flask_app.app_context(): + # init queue manager + queue_manager = PipelineQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=AppMode.RAG_PIPELINE, + ) - # new thread - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "context": contextvars.copy_context(), - "workflow_thread_pool_id": workflow_thread_pool_id, - }, - ) + # new thread + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": contextvars.copy_context(), + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) - worker_thread.start() + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - user=user, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - stream=streaming, - ) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + stream=streaming, + ) - return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def single_iteration_generate( self, @@ -317,7 +347,6 @@ class PipelineGenerator(BaseAppGenerator): call_depth=0, workflow_run_id=str(uuid.uuid4()), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) # Create workflow node execution repository @@ -338,6 +367,7 @@ class PipelineGenerator(BaseAppGenerator): ) return self._generate( + flask_app=current_app._get_current_object(),# type: ignore pipeline=pipeline, workflow=workflow, user=user, @@ -399,7 +429,6 @@ class PipelineGenerator(BaseAppGenerator): single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), workflow_run_id=str(uuid.uuid4()), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -421,6 +450,7 @@ class PipelineGenerator(BaseAppGenerator): ) return self._generate( + flask_app=current_app._get_current_object(),# type: ignore pipeline=pipeline, workflow=workflow, user=user, diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index 90c9879733..f876c06b06 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -17,3 +17,26 @@ class IndexingEstimate(BaseModel): total_segments: int preview: list[PreviewDetail] qa_preview: Optional[list[QAPreviewDetail]] = None + + +class PipelineDataset(BaseModel): + id: str + name: str + description: str + chunk_structure: str + +class PipelineDocument(BaseModel): + id: str + position: int + data_source_info: dict + name: str + indexing_status: str + error: str + enabled: bool + + + +class PipelineGenerateResponse(BaseModel): + batch: str + dataset: PipelineDataset + documents: list[PipelineDocument] diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 8d916a19db..cbc55474c6 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -253,6 +253,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) self, workflow_run_id: str, order_config: Optional[OrderConfig] = None, + triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all WorkflowNodeExecution database models for a specific workflow run. @@ -274,7 +275,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) stmt = select(WorkflowNodeExecution).where( WorkflowNodeExecution.workflow_run_id == workflow_run_id, WorkflowNodeExecution.tenant_id == self._tenant_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + WorkflowNodeExecution.triggered_from == triggered_from, ) if self._app_id: @@ -308,6 +309,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) self, workflow_run_id: str, order_config: Optional[OrderConfig] = None, + triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[NodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. @@ -325,7 +327,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) A list of NodeExecution instances """ # Get the database models using the new method - db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config) + db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from) # Convert database models to domain models domain_models = [] diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 8d675e56fa..2871b3ec16 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -87,6 +87,7 @@ dataset_detail_fields = { "runtime_mode": fields.String, "chunk_structure": fields.String, "icon_info": fields.Nested(icon_info_fields), + "is_published": fields.Boolean, } dataset_query_detail_fields = { diff --git a/api/models/dataset.py b/api/models/dataset.py index 22703771d5..86216ffe98 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -152,6 +152,8 @@ class Dataset(Base): @property def doc_form(self): + if self.chunk_structure: + return self.chunk_structure document = db.session.query(Document).filter(Document.dataset_id == self.id).first() if document: return document.doc_form @@ -206,6 +208,13 @@ class Dataset(Base): "external_knowledge_api_name": external_knowledge_api.name, "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), } + @property + def is_published(self): + if self.pipeline_id: + pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() + if pipeline: + return pipeline.is_published + return False @property def doc_metadata(self): @@ -1154,10 +1163,11 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - pipeline_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False) + chunk_structure = db.Column(db.String(255), nullable=False) icon = db.Column(db.JSON, nullable=False) + yaml_content = db.Column(db.Text, nullable=False) copyright = db.Column(db.String(255), nullable=False) privacy_policy = db.Column(db.String(255), nullable=False) position = db.Column(db.Integer, nullable=False) @@ -1166,9 +1176,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - @property - def pipeline(self): - return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] @@ -1180,11 +1187,12 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) - pipeline_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False) + chunk_structure = db.Column(db.String(255), nullable=False) icon = db.Column(db.JSON, nullable=False) position = db.Column(db.Integer, nullable=False) + yaml_content = db.Column(db.Text, nullable=False) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index 70c72014f2..b0fa54115c 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -23,8 +23,8 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): result = self.fetch_pipeline_templates_from_builtin(language) return result - def get_pipeline_template_detail(self, pipeline_id: str): - result = self.fetch_pipeline_template_detail_from_builtin(pipeline_id) + def get_pipeline_template_detail(self, template_id: str): + result = self.fetch_pipeline_template_detail_from_builtin(template_id) return result @classmethod @@ -54,11 +54,11 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return builtin_data.get("pipeline_templates", {}).get(language, {}) @classmethod - def fetch_pipeline_template_detail_from_builtin(cls, pipeline_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]: """ Fetch pipeline template detail from builtin. - :param pipeline_id: Pipeline ID + :param template_id: Template ID :return: """ builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() - return builtin_data.get("pipeline_templates", {}).get(pipeline_id) + return builtin_data.get("pipeline_templates", {}).get(template_id) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index de69373ba4..b6670b70cd 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,12 +1,13 @@ from typing import Optional from flask_login import current_user +import yaml from extensions.ext_database import db -from models.dataset import Pipeline, PipelineCustomizedTemplate -from services.app_dsl_service import AppDslService +from models.dataset import PipelineCustomizedTemplate from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): @@ -35,13 +36,26 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - pipeline_templates = ( + pipeline_customized_templates = ( db.session.query(PipelineCustomizedTemplate) .filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .all() ) + recommended_pipelines_results = [] + for pipeline_customized_template in pipeline_customized_templates: + + recommended_pipeline_result = { + "id": pipeline_customized_template.id, + "name": pipeline_customized_template.name, + "description": pipeline_customized_template.description, + "icon": pipeline_customized_template.icon, + "position": pipeline_customized_template.position, + "chunk_structure": pipeline_customized_template.chunk_structure, + } + recommended_pipelines_results.append(recommended_pipeline_result) + + return {"pipeline_templates": recommended_pipelines_results} - return {"pipeline_templates": pipeline_templates} @classmethod def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: @@ -57,15 +71,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): if not pipeline_template: return None - # get pipeline detail - pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first() - if not pipeline or not pipeline.is_public: - return None - return { - "id": pipeline.id, - "name": pipeline.name, - "icon": pipeline.icon, - "mode": pipeline.mode, - "export_data": AppDslService.export_dsl(app_model=pipeline), + "id": pipeline_template.id, + "name": pipeline_template.name, + "icon": pipeline_template.icon, + "export_data": yaml.safe_load(pipeline_template.yaml_content), } diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 9ea3cc678b..8019dac0a8 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,7 +1,9 @@ from typing import Optional +import yaml + from extensions.ext_database import db -from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate +from models.dataset import PipelineBuiltInTemplate from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType @@ -36,24 +38,18 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): recommended_pipelines_results = [] for pipeline_built_in_template in pipeline_built_in_templates: - pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline - if not pipeline_model: - continue recommended_pipeline_result = { "id": pipeline_built_in_template.id, "name": pipeline_built_in_template.name, - "pipeline_id": pipeline_model.id, "description": pipeline_built_in_template.description, "icon": pipeline_built_in_template.icon, "copyright": pipeline_built_in_template.copyright, "privacy_policy": pipeline_built_in_template.privacy_policy, "position": pipeline_built_in_template.position, + "chunk_structure": pipeline_built_in_template.chunk_structure, } - dataset: Dataset | None = pipeline_model.dataset - if dataset: - recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure - recommended_pipelines_results.append(recommended_pipeline_result) + recommended_pipelines_results.append(recommended_pipeline_result) return {"pipeline_templates": recommended_pipelines_results} @@ -64,8 +60,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param pipeline_id: Pipeline ID :return: """ - from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService - # is in public recommended list pipeline_template = ( db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first() @@ -74,19 +68,10 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): if not pipeline_template: return None - # get pipeline detail - pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first() - if not pipeline or not pipeline.is_public: - return None - - dataset: Dataset | None = pipeline.dataset - if not dataset: - return None - return { - "id": pipeline.id, - "name": pipeline.name, + "id": pipeline_template.id, + "name": pipeline_template.name, "icon": pipeline_template.icon, - "chunk_structure": dataset.chunk_structure, - "export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline), + "chunk_structure": pipeline_template.chunk_structure, + "export_data": yaml.safe_load(pipeline_template.yaml_content), } diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py index aa8a6298d7..7b87ffe75b 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py @@ -1,4 +1,5 @@ from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval +from services.rag_pipeline.pipeline_template.customized.customized_retrieval import CustomizedPipelineTemplateRetrieval from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType @@ -12,7 +13,7 @@ class PipelineTemplateRetrievalFactory: case PipelineTemplateType.REMOTE: return RemotePipelineTemplateRetrieval case PipelineTemplateType.CUSTOMIZED: - return DatabasePipelineTemplateRetrieval + return CustomizedPipelineTemplateRetrieval case PipelineTemplateType.DATABASE: return DatabasePipelineTemplateRetrieval case PipelineTemplateType.BUILTIN: diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 79e793118a..b3c32a7c78 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -7,7 +7,7 @@ from typing import Any, Optional, cast from uuid import uuid4 from flask_login import current_user -from sqlalchemy import select +from sqlalchemy import or_, select from sqlalchemy.orm import Session import contexts @@ -47,16 +47,19 @@ from models.workflow import ( WorkflowType, ) from services.dataset_service import DatasetService -from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeConfiguration, + PipelineTemplateInfoEntity, +) from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory class RagPipelineService: - @staticmethod + @classmethod def get_pipeline_templates( - type: str = "built-in", language: str = "en-US" - ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: + cls, type: str = "built-in", language: str = "en-US" + ) -> dict: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() @@ -64,14 +67,14 @@ class RagPipelineService: if not result.get("pipeline_templates") and language != "en-US": template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval() result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US") - return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])] + return result else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() result = retrieval_instance.get_pipeline_templates(language) - return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])] + return result - @classmethod + @classmethod def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: """ Get pipeline template detail. @@ -684,7 +687,10 @@ class RagPipelineService: base_query = db.session.query(WorkflowRun).filter( WorkflowRun.tenant_id == pipeline.tenant_id, WorkflowRun.app_id == pipeline.id, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, + or_( + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value, + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value + ) ) if args.get("last_id"): @@ -765,8 +771,26 @@ class RagPipelineService: # Use the repository to get the node executions with ordering order_config = OrderConfig(order_by=["index"], order_direction="desc") - node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, + order_config=order_config, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN) # Convert domain models to database models workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] return workflow_node_executions + + @classmethod + def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict): + """ + Publish customized pipeline template + """ + pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found") + if not pipeline.workflow_id: + raise ValueError("Pipeline workflow not found") + workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + if not workflow: + raise ValueError("Workflow not found") + + db.session.commit() \ No newline at end of file diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c6751825cc..57e81e6f75 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -1,5 +1,7 @@ import base64 +from datetime import UTC, datetime import hashlib +import json import logging import uuid from collections.abc import Mapping @@ -31,13 +33,12 @@ from extensions.ext_redis import redis_client from factories import variable_factory from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, RagPipelineDatasetCreateEntity, ) from services.plugin.dependencies_analysis import DependenciesAnalysisService -from services.rag_pipeline.rag_pipeline import RagPipelineService logger = logging.getLogger(__name__) @@ -206,12 +207,12 @@ class RagPipelineDslService: status = _check_version_compatibility(imported_version) # Extract app data - pipeline_data = data.get("pipeline") + pipeline_data = data.get("rag_pipeline") if not pipeline_data: return RagPipelineImportInfo( id=import_id, status=ImportStatus.FAILED, - error="Missing pipeline data in YAML content", + error="Missing rag_pipeline data in YAML content", ) # If app_id is provided, check if it exists @@ -256,7 +257,7 @@ class RagPipelineDslService: if dependencies: check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies] - # Create or update app + # Create or update pipeline pipeline = self._create_or_update_pipeline( pipeline=pipeline, data=data, @@ -278,7 +279,9 @@ class RagPipelineDslService: if node.get("data", {}).get("type") == "knowledge_index": knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) - if not dataset: + if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure: + raise ValueError("Chunk structure is not compatible with the published pipeline") + else: dataset = Dataset( tenant_id=account.current_tenant_id, name=name, @@ -295,11 +298,6 @@ class RagPipelineDslService: runtime_mode="rag_pipeline", chunk_structure=knowledge_configuration.chunk_structure, ) - else: - dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique - dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() - dataset.runtime_mode = "rag_pipeline" - dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.index_method.indexing_technique == "high_quality": dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) @@ -540,11 +538,45 @@ class RagPipelineDslService: icon_type = "emoji" icon = str(pipeline_data.get("icon", "")) + + # Initialize pipeline based on mode + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for rag pipeline") + + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) + + + graph = workflow_data.get("graph", {}) + for node in graph.get("nodes", []): + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + decrypted_id + for dataset_id in dataset_ids + if ( + decrypted_id := self.decrypt_dataset_id( + encrypted_data=dataset_id, + tenant_id=account.current_tenant_id, + ) + ) + ] + if pipeline: # Update existing pipeline pipeline.name = pipeline_data.get("name", pipeline.name) pipeline.description = pipeline_data.get("description", pipeline.description) pipeline.updated_by = account.id + + else: if account.current_tenant_id is None: raise ValueError("Current tenant is not set") @@ -567,52 +599,44 @@ class RagPipelineDslService: IMPORT_INFO_REDIS_EXPIRY, CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(), ) - - # Initialize pipeline based on mode - workflow_data = data.get("workflow") - if not workflow_data or not isinstance(workflow_data, dict): - raise ValueError("Missing workflow data for rag pipeline") - - environment_variables_list = workflow_data.get("environment_variables", []) - environment_variables = [ - variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list - ] - conversation_variables_list = workflow_data.get("conversation_variables", []) - conversation_variables = [ - variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list - ] - rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) - - rag_pipeline_service = RagPipelineService() - current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) - if current_draft_workflow: - unique_hash = current_draft_workflow.unique_hash - else: - unique_hash = None - graph = workflow_data.get("graph", {}) - for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: - dataset_ids = node["data"].get("dataset_ids", []) - node["data"]["dataset_ids"] = [ - decrypted_id - for dataset_id in dataset_ids - if ( - decrypted_id := self.decrypt_dataset_id( - encrypted_data=dataset_id, - tenant_id=pipeline.tenant_id, - ) - ) - ] - rag_pipeline_service.sync_draft_workflow( - pipeline=pipeline, - graph=workflow_data.get("graph", {}), - unique_hash=unique_hash, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - rag_pipeline_variables=rag_pipeline_variables_list, + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == "draft", + ) + .first() ) + # create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + features="{}", + type=WorkflowType.RAG_PIPELINE.value, + version="draft", + graph=json.dumps(graph), + created_by=account.id, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + rag_pipeline_variables=rag_pipeline_variables_list, + ) + db.session.add(workflow) + db.session.flush() + pipeline.workflow_id = workflow.id + else: + workflow.graph = json.dumps(graph) + workflow.updated_by = account.id + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.environment_variables = environment_variables + workflow.conversation_variables = conversation_variables + workflow.rag_pipeline_variables = rag_pipeline_variables_list + # commit db session changes + db.session.commit() + + return pipeline @classmethod @@ -623,16 +647,19 @@ class RagPipelineDslService: :param include_secret: Whether include secret variable :return: """ + dataset = pipeline.dataset + if not dataset: + raise ValueError("Missing dataset for rag pipeline") + icon_info = dataset.icon_info export_data = { "version": CURRENT_DSL_VERSION, "kind": "rag_pipeline", "pipeline": { "name": pipeline.name, - "mode": pipeline.mode, - "icon": "🤖" if pipeline.icon_type == "image" else pipeline.icon, - "icon_background": "#FFEAD5" if pipeline.icon_type == "image" else pipeline.icon_background, + "icon": icon_info.get("icon", "📙") if icon_info else "📙", + "icon_type": icon_info.get("icon_type", "emoji") if icon_info else "emoji", + "icon_background": icon_info.get("icon_background", "#FFEAD5") if icon_info else "#FFEAD5", "description": pipeline.description, - "use_icon_as_answer_icon": pipeline.use_icon_as_answer_icon, }, } @@ -647,8 +674,16 @@ class RagPipelineDslService: :param export_data: export data :param pipeline: Pipeline instance """ - rag_pipeline_service = RagPipelineService() - workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == "draft", + ) + .first() + ) if not workflow: raise ValueError("Missing draft workflow configuration, please check.") @@ -855,14 +890,6 @@ class RagPipelineDslService: f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) - dataset = Dataset( - name=rag_pipeline_dataset_create_entity.name, - description=rag_pipeline_dataset_create_entity.description, - permission=rag_pipeline_dataset_create_entity.permission, - provider="vendor", - runtime_mode="rag-pipeline", - icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), - ) with Session(db.engine) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) account = cast(Account, current_user) @@ -870,11 +897,11 @@ class RagPipelineDslService: account=account, import_mode=ImportMode.YAML_CONTENT.value, yaml_content=rag_pipeline_dataset_create_entity.yaml_content, - dataset=dataset, + dataset=None, ) return { "id": rag_pipeline_import_info.id, - "dataset_id": dataset.id, + "dataset_id": rag_pipeline_import_info.dataset_id, "pipeline_id": rag_pipeline_import_info.pipeline_id, "status": rag_pipeline_import_info.status, "imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,