Add EdgeCraftRag as a GenAIExample (#1072)

Signed-off-by: ZePan110 <ze.pan@intel.com>
Signed-off-by: chensuyue <suyue.chen@intel.com>
Signed-off-by: Zhu, Yongbo <yongbo.zhu@intel.com>
Signed-off-by: Wang, Xigui <xigui.wang@intel.com>
Co-authored-by: ZePan110 <ze.pan@intel.com>
Co-authored-by: chen, suyue <suyue.chen@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: xiguiw <111278656+xiguiw@users.noreply.github.com>
Co-authored-by: lvliang-intel <liang1.lv@intel.com>
This commit is contained in:
Zhu Yongbo
2024-11-08 21:07:24 +08:00
committed by GitHub
parent 9c3023a12e
commit c9088eb824
45 changed files with 4048 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
ModelIn
modelin

28
EdgeCraftRAG/Dockerfile Normal file
View File

@@ -0,0 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
FROM python:3.11-slim
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
libgl1-mesa-glx \
libjemalloc-dev
RUN useradd -m -s /bin/bash user && \
mkdir -p /home/user && \
chown -R user /home/user/
COPY ./edgecraftrag /home/user/edgecraftrag
COPY ./chatqna.py /home/user/chatqna.py
WORKDIR /home/user/edgecraftrag
RUN pip install --no-cache-dir -r requirements.txt
WORKDIR /home/user
USER user
RUN echo 'ulimit -S -n 999999' >> ~/.bashrc
ENTRYPOINT ["python", "chatqna.py"]

View File

@@ -0,0 +1,35 @@
FROM python:3.11-slim
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
libgl1-mesa-glx \
libjemalloc-dev
RUN apt-get update && apt-get install -y gnupg wget
RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \
gpg --yes --dearmor --output /usr/share/keyrings/intel-graphics.gpg
RUN echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu jammy client" | \
tee /etc/apt/sources.list.d/intel-gpu-jammy.list
RUN apt-get update
RUN apt-get install -y \
intel-opencl-icd intel-level-zero-gpu level-zero intel-level-zero-gpu-raytracing \
intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
RUN useradd -m -s /bin/bash user && \
mkdir -p /home/user && \
chown -R user /home/user/
COPY ./edgecraftrag /home/user/edgecraftrag
WORKDIR /home/user/edgecraftrag
RUN pip install --no-cache-dir -r requirements.txt
WORKDIR /home/user/
USER user
ENTRYPOINT ["python", "-m", "edgecraftrag.server"]

274
EdgeCraftRAG/README.md Normal file
View File

@@ -0,0 +1,274 @@
# Edge Craft Retrieval-Augmented Generation
Edge Craft RAG (EC-RAG) is a customizable, tunable and production-ready
Retrieval-Augmented Generation system for edge solutions. It is designed to
curate the RAG pipeline to meet hardware requirements at edge with guaranteed
quality and performance.
## Quick Start Guide
### Run Containers with Docker Compose
```bash
cd GenAIExamples/EdgeCraftRAG/docker_compose/intel/gpu/arc
export MODEL_PATH="your model path for all your models"
export DOC_PATH="your doc path for uploading a dir of files"
export HOST_IP="your host ip"
export UI_SERVICE_PORT="port for UI service"
# Optional for vllm endpoint
export vLLM_ENDPOINT="http://${HOST_IP}:8008"
# If you have a proxy configured, uncomment below line
# export no_proxy=$no_proxy,${HOST_IP},edgecraftrag,edgecraftrag-server
# If you have a HF mirror configured, it will be imported to the container
# export HF_ENDPOINT="your HF mirror endpoint"
# By default, the ports of the containers are set, uncomment if you want to change
# export MEGA_SERVICE_PORT=16011
# export PIPELINE_SERVICE_PORT=16011
docker compose up -d
```
### (Optional) Build Docker Images for Mega Service, Server and UI by your own
```bash
cd GenAIExamples/EdgeCraftRAG
docker build --build-arg http_proxy=$HTTP_PROXY --build-arg https_proxy=$HTTPS_PROXY --build-arg no_proxy=$NO_PROXY -t opea/edgecraftrag:latest -f Dockerfile .
docker build --build-arg http_proxy=$HTTP_PROXY --build-arg https_proxy=$HTTPS_PROXY --build-arg no_proxy=$NO_PROXY -t opea/edgecraftrag-server:latest -f Dockerfile.server .
docker build --build-arg http_proxy=$HTTP_PROXY --build-arg https_proxy=$HTTPS_PROXY --build-arg no_proxy=$NO_PROXY -t opea/edgecraftrag-ui:latest -f ui/docker/Dockerfile.ui .
```
### ChatQnA with LLM Example (Command Line)
```bash
cd GenAIExamples/EdgeCraftRAG
# Activate pipeline test_pipeline_local_llm
curl -X POST http://${HOST_IP}:16010/v1/settings/pipelines -H "Content-Type: application/json" -d @tests/test_pipeline_local_llm.json | jq '.'
# Will need to wait for several minutes
# Expected output:
# {
# "idx": "3214cf25-8dff-46e6-b7d1-1811f237cf8c",
# "name": "rag_test",
# "comp_type": "pipeline",
# "node_parser": {
# "idx": "ababed12-c192-4cbb-b27e-e49c76a751ca",
# "parser_type": "simple",
# "chunk_size": 400,
# "chunk_overlap": 48
# },
# "indexer": {
# "idx": "46969b63-8a32-4142-874d-d5c86ee9e228",
# "indexer_type": "faiss_vector",
# "model": {
# "idx": "7aae57c0-13a4-4a15-aecb-46c2ec8fe738",
# "type": "embedding",
# "model_id": "BAAI/bge-small-en-v1.5",
# "model_path": "/home/user/models/bge_ov_embedding",
# "device": "auto"
# }
# },
# "retriever": {
# "idx": "3747fa59-ff9b-49b6-a8e8-03cdf8c979a4",
# "retriever_type": "vectorsimilarity",
# "retrieve_topk": 30
# },
# "postprocessor": [
# {
# "idx": "d46a6cae-ba7a-412e-85b7-d334f175efaa",
# "postprocessor_type": "reranker",
# "model": {
# "idx": "374e7471-bd7d-41d0-b69d-a749a052b4b0",
# "type": "reranker",
# "model_id": "BAAI/bge-reranker-large",
# "model_path": "/home/user/models/bge_ov_reranker",
# "device": "auto"
# },
# "top_n": 2
# }
# ],
# "generator": {
# "idx": "52d8f112-6290-4dd3-bc28-f9bd5deeb7c8",
# "generator_type": "local",
# "model": {
# "idx": "fa0c11e1-46d1-4df8-a6d8-48cf6b99eff3",
# "type": "llm",
# "model_id": "qwen2-7b-instruct",
# "model_path": "/home/user/models/qwen2-7b-instruct/INT4_compressed_weights",
# "device": "auto"
# }
# },
# "status": {
# "active": true
# }
# }
# Prepare data from local directory
curl -X POST http://${HOST_IP}:16010/v1/data -H "Content-Type: application/json" -d '{"local_path":"#REPLACE WITH YOUR LOCAL DOC DIR#"}' | jq '.'
# Validate Mega Service
curl -X POST http://${HOST_IP}:16011/v1/chatqna -H "Content-Type: application/json" -d '{"messages":"#REPLACE WITH YOUR QUESTION HERE#", "top_n":5, "max_tokens":512}' | jq '.'
```
### ChatQnA with LLM Example (UI)
Open your browser, access http://${HOST_IP}:8082
> Your browser should be running on the same host of your console, otherwise you will need to access UI with your host domain name instead of ${HOST_IP}.
### (Optional) Launch vLLM with OpenVINO service
```bash
# 1. export LLM_MODEL
export LLM_MODEL="your model id"
# 2. Uncomment below code in 'GenAIExamples/EdgeCraftRAG/docker_compose/intel/gpu/arc/compose.yaml'
# vllm-service:
# image: vllm:openvino
# container_name: vllm-openvino-server
# depends_on:
# - vllm-service
# ports:
# - "8008:80"
# environment:
# no_proxy: ${no_proxy}
# http_proxy: ${http_proxy}
# https_proxy: ${https_proxy}
# vLLM_ENDPOINT: ${vLLM_ENDPOINT}
# LLM_MODEL: ${LLM_MODEL}
# entrypoint: /bin/bash -c "\
# cd / && \
# export VLLM_CPU_KVCACHE_SPACE=50 && \
# python3 -m vllm.entrypoints.openai.api_server \
# --model '${LLM_MODEL}' \
# --host 0.0.0.0 \
# --port 80"
```
## Advanced User Guide
### Pipeline Management
#### Create a pipeline
```bash
curl -X POST http://${HOST_IP}:16010/v1/settings/pipelines -H "Content-Type: application/json" -d @examples/test_pipeline.json | jq '.'
```
It will take some time to prepare the embedding model.
#### Upload a text
```bash
curl -X POST http://${HOST_IP}:16010/v1/data -H "Content-Type: application/json" -d @examples/test_data.json | jq '.'
```
#### Provide a query to retrieve context with similarity search.
```bash
curl -X POST http://${HOST_IP}:16010/v1/retrieval -H "Content-Type: application/json" -d @examples/test_query.json | jq '.'
```
#### Create the second pipeline test2
```bash
curl -X POST http://${HOST_IP}:16010/v1/settings/pipelines -H "Content-Type: application/json" -d @examples/test_pipeline2.json | jq '.'
```
#### Check all pipelines
```bash
curl -X GET http://${HOST_IP}:16010/v1/settings/pipelines -H "Content-Type: application/json" | jq '.'
```
#### Compare similarity retrieval (test1) and keyword retrieval (test2)
```bash
# Activate pipeline test1
curl -X PATCH http://${HOST_IP}:16010/v1/settings/pipelines/test1 -H "Content-Type: application/json" -d '{"active": "true"}' | jq '.'
# Similarity retrieval
curl -X POST http://${HOST_IP}:16010/v1/retrieval -H "Content-Type: application/json" -d '{"messages":"number"}' | jq '.'
# Activate pipeline test2
curl -X PATCH http://${HOST_IP}:16010/v1/settings/pipelines/test2 -H "Content-Type: application/json" -d '{"active": "true"}' | jq '.'
# Keyword retrieval
curl -X POST http://${HOST_IP}:16010/v1/retrieval -H "Content-Type: application/json" -d '{"messages":"number"}' | jq '.'
```
### Model Management
#### Load a model
```bash
curl -X POST http://${HOST_IP}:16010/v1/settings/models -H "Content-Type: application/json" -d @examples/test_model_load.json | jq '.'
```
It will take some time to load the model.
#### Check all models
```bash
curl -X GET http://${HOST_IP}:16010/v1/settings/models -H "Content-Type: application/json" | jq '.'
```
#### Update a model
```bash
curl -X PATCH http://${HOST_IP}:16010/v1/settings/models/BAAI/bge-reranker-large -H "Content-Type: application/json" -d @examples/test_model_update.json | jq '.'
```
#### Check a certain model
```bash
curl -X GET http://${HOST_IP}:16010/v1/settings/models/BAAI/bge-reranker-large -H "Content-Type: application/json" | jq '.'
```
#### Delete a model
```bash
curl -X DELETE http://${HOST_IP}:16010/v1/settings/models/BAAI/bge-reranker-large -H "Content-Type: application/json" | jq '.'
```
### File Management
#### Add a text
```bash
curl -X POST http://${HOST_IP}:16010/v1/data -H "Content-Type: application/json" -d @examples/test_data.json | jq '.'
```
#### Add files from existed file path
```bash
curl -X POST http://${HOST_IP}:16010/v1/data -H "Content-Type: application/json" -d @examples/test_data_dir.json | jq '.'
curl -X POST http://${HOST_IP}:16010/v1/data -H "Content-Type: application/json" -d @examples/test_data_file.json | jq '.'
```
#### Check all files
```bash
curl -X GET http://${HOST_IP}:16010/v1/data/files -H "Content-Type: application/json" | jq '.'
```
#### Check one file
```bash
curl -X GET http://${HOST_IP}:16010/v1/data/files/test2.docx -H "Content-Type: application/json" | jq '.'
```
#### Delete a file
```bash
curl -X DELETE http://${HOST_IP}:16010/v1/data/files/test2.docx -H "Content-Type: application/json" | jq '.'
```
#### Update a file
```bash
curl -X PATCH http://${HOST_IP}:16010/v1/data/files/test.pdf -H "Content-Type: application/json" -d @examples/test_data_file.json | jq '.'
```

72
EdgeCraftRAG/chatqna.py Normal file
View File

@@ -0,0 +1,72 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
from comps import MicroService, ServiceOrchestrator, ServiceType
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "127.0.0.1")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 16011))
PIPELINE_SERVICE_HOST_IP = os.getenv("PIPELINE_SERVICE_HOST_IP", "127.0.0.1")
PIPELINE_SERVICE_PORT = int(os.getenv("PIPELINE_SERVICE_PORT", 16010))
from comps import Gateway, MegaServiceEndpoint
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from fastapi import Request
from fastapi.responses import StreamingResponse
class EdgeCraftRagGateway(Gateway):
def __init__(self, megaservice, host="0.0.0.0", port=16011):
super().__init__(
megaservice, host, port, str(MegaServiceEndpoint.CHAT_QNA), ChatCompletionRequest, ChatCompletionResponse
)
async def handle_request(self, request: Request):
input = await request.json()
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs=input)
for node, response in result_dict.items():
if isinstance(response, StreamingResponse):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="edgecraftrag", choices=choices, usage=usage)
class EdgeCraftRagService:
def __init__(self, host="0.0.0.0", port=16010):
self.host = host
self.port = port
self.megaservice = ServiceOrchestrator()
def add_remote_service(self):
edgecraftrag = MicroService(
name="pipeline",
host=PIPELINE_SERVICE_HOST_IP,
port=PIPELINE_SERVICE_PORT,
endpoint="/v1/chatqna",
use_remote_service=True,
service_type=ServiceType.UNDEFINED,
)
self.megaservice.add(edgecraftrag)
self.gateway = EdgeCraftRagGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
if __name__ == "__main__":
edgecraftrag = EdgeCraftRagService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT)
edgecraftrag.add_remote_service()

View File

