Files
GenAIExamples/EdgeCraftRAG/ui/gradio/ecragui.py
2025-04-03 13:37:49 +08:00

941 lines
38 KiB
Python

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import argparse
import base64
import os
import platform
import re
import time
from datetime import datetime
from threading import Timer
import cpuinfo
import distro # if running Python 3.8 or above
import ecrag_client as cli
import gradio as gr
import httpx
# Creation of the ModelLoader instance and loading models remain the same
import platform_config as pconf
import psutil
from loguru import logger
from omegaconf import OmegaConf
from platform_config import (
get_avail_llm_inference_type,
get_available_devices,
get_available_weights,
get_local_available_models,
)
pipeline_df = []
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "127.0.0.1")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 16011))
UI_SERVICE_HOST_IP = os.getenv("UI_SERVICE_HOST_IP", "0.0.0.0")
UI_SERVICE_PORT = int(os.getenv("UI_SERVICE_PORT", 8082))
def get_image_base64(image_path):
"""Get the Base64 encoding of a PNG image from a local file path.
:param image_path: The file path of the image.
:return: The Base64 encoded string of the image.
"""
with open(image_path, "rb") as image_file:
image_data = image_file.read()
# Encode the image data to Base64
image_base64 = base64.b64encode(image_data).decode("utf-8")
return image_base64
def get_system_status():
cpu_usage = psutil.cpu_percent(interval=1)
memory_info = psutil.virtual_memory()
memory_usage = memory_info.percent
memory_total_gb = memory_info.total / (1024**3)
memory_used_gb = memory_info.used / (1024**3)
# uptime_seconds = time.time() - psutil.boot_time()
# uptime_hours, uptime_minutes = divmod(uptime_seconds // 60, 60)
disk_usage = psutil.disk_usage("/").percent
# net_io = psutil.net_io_counters()
os_info = platform.uname()
kernel_version = os_info.release
processor = cpuinfo.get_cpu_info()["brand_raw"]
dist_name = distro.name(pretty=True)
now = datetime.now()
current_time_str = now.strftime("%Y-%m-%d %H:%M")
status = (
f"{current_time_str} \t"
f"CPU Usage: {cpu_usage}% \t"
f"Memory Usage: {memory_usage}% {memory_used_gb:.2f}GB / {memory_total_gb:.2f}GB \t"
# f"System Uptime: {int(uptime_hours)} hours, {int(uptime_minutes)} minutes \t"
f"Disk Usage: {disk_usage}% \t"
# f"Bytes Sent: {net_io.bytes_sent}\n"
# f"Bytes Received: {net_io.bytes_recv}\n"
f"Kernel: {kernel_version} \t"
f"Processor: {processor} \t"
f"OS: {dist_name} \n"
)
return status
def get_benchmark():
time.sleep(0.5)
active_pipeline_name = get_actived_pipeline()
if active_pipeline_name:
data = cli.get_benchmark(active_pipeline_name)
if data:
return gr.update(
visible=True,
value=data,
)
else:
return gr.update(visible=False)
def get_actived_pipeline():
return cli.get_actived_pipeline()
def build_app(cfg, args):
def user(message, history):
"""Callback function for updating user messages in interface on submit button click.
Params:
message: current message
history: conversation history
Returns:
None
"""
# Append the user's message to the conversation history
return "", history + [[message, ""]]
async def bot(
history,
temperature,
top_p,
top_k,
repetition_penalty,
max_tokens,
docs,
chunk_size,
chunk_overlap,
vector_search_top_k,
vector_rerank_top_n,
):
"""Callback function for running chatbot on submit button click.
Params:
history: conversation history
temperature: parameter for control the level of creativity in AI-generated text.
By adjusting the `temperature`, you can influence the AI model's probability distribution, making the text more focused or diverse.
top_p: parameter for control the range of tokens considered by the AI model based on their cumulative probability.
top_k: parameter for control the range of tokens considered by the AI model based on their cumulative probability, selecting number of tokens with highest probability.
repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text.
conversation_id: unique conversation identifier.
"""
if history[-1][0] == "" or len(history[-1][0]) == 0:
yield history[:-1]
return
stream_opt = True
new_req = {
"messages": history[-1][0],
"stream": stream_opt,
"max_tokens": max_tokens,
"top_n": vector_rerank_top_n,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
server_addr = f"http://{MEGA_SERVICE_HOST_IP}:{MEGA_SERVICE_PORT}"
# Async for stream response
partial_text = ""
async with httpx.AsyncClient() as client:
async with client.stream("POST", f"{server_addr}/v1/chatqna", json=new_req, timeout=None) as response:
async for chunk in response.aiter_text():
wrap_text = chunk
if "参考图片" in chunk:
image_paths = re.compile(r"!\[\]\((.*?)\)").findall(chunk)
for image_path in image_paths:
image_base64 = get_image_base64(image_path)
wrap_text = chunk.replace(
f"![]({image_path})", f'<img src="data:image/png;base64,{image_base64}">'
)
partial_text = partial_text + wrap_text
history[-1][1] = partial_text
yield history
avail_llms = get_local_available_models("llm")
avail_embed_models = get_local_available_models("embed")
avail_rerank_models = get_local_available_models("rerank")
avail_devices = get_available_devices()
avail_weights_compression = get_available_weights()
avail_llm_inference_type = get_avail_llm_inference_type()
avail_node_parsers = pconf.get_available_node_parsers()
avail_indexers = pconf.get_available_indexers()
avail_retrievers = pconf.get_available_retrievers()
avail_postprocessors = pconf.get_available_postprocessors()
avail_generators = pconf.get_available_generators()
css = """
.feedback textarea {font-size: 18px; !important }
#blude_border {border: 1px solid #0000FF}
#white_border {border: 2px solid #FFFFFF}
.test textarea {color: E0E0FF; border: 1px solid #0000FF}
.disclaimer {font-variant-caps: all-small-caps}
html body gradio-app{margin: 0px;}
footer{display: none !important;}
.gradio-container{font-weight: 400;font-size: 14px;line-height: 24px;font-family: "PingFang SC", "Microsoft YaHei", SimHei !important;}
.custom-header{position: relative;}
.custom-log{position: absolute;top: 0px;left: 0px;min-width:0 !important;z-index:20;}
.custom-title{background-color: var(--body-background-fill);h2 {padding:0;background-color: var(--body-background-fill)}}
.custom-des {position: relative;top:-24px;background-color: var(--body-background-fill);
# .benchmark-wrap {position: absolute;top: -40px;height: 36px;z-index: 20;.container {padding: 8px 16px;}h2 {font-size: 14px;padding: 0;font-weight: 500;color: #666666;justify-content: end;}}
.container{padding-top:0;}h2{font-size:14px;padding:0;color:var(--neutral-500);}}
.benchmark-wrap {position: absolute;top: -40px;height: 36px;z-index: 20;.container {padding: 8px 16px;}h2 {font-size: 14px;padding: 0;font-weight: 500;color: #666666;justify-content: end;}}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css, title="Edge Craft RAG based Q&A Chatbot") as app:
with gr.Column(elem_classes="custom-header"):
gr.Image(
value="./assets/ai-logo-inline-onlight-3000.png",
show_label=False,
show_download_button=False,
container=False,
show_fullscreen_button=False,
width="160px",
height="45px",
elem_classes="custom-log",
)
gr.Label(
"Edge Craft RAG based Q&A Chatbot",
show_label=False,
elem_classes="custom-title",
)
gr.Label("Powered by Intel", show_label=False, elem_classes="custom-des")
_ = gr.Textbox(
label="System Status",
value=get_system_status,
max_lines=1,
every=1,
info="",
elem_id="white_border",
)
def get_pipeline_df():
global pipeline_df
pipeline_df = cli.get_current_pipelines()
return pipeline_df
# -------------------
# RAG Settings Layout
# -------------------
with gr.Tab("RAG Settings"):
with gr.Row():
with gr.Column(scale=2):
u_pipelines = gr.Dataframe(
headers=["ID", "Name"],
column_widths=[70, 30],
value=get_pipeline_df,
label="Pipelines",
show_label=True,
interactive=False,
every=5,
)
u_rag_pipeline_status = gr.JSON(label="Status")
with gr.Column(scale=3):
with gr.Accordion("Pipeline Configuration"):
with gr.Row():
rag_create_pipeline = gr.Button("Create Pipeline")
rag_activate_pipeline = gr.Button("Activate Pipeline")
rag_remove_pipeline = gr.Button("Remove Pipeline")
with gr.Column(variant="panel"):
u_pipeline_name = gr.Textbox(
label="Name",
value=cfg.name,
interactive=True,
)
u_active = gr.Checkbox(
value=True,
label="Activated",
interactive=True,
)
with gr.Column(variant="panel"):
with gr.Accordion("Node Parser"):
u_node_parser = gr.Dropdown(
choices=avail_node_parsers,
label="Node Parser",
value=cfg.node_parser,
info="Select a parser to split documents.",
multiselect=False,
interactive=True,
)
u_chunk_size = gr.Slider(
label="Chunk size",
value=cfg.chunk_size,
minimum=100,
maximum=2000,
step=50,
interactive=True,
info="Size of sentence chunk",
)
u_chunk_overlap = gr.Slider(
label="Chunk overlap",
value=cfg.chunk_overlap,
minimum=0,
maximum=400,
step=1,
interactive=True,
info=("Overlap between 2 chunks"),
)
with gr.Column(variant="panel"):
with gr.Accordion("Indexer"):
u_indexer = gr.Dropdown(
choices=avail_indexers,
label="Indexer",
value=cfg.indexer,
info="Select an indexer for indexing content of the documents.",
multiselect=False,
interactive=True,
)
with gr.Accordion("Embedding Model Configuration"):
u_embed_model_id = gr.Dropdown(
choices=avail_embed_models,
value=cfg.embedding_model_id,
label="Embedding Model",
# info="Select a Embedding Model",
multiselect=False,
allow_custom_value=True,
)
u_embed_device = gr.Dropdown(
choices=avail_devices,
value=cfg.embedding_device,
label="Embedding run device",
# info="Run embedding model on which device?",
multiselect=False,
interactive=True,
)
with gr.Column(variant="panel"):
with gr.Accordion("Retriever"):
u_retriever = gr.Dropdown(
choices=avail_retrievers,
value=cfg.retriever,
label="Retriever",
info="Select a retriever for retrieving context.",
multiselect=False,
interactive=True,
)
u_vector_search_top_k = gr.Slider(
1,
50,
value=cfg.k_retrieval,
step=1,
label="Search top k",
info="Number of searching results, must >= Rerank top n",
interactive=True,
)
with gr.Column(variant="panel"):
with gr.Accordion("Postprocessor"):
u_postprocessor = gr.Dropdown(
choices=avail_postprocessors,
value=cfg.postprocessor,
label="Postprocessor",
info="Select postprocessors for post-processing of the context.",
multiselect=True,
interactive=True,
)
with gr.Accordion("Rerank Model Configuration", open=True) as rerank_model:
u_rerank_model_id = gr.Dropdown(
choices=avail_rerank_models,
value=cfg.rerank_model_id,
label="Rerank Model",
# info="Select a Rerank Model",
multiselect=False,
allow_custom_value=True,
)
u_rerank_device = gr.Dropdown(
choices=avail_devices,
value=cfg.rerank_device,
label="Rerank run device",
# info="Run rerank model on which device?",
multiselect=False,
interactive=True,
)
with gr.Column(variant="panel"):
with gr.Accordion("Generator"):
u_generator = gr.Dropdown(
choices=avail_generators,
value=cfg.generator,
label="Generator",
info="Select a generator for AI inference.",
multiselect=False,
interactive=True,
)
u_llm_infertype = gr.Radio(
choices=avail_llm_inference_type, label="LLM Inference Type", value="local"
)
with gr.Accordion("LLM Configuration", open=True) as accordion:
u_llm_model_id = gr.Dropdown(
choices=avail_llms,
value=cfg.llm_model_id,
label="Large Language Model",
# info="Select a Large Language Model",
multiselect=False,
allow_custom_value=True,
)
u_llm_device = gr.Dropdown(
choices=avail_devices,
value=cfg.llm_device,
label="LLM run device",
# info="Run LLM on which device?",
multiselect=False,
interactive=True,
)
u_llm_weights = gr.Radio(
avail_weights_compression,
label="Weights",
info="weights compression",
value=cfg.llm_weights,
interactive=True,
)
# -------------------
# RAG Settings Events
# -------------------
# Event handlers
def update_visibility(selected_value): # Accept the event argument, even if not used
if selected_value == "vllm":
return gr.Accordion(visible=False)
else:
return gr.Accordion(visible=True)
def update_rerank_model(selected_list): # Accept the event argument, even if not used
print(selected_list)
if "reranker" in selected_list:
return gr.Accordion(visible=True)
else:
return gr.Accordion(visible=False)
def show_pipeline_detail(evt: gr.SelectData):
# get selected pipeline id
# Dataframe: {'headers': '', 'data': [[x00, x01], [x10, x11]}
# SelectData.index: [i, j]
# always use pipeline id for indexing
selected_id = pipeline_df[evt.index[0]][0]
pl = cli.get_pipeline(selected_id)
return (
pl["name"],
pl["status"]["active"],
pl["node_parser"]["parser_type"],
pl["node_parser"]["chunk_size"],
pl["node_parser"]["chunk_overlap"],
pl["indexer"]["indexer_type"],
pl["retriever"]["retriever_type"],
pl["retriever"]["retrieve_topk"],
pl["postprocessor"][0]["processor_type"],
pl["generator"]["generator_type"],
pl["generator"]["inference_type"],
pl["generator"]["model"]["model_id"],
pl["generator"]["model"]["device"],
pl["generator"]["model"]["weight"],
pl["indexer"]["model"]["model_id"],
pl["indexer"]["model"]["device"],
pl["postprocessor"][0]["model"]["model_id"] if pl["postprocessor"][0]["model"] is not None else "",
pl["postprocessor"][0]["model"]["device"] if pl["postprocessor"][0]["model"] is not None else "",
)
def modify_create_pipeline_button():
return "Create Pipeline"
def modify_update_pipeline_button():
return "Update Pipeline"
def create_update_pipeline(
name,
active,
node_parser,
chunk_size,
chunk_overlap,
indexer,
retriever,
vector_search_top_k,
postprocessor,
generator,
llm_infertype,
llm_id,
llm_device,
llm_weights,
embedding_id,
embedding_device,
rerank_id,
rerank_device,
):
res = cli.create_update_pipeline(
name,
active,
node_parser,
chunk_size,
chunk_overlap,
indexer,
retriever,
vector_search_top_k,
postprocessor,
generator,
llm_infertype,
llm_id,
llm_device,
llm_weights,
embedding_id,
embedding_device,
rerank_id,
rerank_device,
)
return res, get_pipeline_df()
# Events
u_llm_infertype.change(update_visibility, inputs=u_llm_infertype, outputs=accordion)
u_postprocessor.change(update_rerank_model, inputs=u_postprocessor, outputs=rerank_model)
u_pipelines.select(
show_pipeline_detail,
inputs=None,
outputs=[
u_pipeline_name,
u_active,
# node parser
u_node_parser,
u_chunk_size,
u_chunk_overlap,
# indexer
u_indexer,
# retriever
u_retriever,
u_vector_search_top_k,
# postprocessor
u_postprocessor,
# generator
u_generator,
u_llm_infertype,
# models
u_llm_model_id,
u_llm_device,
u_llm_weights,
u_embed_model_id,
u_embed_device,
u_rerank_model_id,
u_rerank_device,
],
)
u_pipeline_name.input(modify_create_pipeline_button, inputs=None, outputs=rag_create_pipeline)
# Create pipeline button will change to update pipeline button if any
# of the listed fields changed
gr.on(
triggers=[
u_active.input,
# node parser
u_node_parser.input,
u_chunk_size.input,
u_chunk_overlap.input,
# indexer
u_indexer.input,
# retriever
u_retriever.input,
u_vector_search_top_k.input,
# postprocessor
u_postprocessor.input,
# generator
u_generator.input,
# models
u_llm_model_id.input,
u_llm_device.input,
u_llm_weights.input,
u_llm_infertype.input,
u_embed_model_id.input,
u_embed_device.input,
u_rerank_model_id.input,
u_rerank_device.input,
],
fn=modify_update_pipeline_button,
inputs=None,
outputs=rag_create_pipeline,
)
rag_create_pipeline.click(
create_update_pipeline,
inputs=[
u_pipeline_name,
u_active,
u_node_parser,
u_chunk_size,
u_chunk_overlap,
u_indexer,
u_retriever,
u_vector_search_top_k,
u_postprocessor,
u_generator,
u_llm_infertype,
u_llm_model_id,
u_llm_device,
u_llm_weights,
u_embed_model_id,
u_embed_device,
u_rerank_model_id,
u_rerank_device,
],
outputs=[u_rag_pipeline_status, u_pipelines],
queue=False,
)
rag_activate_pipeline.click(
cli.activate_pipeline,
inputs=[u_pipeline_name],
outputs=[u_rag_pipeline_status, u_active],
queue=False,
)
rag_remove_pipeline.click(
cli.remove_pipeline,
inputs=[u_pipeline_name],
outputs=[u_rag_pipeline_status],
queue=False,
)
# --------------
# Chatbot Layout
# --------------
def get_files():
return cli.get_files()
def create_vectordb(docs, spliter):
res = cli.create_vectordb(docs, spliter)
return gr.update(value=get_files()), res, None
global u_files_selected_row
u_files_selected_row = None
def select_file(data, evt: gr.SelectData):
if not evt.selected or len(evt.index) == 0:
return "No file selected"
global u_files_selected_row
row_index = evt.index[0]
u_files_selected_row = data.iloc[row_index]
file_name, file_id = u_files_selected_row
return f"File Name: {file_name}\nFile ID: {file_id}"
def deselect_file():
global u_files_selected_row
u_files_selected_row = None
return gr.update(value=get_files()), "Selection cleared"
def delete_file():
global u_files_selected_row
if u_files_selected_row is None:
res = "Please select a file first."
else:
file_name, file_id = u_files_selected_row
u_files_selected_row = None
res = cli.delete_file(file_id)
return gr.update(value=get_files()), res
with gr.Tab("Chatbot"):
with gr.Row():
with gr.Column(scale=1):
docs = gr.File(
label="Step 1: Load text files",
file_count="multiple",
file_types=[
".csv",
".doc",
".docx",
".enex",
".epub",
".html",
".md",
".odt",
".pdf",
".ppt",
".pptx",
".txt",
],
)
retriever_argument = gr.Accordion("Vector Store Configuration", open=False)
with retriever_argument:
spliter = gr.Dropdown(
["Character", "RecursiveCharacter", "Markdown", "Chinese"],
value=cfg.splitter_name,
label="Text Spliter",
info="Method used to split the documents",
multiselect=False,
)
load_docs = gr.Button("Upload files")
u_files_status = gr.Textbox(label="File Processing Status", value="", interactive=False)
u_files = gr.Dataframe(
headers=["Loaded File Name", "File ID"],
value=get_files,
label="Loaded Files",
show_label=False,
interactive=False,
every=5,
)
with gr.Accordion("Delete File", open=False):
selected_files = gr.Textbox(label="Click file to select", value="", interactive=False)
with gr.Row():
with gr.Column():
delete_button = gr.Button("Delete Selected File")
with gr.Column():
deselect_button = gr.Button("Clear Selection")
with gr.Accordion("Generation Configuration", open=False):
with gr.Row():
with gr.Column():
with gr.Row():
temperature = gr.Slider(
label="Temperature",
value=0.1,
minimum=0.0,
maximum=1.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
with gr.Column():
with gr.Row():
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=1.0,
minimum=0.0,
maximum=1,
step=0.01,
interactive=True,
info=(
"Sample from the smallest possible set of tokens whose cumulative probability "
"exceeds top_p. Set to 1 to disable and sample from all tokens."
),
)
with gr.Column():
with gr.Row():
top_k = gr.Slider(
label="Top-k",
value=50,
minimum=0.0,
maximum=200,
step=1,
interactive=True,
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
)
with gr.Column():
with gr.Row():
repetition_penalty = gr.Slider(
label="Repetition Penalty",
value=1.1,
minimum=1.0,
maximum=2.0,
step=0.1,
interactive=True,
info="Penalize repetition — 1.0 to disable.",
)
with gr.Column():
with gr.Row():
u_max_tokens = gr.Slider(
label="Max Token Number",
value=512,
minimum=1,
maximum=8192,
step=10,
interactive=True,
info="Set Max Output Token",
)
with gr.Column(scale=4):
chatbot = gr.Chatbot(
height=600,
label="Step 2: Input Query",
show_copy_button=True,
)
with gr.Row():
benchmark = gr.Label(
show_label=False,
visible=False,
elem_classes="benchmark-wrap",
)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="QA Message Box",
placeholder="Chat Message Box",
show_label=False,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit")
clear = gr.Button("Clear")
retriever_argument = gr.Accordion("Retriever Configuration", open=False)
with retriever_argument:
with gr.Row():
with gr.Row():
vector_rerank_top_n = gr.Slider(
1,
10,
value=cfg.k_rerank,
step=1,
label="Rerank top n",
info="Number of rerank results",
interactive=True,
)
load_docs.click(
create_vectordb,
inputs=[
docs,
spliter,
],
outputs=[u_files, u_files_status, docs],
queue=True,
)
# TODO: Need to de-select the dataframe,
# otherwise every time the dataframe is updated, a select event is triggered
u_files.select(select_file, inputs=[u_files], outputs=selected_files, queue=True)
delete_button.click(
delete_file,
outputs=[u_files, u_files_status],
queue=True,
)
deselect_button.click(
deselect_file,
outputs=[u_files, selected_files],
queue=True,
)
submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot,
[
chatbot,
temperature,
top_p,
top_k,
repetition_penalty,
u_max_tokens,
docs,
u_chunk_size,
u_chunk_overlap,
u_vector_search_top_k,
vector_rerank_top_n,
],
chatbot,
queue=True,
)
submit_click_event = submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot,
[
chatbot,
temperature,
top_p,
top_k,
repetition_penalty,
u_max_tokens,
docs,
u_chunk_size,
u_chunk_overlap,
u_vector_search_top_k,
vector_rerank_top_n,
],
chatbot,
queue=True,
)
clear.click(lambda: None, None, chatbot, queue=False)
chatbot.change(
get_benchmark,
inputs=None,
outputs=benchmark,
)
return app
def main():
# Create the parser
parser = argparse.ArgumentParser(description="Load Embedding and LLM Models with OpenVino.")
# Add the arguments
parser.add_argument("--prompt_template", type=str, required=False, help="User specific template")
parser.add_argument("--config", type=str, default="./default.yaml", help="configuration file path")
parser.add_argument("--share", action="store_true", help="share model")
parser.add_argument("--debug", action="store_true", help="enable debugging")
# Execute the parse_args() method to collect command line arguments
args = parser.parse_args()
logger.info(args)
cfg = OmegaConf.load(args.config)
init_cfg_(cfg)
logger.info(cfg)
app = build_app(cfg, args)
# if you are launching remotely, specify server_name and server_port
# app.launch(server_name='your server name', server_port='server port in int')
# if you have any issue to launch on your platform, you can pass share=True to launch method:
# app.launch(share=True)
# it creates a publicly shareable link for the interface. Read more in the docs: https://gradio.app/docs/
# app.launch(share=True)
app.queue().launch(
server_name=UI_SERVICE_HOST_IP, server_port=UI_SERVICE_PORT, share=args.share, allowed_paths=["."]
)
# %%
# please run this cell for stopping gradio interface
app.close()
def init_cfg_(cfg):
if "name" not in cfg:
cfg.name = "default"
if "embedding_device" not in cfg:
cfg.embedding_device = "CPU"
if "rerank_device" not in cfg:
cfg.rerank_device = "CPU"
if "llm_device" not in cfg:
cfg.llm_device = "CPU"
if "model_language" not in cfg:
cfg.model_language = "Chinese"
if "splitter_name" not in cfg:
cfg.splitter_name = "RecursiveCharacter" # or "Chinese"
if "search_method" not in cfg:
cfg.search_method = "similarity"
if "score_threshold" not in cfg:
cfg.score_threshold = 0.5
if "llm_weights" not in cfg:
cfg.llm_weights = "FP16"
if __name__ == "__main__":
main()