# Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import weakref from edgecraftrag.api_schema import PipelineCreateIn from edgecraftrag.base import IndexerType, InferenceType, ModelType, NodeParserType, PostProcessorType, RetrieverType from edgecraftrag.components.generator import QnAGenerator from edgecraftrag.components.indexer import VectorIndexer from edgecraftrag.components.node_parser import HierarchyNodeParser, SimpleNodeParser, SWindowNodeParser from edgecraftrag.components.postprocessor import MetadataReplaceProcessor, RerankProcessor from edgecraftrag.components.retriever import AutoMergeRetriever, SimpleBM25Retriever, VectorSimRetriever from edgecraftrag.context import ctx from fastapi import FastAPI pipeline_app = FastAPI() # GET Pipelines @pipeline_app.get(path="/v1/settings/pipelines") async def get_pipelines(): return ctx.get_pipeline_mgr().get_pipelines() # GET Pipeline @pipeline_app.get(path="/v1/settings/pipelines/{name}") async def get_pipeline(name): return ctx.get_pipeline_mgr().get_pipeline_by_name_or_id(name) # POST Pipeline @pipeline_app.post(path="/v1/settings/pipelines") async def add_pipeline(request: PipelineCreateIn): pl = ctx.get_pipeline_mgr().get_pipeline_by_name_or_id(request.name) if pl is None: pl = ctx.get_pipeline_mgr().create_pipeline(request.name) active_pl = ctx.get_pipeline_mgr().get_active_pipeline() if pl == active_pl: if not request.active: pass else: return "Unable to patch an active pipeline..." update_pipeline_handler(pl, request) return pl # PATCH Pipeline @pipeline_app.patch(path="/v1/settings/pipelines/{name}") async def update_pipeline(name, request: PipelineCreateIn): pl = ctx.get_pipeline_mgr().get_pipeline_by_name_or_id(name) if pl is None: return None active_pl = ctx.get_pipeline_mgr().get_active_pipeline() if pl == active_pl: if not request.active: pass else: return "Unable to patch an active pipeline..." async with ctx.get_pipeline_mgr()._lock: update_pipeline_handler(pl, request) return pl def update_pipeline_handler(pl, req): if req.node_parser is not None: np = req.node_parser found_parser = ctx.get_node_parser_mgr().search_parser(np) if found_parser is not None: pl.node_parser = found_parser else: match np.parser_type: case NodeParserType.SIMPLE: pl.node_parser = SimpleNodeParser(chunk_size=np.chunk_size, chunk_overlap=np.chunk_overlap) case NodeParserType.HIERARCHY: """ HierarchyNodeParser is for Auto Merging Retriever (https://docs.llamaindex.ai/en/stable/examples/retrievers/auto_merging_retriever/) By default, the hierarchy is: 1st level: chunk size 2048 2nd level: chunk size 512 3rd level: chunk size 128 Please set chunk size with List. e.g. chunk_size=[2048,512,128] """ pl.node_parser = HierarchyNodeParser.from_defaults( chunk_sizes=np.chunk_sizes, chunk_overlap=np.chunk_overlap ) case NodeParserType.SENTENCEWINDOW: pl.node_parser = SWindowNodeParser.from_defaults(window_size=np.window_size) ctx.get_node_parser_mgr().add(pl.node_parser) if req.indexer is not None: ind = req.indexer found_indexer = ctx.get_indexer_mgr().search_indexer(ind) if found_indexer is not None: pl.indexer = found_indexer else: embed_model = None if ind.embedding_model: embed_model = ctx.get_model_mgr().search_model(ind.embedding_model) if embed_model is None: ind.embedding_model.model_type = ModelType.EMBEDDING embed_model = ctx.get_model_mgr().load_model(ind.embedding_model) ctx.get_model_mgr().add(embed_model) match ind.indexer_type: case IndexerType.DEFAULT_VECTOR | IndexerType.FAISS_VECTOR: # TODO: **RISK** if considering 2 pipelines with different # nodes, but same indexer, what will happen? pl.indexer = VectorIndexer(embed_model, ind.indexer_type) case _: pass ctx.get_indexer_mgr().add(pl.indexer) if req.retriever is not None: retr = req.retriever match retr.retriever_type: case RetrieverType.VECTORSIMILARITY: if pl.indexer is not None: pl.retriever = VectorSimRetriever(pl.indexer, similarity_top_k=retr.retrieve_topk) else: return "No indexer" case RetrieverType.AUTOMERGE: # AutoMergeRetriever looks at a set of leaf nodes and recursively "merges" subsets of leaf nodes that reference a parent node if pl.indexer is not None: pl.retriever = AutoMergeRetriever(pl.indexer, similarity_top_k=retr.retrieve_topk) else: return "No indexer" case RetrieverType.BM25: if pl.indexer is not None: pl.retriever = SimpleBM25Retriever(pl.indexer, similarity_top_k=retr.retrieve_topk) else: return "No indexer" case _: pass if req.postprocessor is not None: pp = req.postprocessor pl.postprocessor = [] for processor in pp: match processor.processor_type: case PostProcessorType.RERANKER: if processor.reranker_model: prm = processor.reranker_model reranker_model = ctx.get_model_mgr().search_model(prm) if reranker_model is None: prm.model_type = ModelType.RERANKER reranker_model = ctx.get_model_mgr().load_model(prm) ctx.get_model_mgr().add(reranker_model) postprocessor = RerankProcessor(reranker_model, processor.top_n) pl.postprocessor.append(postprocessor) else: return "No reranker model" case PostProcessorType.METADATAREPLACE: postprocessor = MetadataReplaceProcessor(target_metadata_key="window") pl.postprocessor.append(postprocessor) if req.generator: gen = req.generator if gen.model is None: return "No ChatQnA Model" if gen.inference_type == InferenceType.VLLM: if gen.model.model_id: model_ref = gen.model.model_id else: model_ref = gen.model.model_path pl.generator = QnAGenerator(model_ref, gen.prompt_path, gen.inference_type) elif gen.inference_type == InferenceType.LOCAL: model = ctx.get_model_mgr().search_model(gen.model) if model is None: gen.model.model_type = ModelType.LLM model = ctx.get_model_mgr().load_model(gen.model) ctx.get_model_mgr().add(model) # Use weakref to achieve model deletion and memory release model_ref = weakref.ref(model) pl.generator = QnAGenerator(model_ref, gen.prompt_path, gen.inference_type) else: return "Inference Type Not Supported" if pl.status.active != req.active: ctx.get_pipeline_mgr().activate_pipeline(pl.name, req.active, ctx.get_node_mgr()) return pl