@@ -0,0 +1,78 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
services:
server:
image: ${REGISTRY:-opea}/edgecraftrag-server:${TAG:-latest}
container_name: edgecraftrag-server
environment:
no_proxy: ${no_proxy}
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
HF_ENDPOINT: ${HF_ENDPOINT}
vLLM_ENDPOINT: ${vLLM_ENDPOINT}
volumes:
- ${MODEL_PATH:-${PWD}}:/home/user/models
- ${DOC_PATH:-${PWD}}:/home/user/docs
ports:
- ${PIPELINE_SERVICE_PORT:-16010}:${PIPELINE_SERVICE_PORT:-16010}
devices:
- /dev/dri:/dev/dri
group_add:
- video
ecrag:
image: ${REGISTRY:-opea}/edgecraftrag:${TAG:-latest}
container_name: edgecraftrag
environment:
no_proxy: ${no_proxy}
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
MEGA_SERVICE_PORT: ${MEGA_SERVICE_PORT:-16011}
MEGA_SERVICE_HOST_IP: ${MEGA_SERVICE_HOST_IP:-${HOST_IP}}
PIPELINE_SERVICE_PORT: ${PIPELINE_SERVICE_PORT:-16010}
PIPELINE_SERVICE_HOST_IP: ${PIPELINE_SERVICE_HOST_IP:-${HOST_IP}}
ports:
- ${MEGA_SERVICE_PORT:-16011}:${MEGA_SERVICE_PORT:-16011}
depends_on:
- server
ui:
image: ${REGISTRY:-opea}/edgecraftrag-ui:${TAG:-latest}
container_name: edgecraftrag-ui
environment:
no_proxy: ${no_proxy}
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
MEGA_SERVICE_PORT: ${MEGA_SERVICE_PORT:-16011}
MEGA_SERVICE_HOST_IP: ${MEGA_SERVICE_HOST_IP:-${HOST_IP}}
PIPELINE_SERVICE_PORT: ${PIPELINE_SERVICE_PORT:-16010}
PIPELINE_SERVICE_HOST_IP: ${PIPELINE_SERVICE_HOST_IP:-${HOST_IP}}
UI_SERVICE_PORT: ${UI_SERVICE_PORT:-8082}
UI_SERVICE_HOST_IP: ${UI_SERVICE_HOST_IP:-0.0.0.0}
ports:
- ${UI_SERVICE_PORT:-8082}:${UI_SERVICE_PORT:-8082}
restart: always
depends_on:
- server
- ecrag
# vllm-service:
# image: vllm:openvino
# container_name: vllm-openvino-server
# ports:
# - "8008:80"
# environment:
# no_proxy: ${no_proxy}
# http_proxy: ${http_proxy}
# https_proxy: ${https_proxy}
# vLLM_ENDPOINT: ${vLLM_ENDPOINT}
# LLM_MODEL: ${LLM_MODEL}
# entrypoint: /bin/bash -c "\
# cd / && \
# export VLLM_CPU_KVCACHE_SPACE=50 && \
# python3 -m vllm.entrypoints.openai.api_server \
# --model '${LLM_MODEL}' \
# --host 0.0.0.0 \
# --port 80"
networks:
default:
driver: bridge

View File

@@ -0,0 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
services:
server:
build:
context: ..
args:
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
dockerfile: ./Dockerfile.server
image: ${REGISTRY:-opea}/edgecraftrag-server:${TAG:-latest}
ui:
build:
context: ..
args:
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
dockerfile: ./ui/docker/Dockerfile.ui
image: ${REGISTRY:-opea}/edgecraftrag-ui:${TAG:-latest}
ecrag:
build:
context: ..
args:
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
dockerfile: ./Dockerfile
image: ${REGISTRY:-opea}/edgecraftrag:${TAG:-latest}

View File

@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

View File

@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

View File

@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

View File

@@ -0,0 +1,29 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from comps.cores.proto.api_protocol import ChatCompletionRequest
from edgecraftrag.context import ctx
from fastapi import FastAPI
chatqna_app = FastAPI()
# Retrieval
@chatqna_app.post(path="/v1/retrieval")
async def retrieval(request: ChatCompletionRequest):
nodeswithscore = ctx.get_pipeline_mgr().run_retrieve(chat_request=request)
print(nodeswithscore)
if nodeswithscore is not None:
ret = []
for n in nodeswithscore:
ret.append((n.node.node_id, n.node.text, n.score))
return ret
return "Not found"
# ChatQnA
@chatqna_app.post(path="/v1/chatqna")
async def chatqna(request: ChatCompletionRequest):
ret = ctx.get_pipeline_mgr().run_pipeline(chat_request=request)
return str(ret)

View File

@@ -0,0 +1,102 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from edgecraftrag.api_schema import DataIn, FilesIn
from edgecraftrag.context import ctx
from fastapi import FastAPI
data_app = FastAPI()
# Upload a text or files
@data_app.post(path="/v1/data")
async def add_data(request: DataIn):
nodelist = None
docs = []
if request.text is not None:
docs.extend(ctx.get_file_mgr().add_text(text=request.text))
if request.local_path is not None:
docs.extend(ctx.get_file_mgr().add_files(docs=request.local_path))
nodelist = ctx.get_pipeline_mgr().run_data_prepare(docs=docs)
if nodelist is None:
return "Error"
pl = ctx.get_pipeline_mgr().get_active_pipeline()
# TODO: Need bug fix, when node_parser is None
ctx.get_node_mgr().add_nodes(pl.node_parser.idx, nodelist)
return "Done"
# Upload files by a list of file_path
@data_app.post(path="/v1/data/files")
async def add_files(request: FilesIn):
nodelist = None
docs = []
if request.local_paths is not None:
docs.extend(ctx.get_file_mgr().add_files(docs=request.local_paths))
nodelist = ctx.get_pipeline_mgr().run_data_prepare(docs=docs)
if nodelist is None:
return "Error"
pl = ctx.get_pipeline_mgr().get_active_pipeline()
# TODO: Need bug fix, when node_parser is None
ctx.get_node_mgr().add_nodes(pl.node_parser.idx, nodelist)
return "Done"
# GET files
@data_app.get(path="/v1/data/files")
async def get_files():
return ctx.get_file_mgr().get_files()
# GET a file
@data_app.get(path="/v1/data/files")
async def get_file_docs(name):
return ctx.get_file_mgr().get_docs_by_file(name)
# DELETE a file
@data_app.delete(path="/v1/data/files/{name}")
async def delete_file(name):
if ctx.get_file_mgr().del_file(name):
# TODO: delete the nodes related to the file
all_docs = ctx.get_file_mgr().get_all_docs()
nodelist = ctx.get_pipeline_mgr().run_data_prepare(docs=all_docs)
if nodelist is None:
return "Error"
pl = ctx.get_pipeline_mgr().get_active_pipeline()
ctx.get_node_mgr().del_nodes_by_np_idx(pl.node_parser.idx)
ctx.get_node_mgr().add_nodes(pl.node_parser.idx, nodelist)
return f"File {name} is deleted"
else:
return f"File {name} not found"
# UPDATE a file
@data_app.patch(path="/v1/data/files/{name}")
async def update_file(name, request: DataIn):
# 1. Delete
if ctx.get_file_mgr().del_file(name):
# 2. Add
docs = []
if request.text is not None:
docs.extend(ctx.get_file_mgr().add_text(text=request.text))
if request.local_path is not None:
docs.extend(ctx.get_file_mgr().add_files(docs=request.local_path))
# 3. Re-run the pipeline
# TODO: update the nodes related to the file
all_docs = ctx.get_file_mgr().get_all_docs()
nodelist = ctx.get_pipeline_mgr().run_data_prepare(docs=all_docs)
if nodelist is None:
return "Error"
pl = ctx.get_pipeline_mgr().get_active_pipeline()
ctx.get_node_mgr().del_nodes_by_np_idx(pl.node_parser.idx)
ctx.get_node_mgr().add_nodes(pl.node_parser.idx, nodelist)
return f"File {name} is updated"
else:
return f"File {name} not found"

View File

@@ -0,0 +1,76 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import gc
from edgecraftrag.api_schema import ModelIn
from edgecraftrag.context import ctx
from fastapi import FastAPI
model_app = FastAPI()
# GET Models
@model_app.get(path="/v1/settings/models")
async def get_models():
return ctx.get_model_mgr().get_models()
# GET Model
@model_app.get(path="/v1/settings/models/{model_id:path}")
async def get_model_by_name(model_id):
return ctx.get_model_mgr().get_model_by_name(model_id)
# POST Model
@model_app.post(path="/v1/settings/models")
async def add_model(request: ModelIn):
modelmgr = ctx.get_model_mgr()
# Currently use asyncio.Lock() to deal with multi-requests
async with modelmgr._lock:
model = modelmgr.search_model(request)
if model is None:
model = modelmgr.load_model(request)
modelmgr.add(model)
return model.model_id + " model loaded"
# PATCH Model
@model_app.patch(path="/v1/settings/models/{model_id:path}")
async def update_model(model_id, request: ModelIn):
# The process of patch model is : 1.delete model 2.create model
active_pl = ctx.get_pipeline_mgr().get_active_pipeline()
modelmgr = ctx.get_model_mgr()
if active_pl and active_pl.model_existed(model_id):
return "Model is being used by active pipeline, unable to update model"
else:
async with modelmgr._lock:
if modelmgr.get_model_by_name(model_id) is None:
# Need to make sure original model still exists before updating model
# to prevent memory leak in concurrent requests situation
return "Model " + model_id + " not exists"
model = modelmgr.search_model(request)
if model is None:
modelmgr.del_model_by_name(model_id)
# Clean up memory occupation
gc.collect()
# load new model
model = modelmgr.load_model(request)
modelmgr.add(model)
return model
# DELETE Model
@model_app.delete(path="/v1/settings/models/{model_id:path}")
async def delete_model(model_id):
active_pl = ctx.get_pipeline_mgr().get_active_pipeline()
if active_pl and active_pl.model_existed(model_id):
return "Model is being used by active pipeline, unable to remove"
else:
modelmgr = ctx.get_model_mgr()
# Currently use asyncio.Lock() to deal with multi-requests
async with modelmgr._lock:
response = modelmgr.del_model_by_name(model_id)
# Clean up memory occupation
gc.collect()
return response

View File

@@ -0,0 +1,180 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import weakref
from edgecraftrag.api_schema import PipelineCreateIn
from edgecraftrag.base import IndexerType, InferenceType, ModelType, NodeParserType, PostProcessorType, RetrieverType
from edgecraftrag.components.generator import QnAGenerator
from edgecraftrag.components.indexer import VectorIndexer
from edgecraftrag.components.node_parser import HierarchyNodeParser, SimpleNodeParser, SWindowNodeParser
from edgecraftrag.components.postprocessor import MetadataReplaceProcessor, RerankProcessor
from edgecraftrag.components.retriever import AutoMergeRetriever, SimpleBM25Retriever, VectorSimRetriever
from edgecraftrag.context import ctx
from fastapi import FastAPI
pipeline_app = FastAPI()
# GET Pipelines
@pipeline_app.get(path="/v1/settings/pipelines")
async def get_pipelines():
return ctx.get_pipeline_mgr().get_pipelines()
# GET Pipeline
@pipeline_app.get(path="/v1/settings/pipelines/{name}")
async def get_pipeline(name):
return ctx.get_pipeline_mgr().get_pipeline_by_name_or_id(name)
# POST Pipeline
@pipeline_app.post(path="/v1/settings/pipelines")
async def add_pipeline(request: PipelineCreateIn):
pl = ctx.get_pipeline_mgr().get_pipeline_by_name_or_id(request.name)
if pl is None:
pl = ctx.get_pipeline_mgr().create_pipeline(request.name)
active_pl = ctx.get_pipeline_mgr().get_active_pipeline()
if pl == active_pl:
if not request.active:
pass
else:
return "Unable to patch an active pipeline..."
update_pipeline_handler(pl, request)
return pl
# PATCH Pipeline
@pipeline_app.patch(path="/v1/settings/pipelines/{name}")
async def update_pipeline(name, request: PipelineCreateIn):
pl = ctx.get_pipeline_mgr().get_pipeline_by_name_or_id(name)
if pl is None:
return None
active_pl = ctx.get_pipeline_mgr().get_active_pipeline()
if pl == active_pl:
if not request.active:
pass
else:
return "Unable to patch an active pipeline..."
async with ctx.get_pipeline_mgr()._lock:
update_pipeline_handler(pl, request)
return pl
def update_pipeline_handler(pl, req):
if req.node_parser is not None:
np = req.node_parser
found_parser = ctx.get_node_parser_mgr().search_parser(np)
if found_parser is not None:
pl.node_parser = found_parser
else:
match np.parser_type:
case NodeParserType.SIMPLE:
pl.node_parser = SimpleNodeParser(chunk_size=np.chunk_size, chunk_overlap=np.chunk_overlap)
case NodeParserType.HIERARCHY:
"""
HierarchyNodeParser is for Auto Merging Retriever
(https://docs.llamaindex.ai/en/stable/examples/retrievers/auto_merging_retriever/)
By default, the hierarchy is:
1st level: chunk size 2048
2nd level: chunk size 512
3rd level: chunk size 128
Please set chunk size with List. e.g. chunk_size=[2048,512,128]
"""
pl.node_parser = HierarchyNodeParser.from_defaults(
chunk_sizes=np.chunk_sizes, chunk_overlap=np.chunk_overlap
)
case NodeParserType.SENTENCEWINDOW:
pl.node_parser = SWindowNodeParser.from_defaults(window_size=np.window_size)
ctx.get_node_parser_mgr().add(pl.node_parser)
if req.indexer is not None:
ind = req.indexer
found_indexer = ctx.get_indexer_mgr().search_indexer(ind)
if found_indexer is not None:
pl.indexer = found_indexer
else:
embed_model = None
if ind.embedding_model:
embed_model = ctx.get_model_mgr().search_model(ind.embedding_model)
if embed_model is None:
ind.embedding_model.model_type = ModelType.EMBEDDING
embed_model = ctx.get_model_mgr().load_model(ind.embedding_model)
ctx.get_model_mgr().add(embed_model)
match ind.indexer_type:
case IndexerType.DEFAULT_VECTOR | IndexerType.FAISS_VECTOR:
# TODO: **RISK** if considering 2 pipelines with different
# nodes, but same indexer, what will happen?
pl.indexer = VectorIndexer(embed_model, ind.indexer_type)
case _:
pass
ctx.get_indexer_mgr().add(pl.indexer)
if req.retriever is not None:
retr = req.retriever
match retr.retriever_type:
case RetrieverType.VECTORSIMILARITY:
if pl.indexer is not None:
pl.retriever = VectorSimRetriever(pl.indexer, similarity_top_k=retr.retrieve_topk)
else:
return "No indexer"
case RetrieverType.AUTOMERGE:
# AutoMergeRetriever looks at a set of leaf nodes and recursively "merges" subsets of leaf nodes that reference a parent node
if pl.indexer is not None:
pl.retriever = AutoMergeRetriever(pl.indexer, similarity_top_k=retr.retrieve_topk)
else:
return "No indexer"
case RetrieverType.BM25:
if pl.indexer is not None:
pl.retriever = SimpleBM25Retriever(pl.indexer, similarity_top_k=retr.retrieve_topk)
else:
return "No indexer"
case _:
pass
if req.postprocessor is not None:
pp = req.postprocessor
pl.postprocessor = []
for processor in pp:
match processor.processor_type:
case PostProcessorType.RERANKER:
if processor.reranker_model:
prm = processor.reranker_model
reranker_model = ctx.get_model_mgr().search_model(prm)
if reranker_model is None:
prm.model_type = ModelType.RERANKER
reranker_model = ctx.get_model_mgr().load_model(prm)
ctx.get_model_mgr().add(reranker_model)
postprocessor = RerankProcessor(reranker_model, processor.top_n)
pl.postprocessor.append(postprocessor)
else:
return "No reranker model"
case PostProcessorType.METADATAREPLACE:
postprocessor = MetadataReplaceProcessor(target_metadata_key="window")
pl.postprocessor.append(postprocessor)
if req.generator:
gen = req.generator
if gen.model is None:
return "No ChatQnA Model"
if gen.inference_type == InferenceType.VLLM:
if gen.model.model_id:
model_ref = gen.model.model_id
else:
model_ref = gen.model.model_path
pl.generator = QnAGenerator(model_ref, gen.prompt_path, gen.inference_type)
elif gen.inference_type == InferenceType.LOCAL:
model = ctx.get_model_mgr().search_model(gen.model)
if model is None:
gen.model.model_type = ModelType.LLM
model = ctx.get_model_mgr().load_model(gen.model)
ctx.get_model_mgr().add(model)
# Use weakref to achieve model deletion and memory release
model_ref = weakref.ref(model)
pl.generator = QnAGenerator(model_ref, gen.prompt_path, gen.inference_type)
else:
return "Inference Type Not Supported"
if pl.status.active != req.active:
ctx.get_pipeline_mgr().activate_pipeline(pl.name, req.active, ctx.get_node_mgr())
return pl

View File

@@ -0,0 +1,62 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from pydantic import BaseModel
class ModelIn(BaseModel):
model_type: Optional[str] = "LLM"
model_id: Optional[str]
model_path: Optional[str] = "./"
device: Optional[str] = "cpu"
class NodeParserIn(BaseModel):
chunk_size: Optional[int] = None
chunk_overlap: Optional[int] = None
chunk_sizes: Optional[list] = None
parser_type: str
window_size: Optional[int] = None
class IndexerIn(BaseModel):
indexer_type: str
embedding_model: Optional[ModelIn] = None
class RetrieverIn(BaseModel):
retriever_type: str
retrieve_topk: Optional[int] = 3
class PostProcessorIn(BaseModel):
processor_type: str
reranker_model: Optional[ModelIn] = None
top_n: Optional[int] = 5
class GeneratorIn(BaseModel):
prompt_path: Optional[str] = None
model: Optional[ModelIn] = None
inference_type: Optional[str] = "local"
class PipelineCreateIn(BaseModel):
name: Optional[str] = None
node_parser: Optional[NodeParserIn] = None
indexer: Optional[IndexerIn] = None
retriever: Optional[RetrieverIn] = None
postprocessor: Optional[list[PostProcessorIn]] = None
generator: Optional[GeneratorIn] = None
active: Optional[bool] = False
class DataIn(BaseModel):
text: Optional[str] = None
local_path: Optional[str] = None
class FilesIn(BaseModel):
local_paths: Optional[list[str]] = None

View File

@@ -0,0 +1,128 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import abc
import uuid
from enum import Enum
from typing import Any, Callable, List, Optional
from pydantic import BaseModel, ConfigDict, Field, model_serializer
class CompType(str, Enum):
DEFAULT = "default"
MODEL = "model"
PIPELINE = "pipeline"
NODEPARSER = "node_parser"
INDEXER = "indexer"
RETRIEVER = "retriever"
POSTPROCESSOR = "postprocessor"
GENERATOR = "generator"
FILE = "file"
class ModelType(str, Enum):
EMBEDDING = "embedding"
RERANKER = "reranker"
LLM = "llm"
class FileType(str, Enum):
TEXT = "text"
VISUAL = "visual"
AURAL = "aural"
VIRTUAL = "virtual"
OTHER = "other"
class NodeParserType(str, Enum):
DEFAULT = "default"
SIMPLE = "simple"
HIERARCHY = "hierarchical"
SENTENCEWINDOW = "sentencewindow"
class IndexerType(str, Enum):
DEFAULT = "default"
FAISS_VECTOR = "faiss_vector"
DEFAULT_VECTOR = "vector"
class RetrieverType(str, Enum):
DEFAULT = "default"
VECTORSIMILARITY = "vectorsimilarity"
AUTOMERGE = "auto_merge"
BM25 = "bm25"
class PostProcessorType(str, Enum):
RERANKER = "reranker"
METADATAREPLACE = "metadata_replace"
class GeneratorType(str, Enum):
CHATQNA = "chatqna"
class InferenceType(str, Enum):
LOCAL = "local"
VLLM = "vllm"
class CallbackType(str, Enum):
DATAPREP = "dataprep"
RETRIEVE = "retrieve"
PIPELINE = "pipeline"
class BaseComponent(BaseModel):
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
idx: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: Optional[str] = Field(default="")
comp_type: str = Field(default="")
comp_subtype: Optional[str] = Field(default="")
@model_serializer
def ser_model(self):
set = {
"idx": self.idx,
"name": self.name,
"comp_type": self.comp_type,
"comp_subtype": self.comp_subtype,
}
return set
@abc.abstractmethod
def run(self, **kwargs) -> Any:
pass
class BaseMgr:
def __init__(self):
self.components = {}
def add(self, comp: BaseComponent):
self.components[comp.idx] = comp
def get(self, idx: str) -> BaseComponent:
if idx in self.components:
return self.components[idx]
else:
return None
def remove(self, idx):
# remove the reference count
# after reference count == 0, object memory can be freed with Garbage Collector
del self.components[idx]

View File

@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

View File

@@ -0,0 +1,65 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Any, List, Optional
from edgecraftrag.base import BaseComponent, CompType, FileType
from llama_index.core.schema import Document
from pydantic import BaseModel, Field, model_serializer
class File(BaseComponent):
file_path: str = Field(default="")
comp_subtype: str = Field(default="")
documents: List[Document] = Field(default=[])
def __init__(self, file_name: Optional[str] = None, file_path: Optional[str] = None, content: Optional[str] = None):
super().__init__(comp_type=CompType.FILE)
if not file_name and not file_path:
raise ValueError("File name or path must be provided")
_path = Path(file_path) if file_path else None
if file_name:
self.name = file_name
else:
self.name = _path.name
self.file_path = _path
self.comp_subtype = FileType.TEXT
if _path and _path.exists():
self.documents.extend(convert_file_to_documents(_path))
if content:
self.documents.extend(convert_text_to_documents(content))
def run(self, **kwargs) -> Any:
pass
@model_serializer
def ser_model(self):
set = {
"file_name": self.name,
"file_id": self.idx,
"file_type": self.comp_subtype,
"file_path": str(self.file_path),
"docs_count": len(self.documents),
}
return set
def convert_text_to_documents(text) -> List[Document]:
return [Document(text=text, metadata={"file_name": "text"})]
def convert_file_to_documents(file_path) -> List[Document]:
from llama_index.core import SimpleDirectoryReader
supported_exts = [".pdf", ".txt", ".doc", ".docx", ".pptx", ".ppt", ".csv", ".md", ".html", ".rst"]
if file_path.is_dir():
docs = SimpleDirectoryReader(input_dir=file_path, recursive=True, required_exts=supported_exts).load_data()
elif file_path.is_file():
docs = SimpleDirectoryReader(input_files=[file_path], required_exts=supported_exts).load_data()
else:
docs = []
return docs

View File

@@ -0,0 +1,194 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import dataclasses
import os
from comps import GeneratedDoc, opea_telemetry
from edgecraftrag.base import BaseComponent, CompType, GeneratorType
from fastapi.responses import StreamingResponse
from langchain_core.prompts import PromptTemplate
from llama_index.llms.openai_like import OpenAILike
from pydantic import model_serializer
@opea_telemetry
def post_process_text(text: str):
if text == " ":
return "data: @#$\n\n"
if text == "\n":
return "data: <br/>\n\n"
if text.isspace():
return None
new_text = text.replace(" ", "@#$")
return f"data: {new_text}\n\n"
class QnAGenerator(BaseComponent):
def __init__(self, llm_model, prompt_template, inference_type, **kwargs):
BaseComponent.__init__(
self,
comp_type=CompType.GENERATOR,
comp_subtype=GeneratorType.CHATQNA,
)
self.inference_type = inference_type
self._REPLACE_PAIRS = (
("\n\n", "\n"),
("\t\n", "\n"),
)
template = prompt_template
self.prompt = (
DocumentedContextRagPromptTemplate.from_file(template)
if os.path.isfile(template)
else DocumentedContextRagPromptTemplate.from_template(template)
)
self.llm = llm_model
if isinstance(llm_model, str):
self.model_id = llm_model
else:
self.model_id = llm_model().model_id
def clean_string(self, string):
ret = string
for p in self._REPLACE_PAIRS:
ret = ret.replace(*p)
return ret
def run(self, chat_request, retrieved_nodes, **kwargs):
if self.llm() is None:
# This could happen when User delete all LLMs through RESTful API
return "No LLM available, please load LLM"
# query transformation
text_gen_context = ""
for n in retrieved_nodes:
origin_text = n.node.get_text()
text_gen_context += self.clean_string(origin_text.strip())
query = chat_request.messages
prompt_str = self.prompt.format(input=query, context=text_gen_context)
generate_kwargs = dict(
temperature=chat_request.temperature,
do_sample=chat_request.temperature > 0.0,
top_p=chat_request.top_p,
top_k=chat_request.top_k,
typical_p=chat_request.typical_p,
repetition_penalty=chat_request.repetition_penalty,
)
self.llm().generate_kwargs = generate_kwargs
return self.llm().complete(prompt_str)
def run_vllm(self, chat_request, retrieved_nodes, **kwargs):
if self.llm is None:
return "No LLM provided, please provide model_id_or_path"
# query transformation
text_gen_context = ""
for n in retrieved_nodes:
origin_text = n.node.get_text()
text_gen_context += self.clean_string(origin_text.strip())
query = chat_request.messages
prompt_str = self.prompt.format(input=query, context=text_gen_context)
llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8008")
model_name = self.llm
llm = OpenAILike(
api_key="fake",
api_base=llm_endpoint + "/v1",
max_tokens=chat_request.max_tokens,
model=model_name,
top_p=chat_request.top_p,
temperature=chat_request.temperature,
streaming=chat_request.stream,
)
if chat_request.stream:
async def stream_generator():
response = await llm.astream_complete(prompt_str)
async for text in response:
output = text.text
yield f"data: {output}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = llm.complete(prompt_str)
response = response.text
return GeneratedDoc(text=response, prompt=prompt_str)
@model_serializer
def ser_model(self):
set = {"idx": self.idx, "generator_type": self.comp_subtype, "model": self.model_id}
return set
@dataclasses.dataclass
class INSTRUCTIONS:
IM_START = "You are an AI assistant that helps users answer questions given a specific context."
SUCCINCT = "Ensure your response is succinct"
ACCURATE = "Ensure your response is accurate."
SUCCINCT_AND_ACCURATE = "Ensure your response is succinct. Try to be accurate if possible."
ACCURATE_AND_SUCCINCT = "Ensure your response is accurate. Try to be succinct if possible."
NO_RAMBLING = "Avoid posing new questions or self-questioning and answering, and refrain from repeating words in your response."
SAY_SOMETHING = "Avoid meaningless answer such a random symbol or blanks."
ENCOURAGE = "If you cannot well understand the question, try to translate it into English, and translate the answer back to the language of the question."
NO_IDEA = (
'If the answer is not discernible, please respond with "Sorry. I have no idea" in the language of the question.'
)
CLOZE_TEST = """The task is a fill-in-the-blank/cloze test."""
NO_MEANINGLESS_SYMBOLS = "Meaningless symbols and ``` should not be included in your response."
ADAPT_NATIVE_LANGUAGE = "Please try to think like a person that speak the same language that the question used."
def _is_cloze(question):
return ("()" in question or "" in question) and ("" in question or "fill" in question or "cloze" in question)
# depreciated
def get_instructions(question):
# naive pre-retrieval rewrite
# cloze
if _is_cloze(question):
instructions = [
INSTRUCTIONS.CLOZE_TEST,
]
else:
instructions = [
INSTRUCTIONS.ACCURATE_AND_SUCCINCT,
INSTRUCTIONS.NO_RAMBLING,
INSTRUCTIONS.NO_MEANINGLESS_SYMBOLS,
]
return ["System: {}".format(_) for _ in instructions]
def preprocess_question(question):
if _is_cloze(question):
question = question.replace(" ", "").replace("", "(").replace("", ")")
# .replace("()", " <|blank|> ")
ret = "User: Please finish the following fill-in-the-blank question marked by $$$ at the beginning and end. Make sure all the () are filled.\n$$$\n{}\n$$$\nAssistant: ".format(
question
)
else:
ret = "User: {}\nAssistant: 从上下文提供的信息中可以知道,".format(question)
return ret
class DocumentedContextRagPromptTemplate(PromptTemplate):
def format(self, **kwargs) -> str:
# context = '\n'.join([clean_string(f"{_.page_content}".strip()) for i, _ in enumerate(kwargs["context"])])
context = kwargs["context"]
question = kwargs["input"]
preprocessed_question = preprocess_question(question)
if "instructions" in self.template:
instructions = get_instructions(question)
prompt_str = self.template.format(
context=context, instructions="\n".join(instructions), input=preprocessed_question
)
else:
prompt_str = self.template.format(context=context, input=preprocessed_question)
return prompt_str

View File

@@ -0,0 +1,45 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import faiss
from edgecraftrag.base import BaseComponent, CompType, IndexerType
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.vector_stores.faiss import FaissVectorStore
from pydantic import model_serializer
class VectorIndexer(BaseComponent, VectorStoreIndex):
def __init__(self, embed_model, vector_type):
BaseComponent.__init__(
self,
comp_type=CompType.INDEXER,
comp_subtype=vector_type,
)
self.model = embed_model
if not embed_model:
# Settings.embed_model should be set to None when embed_model is None to avoid 'no oneapi key' error
from llama_index.core import Settings
Settings.embed_model = None
match vector_type:
case IndexerType.DEFAULT_VECTOR:
VectorStoreIndex.__init__(self, embed_model=embed_model, nodes=[])
case IndexerType.FAISS_VECTOR:
if embed_model:
d = embed_model._model.request.outputs[0].get_partial_shape()[2].get_length()
else:
d = 128
faiss_index = faiss.IndexFlatL2(d)
faiss_store = StorageContext.from_defaults(vector_store=FaissVectorStore(faiss_index=faiss_index))
VectorStoreIndex.__init__(self, embed_model=embed_model, nodes=[], storage_context=faiss_store)
def run(self, **kwargs) -> Any:
pass
@model_serializer
def ser_model(self):
set = {"idx": self.idx, "indexer_type": self.comp_subtype, "model": self.model}
return set

View File

@@ -0,0 +1,74 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
from edgecraftrag.base import BaseComponent, CompType, ModelType
from llama_index.embeddings.huggingface_openvino import OpenVINOEmbedding
from llama_index.llms.openvino import OpenVINOLLM
from llama_index.postprocessor.openvino_rerank import OpenVINORerank
from pydantic import Field, model_serializer
class BaseModelComponent(BaseComponent):
model_id: Optional[str] = Field(default="")
model_path: Optional[str] = Field(default="")
device: Optional[str] = Field(default="cpu")
def run(self, **kwargs) -> Any:
pass
@model_serializer
def ser_model(self):
set = {
"idx": self.idx,
"type": self.comp_subtype,
"model_id": self.model_id,
"model_path": self.model_path,
"device": self.device,
}
return set
class OpenVINOEmbeddingModel(BaseModelComponent, OpenVINOEmbedding):
def __init__(self, model_id, model_path, device):
OpenVINOEmbedding.create_and_save_openvino_model(model_id, model_path)
OpenVINOEmbedding.__init__(self, model_id_or_path=model_path, device=device)
self.comp_type = CompType.MODEL
self.comp_subtype = ModelType.EMBEDDING
self.model_id = model_id
self.model_path = model_path
self.device = device
class OpenVINORerankModel(BaseModelComponent, OpenVINORerank):
def __init__(self, model_id, model_path, device):
OpenVINORerank.create_and_save_openvino_model(model_id, model_path)
OpenVINORerank.__init__(
self,
model_id_or_path=model_path,
device=device,
)
self.comp_type = CompType.MODEL
self.comp_subtype = ModelType.RERANKER
self.model_id = model_id
self.model_path = model_path
self.device = device
class OpenVINOLLMModel(BaseModelComponent, OpenVINOLLM):
def __init__(self, model_id, model_path, device):
OpenVINOLLM.__init__(
self,
model_id_or_path=model_path,
device_map=device,
)
self.comp_type = CompType.MODEL
self.comp_subtype = ModelType.LLM
self.model_id = model_id
self.model_path = model_path
self.device = device

View File

@@ -0,0 +1,85 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Any
from edgecraftrag.base import BaseComponent, CompType, NodeParserType
from llama_index.core.node_parser import HierarchicalNodeParser, SentenceSplitter, SentenceWindowNodeParser
from pydantic import model_serializer
class SimpleNodeParser(BaseComponent, SentenceSplitter):
# Use super for SentenceSplitter since it's __init__ will cleanup
# BaseComponent fields
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.comp_type = CompType.NODEPARSER
self.comp_subtype = NodeParserType.SIMPLE
def run(self, **kwargs) -> Any:
for k, v in kwargs.items():
if k == "docs":
return self.get_nodes_from_documents(v, show_progress=False)
return None
@model_serializer
def ser_model(self):
set = {
"idx": self.idx,
"parser_type": self.comp_subtype,
"chunk_size": self.chunk_size,
"chunk_overlap": self.chunk_overlap,
}
return set
class HierarchyNodeParser(BaseComponent, HierarchicalNodeParser):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.comp_type = CompType.NODEPARSER
self.comp_subtype = NodeParserType.HIERARCHY
def run(self, **kwargs) -> Any:
for k, v in kwargs.items():
if k == "docs":
return self.get_nodes_from_documents(v, show_progress=False)
return None
@model_serializer
def ser_model(self):
set = {
"idx": self.idx,
"parser_type": self.comp_subtype,
"chunk_size": self.chunk_sizes,
"chunk_overlap": None,
}
return set
class SWindowNodeParser(BaseComponent, SentenceWindowNodeParser):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.comp_type = CompType.NODEPARSER
self.comp_subtype = NodeParserType.SENTENCEWINDOW
def run(self, **kwargs) -> Any:
for k, v in kwargs.items():
if k == "docs":
return self.get_nodes_from_documents(v, show_progress=False)
return None
@model_serializer
def ser_model(self):
set = {
"idx": self.idx,
"parser_type": self.comp_subtype,
"chunk_size": None,
"chunk_overlap": None,
}
return set

View File

@@ -0,0 +1,160 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, List, Optional
from comps.cores.proto.api_protocol import ChatCompletionRequest
from edgecraftrag.base import BaseComponent, CallbackType, CompType, InferenceType
from edgecraftrag.components.postprocessor import RerankProcessor
from llama_index.core.schema import Document, QueryBundle
from pydantic import BaseModel, Field, model_serializer
class PipelineStatus(BaseModel):
active: bool = False
class Pipeline(BaseComponent):
node_parser: Optional[BaseComponent] = Field(default=None)
indexer: Optional[BaseComponent] = Field(default=None)
retriever: Optional[BaseComponent] = Field(default=None)
postprocessor: Optional[List[BaseComponent]] = Field(default=None)
generator: Optional[BaseComponent] = Field(default=None)
status: PipelineStatus = Field(default=PipelineStatus())
run_pipeline_cb: Optional[Callable[..., Any]] = Field(default=None)
run_retriever_cb: Optional[Callable[..., Any]] = Field(default=None)
run_data_prepare_cb: Optional[Callable[..., Any]] = Field(default=None)
def __init__(
self,
name,
):
super().__init__(name=name, comp_type=CompType.PIPELINE)
if self.name == "" or self.name is None:
self.name = self.idx
self.run_pipeline_cb = run_test_generator
self.run_retriever_cb = run_test_retrieve
self.run_data_prepare_cb = run_simple_doc
self._node_changed = True
# TODO: consider race condition
@property
def node_changed(self) -> bool:
return self._node_changed
# TODO: update doc changes
# TODO: more operations needed, add, del, modify
def update_nodes(self, nodes):
print("updating nodes ", nodes)
if self.indexer is not None:
self.indexer.insert_nodes(nodes)
# TODO: check more conditions
def check_active(self, nodelist):
if self._node_changed and nodelist is not None:
self.update_nodes(nodelist)
# Implement abstract run function
# callback dispatcher
def run(self, **kwargs) -> Any:
print(kwargs)
if "cbtype" in kwargs:
if kwargs["cbtype"] == CallbackType.DATAPREP:
if "docs" in kwargs:
return self.run_data_prepare_cb(self, docs=kwargs["docs"])
if kwargs["cbtype"] == CallbackType.RETRIEVE:
if "chat_request" in kwargs:
return self.run_retriever_cb(self, chat_request=kwargs["chat_request"])
if kwargs["cbtype"] == CallbackType.PIPELINE:
if "chat_request" in kwargs:
return self.run_pipeline_cb(self, chat_request=kwargs["chat_request"])
def update(self, node_parser=None, indexer=None, retriever=None, postprocessor=None, generator=None):
if node_parser is not None:
self.node_parser = node_parser
if indexer is not None:
self.indexer = indexer
if retriever is not None:
self.retriever = retriever
if postprocessor is not None:
self.postprocessor = postprocessor
if generator is not None:
self.generator = generator
@model_serializer
def ser_model(self):
set = {
"idx": self.idx,
"name": self.name,
"comp_type": self.comp_type,
"node_parser": self.node_parser,
"indexer": self.indexer,
"retriever": self.retriever,
"postprocessor": self.postprocessor,
"generator": self.generator,
"status": self.status,
}
return set
def model_existed(self, model_id: str) -> bool:
# judge if the given model is existed in a pipeline by model_id
if self.indexer:
if hasattr(self.indexer, "_embed_model") and self.indexer._embed_model.model_id == model_id:
return True
if hasattr(self.indexer, "_llm") and self.indexer._llm.model_id == model_id:
return True
if self.postprocessor:
for processor in self.postprocessor:
if hasattr(processor, "model_id") and processor.model_id == model_id:
return True
if self.generator:
llm = self.generator.llm
if llm() and llm().model_id == model_id:
return True
return False
# Test callback to retrieve nodes from query
def run_test_retrieve(pl: Pipeline, chat_request: ChatCompletionRequest) -> Any:
query = chat_request.messages
retri_res = pl.retriever.run(query=query)
query_bundle = QueryBundle(query)
if pl.postprocessor:
for processor in pl.postprocessor:
if (
isinstance(processor, RerankProcessor)
and chat_request.top_n != ChatCompletionRequest.model_fields["top_n"].default
):
processor.top_n = chat_request.top_n
retri_res = processor.run(retri_res=retri_res, query_bundle=query_bundle)
return retri_res
def run_simple_doc(pl: Pipeline, docs: List[Document]) -> Any:
n = pl.node_parser.run(docs=docs)
if pl.indexer is not None:
pl.indexer.insert_nodes(n)
print(pl.indexer._index_struct)
return n
def run_test_generator(pl: Pipeline, chat_request: ChatCompletionRequest) -> Any:
query = chat_request.messages
retri_res = pl.retriever.run(query=query)
query_bundle = QueryBundle(query)
if pl.postprocessor:
for processor in pl.postprocessor:
if (
isinstance(processor, RerankProcessor)
and chat_request.top_n != ChatCompletionRequest.model_fields["top_n"].default
):
processor.top_n = chat_request.top_n
retri_res = processor.run(retri_res=retri_res, query_bundle=query_bundle)
if pl.generator is None:
return "No Generator Specified"
if pl.generator.inference_type == InferenceType.LOCAL:
answer = pl.generator.run(chat_request, retri_res)
elif pl.generator.inference_type == InferenceType.VLLM:
answer = pl.generator.run_vllm(chat_request, retri_res)
return answer

View File

@@ -0,0 +1,64 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Any
from edgecraftrag.base import BaseComponent, CompType, PostProcessorType
from llama_index.core.postprocessor import MetadataReplacementPostProcessor
from pydantic import model_serializer
class RerankProcessor(BaseComponent):
def __init__(self, rerank_model, top_n):
BaseComponent.__init__(
self,
comp_type=CompType.POSTPROCESSOR,
comp_subtype=PostProcessorType.RERANKER,
)
self.model = rerank_model
self.top_n = top_n
def run(self, **kwargs) -> Any:
self.model.top_n = self.top_n
query_bundle = None
query_str = None
if "retri_res" in kwargs:
nodes = kwargs["retri_res"]
if "query_bundle" in kwargs:
query_bundle = kwargs["query_bundle"]
if "query_str" in kwargs:
query_str = kwargs["query_str"]
return self.model.postprocess_nodes(nodes, query_bundle=query_bundle, query_str=query_str)
@model_serializer
def ser_model(self):
set = {"idx": self.idx, "postprocessor_type": self.comp_subtype, "model": self.model, "top_n": self.top_n}
return set
class MetadataReplaceProcessor(BaseComponent, MetadataReplacementPostProcessor):
def __init__(self, target_metadata_key="window"):
BaseComponent.__init__(
self,
target_metadata_key=target_metadata_key,
comp_type=CompType.POSTPROCESSOR,
comp_subtype=PostProcessorType.METADATAREPLACE,
)
def run(self, **kwargs) -> Any:
query_bundle = None
query_str = None
if "retri_res" in kwargs:
nodes = kwargs["retri_res"]
if "query_bundle" in kwargs:
query_bundle = kwargs["query_bundle"]
if "query_str" in kwargs:
query_str = kwargs["query_str"]
return self.postprocess_nodes(nodes, query_bundle=query_bundle, query_str=query_str)
@model_serializer
def ser_model(self):
set = {"idx": self.idx, "postprocessor_type": self.comp_subtype, "model": None, "top_n": None}
return set

View File

@@ -0,0 +1,104 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Any, List, cast
from edgecraftrag.base import BaseComponent, CompType, RetrieverType
from llama_index.core.indices.vector_store.retrievers import VectorIndexRetriever
from llama_index.core.retrievers import AutoMergingRetriever
from llama_index.core.schema import BaseNode
from llama_index.retrievers.bm25 import BM25Retriever
from pydantic import model_serializer
class VectorSimRetriever(BaseComponent, VectorIndexRetriever):
def __init__(self, indexer, **kwargs):
BaseComponent.__init__(
self,
comp_type=CompType.RETRIEVER,
comp_subtype=RetrieverType.VECTORSIMILARITY,
)
VectorIndexRetriever.__init__(
self,
index=indexer,
node_ids=list(indexer.index_struct.nodes_dict.values()),
callback_manager=indexer._callback_manager,
object_map=indexer._object_map,
**kwargs,
)
# This might be a bug of llamaindex retriever.
# The node_ids will never be updated after the retriever's
# creation. However, the node_ids decides the available node
# ids to be retrieved which means the target nodes to be
# retrieved are freezed to the time of the retriever's creation.
self._node_ids = None
def run(self, **kwargs) -> Any:
for k, v in kwargs.items():
if k == "query":
return self.retrieve(v)
return None
@model_serializer
def ser_model(self):
set = {
"idx": self.idx,
"retriever_type": self.comp_subtype,
"retrieve_topk": self.similarity_top_k,
}
return set
class AutoMergeRetriever(BaseComponent, AutoMergingRetriever):
def __init__(self, indexer, **kwargs):
BaseComponent.__init__(
self,
comp_type=CompType.RETRIEVER,
comp_subtype=RetrieverType.AUTOMERGE,
)
self._index = indexer
self.topk = kwargs["similarity_top_k"]
AutoMergingRetriever.__init__(
self,
vector_retriever=indexer.as_retriever(**kwargs),
storage_context=indexer._storage_context,
object_map=indexer._object_map,
callback_manager=indexer._callback_manager,
)
def run(self, **kwargs) -> Any:
for k, v in kwargs.items():
if k == "query":
# vector_retriever needs to be updated
self._vector_retriever = self._index.as_retriever(similarity_top_k=self.topk)
return self.retrieve(v)
return None
class SimpleBM25Retriever(BaseComponent):
# The nodes parameter in BM25Retriever is not from index,
# nodes in BM25Retriever can not be updated through 'indexer.insert_nodes()',
# which means nodes should be passed to BM25Retriever after data preparation stage, not init stage
def __init__(self, indexer, **kwargs):
BaseComponent.__init__(
self,
comp_type=CompType.RETRIEVER,
comp_subtype=RetrieverType.BM25,
)
self._docstore = indexer._docstore
self.topk = kwargs["similarity_top_k"]
def run(self, **kwargs) -> Any:
for k, v in kwargs.items():
if k == "query":
nodes = cast(List[BaseNode], list(self._docstore.docs.values()))
bm25_retr = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=self.topk)
return bm25_retr.retrieve(v)
return None

View File

@@ -0,0 +1,52 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from edgecraftrag.controllers.compmgr import GeneratorMgr, IndexerMgr, NodeParserMgr, PostProcessorMgr, RetrieverMgr
from edgecraftrag.controllers.filemgr import FilelMgr
from edgecraftrag.controllers.modelmgr import ModelMgr
from edgecraftrag.controllers.nodemgr import NodeMgr
from edgecraftrag.controllers.pipelinemgr import PipelineMgr
class Context:
def __init__(self):
self.plmgr = PipelineMgr()
self.nodemgr = NodeMgr()
self.npmgr = NodeParserMgr()
self.idxmgr = IndexerMgr()
self.rtvmgr = RetrieverMgr()
self.ppmgr = PostProcessorMgr()
self.modmgr = ModelMgr()
self.genmgr = GeneratorMgr()
self.filemgr = FilelMgr()
def get_pipeline_mgr(self):
return self.plmgr
def get_node_mgr(self):
return self.nodemgr
def get_node_parser_mgr(self):
return self.npmgr
def get_indexer_mgr(self):
return self.idxmgr
def get_retriever_mgr(self):
return self.rtvmgr
def get_postprocessor_mgr(self):
return self.ppmgr
def get_model_mgr(self):
return self.modmgr
def get_generator_mgr(self):
return self.genmgr
def get_file_mgr(self):
return self.filemgr
ctx = Context()

View File

@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

View File

@@ -0,0 +1,66 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from edgecraftrag.api_schema import IndexerIn, ModelIn, NodeParserIn
from edgecraftrag.base import BaseComponent, BaseMgr, CallbackType, ModelType, NodeParserType
class NodeParserMgr(BaseMgr):
def __init__(self):
super().__init__()
def search_parser(self, npin: NodeParserIn) -> BaseComponent:
for _, v in self.components.items():
v_parser_type = v.comp_subtype
if v_parser_type == npin.parser_type:
if v_parser_type == NodeParserType.HIERARCHY and v.chunk_sizes == npin.chunk_sizes:
return v
elif v_parser_type == NodeParserType.SENTENCEWINDOW and v.window_size == npin.window_size:
return v
elif (
v_parser_type == NodeParserType.SIMPLE
and v.chunk_size == npin.chunk_size
and v.chunk_overlap == npin.chunk_overlap
):
return v
return None
class IndexerMgr(BaseMgr):
def __init__(self):
super().__init__()
def search_indexer(self, indin: IndexerIn) -> BaseComponent:
for _, v in self.components.items():
if v.comp_subtype == indin.indexer_type:
if (
hasattr(v, "model")
and v.model
and indin.embedding_model
and (
(v.model.model_id_or_path == indin.embedding_model.model_id)
or (v.model.model_id_or_path == indin.embedding_model.model_path)
)
):
return v
return None
class RetrieverMgr(BaseMgr):
def __init__(self):
super().__init__()
class PostProcessorMgr(BaseMgr):
def __init__(self):
super().__init__()
class GeneratorMgr(BaseMgr):
def __init__(self):
super().__init__()

View File

@@ -0,0 +1,83 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
from typing import Any, Callable, List, Optional
from edgecraftrag.base import BaseMgr
from edgecraftrag.components.data import File
from llama_index.core.schema import Document
class FilelMgr(BaseMgr):
def __init__(self):
super().__init__()
def add_text(self, text: str):
file = File(file_name="text", content=text)
self.add(file)
return file.documents
def add_files(self, docs: Any):
if not isinstance(docs, list):
docs = [docs]
input_docs = []
for doc in docs:
if not os.path.exists(doc):
continue
if os.path.isfile(doc):
files = [doc]
elif os.path.isdir(doc):
files = [os.path.join(root, f) for root, _, files in os.walk(doc) for f in files]
else:
continue
if not files:
continue
for file_path in files:
file = File(file_path=file_path)
self.add(file)
input_docs.extend(file.documents)
return input_docs
def get_file_by_name_or_id(self, name: str):
for _, file in self.components.items():
if file.name == name or file.idx == name:
return file
return None
def get_files(self):
return [file for _, file in self.components.items()]
def get_all_docs(self) -> List[Document]:
all_docs = []
for _, file in self.components.items():
all_docs.extend(file.documents)
return all_docs
def get_docs_by_file(self, name) -> List[Document]:
file = self.get_file_by_name_or_id(name)
return file.documents if file else []
def del_file(self, name):
file = self.get_file_by_name_or_id(name)
if file:
self.remove(file.idx)
return True
else:
return False
def update_file(self, name):
file = self.get_file_by_name_or_id(name)
if file:
self.remove(file.idx)
self.add_files(docs=name)
return True
else:
return False

View File

@@ -0,0 +1,94 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import asyncio
from edgecraftrag.api_schema import IndexerIn, ModelIn, NodeParserIn
from edgecraftrag.base import BaseComponent, BaseMgr, CallbackType, ModelType
from edgecraftrag.components.model import OpenVINOEmbeddingModel, OpenVINOLLMModel, OpenVINORerankModel
class ModelMgr(BaseMgr):
def __init__(self):
self._lock = asyncio.Lock()
super().__init__()
def get_model_by_name(self, name: str):
for _, v in self.components.items():
if v.model_id == name:
model_type = v.comp_subtype.value
model_info = {
"model_type": model_type,
"model_id": getattr(v, "model_id", "Unknown"),
}
if model_type == ModelType.LLM:
model_info["model_path"] = getattr(v, "model_name", "Unknown")
model_info["device"] = getattr(v, "device_map", "Unknown")
else:
model_info["model_path"] = getattr(v, "model_id_or_path", "Unknown")
model_info["device"] = getattr(v, "device", getattr(v, "_device", "Unknown"))
return model_info
return None
def get_models(self):
model = {}
for k, v in self.components.items():
# Supplement the information of the model
model_type = v.comp_subtype.value
model_info = {
"model_type": model_type,
"model_id": getattr(v, "model_id", "Unknown"),
}
if model_type == ModelType.LLM:
model_info["model_path"] = getattr(v, "model_name", "Unknown")
model_info["device"] = getattr(v, "device_map", "Unknown")
else:
model_info["model_path"] = getattr(v, "model_id_or_path", "Unknown")
model_info["device"] = getattr(v, "device", getattr(v, "_device", "Unknown"))
model[k] = model_info
return model
def search_model(self, modelin: ModelIn) -> BaseComponent:
# Compare model_path and device to search model
for _, v in self.components.items():
model_path = v.model_name if v.comp_subtype.value == "llm" else v.model_id_or_path
model_dev = (
v.device_map
if v.comp_subtype.value == "llm"
else getattr(v, "device", getattr(v, "_device", "Unknown"))
)
if model_path == modelin.model_path and model_dev == modelin.device:
return v
return None
def del_model_by_name(self, name: str):
for key, v in self.components.items():
if v and v.model_id == name:
self.remove(key)
return "Model deleted"
return "Model not found"
@staticmethod
def load_model(model_para: ModelIn):
model = None
match model_para.model_type:
case ModelType.EMBEDDING:
model = OpenVINOEmbeddingModel(
model_id=model_para.model_id,
model_path=model_para.model_path,
device=model_para.device,
)
case ModelType.RERANKER:
model = OpenVINORerankModel(
model_id=model_para.model_id,
model_path=model_para.model_path,
device=model_para.device,
)
case ModelType.LLM:
model = OpenVINOLLMModel(
model_id=model_para.model_id,
model_path=model_para.model_path,
device=model_para.device,
)
return model

View File

@@ -0,0 +1,34 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import List
from edgecraftrag.api_schema import IndexerIn, ModelIn, NodeParserIn
from edgecraftrag.base import BaseComponent, BaseMgr, CallbackType, ModelType
from llama_index.core.schema import BaseNode
class NodeMgr:
def __init__(self):
self.nodes = {}
# idx: index of node_parser
def add_nodes(self, np_idx, nodes):
if np_idx in self.nodes:
self.nodes[np_idx].append(nodes)
else:
self.nodes[np_idx] = nodes
# TODO: to be implemented
def del_nodes(self, nodes):
pass
def del_nodes_by_np_idx(self, np_idx):
del self.nodes[np_idx]
def get_nodes(self, np_idx) -> List[BaseNode]:
if np_idx in self.nodes:
return self.nodes[np_idx]
else:
return []

View File

@@ -0,0 +1,79 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import Any, List
from comps.cores.proto.api_protocol import ChatCompletionRequest
from edgecraftrag.base import BaseMgr, CallbackType
from edgecraftrag.components.pipeline import Pipeline
from edgecraftrag.controllers.nodemgr import NodeMgr
from llama_index.core.schema import Document
class PipelineMgr(BaseMgr):
def __init__(self):
self._active_pipeline = None
self._lock = asyncio.Lock()
super().__init__()
def create_pipeline(self, name: str):
pl = Pipeline(name)
self.add(pl)
return pl
def get_pipeline_by_name_or_id(self, name: str):
for _, pl in self.components.items():
if pl.name == name or pl.idx == name:
return pl
return None
def get_pipelines(self):
return [pl for _, pl in self.components.items()]
def activate_pipeline(self, name: str, active: bool, nm: NodeMgr):
pl = self.get_pipeline_by_name_or_id(name)
nodelist = None
if pl is not None:
if not active:
pl.status.active = False
self._active_pipeline = None
return
if pl.node_changed:
nodelist = nm.get_nodes(pl.node_parser.idx)
pl.check_active(nodelist)
prevactive = self._active_pipeline
if prevactive:
prevactive.status.active = False
pl.status.active = True
self._active_pipeline = pl
def get_active_pipeline(self) -> Pipeline:
return self._active_pipeline
def notify_node_change(self):
for _, pl in self.components.items():
pl.set_node_change()
def run_pipeline(self, chat_request: ChatCompletionRequest) -> Any:
ap = self.get_active_pipeline()
out = None
if ap is not None:
out = ap.run(cbtype=CallbackType.PIPELINE, chat_request=chat_request)
return out
return -1
def run_retrieve(self, chat_request: ChatCompletionRequest) -> Any:
ap = self.get_active_pipeline()
out = None
if ap is not None:
out = ap.run(cbtype=CallbackType.RETRIEVE, chat_request=chat_request)
return out
return -1
def run_data_prepare(self, docs: List[Document]) -> Any:
ap = self.get_active_pipeline()
if ap is not None:
return ap.run(cbtype=CallbackType.DATAPREP, docs=docs)
return -1

View File

@@ -0,0 +1,8 @@
<|im_start|>System: You are an AI assistant. Your task is to learn from the following context. Then answer the user's question based on what you learned from the context but not your own knowledge.<|im_end|>
<|im_start|>{context}<|im_end|>
<|im_start|>System: Pay attention to your formatting of response. If you need to reference content from context, try to keep the formatting.<|im_end|>
<|im_start|>System: Try to summarize from the context, do some reasoning before response, then response. Make sure your response is logically sound and self-consistent.<|im_end|>
<|im_start|>{input}

View File

@@ -0,0 +1,16 @@
docx2txt
faiss-cpu>=1.8.0.post1
gradio>=4.44.1
langchain-core==0.2.29
llama-index>=0.11.0
llama-index-embeddings-openvino>=0.4.0
llama-index-llms-openai-like>=0.2.0
llama-index-llms-openvino>=0.3.1
llama-index-postprocessor-openvino-rerank>=0.3.0
llama-index-retrievers-bm25>=0.3.0
llama-index-vector-stores-faiss>=0.2.1
loguru>=0.7.2
omegaconf>=2.3.0
opea-comps>=0.9
py-cpuinfo>=9.0.0
uvicorn>=0.30.6

View File

@@ -0,0 +1,27 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
import uvicorn
from edgecraftrag.api.v1.chatqna import chatqna_app
from edgecraftrag.api.v1.data import data_app
from edgecraftrag.api.v1.model import model_app
from edgecraftrag.api.v1.pipeline import pipeline_app
from fastapi import FastAPI
from llama_index.core.settings import Settings
app = FastAPI()
sub_apps = [data_app, model_app, pipeline_app, chatqna_app]
for sub_app in sub_apps:
for route in sub_app.routes:
app.router.routes.append(route)
if __name__ == "__main__":
Settings.llm = None
host = os.getenv("PIPELINE_SERVICE_HOST_IP", "0.0.0.0")
port = int(os.getenv("PIPELINE_SERVICE_PORT", 16010))
uvicorn.run(app, host=host, port=port)

View File

@@ -0,0 +1,41 @@
{
"name": "rag_test_local_llm",
"node_parser": {
"chunk_size": 400,
"chunk_overlap": 48,
"parser_type": "simple"
},
"indexer": {
"indexer_type": "faiss_vector",
"embedding_model": {
"model_id": "BAAI/bge-small-en-v1.5",
"model_path": "./models/bge_ov_embedding",
"device": "auto"
}
},
"retriever": {
"retriever_type": "vectorsimilarity",
"retrieve_topk": 30
},
"postprocessor": [
{
"processor_type": "reranker",
"top_n": 2,
"reranker_model": {
"model_id": "BAAI/bge-reranker-large",
"model_path": "./models/bge_ov_reranker",
"device": "auto"
}
}
],
"generator": {
"model": {
"model_id": "Qwen/Qwen2-7B-Instruct",
"model_path": "./models/qwen2-7b-instruct/INT4_compressed_weights",
"device": "cpu"
},
"prompt_path": "./edgecraftrag/prompt_template/default_prompt.txt",
"inference_type": "local"
},
"active": "True"
}

View File

@@ -0,0 +1,23 @@
FROM python:3.11-slim
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
libgl1-mesa-glx \
libjemalloc-dev
RUN useradd -m -s /bin/bash user && \
mkdir -p /home/user && \
chown -R user /home/user/
COPY ./ui/gradio /home/user/ui
COPY ./edgecraftrag /home/user/edgecraftrag
WORKDIR /home/user/edgecraftrag
RUN pip install --no-cache-dir -r requirements.txt
WORKDIR /home/user/ui
USER user
RUN echo 'ulimit -S -n 999999' >> ~/.bashrc
ENTRYPOINT ["python", "ecragui.py"]

View File

@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 47 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

View File

@@ -0,0 +1,358 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
DEFAULT_SYSTEM_PROMPT_CHINESE = """\
你是一个乐于助人、尊重他人以及诚实可靠的助手。在安全的情况下,始终尽可能有帮助地回答。 您的回答不应包含任何有害、不道德、种族主义、性别歧视、有毒、危险或非法的内容。请确保您的回答在社会上是公正的和积极的。
如果一个问题没有任何意义或与事实不符,请解释原因,而不是回答错误的问题。如果您不知道问题的答案,请不要分享虚假信息。另外,答案请使用中文。\
"""
DEFAULT_SYSTEM_PROMPT_JAPANESE = """\
あなたは親切で、礼儀正しく、誠実なアシスタントです。 常に安全を保ちながら、できるだけ役立つように答えてください。 回答には、有害、非倫理的、人種差別的、性差別的、有毒、危険、または違法なコンテンツを含めてはいけません。 回答は社会的に偏見がなく、本質的に前向きなものであることを確認してください。
質問が意味をなさない場合、または事実に一貫性がない場合は、正しくないことに答えるのではなく、その理由を説明してください。 質問の答えがわからない場合は、誤った情報を共有しないでください。\
"""
DEFAULT_RAG_PROMPT = """\
You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\
"""
DEFAULT_RAG_PROMPT_CHINESE = """\
基于以下已知信息,请简洁并专业地回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题""没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。\
"""
def red_pijama_partial_text_processor(partial_text, new_text):
if new_text == "<":
return partial_text
partial_text += new_text
return partial_text.split("<bot>:")[-1]
def llama_partial_text_processor(partial_text, new_text):
new_text = new_text.replace("[INST]", "").replace("[/INST]", "")
partial_text += new_text
return partial_text
def chatglm_partial_text_processor(partial_text, new_text):
new_text = new_text.strip()
new_text = new_text.replace("[[训练时间]]", "2023年")
partial_text += new_text
return partial_text
def youri_partial_text_processor(partial_text, new_text):
new_text = new_text.replace("システム:", "")
partial_text += new_text
return partial_text
def internlm_partial_text_processor(partial_text, new_text):
partial_text += new_text
return partial_text.split("<|im_end|>")[0]
SUPPORTED_LLM_MODELS = {
"English": {
"tiny-llama-1b-chat": {
"model_id": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"remote_code": False,
"start_message": f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}</s>\n",
"history_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}</s> \n",
"current_message_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}",
"rag_prompt_template": f"""<|system|> {DEFAULT_RAG_PROMPT }</s>"""
+ """
<|user|>
Question: {input}
Context: {context}
Answer: </s>
<|assistant|>""",
},
"gemma-2b-it": {
"model_id": "google/gemma-2b-it",
"remote_code": False,
"start_message": DEFAULT_SYSTEM_PROMPT + ", ",
"history_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}<end_of_turn>",
"current_message_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}",
"rag_prompt_template": f"""{DEFAULT_RAG_PROMPT},"""
+ """<start_of_turn>user{input}<end_of_turn><start_of_turn>context{context}<end_of_turn><start_of_turn>model""",
},
"red-pajama-3b-chat": {
"model_id": "togethercomputer/RedPajama-INCITE-Chat-3B-v1",
"remote_code": False,
"start_message": "",
"history_template": "\n<human>:{user}\n<bot>:{assistant}",
"stop_tokens": [29, 0],
"partial_text_processor": red_pijama_partial_text_processor,
"current_message_template": "\n<human>:{user}\n<bot>:{assistant}",
"rag_prompt_template": f"""{DEFAULT_RAG_PROMPT }"""
+ """
<human>: Question: {input}
Context: {context}
Answer: <bot>""",
},
"gemma-7b-it": {
"model_id": "google/gemma-7b-it",
"remote_code": False,
"start_message": DEFAULT_SYSTEM_PROMPT + ", ",
"history_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}<end_of_turn>",
"current_message_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}",
"rag_prompt_template": f"""{DEFAULT_RAG_PROMPT},"""
+ """<start_of_turn>user{input}<end_of_turn><start_of_turn>context{context}<end_of_turn><start_of_turn>model""",
},
"llama-2-chat-7b": {
"model_id": "meta-llama/Llama-2-7b-chat-hf",
"remote_code": False,
"start_message": f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT }\n<</SYS>>\n\n",
"history_template": "{user}[/INST]{assistant}</s><s>[INST]",
"current_message_template": "{user} [/INST]{assistant}",
"tokenizer_kwargs": {"add_special_tokens": False},
"partial_text_processor": llama_partial_text_processor,
"rag_prompt_template": f"""[INST]Human: <<SYS>> {DEFAULT_RAG_PROMPT }<</SYS>>"""
+ """
Question: {input}
Context: {context}
Answer: [/INST]""",
},
"mpt-7b-chat": {
"model_id": "mosaicml/mpt-7b-chat",
"remote_code": False,
"start_message": f"<|im_start|>system\n {DEFAULT_SYSTEM_PROMPT }<|im_end|>",
"history_template": "<|im_start|>user\n{user}<im_end><|im_start|>assistant\n{assistant}<|im_end|>",
"current_message_template": '"<|im_start|>user\n{user}<im_end><|im_start|>assistant\n{assistant}',
"stop_tokens": ["<|im_end|>", "<|endoftext|>"],
"rag_prompt_template": f"""<|im_start|>system
{DEFAULT_RAG_PROMPT }<|im_end|>"""
+ """
<|im_start|>user
Question: {input}
Context: {context}
Answer: <im_end><|im_start|>assistant""",
},
"mistral-7b": {
"model_id": "mistralai/Mistral-7B-v0.1",
"remote_code": False,
"start_message": f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT }\n<</SYS>>\n\n",
"history_template": "{user}[/INST]{assistant}</s><s>[INST]",
"current_message_template": "{user} [/INST]{assistant}",
"tokenizer_kwargs": {"add_special_tokens": False},
"partial_text_processor": llama_partial_text_processor,
"rag_prompt_template": f"""<s> [INST] {DEFAULT_RAG_PROMPT } [/INST] </s>"""
+ """
[INST] Question: {input}
Context: {context}
Answer: [/INST]""",
},
"zephyr-7b-beta": {
"model_id": "HuggingFaceH4/zephyr-7b-beta",
"remote_code": False,
"start_message": f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}</s>\n",
"history_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}</s> \n",
"current_message_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}",
"rag_prompt_template": f"""<|system|> {DEFAULT_RAG_PROMPT }</s>"""
+ """
<|user|>
Question: {input}
Context: {context}
Answer: </s>
<|assistant|>""",
},
"notus-7b-v1": {
"model_id": "argilla/notus-7b-v1",
"remote_code": False,
"start_message": f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}</s>\n",
"history_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}</s> \n",
"current_message_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}",
"rag_prompt_template": f"""<|system|> {DEFAULT_RAG_PROMPT }</s>"""
+ """
<|user|>
Question: {input}
Context: {context}
Answer: </s>
<|assistant|>""",
},
"neural-chat-7b-v3-1": {
"model_id": "Intel/neural-chat-7b-v3-3",
"remote_code": False,
"start_message": f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT }\n<</SYS>>\n\n",
"history_template": "{user}[/INST]{assistant}</s><s>[INST]",
"current_message_template": "{user} [/INST]{assistant}",
"tokenizer_kwargs": {"add_special_tokens": False},
"partial_text_processor": llama_partial_text_processor,
"rag_prompt_template": f"""<s> [INST] {DEFAULT_RAG_PROMPT } [/INST] </s>"""
+ """
[INST] Question: {input}
Context: {context}
Answer: [/INST]""",
},
},
"Chinese": {
"qwen1.5-0.5b-chat": {
"model_id": "Qwen/Qwen1.5-0.5B-Chat",
"remote_code": False,
"start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
"stop_tokens": ["<|im_end|>", "<|endoftext|>"],
},
"qwen1.5-7b-chat": {
"model_id": "Qwen/Qwen1.5-7B-Chat",
"remote_code": False,
"start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
"stop_tokens": ["<|im_end|>", "<|endoftext|>"],
"summarization_prompt_template": """
<|im_start|>user
问题: 总结下文内容,不少于{character_num}字.
已知内容: {text}
回答: <|im_end|><|im_start|>assistant""",
"split_summary_template": """
<|im_start|>user
问题: 根据已知内容写一篇简短的摘要.
已知内容: {text}
回答: <|im_end|><|im_start|>assistant""",
"combine_summary_template": """
<|im_start|>user
问题: 根据已知内容写一篇摘要,不少于{character_num}字.
已知内容: {text}
回答: <|im_end|><|im_start|>assistant""",
"rag_prompt_template": f"""<|im_start|>system
{DEFAULT_RAG_PROMPT_CHINESE }<|im_end|>"""
+ """
<|im_start|>user
问题: {input}
已知内容: {context}
回答: <|im_end|><|im_start|>assistant""",
},
"qwen-7b-chat": {
"model_id": "Qwen/Qwen-7B-Chat",
"remote_code": True,
"start_message": f"<|im_start|>system\n {DEFAULT_SYSTEM_PROMPT_CHINESE }<|im_end|>",
"history_template": "<|im_start|>user\n{user}<im_end><|im_start|>assistant\n{assistant}<|im_end|>",
"current_message_template": '"<|im_start|>user\n{user}<im_end><|im_start|>assistant\n{assistant}',
"stop_tokens": ["<|im_end|>", "<|endoftext|>"],
"revision": "2abd8e5777bb4ce9c8ab4be7dbbd0fe4526db78d",
"rag_prompt_template": f"""<|im_start|>system
{DEFAULT_RAG_PROMPT_CHINESE }<|im_end|>"""
+ """
<|im_start|>user
问题: {input}
已知内容: {context}
回答: <|im_end|><|im_start|>assistant""",
},
"qwen2-7b-instruct": {
"model_id": "Qwen/Qwen2-7B-Instruct",
"remote_code": True,
"start_message": f"<|im_start|>system\n {DEFAULT_SYSTEM_PROMPT_CHINESE }<|im_end|>",
"history_template": "<|im_start|>user\n{user}<im_end><|im_start|>assistant\n{assistant}<|im_end|>",
"current_message_template": '"<|im_start|>user\n{user}<im_end><|im_start|>assistant\n{assistant}',
"stop_tokens": ["<|im_end|>", "<|endoftext|>"],
"revision": "2abd8e5777bb4ce9c8ab4be7dbbd0fe4526db78d",
"rag_prompt_template": f"""<|im_start|>system
{DEFAULT_RAG_PROMPT_CHINESE }<|im_end|>"""
+ """
<|im_start|>user
问题: {input}
已知内容: {context}
回答: <|im_end|><|im_start|>assistant""",
},
"chatglm3-6b": {
"model_id": "THUDM/chatglm3-6b",
"remote_code": True,
"start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
"tokenizer_kwargs": {"add_special_tokens": False},
"stop_tokens": [0, 2],
"rag_prompt_template": f"""{DEFAULT_RAG_PROMPT_CHINESE }"""
+ """
问题: {input}
已知内容: {context}
回答:
""",
},
"baichuan2-7b-chat": {
"model_id": "baichuan-inc/Baichuan2-7B-Chat",
"remote_code": True,
"start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
"tokenizer_kwargs": {"add_special_tokens": False},
"stop_tokens": [0, 2],
"rag_prompt_template": f"""{DEFAULT_RAG_PROMPT_CHINESE }"""
+ """
问题: {input}
已知内容: {context}
回答:
""",
},
"minicpm-2b-dpo": {
"model_id": "openbmb/MiniCPM-2B-dpo-fp16",
"remote_code": True,
"start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
"stop_tokens": [2],
},
"internlm2-chat-1.8b": {
"model_id": "internlm/internlm2-chat-1_8b",
"remote_code": True,
"start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
"stop_tokens": [2, 92542],
"partial_text_processor": internlm_partial_text_processor,
},
"qwen1.5-1.8b-chat": {
"model_id": "Qwen/Qwen1.5-1.8B-Chat",
"remote_code": False,
"start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
"stop_tokens": ["<|im_end|>", "<|endoftext|>"],
"rag_prompt_template": f"""<|im_start|>system
{DEFAULT_RAG_PROMPT_CHINESE }<|im_end|>"""
+ """
<|im_start|>user
问题: {input}
已知内容: {context}
回答: <|im_end|><|im_start|>assistant""",
},
},
"Japanese": {
"youri-7b-chat": {
"model_id": "rinna/youri-7b-chat",
"remote_code": False,
"start_message": f"設定: {DEFAULT_SYSTEM_PROMPT_JAPANESE}\n",
"history_template": "ユーザー: {user}\nシステム: {assistant}\n",
"current_message_template": "ユーザー: {user}\nシステム: {assistant}",
"tokenizer_kwargs": {"add_special_tokens": False},
"partial_text_processor": youri_partial_text_processor,
},
},
}
SUPPORTED_EMBEDDING_MODELS = {
"English": {
"bge-small-en-v1.5": {
"model_id": "BAAI/bge-small-en-v1.5",
"mean_pooling": False,
"normalize_embeddings": True,
},
"bge-large-en-v1.5": {
"model_id": "BAAI/bge-large-en-v1.5",
"mean_pooling": False,
"normalize_embeddings": True,
},
},
"Chinese": {
"bge-small-zh-v1.5": {
"model_id": "BAAI/bge-small-zh-v1.5",
"mean_pooling": False,
"normalize_embeddings": True,
},
"bge-large-zh-v1.5": {
"model_id": "bge-large-zh-v1.5",
"mean_pooling": False,
"normalize_embeddings": True,
},
},
}
SUPPORTED_RERANK_MODELS = {
"bge-reranker-large": {"model_id": "BAAI/bge-reranker-large"},
"bge-reranker-base": {"model_id": "BAAI/bge-reranker-base"},
}

View File

@@ -0,0 +1,49 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Model language for LLM
model_language: "Chinese"
vector_db: "FAISS"
splitter_name: "RecursiveCharacter"
k_rerank: 5
search_method: "similarity"
score_threshold: 0.5
bm25_weight: 0
# Pipeline
name: "default"
# Node parser
node_parser: "simple"
chunk_size: 192
chunk_overlap: 48
# Indexer
indexer: "faiss_vector"
# Retriever
retriever: "vectorsimilarity"
k_retrieval: 30
# Post Processor
postprocessor: "reranker"
# Generator
generator: "local"
prompt_path: "./data/default_prompt.txt"
# Models
embedding_model_id: "BAAI/bge-small-en-v1.5"
embedding_model_path: "./bge_ov_embedding"
# Device for embedding model inference
embedding_device: "AUTO"
rerank_model_id: "BAAI/bge-reranker-large"
rerank_model_path: "./bge_ov_reranker"
# Device for reranking model inference
rerank_device: "AUTO"
llm_model_id: "qwen2-7b-instruct"
llm_model_path: "./qwen2-7b-instruct/INT4_compressed_weights"
# Device for LLM model inference
llm_device: "AUTO"

View File

@@ -0,0 +1,124 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import sys
import requests
sys.path.append("..")
import os
from edgecraftrag import api_schema
PIPELINE_SERVICE_HOST_IP = os.getenv("PIPELINE_SERVICE_HOST_IP", "127.0.0.1")
PIPELINE_SERVICE_PORT = int(os.getenv("PIPELINE_SERVICE_PORT", 16010))
server_addr = f"http://{PIPELINE_SERVICE_HOST_IP}:{PIPELINE_SERVICE_PORT}"
def get_current_pipelines():
res = requests.get(f"{server_addr}/v1/settings/pipelines", proxies={"http": None})
pls = []
for pl in res.json():
if pl["status"]["active"]:
pls.append((pl["idx"], pl["name"] + " (active)"))
else:
pls.append((pl["idx"], pl["name"]))
return pls
def get_pipeline(name):
res = requests.get(f"{server_addr}/v1/settings/pipelines/{name}", proxies={"http": None})
return res.json()
def create_update_pipeline(
name,
active,
node_parser,
chunk_size,
chunk_overlap,
indexer,
retriever,
vector_search_top_k,
postprocessor,
generator,
llm_id,
llm_device,
llm_weights,
embedding_id,
embedding_device,
rerank_id,
rerank_device,
):
req_dict = api_schema.PipelineCreateIn(
name=name,
active=active,
node_parser=api_schema.NodeParserIn(
parser_type=node_parser, chunk_size=chunk_size, chunk_overlap=chunk_overlap
),
indexer=api_schema.IndexerIn(
indexer_type=indexer,
embedding_model=api_schema.ModelIn(
model_id=embedding_id,
# TODO: remove hardcoding
model_path="./bge_ov_embedding",
device=embedding_device,
),
),
retriever=api_schema.RetrieverIn(retriever_type=retriever, retriever_topk=vector_search_top_k),
postprocessor=[
api_schema.PostProcessorIn(
processor_type=postprocessor[0],
reranker_model=api_schema.ModelIn(
model_id=rerank_id,
# TODO: remove hardcoding
model_path="./bge_ov_reranker",
device=rerank_device,
),
)
],
generator=api_schema.GeneratorIn(
# TODO: remove hardcoding
prompt_path="./edgecraftrag/prompt_template/default_prompt.txt",
model=api_schema.ModelIn(
model_id=llm_id,
# TODO: remove hardcoding
model_path="./models/qwen2-7b-instruct/INT4_compressed_weights",
device=llm_device,
),
),
)
# hard code only for test
print(req_dict)
res = requests.post(f"{server_addr}/v1/settings/pipelines", json=req_dict.dict(), proxies={"http": None})
return res.text
def activate_pipeline(name):
active_dict = {"active": "True"}
res = requests.patch(f"{server_addr}/v1/settings/pipelines/{name}", json=active_dict, proxies={"http": None})
status = False
restext = f"Activate pipeline {name} failed."
if res.ok:
status = True
restext = f"Activate pipeline {name} successfully."
return restext, status
def create_vectordb(docs, spliter, vector_db):
req_dict = api_schema.FilesIn(local_paths=docs)
res = requests.post(f"{server_addr}/v1/data/files", json=req_dict.dict(), proxies={"http": None})
return res.text
def get_files():
res = requests.get(f"{server_addr}/v1/data/files", proxies={"http": None})
files = []
for file in res.json():
files.append((file["file_name"], file["file_id"]))
return files
def delete_file(file_name_or_id):
res = requests.delete(f"{server_addr}/v1/data/files/{file_name_or_id}", proxies={"http": None})
return res.text

View File

@@ -0,0 +1,983 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
import platform
import re
from datetime import datetime
from pathlib import Path
import cpuinfo
import distro # if running Python 3.8 or above
import ecrag_client as cli
import gradio as gr
import httpx
# Creation of the ModelLoader instance and loading models remain the same
import platform_config as pconf
import psutil
import requests
from loguru import logger
from omegaconf import OmegaConf
from platform_config import get_available_devices, get_available_weights, get_local_available_models
pipeline_df = []
import os
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "127.0.0.1")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 16011))
UI_SERVICE_HOST_IP = os.getenv("UI_SERVICE_HOST_IP", "0.0.0.0")
UI_SERVICE_PORT = int(os.getenv("UI_SERVICE_PORT", 8084))
def get_llm_model_dir(llm_model_id, weights_compression):
model_dirs = {
"fp16_model_dir": Path(llm_model_id) / "FP16",
"int8_model_dir": Path(llm_model_id) / "INT8_compressed_weights",
"int4_model_dir": Path(llm_model_id) / "INT4_compressed_weights",
}
if weights_compression == "INT4":
model_dir = model_dirs["int4_model_dir"]
elif weights_compression == "INT8":
model_dir = model_dirs["int8_model_dir"]
else:
model_dir = model_dirs["fp16_model_dir"]
if not model_dir.exists():
raise FileNotFoundError(f"The model directory {model_dir} does not exist.")
elif not model_dir.is_dir():
raise NotADirectoryError(f"The path {model_dir} is not a directory.")
return model_dir
def get_system_status():
cpu_usage = psutil.cpu_percent(interval=1)
memory_info = psutil.virtual_memory()
memory_usage = memory_info.percent
memory_total_gb = memory_info.total / (1024**3)
memory_used_gb = memory_info.used / (1024**3)
# uptime_seconds = time.time() - psutil.boot_time()
# uptime_hours, uptime_minutes = divmod(uptime_seconds // 60, 60)
disk_usage = psutil.disk_usage("/").percent
# net_io = psutil.net_io_counters()
os_info = platform.uname()
kernel_version = os_info.release
processor = cpuinfo.get_cpu_info()["brand_raw"]
dist_name = distro.name(pretty=True)
now = datetime.now()
current_time_str = now.strftime("%Y-%m-%d %H:%M")
status = (
f"{current_time_str} \t"
f"CPU Usage: {cpu_usage}% \t"
f"Memory Usage: {memory_usage}% {memory_used_gb:.2f}GB / {memory_total_gb:.2f}GB \t"
# f"System Uptime: {int(uptime_hours)} hours, {int(uptime_minutes)} minutes \t"
f"Disk Usage: {disk_usage}% \t"
# f"Bytes Sent: {net_io.bytes_sent}\n"
# f"Bytes Received: {net_io.bytes_recv}\n"
f"Kernel: {kernel_version} \t"
f"Processor: {processor} \t"
f"OS: {dist_name} \n"
)
return status
def build_demo(cfg, args):
def load_chatbot_models(
llm_id,
llm_device,
llm_weights,
embedding_id,
embedding_device,
rerank_id,
rerank_device,
):
req_dict = {
"llm_id": llm_id,
"llm_device": llm_device,
"llm_weights": llm_weights,
"embedding_id": embedding_id,
"embedding_device": embedding_device,
"rerank_id": rerank_id,
"rerank_device": rerank_device,
}
# hard code only for test
worker_addr = "http://127.0.0.1:8084"
print(req_dict)
result = requests.post(f"{worker_addr}/load", json=req_dict, proxies={"http": None})
return result.text
def user(message, history):
"""Callback function for updating user messages in interface on submit button click.
Params:
message: current message
history: conversation history
Returns:
None
"""
# Append the user's message to the conversation history
return "", history + [[message, ""]]
async def bot(
history,
temperature,
top_p,
top_k,
repetition_penalty,
hide_full_prompt,
do_rag,
docs,
spliter_name,
vector_db,
chunk_size,
chunk_overlap,
vector_search_top_k,
vector_search_top_n,
run_rerank,
search_method,
score_threshold,
):
"""Callback function for running chatbot on submit button click.
Params:
history: conversation history
temperature: parameter for control the level of creativity in AI-generated text.
By adjusting the `temperature`, you can influence the AI model's probability distribution, making the text more focused or diverse.
top_p: parameter for control the range of tokens considered by the AI model based on their cumulative probability.
top_k: parameter for control the range of tokens considered by the AI model based on their cumulative probability, selecting number of tokens with highest probability.
repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text.
conversation_id: unique conversation identifier.
"""
# req_dict = {
# "history": history,
# "temperature": temperature,
# "top_p": top_p,
# "top_k": top_k,
# "repetition_penalty": repetition_penalty,
# "hide_full_prompt": hide_full_prompt,
# "do_rag": do_rag,
# "docs": docs,
# "spliter_name": spliter_name,
# "vector_db": vector_db,
# "chunk_size": chunk_size,
# "chunk_overlap": chunk_overlap,
# "vector_search_top_k": vector_search_top_k,
# "vector_search_top_n": vector_search_top_n,
# "run_rerank": run_rerank,
# "search_method": search_method,
# "score_threshold": score_threshold,
# "streaming": True
# }
print(history)
new_req = {"messages": history[-1][0]}
server_addr = f"http://{MEGA_SERVICE_HOST_IP}:{MEGA_SERVICE_PORT}"
# Async for streaming response
partial_text = ""
async with httpx.AsyncClient() as client:
async with client.stream("POST", f"{server_addr}/v1/chatqna", json=new_req, timeout=None) as response:
partial_text = ""
async for chunk in response.aiter_lines():
new_text = chunk
if new_text.startswith("data"):
new_text = re.sub(r"\r\n", "", chunk.split("data: ")[-1])
new_text = json.loads(chunk)["choices"][0]["message"]["content"]
partial_text = partial_text + new_text
history[-1][1] = partial_text
yield history
avail_llms = get_local_available_models("llm")
avail_embed_models = get_local_available_models("embed")
avail_rerank_models = get_local_available_models("rerank")
avail_devices = get_available_devices()
avail_weights_compression = get_available_weights()
avail_node_parsers = pconf.get_available_node_parsers()
avail_indexers = pconf.get_available_indexers()
avail_retrievers = pconf.get_available_retrievers()
avail_postprocessors = pconf.get_available_postprocessors()
avail_generators = pconf.get_available_generators()
css = """
.feedback textarea {font-size: 18px; !important }
#blude_border {border: 1px solid #0000FF}
#white_border {border: 2px solid #FFFFFF}
.test textarea {color: E0E0FF; border: 1px solid #0000FF}
.disclaimer {font-variant-caps: all-small-caps}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.HTML(
"""
<!DOCTYPE html>
<html>
<head>
<style>
.container {
display: flex; /* Establish a flex container */
align-items: center; /* Vertically align everything in the middle */
width: 100%; /* Take the full width of the container */
}
.title-container {
flex-grow: 1; /* Allow the title to grow and occupy the available space */
text-align: center; /* Center the text block inside the title container */
}
.title-line {
display: block; /* Makes the span behave like a div in terms of layout */
line-height: 1.2; /* Adjust this value as needed for better appearance */
}
img {
/* Consider setting a specific width or height if necessary */
}
</style>
</head>
<body>
<div class="container">
<!-- Image aligned to the left -->
<a href="https://www.intel.cn/content/www/cn/zh/artificial-intelligence/overview.html"><img src="/file/assets/ai-logo-inline-onlight-3000.png" alt="Sample Image" width="200"></a>
<!-- Title centered in the remaining space -->
<!-- Title container centered in the remaining space -->
<div class="title-container">
<span class="title-line"><h1 >Edge Craft RAG based Q&A Chatbot</h1></span>
<span class="title-line"><h5 style="margin: 0;">Powered by Intel NEXC Edge AI solutions</h5></span>
</div>
</div>
</body>
</html>
"""
)
_ = gr.Textbox(
label="System Status",
value=get_system_status,
max_lines=1,
every=1,
info="",
elem_id="white_border",
)
def get_pipeline_df():
global pipeline_df
pipeline_df = cli.get_current_pipelines()
return pipeline_df
# -------------------
# RAG Settings Layout
# -------------------
with gr.Tab("RAG Settings"):
with gr.Row():
with gr.Column(scale=2):
u_pipelines = gr.Dataframe(
headers=["ID", "Name"],
column_widths=[70, 30],
value=get_pipeline_df,
label="Pipelines",
show_label=True,
interactive=False,
every=5,
)
u_rag_pipeline_status = gr.Textbox(label="Status", value="", interactive=False)
with gr.Column(scale=3):
with gr.Accordion("Pipeline Configuration"):
with gr.Row():
rag_create_pipeline = gr.Button("Create Pipeline")
rag_activate_pipeline = gr.Button("Activate Pipeline")
rag_remove_pipeline = gr.Button("Remove Pipeline")
with gr.Column(variant="panel"):
u_pipeline_name = gr.Textbox(
label="Name",
value=cfg.name,
interactive=True,
)
u_active = gr.Checkbox(
value=True,
label="Activated",
interactive=True,
)
with gr.Column(variant="panel"):
with gr.Accordion("Node Parser"):
u_node_parser = gr.Dropdown(
choices=avail_node_parsers,
label="Node Parser",
value=cfg.node_parser,
info="Select a parser to split documents.",
multiselect=False,
interactive=True,
)
u_chunk_size = gr.Slider(
label="Chunk size",
value=cfg.chunk_size,
minimum=100,
maximum=2000,
step=50,
interactive=True,
info="Size of sentence chunk",
)
u_chunk_overlap = gr.Slider(
label="Chunk overlap",
value=cfg.chunk_overlap,
minimum=0,
maximum=400,
step=1,
interactive=True,
info=("Overlap between 2 chunks"),
)
with gr.Column(variant="panel"):
with gr.Accordion("Indexer"):
u_indexer = gr.Dropdown(
choices=avail_indexers,
label="Indexer",
value=cfg.indexer,
info="Select an indexer for indexing content of the documents.",
multiselect=False,
interactive=True,
)
with gr.Accordion("Embedding Model Configuration"):
u_embed_model_id = gr.Dropdown(
choices=avail_embed_models,
value=cfg.embedding_model_id,
label="Embedding Model",
# info="Select a Embedding Model",
multiselect=False,
allow_custom_value=True,
)
u_embed_device = gr.Dropdown(
choices=avail_devices,
value=cfg.embedding_device,
label="Embedding run device",
# info="Run embedding model on which device?",
multiselect=False,
)
with gr.Column(variant="panel"):
with gr.Accordion("Retriever"):
u_retriever = gr.Dropdown(
choices=avail_retrievers,
value=cfg.retriever,
label="Retriever",
info="Select a retriever for retrieving context.",
multiselect=False,
interactive=True,
)
u_vector_search_top_k = gr.Slider(
1,
50,
value=cfg.k_retrieval,
step=1,
label="Search top k",
info="Number of searching results, must >= Rerank top n",
interactive=True,
)
with gr.Column(variant="panel"):
with gr.Accordion("Postprocessor"):
u_postprocessor = gr.Dropdown(
choices=avail_postprocessors,
value=cfg.postprocessor,
label="Postprocessor",
info="Select postprocessors for post-processing of the context.",
multiselect=True,
interactive=True,
)
with gr.Accordion("Rerank Model Configuration", open=True):
u_rerank_model_id = gr.Dropdown(
choices=avail_rerank_models,
value=cfg.rerank_model_id,
label="Rerank Model",
# info="Select a Rerank Model",
multiselect=False,
allow_custom_value=True,
)
u_rerank_device = gr.Dropdown(
choices=avail_devices,
value=cfg.rerank_device,
label="Rerank run device",
# info="Run rerank model on which device?",
multiselect=False,
)
with gr.Column(variant="panel"):
with gr.Accordion("Generator"):
u_generator = gr.Dropdown(
choices=avail_generators,
value=cfg.generator,
label="Generator",
info="Select a generator for AI inference.",
multiselect=False,
interactive=True,
)
with gr.Accordion("LLM Configuration", open=True):
u_llm_model_id = gr.Dropdown(
choices=avail_llms,
value=cfg.llm_model_id,
label="Large Language Model",
# info="Select a Large Language Model",
multiselect=False,
allow_custom_value=True,
)
u_llm_device = gr.Dropdown(
choices=avail_devices,
value=cfg.llm_device,
label="LLM run device",
# info="Run LLM on which device?",
multiselect=False,
)
u_llm_weights = gr.Radio(
avail_weights_compression,
label="Weights",
info="weights compression",
)
# -------------------
# RAG Settings Events
# -------------------
# Event handlers
def show_pipeline_detail(evt: gr.SelectData):
# get selected pipeline id
# Dataframe: {'headers': '', 'data': [[x00, x01], [x10, x11]}
# SelectData.index: [i, j]
print(u_pipelines.value["data"])
print(evt.index)
# always use pipeline id for indexing
selected_id = pipeline_df[evt.index[0]][0]
pl = cli.get_pipeline(selected_id)
# TODO: change to json fomart
# pl["postprocessor"][0]["processor_type"]
# pl["postprocessor"]["model"]["model_id"], pl["postprocessor"]["model"]["device"]
return (
pl["name"],
pl["status"]["active"],
pl["node_parser"]["parser_type"],
pl["node_parser"]["chunk_size"],
pl["node_parser"]["chunk_overlap"],
pl["indexer"]["indexer_type"],
pl["retriever"]["retriever_type"],
pl["retriever"]["retrieve_topk"],
pl["generator"]["generator_type"],
pl["generator"]["model"]["model_id"],
pl["generator"]["model"]["device"],
"",
pl["indexer"]["model"]["model_id"],
pl["indexer"]["model"]["device"],
)
def modify_create_pipeline_button():
return "Create Pipeline"
def modify_update_pipeline_button():
return "Update Pipeline"
def create_update_pipeline(
name,
active,
node_parser,
chunk_size,
chunk_overlap,
indexer,
retriever,
vector_search_top_k,
postprocessor,
generator,
llm_id,
llm_device,
llm_weights,
embedding_id,
embedding_device,
rerank_id,
rerank_device,
):
res = cli.create_update_pipeline(
name,
active,
node_parser,
chunk_size,
chunk_overlap,
indexer,
retriever,
vector_search_top_k,
postprocessor,
generator,
llm_id,
llm_device,
llm_weights,
embedding_id,
embedding_device,
rerank_id,
rerank_device,
)
return res, get_pipeline_df()
# Events
u_pipelines.select(
show_pipeline_detail,
inputs=None,
outputs=[
u_pipeline_name,
u_active,
# node parser
u_node_parser,
u_chunk_size,
u_chunk_overlap,
# indexer
u_indexer,
# retriever
u_retriever,
u_vector_search_top_k,
# postprocessor
# u_postprocessor,
# generator
u_generator,
# models
u_llm_model_id,
u_llm_device,
u_llm_weights,
u_embed_model_id,
u_embed_device,
# u_rerank_model_id,
# u_rerank_device
],
)
u_pipeline_name.input(modify_create_pipeline_button, inputs=None, outputs=rag_create_pipeline)
# Create pipeline button will change to update pipeline button if any
# of the listed fields changed
gr.on(
triggers=[
u_active.input,
# node parser
u_node_parser.input,
u_chunk_size.input,
u_chunk_overlap.input,
# indexer
u_indexer.input,
# retriever
u_retriever.input,
u_vector_search_top_k.input,
# postprocessor
u_postprocessor.input,
# generator
u_generator.input,
# models
u_llm_model_id.input,
u_llm_device.input,
u_llm_weights.input,
u_embed_model_id.input,
u_embed_device.input,
u_rerank_model_id.input,
u_rerank_device.input,
],
fn=modify_update_pipeline_button,
inputs=None,
outputs=rag_create_pipeline,
)
rag_create_pipeline.click(
create_update_pipeline,
inputs=[
u_pipeline_name,
u_active,
u_node_parser,
u_chunk_size,
u_chunk_overlap,
u_indexer,
u_retriever,
u_vector_search_top_k,
u_postprocessor,
u_generator,
u_llm_model_id,
u_llm_device,
u_llm_weights,
u_embed_model_id,
u_embed_device,
u_rerank_model_id,
u_rerank_device,
],
outputs=[u_rag_pipeline_status, u_pipelines],
queue=False,
)
rag_activate_pipeline.click(
cli.activate_pipeline,
inputs=[u_pipeline_name],
outputs=[u_rag_pipeline_status, u_active],
queue=False,
)
# --------------
# Chatbot Layout
# --------------
def get_files():
return cli.get_files()
def create_vectordb(docs, spliter, vector_db):
res = cli.create_vectordb(docs, spliter, vector_db)
return gr.update(value=get_files()), res
global u_files_selected_row
u_files_selected_row = None
def select_file(data, evt: gr.SelectData):
if not evt.selected or len(evt.index) == 0:
return "No file selected"
global u_files_selected_row
row_index = evt.index[0]
u_files_selected_row = data.iloc[row_index]
file_name, file_id = u_files_selected_row
return f"File Name: {file_name}\nFile ID: {file_id}"
def deselect_file():
global u_files_selected_row
u_files_selected_row = None
return gr.update(value=get_files()), "Selection cleared"
def delete_file():
global u_files_selected_row
if u_files_selected_row is None:
res = "Please select a file first."
else:
file_name, file_id = u_files_selected_row
u_files_selected_row = None
res = cli.delete_file(file_id)
return gr.update(value=get_files()), res
with gr.Tab("Chatbot"):
with gr.Row():
with gr.Column(scale=1):
docs = gr.File(
label="Step 1: Load text files",
file_count="multiple",
file_types=[
".csv",
".doc",
".docx",
".enex",
".epub",
".html",
".md",
".odt",
".pdf",
".ppt",
".pptx",
".txt",
],
)
retriever_argument = gr.Accordion("Vector Store Configuration", open=False)
with retriever_argument:
spliter = gr.Dropdown(
["Character", "RecursiveCharacter", "Markdown", "Chinese"],
value=cfg.splitter_name,
label="Text Spliter",
info="Method used to split the documents",
multiselect=False,
)
vector_db = gr.Dropdown(
["FAISS", "Chroma"],
value=cfg.vector_db,
label="Vector Stores",
info="Stores embedded data and performs vector search.",
multiselect=False,
)
load_docs = gr.Button("Upload files")
u_files_status = gr.Textbox(label="File Processing Status", value="", interactive=False)
u_files = gr.Dataframe(
headers=["Loaded File Name", "File ID"],
value=get_files,
label="Loaded Files",
show_label=False,
interactive=False,
every=5,
)
with gr.Accordion("Delete File", open=False):
selected_files = gr.Textbox(label="Click file to select", value="", interactive=False)
with gr.Row():
with gr.Column():
delete_button = gr.Button("Delete Selected File")
with gr.Column():
deselect_button = gr.Button("Clear Selection")
do_rag = gr.Checkbox(
value=True,
label="RAG is ON",
interactive=True,
info="Whether to do RAG for generation",
)
with gr.Accordion("Generation Configuration", open=False):
with gr.Row():
with gr.Column():
with gr.Row():
temperature = gr.Slider(
label="Temperature",
value=0.1,
minimum=0.0,
maximum=1.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
with gr.Column():
with gr.Row():
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=1.0,
minimum=0.0,
maximum=1,
step=0.01,
interactive=True,
info=(
"Sample from the smallest possible set of tokens whose cumulative probability "
"exceeds top_p. Set to 1 to disable and sample from all tokens."
),
)
with gr.Column():
with gr.Row():
top_k = gr.Slider(
label="Top-k",
value=50,
minimum=0.0,
maximum=200,
step=1,
interactive=True,
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
)
with gr.Column():
with gr.Row():
repetition_penalty = gr.Slider(
label="Repetition Penalty",
value=1.1,
minimum=1.0,
maximum=2.0,
step=0.1,
interactive=True,
info="Penalize repetition — 1.0 to disable.",
)
with gr.Column(scale=4):
chatbot = gr.Chatbot(
height=600,
label="Step 2: Input Query",
show_copy_button=True,
)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="QA Message Box",
placeholder="Chat Message Box",
show_label=False,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit")
stop = gr.Button("Stop")
clear = gr.Button("Clear")
retriever_argument = gr.Accordion("Retriever Configuration", open=True)
with retriever_argument:
with gr.Row():
with gr.Row():
do_rerank = gr.Checkbox(
value=True,
label="Rerank searching result",
interactive=True,
)
hide_context = gr.Checkbox(
value=True,
label="Hide searching result in prompt",
interactive=True,
)
with gr.Row():
search_method = gr.Dropdown(
["similarity_score_threshold", "similarity", "mmr"],
value=cfg.search_method,
label="Searching Method",
info="Method used to search vector store",
multiselect=False,
interactive=True,
)
with gr.Row():
score_threshold = gr.Slider(
0.01,
0.99,
value=cfg.score_threshold,
step=0.01,
label="Similarity Threshold",
info="Only working for 'similarity score threshold' method",
interactive=True,
)
with gr.Row():
vector_rerank_top_n = gr.Slider(
1,
10,
value=cfg.k_rerank,
step=1,
label="Rerank top n",
info="Number of rerank results",
interactive=True,
)
load_docs.click(
create_vectordb,
inputs=[
docs,
spliter,
vector_db,
],
outputs=[u_files, u_files_status],
queue=True,
)
# TODO: Need to de-select the dataframe,
# otherwise every time the dataframe is updated, a select event is triggered
u_files.select(select_file, inputs=[u_files], outputs=selected_files, queue=True)
delete_button.click(
delete_file,
outputs=[u_files, u_files_status],
queue=True,
)
deselect_button.click(
deselect_file,
outputs=[u_files, selected_files],
queue=True,
)
submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot,
[
chatbot,
temperature,
top_p,
top_k,
repetition_penalty,
hide_context,
do_rag,
docs,
spliter,
vector_db,
u_chunk_size,
u_chunk_overlap,
u_vector_search_top_k,
vector_rerank_top_n,
do_rerank,
search_method,
score_threshold,
],
chatbot,
queue=True,
)
submit_click_event = submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot,
[
chatbot,
temperature,
top_p,
top_k,
repetition_penalty,
hide_context,
do_rag,
docs,
spliter,
vector_db,
u_chunk_size,
u_chunk_overlap,
u_vector_search_top_k,
vector_rerank_top_n,
do_rerank,
search_method,
score_threshold,
],
chatbot,
queue=True,
)
# stop.click(
# fn=request_cancel,
# inputs=None,
# outputs=None,
# cancels=[submit_event, submit_click_event],
# queue=False,
# )
clear.click(lambda: None, None, chatbot, queue=False)
return demo
def main():
# Create the parser
parser = argparse.ArgumentParser(description="Load Embedding and LLM Models with OpenVino.")
# Add the arguments
parser.add_argument("--prompt_template", type=str, required=False, help="User specific template")
# parser.add_argument("--server_name", type=str, default="0.0.0.0")
# parser.add_argument("--server_port", type=int, default=8082)
parser.add_argument("--config", type=str, default="./default.yaml", help="configuration file path")
parser.add_argument("--share", action="store_true", help="share model")
parser.add_argument("--debug", action="store_true", help="enable debugging")
# Execute the parse_args() method to collect command line arguments
args = parser.parse_args()
logger.info(args)
cfg = OmegaConf.load(args.config)
init_cfg_(cfg)
logger.info(cfg)
demo = build_demo(cfg, args)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# if you have any issue to launch on your platform, you can pass share=True to launch method:
# demo.launch(share=True)
# it creates a publicly shareable link for the interface. Read more in the docs: https://gradio.app/docs/
# demo.launch(share=True)
demo.queue().launch(
server_name=UI_SERVICE_HOST_IP, server_port=UI_SERVICE_PORT, share=args.share, allowed_paths=["."]
)
# %%
# please run this cell for stopping gradio interface
demo.close()
def init_cfg_(cfg):
if "name" not in cfg:
cfg.name = "default"
if "embedding_device" not in cfg:
cfg.embedding_device = "CPU"
if "rerank_device" not in cfg:
cfg.rerank_device = "CPU"
if "llm_device" not in cfg:
cfg.llm_device = "CPU"
if "model_language" not in cfg:
cfg.model_language = "Chinese"
if "vector_db" not in cfg:
cfg.vector_db = "FAISS"
if "splitter_name" not in cfg:
cfg.splitter_name = "RecursiveCharacter" # or "Chinese"
if "search_method" not in cfg:
cfg.search_method = "similarity"
if "score_threshold" not in cfg:
cfg.score_threshold = 0.5
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,114 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
import sys
from enum import Enum
import openvino.runtime as ov
from config import SUPPORTED_EMBEDDING_MODELS, SUPPORTED_LLM_MODELS, SUPPORTED_RERANK_MODELS
sys.path.append("..")
from edgecraftrag.base import GeneratorType, IndexerType, NodeParserType, PostProcessorType, RetrieverType
def _get_llm_model_ids(supported_models, model_language=None):
if model_language is None:
model_ids = [model_id for model_id, _ in supported_models.items()]
return model_ids
if model_language not in supported_models:
print("Invalid model language! Please choose from the available options.")
return None
# Create a list of model IDs based on the selected language
llm_model_ids = [
model_id
for model_id, model_config in supported_models[model_language].items()
if model_config.get("rag_prompt_template") or model_config.get("normalize_embeddings")
]
return llm_model_ids
def _list_subdirectories(parent_directory):
"""List all subdirectories under the given parent directory using os.listdir.
Parameters:
parent_directory (str): The path to the parent directory from which to list subdirectories.
Returns:
list: A list of subdirectory names found in the parent directory.
"""
# Get a list of all entries in the parent directory
entries = os.listdir(parent_directory)
# Filter out the entries to only keep directories
subdirectories = [entry for entry in entries if os.path.isdir(os.path.join(parent_directory, entry))]
return sorted(subdirectories)
def _get_available_models(model_ids, local_dirs):
"""Filters and sorts model IDs based on their presence in the local directories.
Parameters:
model_ids (list): A list of model IDs to check.
local_dirs (list): A list of local directory names to check against.
Returns:
list: A sorted list of available model IDs.
"""
# Filter model_ids for those that are present in local directories
return sorted([model_id for model_id in model_ids if model_id in local_dirs])
def get_local_available_models(model_type: str, local_path: str = "./"):
local_dirs = _list_subdirectories(local_path)
if model_type == "llm":
model_ids = _get_llm_model_ids(SUPPORTED_LLM_MODELS, "Chinese")
elif model_type == "embed":
model_ids = _get_llm_model_ids(SUPPORTED_EMBEDDING_MODELS, "Chinese")
elif model_type == "rerank":
model_ids = _get_llm_model_ids(SUPPORTED_RERANK_MODELS)
else:
print("Unknown model type")
avail_models = _get_available_models(model_ids, local_dirs)
return avail_models
def get_available_devices():
core = ov.Core()
avail_devices = core.available_devices + ["AUTO"]
if "NPU" in avail_devices:
avail_devices.remove("NPU")
return avail_devices
def get_available_weights():
avail_weights_compression = ["FP16", "INT8", "INT4"]
return avail_weights_compression
def get_enum_values(c: Enum):
return [v.value for k, v in vars(c).items() if not callable(v) and not k.startswith("__") and not k.startswith("_")]
def get_available_node_parsers():
return get_enum_values(NodeParserType)
def get_available_indexers():
return get_enum_values(IndexerType)
def get_available_retrievers():
return get_enum_values(RetrieverType)
def get_available_postprocessors():
return get_enum_values(PostProcessorType)
def get_available_generators():
return get_enum_values(GeneratorType)