Remove deprecated folder. (#536)
Signed-off-by: zepan <ze.pan@intel.com>
This commit is contained in:
@@ -1,272 +0,0 @@
|
||||
# AudioQnA
|
||||
|
||||

|
||||
|
||||
In this example we will show you how to build an Audio Question and Answering application (AudioQnA). AudioQnA serves like a talking bot, enabling LLMs to talk with users. It basically accepts users' audio inputs, converts to texts and feed to LLMs, gets the text answers and converts back to audio outputs.
|
||||
|
||||
What AudioQnA is delivering and why it stands out:
|
||||
|
||||
- Fast ASR/TTS inference as microservices on Intel Xeon CPUs with optimization
|
||||
- Multilingual Zero-shot voice cloning cross languages, customizable voice
|
||||
- Fast LLM inference on Intel Gaudi through TGI with RAG and other features support
|
||||
|
||||
There are four folders under the current example.
|
||||
|
||||
`front_end/`: the UI users interact with
|
||||
`serving/`: TGI LLM service endpoint
|
||||
`langchain/`: pipeline the flow of text input -> RAG -> TGI LLM service -> text output
|
||||
`audio/`: pipeline the flow of audio-to-text service -> langchain -> text-to-audio service -> ui
|
||||
|
||||
## Start the Audio services
|
||||
|
||||
### Build ASR and TTS services
|
||||
|
||||
```shell
|
||||
cd audio/docker
|
||||
|
||||
# Build ASR Docker service
|
||||
docker build . --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${http_proxy} -f Dockerfile_asr -t intel/gen-ai-examples:audioqna-asr
|
||||
# Build TTS Docker service
|
||||
docker build . --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${http_proxy} -f Dockerfile_tts -t intel/gen-ai-examples:audioqna-tts
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```shell
|
||||
# Start ASR service
|
||||
docker run -d -e http_proxy=$http_proxy -e https_proxy=$https_proxy -p 8008:8008 intel/gen-ai-examples:audioqna-asr
|
||||
|
||||
# Test ASR
|
||||
wget https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav
|
||||
http_proxy= curl -F 'file=@sample.wav' http://localhost:8008/v1/audio/transcriptions
|
||||
|
||||
# Start TTS service
|
||||
# Predownload local models and mapped in
|
||||
git clone https://huggingface.co/lj1995/GPT-SoVITS pretrained_tts_models
|
||||
docker run -d -v ./pretrained_tts_models:/GPT-SoVITS/GPT_SoVITS/pretrained_models -e http_proxy=$http_proxy -e https_proxy=$https_proxy -p 9880:9880 intel/gen-ai-examples:audioqna-tts --default_refer_path /GPT-SoVITS/sample.wav --default_refer_text="Who is Pat Gelsinger?" --default_refer_language="en" --bf16 --return_text_stream
|
||||
|
||||
# Upload/Change reference audio
|
||||
# http_proxy= curl --location 'localhost:9880/upload_as_default' \
|
||||
# --form 'default_refer_file=@"sample.wav"' \
|
||||
# --form 'default_refer_text="Who is Pat Gelsinger?"' \
|
||||
# --form 'default_refer_language="en"'
|
||||
|
||||
# Test TTS
|
||||
http_proxy= curl --location 'localhost:9880/v1/audio/speech' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"text": "You can have a look, but you should not touch this item.",
|
||||
"text_language": "en"
|
||||
}' \
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
## Prepare TGI Docker
|
||||
|
||||
Getting started is straightforward with the official Docker container. Simply pull the image using:
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/huggingface/tgi-gaudi:1.2.1
|
||||
```
|
||||
|
||||
Alternatively, you can build the Docker image yourself using latest [TGI-Gaudi](https://github.com/huggingface/tgi-gaudi) code with the below command:
|
||||
|
||||
```bash
|
||||
bash ./serving/tgi_gaudi/build_docker.sh
|
||||
```
|
||||
|
||||
## Launch TGI Gaudi Service
|
||||
|
||||
### Launch a local server instance on 1 Gaudi card:
|
||||
|
||||
```bash
|
||||
bash ./serving/tgi_gaudi/launch_tgi_service.sh
|
||||
```
|
||||
|
||||
For gated models such as `LLAMA-2`, you will have to pass -e HUGGING_FACE_HUB_TOKEN=\<token\> to the docker run command above with a valid Hugging Face Hub read token.
|
||||
|
||||
Please follow this link [huggingface token](https://huggingface.co/docs/hub/security-tokens) to get the access token and export `HUGGINGFACEHUB_API_TOKEN` environment with the token.
|
||||
|
||||
```bash
|
||||
export HUGGINGFACEHUB_API_TOKEN=<token>
|
||||
```
|
||||
|
||||
### Launch a local server instance on 8 Gaudi cards:
|
||||
|
||||
```bash
|
||||
bash ./serving/tgi_gaudi/launch_tgi_service.sh 8
|
||||
```
|
||||
|
||||
And then you can make requests like below to check the service status:
|
||||
|
||||
```bash
|
||||
curl 127.0.0.1:8080/generate \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":32}}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
### Customize TGI Gaudi Service
|
||||
|
||||
The ./serving/tgi_gaudi/launch_tgi_service.sh script accepts three parameters:
|
||||
|
||||
- num_cards: The number of Gaudi cards to be utilized, ranging from 1 to 8. The default is set to 1.
|
||||
- port_number: The port number assigned to the TGI Gaudi endpoint, with the default being 8080.
|
||||
- model_name: The model name utilized for LLM, with the default set to "Intel/neural-chat-7b-v3-3".
|
||||
|
||||
You have the flexibility to customize these parameters according to your specific needs. Additionally, you can set the TGI Gaudi endpoint by exporting the environment variable `TGI_LLM_ENDPOINT`:
|
||||
|
||||
```bash
|
||||
export TGI_LLM_ENDPOINT="http://xxx.xxx.xxx.xxx:8080"
|
||||
```
|
||||
|
||||
## Enable TEI for embedding model
|
||||
|
||||
Text Embeddings Inference (TEI) is a toolkit designed for deploying and serving open-source text embeddings and sequence classification models efficiently. With TEI, users can extract high-performance features using various popular models. It supports token-based dynamic batching for enhanced performance.
|
||||
|
||||
To launch the TEI service, you can use the following commands:
|
||||
|
||||
```bash
|
||||
model=BAAI/bge-large-en-v1.5
|
||||
revision=refs/pr/5
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
docker run -p 9090:80 -v $volume:/data -e http_proxy=$http_proxy -e https_proxy=$https_proxy --pull always ghcr.io/huggingface/text-embeddings-inference:cpu-1.2 --model-id $model --revision $revision
|
||||
export TEI_ENDPOINT="http://xxx.xxx.xxx.xxx:9090"
|
||||
```
|
||||
|
||||
And then you can make requests like below to check the service status:
|
||||
|
||||
```bash
|
||||
curl 127.0.0.1:9090/embed \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is Deep Learning?"}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
Note: If you want to integrate the TEI service into the LangChain application, you'll need to restart the LangChain backend service after launching the TEI service.
|
||||
|
||||
## Launch Redis and LangChain Backend Service
|
||||
|
||||
Update the `HUGGINGFACEHUB_API_TOKEN` environment variable with your huggingface token in the `docker-compose.yml`
|
||||
|
||||
```bash
|
||||
cd langchain/docker
|
||||
docker compose -f docker-compose.yml up -d
|
||||
cd ../../
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> If you have modified any files and want that change to be introduced in this step, add `--build` to the end of the command to build the container image instead of pulling it from dockerhub.
|
||||
|
||||
## Ingest data into Redis (Optional)
|
||||
|
||||
Each time the Redis container is launched, data should be ingested into the container using the commands:
|
||||
|
||||
```bash
|
||||
docker exec -it qna-rag-redis-server bash
|
||||
cd /ws
|
||||
python ingest.py
|
||||
exit
|
||||
```
|
||||
|
||||
Note: `ingest.py` will download the embedding model. Please set the proxy if necessary.
|
||||
|
||||
# Start LangChain Server
|
||||
|
||||
## Enable GuardRails using Meta's Llama Guard model (Optional)
|
||||
|
||||
We offer content moderation support utilizing Meta's [Llama Guard](https://huggingface.co/meta-llama/LlamaGuard-7b) model. To activate GuardRails, kindly follow the instructions below to deploy the Llama Guard model on TGI Gaudi.
|
||||
|
||||
```bash
|
||||
volume=$PWD/data
|
||||
model_id="meta-llama/LlamaGuard-7b"
|
||||
docker run -p 8088:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host -e HUGGING_FACE_HUB_TOKEN=<your HuggingFace token> -e HTTPS_PROXY=$https_proxy -e HTTP_PROXY=$https_proxy tgi_gaudi --model-id $model_id
|
||||
export SAFETY_GUARD_ENDPOINT="http://xxx.xxx.xxx.xxx:8088"
|
||||
```
|
||||
|
||||
And then you can make requests like below to check the service status:
|
||||
|
||||
```bash
|
||||
curl 127.0.0.1:8088/generate \
|
||||
-X POST \
|
||||
-d '{"inputs":"How do you buy a tiger in the US?","parameters":{"max_new_tokens":32}}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
## Start the Backend Service
|
||||
|
||||
Make sure TGI-Gaudi service is running and also make sure data is populated into Redis. Launch the backend service:
|
||||
|
||||
```bash
|
||||
docker exec -it qna-rag-redis-server bash
|
||||
nohup python app/server.py &
|
||||
```
|
||||
|
||||
The LangChain backend service listens to port 8000, you can customize it by changing the code in `docker/qna-app/app/server.py`.
|
||||
|
||||
And then you can make requests like below to check the LangChain backend service status:
|
||||
|
||||
```bash
|
||||
# non-streaming endpoint
|
||||
curl 127.0.0.1:8000/v1/rag/chat \
|
||||
-X POST \
|
||||
-d '{"query":"What is the total revenue of Nike in 2023?"}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
```bash
|
||||
# streaming endpoint
|
||||
curl 127.0.0.1:8000/v1/rag/chat_stream \
|
||||
-X POST \
|
||||
-d '{"query":"What is the total revenue of Nike in 2023?"}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
## Start the Frontend Service
|
||||
|
||||
Please refer to frontend [README](./front_end/README.md).
|
||||
|
||||
## Enable TGI Gaudi FP8 for higher throughput (Optional)
|
||||
|
||||
The TGI Gaudi utilizes BFLOAT16 optimization as the default setting. If you aim to achieve higher throughput, you can enable FP8 quantization on the TGI Gaudi. Note that currently only Llama2 series and Mistral series models support FP8 quantization. Please follow the below steps to enable FP8 quantization.
|
||||
|
||||
### Prepare Metadata for FP8 Quantization
|
||||
|
||||
Enter into the TGI Gaudi docker container, and then run the below commands:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/optimum-habana.git
|
||||
git clone https://github.com/huggingface/optimum-habana.git
|
||||
cd optimum-habana/examples/text-generation
|
||||
pip install -r requirements_lm_eval.txt
|
||||
QUANT_CONFIG=./quantization_config/maxabs_measure.json python ../gaudi_spawn.py run_lm_eval.py -o acc_7b_bs1_measure.txt --model_name_or_path Intel/neural-chat-7b-v3-3 --attn_softmax_bf16 --use_hpu_graphs --trim_logits --use_kv_cache --reuse_cache --bf16 --batch_size 1
|
||||
QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py run_lm_eval.py -o acc_7b_bs1_quant.txt --model_name_or_path Intel/neural-chat-7b-v3-3 --attn_softmax_bf16 --use_hpu_graphs --trim_logits --use_kv_cache --reuse_cache --bf16 --batch_size 1 --fp8
|
||||
```
|
||||
|
||||
After finishing the above commands, the quantization metadata will be generated. Move the metadata directory ./hqt_output/ and copy the quantization JSON file to the host (under …/data). Please adapt the commands with your Docker ID and directory path.
|
||||
|
||||
```bash
|
||||
docker cp 262e04bbe466:/usr/src/optimum-habana/examples/text-generation/hqt_output data/
|
||||
docker cp 262e04bbe466:/usr/src/optimum-habana/examples/text-generation/quantization_config/maxabs_quant.json data/
|
||||
```
|
||||
|
||||
Then modify the `dump_stats_path` to "/data/hqt_output/measure" and update `dump_stats_xlsx_path` to /data/hqt_output/measure/fp8stats.xlsx" in maxabs_quant.json file.
|
||||
|
||||
### Restart the TGI Gaudi server within all the metadata mapped
|
||||
|
||||
```bash
|
||||
docker run -p 8080:80 -e QUANT_CONFIG=/data/maxabs_quant.json -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:1.2.1 --model-id Intel/neural-chat-7b-v3-3
|
||||
```
|
||||
|
||||
Now the TGI Gaudi will launch the FP8 model by default and you can make requests like below to check the service status:
|
||||
|
||||
```bash
|
||||
curl 127.0.0.1:8080/generate \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":32}}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
#
|
||||
|
||||
SCRIPT USAGE NOTICE: By downloading and using any script file included with the associated software package (such as files with .bat, .cmd, or .JS extensions, Docker files, or any other type of file that, when executed, automatically downloads and/or installs files onto your system) (the “Script File”), it is your obligation to review the Script File to understand what files (e.g., other software, AI models, AI Datasets) the Script File will download to your system (“Downloaded Files”). Furthermore, by downloading and using the Downloaded Files, even if they are installed through a silent install, you agree to any and all terms and conditions associated with such files, including but not limited to, license terms, notices, or disclaimers.
|
||||
@@ -1,15 +0,0 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
ENV LANG=C.UTF-8
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ffmpeg
|
||||
|
||||
COPY ./asr /asr
|
||||
RUN pip install --no-cache-dir -r /asr/requirements.txt
|
||||
|
||||
WORKDIR /asr
|
||||
|
||||
ENTRYPOINT ["python", "asr_server.py"]
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
FROM python:3.9-slim
|
||||
|
||||
ENV LANG=C.UTF-8
|
||||
ENV PYTHONPATH=/home/user:/GPT-SoVITS/GPT_SoVITS
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ffmpeg \
|
||||
&& apt-get install -y build-essential wget numactl git \
|
||||
&& apt-get install -y libomp-dev google-perftools
|
||||
|
||||
ENV LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libiomp5.so:/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
|
||||
ENV MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000"
|
||||
ENV OMP_NUM_THREADS=56
|
||||
|
||||
|
||||
RUN git clone https://github.com/RVC-Boss/GPT-SoVITS.git /GPT-SoVITS -b main
|
||||
|
||||
RUN pip install --no-cache-dir -r /GPT-SoVITS/requirements.txt
|
||||
|
||||
COPY ./tts/tts_server.py /GPT-SoVITS/
|
||||
COPY ./tts/config.py /GPT-SoVITS/
|
||||
|
||||
# Download the sample ref wav
|
||||
RUN wget https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav -P /GPT-SoVITS
|
||||
RUN wget https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/welcome_cn.wav -P /GPT-SoVITS
|
||||
|
||||
|
||||
#RUN useradd -m -s /bin/bash user && \
|
||||
# mkdir -p /home/user && \
|
||||
# chown -R user /home/user/
|
||||
|
||||
#USER user
|
||||
|
||||
WORKDIR /GPT-SoVITS
|
||||
|
||||
ENTRYPOINT ["python", "tts_server.py"]
|
||||
@@ -1,124 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
import urllib.request
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Audio, Dataset
|
||||
from pydub import AudioSegment
|
||||
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
||||
|
||||
|
||||
class AudioSpeechRecognition:
|
||||
"""Convert audio to text."""
|
||||
|
||||
def __init__(self, model_name_or_path="openai/whisper-small", bf16=False, language="english", device="cpu"):
|
||||
if device == "hpu":
|
||||
# Explicitly link HPU with Torch
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
|
||||
adapt_transformers_to_gaudi()
|
||||
|
||||
self.device = device
|
||||
asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path)
|
||||
print("Downloading model: {}".format(asr_model_name_or_path))
|
||||
self.model = WhisperForConditionalGeneration.from_pretrained(asr_model_name_or_path).to(self.device)
|
||||
self.processor = WhisperProcessor.from_pretrained(asr_model_name_or_path)
|
||||
self.model.eval()
|
||||
self.bf16 = bf16
|
||||
if self.bf16:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
self.model = ipex.optimize(self.model, dtype=torch.bfloat16)
|
||||
self.language = language
|
||||
|
||||
if device == "hpu":
|
||||
# do hpu graph warmup with a long enough input audio
|
||||
# whisper has a receptive field of 30 seconds
|
||||
# here we select a relatively long audio (~15 sec) to quickly warmup
|
||||
self._warmup_whisper_hpu_graph("https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav")
|
||||
|
||||
def _audiosegment_to_librosawav(self, audiosegment):
|
||||
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples
|
||||
# This way is faster than librosa.load or HuggingFace Dataset wrapper
|
||||
channel_sounds = audiosegment.split_to_mono()[:1] # only select the first channel
|
||||
samples = [s.get_array_of_samples() for s in channel_sounds]
|
||||
|
||||
fp_arr = np.array(samples).T.astype(np.float32)
|
||||
fp_arr /= np.iinfo(samples[0].typecode).max
|
||||
fp_arr = fp_arr.reshape(-1)
|
||||
|
||||
return fp_arr
|
||||
|
||||
def _warmup_whisper_hpu_graph(self, url):
|
||||
print("[ASR] fetch warmup audio...")
|
||||
urllib.request.urlretrieve(
|
||||
url,
|
||||
"warmup.wav",
|
||||
)
|
||||
print("[ASR] warmup...")
|
||||
waveform = AudioSegment.from_file("warmup.wav").set_frame_rate(16000)
|
||||
waveform = self._audiosegment_to_librosawav(waveform)
|
||||
# pylint: disable=E1101
|
||||
inputs = self.processor.feature_extractor(
|
||||
waveform, return_tensors="pt", sampling_rate=16_000
|
||||
).input_features.to(self.device)
|
||||
_ = self.model.generate(inputs, language="chinese")
|
||||
|
||||
def audio2text(self, audio_path):
|
||||
"""Convert audio to text.
|
||||
|
||||
audio_path: the path to the input audio, e.g. ~/xxx.mp3
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
waveform = AudioSegment.from_file(audio_path).set_frame_rate(16000)
|
||||
waveform = self._audiosegment_to_librosawav(waveform)
|
||||
except Exception as e:
|
||||
print(f"[ASR] audiosegment to librosa wave fail: {e}")
|
||||
audio_dataset = Dataset.from_dict({"audio": [audio_path]}).cast_column("audio", Audio(sampling_rate=16000))
|
||||
waveform = audio_dataset[0]["audio"]["array"]
|
||||
|
||||
# pylint: disable=E1101
|
||||
inputs = self.processor.feature_extractor(
|
||||
waveform, return_tensors="pt", sampling_rate=16_000
|
||||
).input_features.to(self.device)
|
||||
with torch.cpu.amp.autocast() if self.bf16 else contextlib.nullcontext():
|
||||
predicted_ids = self.model.generate(inputs, language=self.language)
|
||||
# pylint: disable=E1101
|
||||
result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
|
||||
if self.language in ["chinese", "mandarin"]:
|
||||
from zhconv import convert
|
||||
|
||||
result = convert(result, "zh-cn")
|
||||
print(f"generated text in {time.time() - start} seconds, and the result is: {result}")
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asr = AudioSpeechRecognition(language="english")
|
||||
|
||||
# Test multilanguage asr
|
||||
urllib.request.urlretrieve(
|
||||
"https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav",
|
||||
"sample.wav",
|
||||
)
|
||||
asr.language = "chinese"
|
||||
text = asr.audio2text("sample.wav")
|
||||
|
||||
urllib.request.urlretrieve(
|
||||
"https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav",
|
||||
"sample.wav",
|
||||
)
|
||||
text = asr.audio2text("sample.wav")
|
||||
|
||||
os.remove("sample.wav")
|
||||
@@ -1,69 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from asr import AudioSpeechRecognition
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydub import AudioSegment
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
asr = None
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def audio_to_text(file: UploadFile = File(...)):
|
||||
file_name = file.filename
|
||||
print(f"Received file: {file_name}")
|
||||
with open("tmp_audio_bytes", "wb") as fout:
|
||||
content = await file.read()
|
||||
fout.write(content)
|
||||
audio = AudioSegment.from_file("tmp_audio_bytes")
|
||||
audio = audio.set_frame_rate(16000)
|
||||
# bytes to wav
|
||||
file_name = file_name + ".wav"
|
||||
audio.export(f"{file_name}", format="wav")
|
||||
try:
|
||||
asr_result = asr.audio2text(file_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
asr_result = e
|
||||
finally:
|
||||
os.remove(file_name)
|
||||
os.remove("tmp_audio_bytes")
|
||||
return {"asr_result": asr_result}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=8008)
|
||||
parser.add_argument("--model_name_or_path", type=str, default="openai/whisper-tiny")
|
||||
parser.add_argument("--bf16", default=False, action="store_true")
|
||||
parser.add_argument("--language", type=str, default="english")
|
||||
parser.add_argument("--device", type=str, default="cpu")
|
||||
|
||||
args = parser.parse_args()
|
||||
asr = AudioSpeechRecognition(
|
||||
model_name_or_path=args.model_name_or_path, bf16=args.bf16, language=args.language, device=args.device
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
@@ -1,11 +0,0 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
datasets
|
||||
fastapi
|
||||
ffmpeg-python
|
||||
numpy
|
||||
pydub
|
||||
python-multipart
|
||||
torch==2.2.0
|
||||
transformers
|
||||
uvicorn
|
||||
zhconv
|
||||
@@ -1,101 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
#
|
||||
# This script is adapted from
|
||||
# https://github.com/RVC-Boss/GPT-SoVITS/blob/main/api.py
|
||||
# which is under the MIT license
|
||||
#
|
||||
# Copyright (c) 2024 RVC-Boss
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
sovits_path = ""
|
||||
gpt_path = ""
|
||||
is_half_str = os.environ.get("is_half", "True")
|
||||
is_half = True if is_half_str.lower() == "true" else False
|
||||
is_share_str = os.environ.get("is_share", "False")
|
||||
is_share = True if is_share_str.lower() == "true" else False
|
||||
|
||||
cnhubert_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
pretrained_sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
||||
pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||
|
||||
exp_root = "logs"
|
||||
python_exec = sys.executable or "python"
|
||||
if torch.cuda.is_available():
|
||||
infer_device = "cuda"
|
||||
else:
|
||||
infer_device = "cpu"
|
||||
|
||||
webui_port_main = 9874
|
||||
webui_port_uvr5 = 9873
|
||||
webui_port_infer_tts = 9872
|
||||
webui_port_subfix = 9871
|
||||
|
||||
api_port = 9880
|
||||
|
||||
if infer_device == "cuda":
|
||||
gpu_name = torch.cuda.get_device_name(0)
|
||||
if (
|
||||
("16" in gpu_name and "V100" not in gpu_name.upper())
|
||||
or "P40" in gpu_name.upper()
|
||||
or "P10" in gpu_name.upper()
|
||||
or "1060" in gpu_name
|
||||
or "1070" in gpu_name
|
||||
or "1080" in gpu_name
|
||||
):
|
||||
is_half = False
|
||||
|
||||
if infer_device == "cpu":
|
||||
is_half = False
|
||||
use_bf16 = False
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
self.sovits_path = sovits_path
|
||||
self.gpt_path = gpt_path
|
||||
self.is_half = is_half
|
||||
self.use_bf16 = use_bf16
|
||||
|
||||
self.cnhubert_path = cnhubert_path
|
||||
self.bert_path = bert_path
|
||||
self.pretrained_sovits_path = pretrained_sovits_path
|
||||
self.pretrained_gpt_path = pretrained_gpt_path
|
||||
|
||||
self.exp_root = exp_root
|
||||
self.python_exec = python_exec
|
||||
self.infer_device = infer_device
|
||||
|
||||
self.webui_port_main = webui_port_main
|
||||
self.webui_port_uvr5 = webui_port_uvr5
|
||||
self.webui_port_infer_tts = webui_port_infer_tts
|
||||
self.webui_port_subfix = webui_port_subfix
|
||||
|
||||
self.api_port = api_port
|
||||
@@ -1,28 +0,0 @@
|
||||
chardet
|
||||
# funasr==1.0.0
|
||||
cn2an
|
||||
# gradio==3.38.0
|
||||
# gradio_client==0.8.1
|
||||
ffmpeg-python
|
||||
g2p_en
|
||||
jieba
|
||||
jieba_fast
|
||||
LangSegment>=0.2.0
|
||||
# tensorboard
|
||||
librosa==0.9.2
|
||||
numba==0.56.4
|
||||
numpy
|
||||
psutil
|
||||
pyopenjtalk
|
||||
pypinyin
|
||||
pytorch-lightning
|
||||
PyYAML
|
||||
scipy
|
||||
# modelscope==1.10.0
|
||||
sentencepiece
|
||||
torchaudio
|
||||
# onnxruntime
|
||||
tqdm
|
||||
transformers
|
||||
# Faster_Whisper
|
||||
wordsegment
|
||||
@@ -1,741 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
#
|
||||
# This script is adapted from
|
||||
# https://github.com/RVC-Boss/GPT-SoVITS/blob/main/api.py
|
||||
# which is under the MIT license
|
||||
#
|
||||
# Copyright (c) 2024 RVC-Boss
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from time import time as ttime
|
||||
|
||||
import config as global_config
|
||||
import LangSegment
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import uvicorn
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from feature_extractor import cnhubert
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from module.models import SynthesizerTrn
|
||||
from my_utils import load_audio
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from text import cleaned_text_to_sequence
|
||||
from text.cleaner import clean_text
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
|
||||
class DefaultRefer:
|
||||
def __init__(self, path, text, language):
|
||||
self.path = args.default_refer_path
|
||||
self.text = args.default_refer_text
|
||||
self.language = args.default_refer_language
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return is_full(self.path, self.text, self.language)
|
||||
|
||||
|
||||
def is_empty(*items):
|
||||
for item in items:
|
||||
if item is not None and item != "":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_full(*items):
|
||||
for item in items:
|
||||
if item is None or item == "":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def change_sovits_weights(sovits_path):
|
||||
global vq_model, hps
|
||||
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
||||
hps = dict_s2["config"]
|
||||
hps = DictToAttrRecursive(hps)
|
||||
hps.model.semantic_frame_rate = "25hz"
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
)
|
||||
if "pretrained" not in sovits_path:
|
||||
del vq_model.enc_q
|
||||
if is_half:
|
||||
vq_model = vq_model.half().to(device)
|
||||
else:
|
||||
vq_model = vq_model.to(device)
|
||||
vq_model.eval()
|
||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
global hz, max_sec, t2s_model, config
|
||||
hz = 50
|
||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||
config = dict_s1["config"]
|
||||
max_sec = config["data"]["max_sec"]
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
if is_half:
|
||||
t2s_model = t2s_model.half()
|
||||
t2s_model = t2s_model.to(device)
|
||||
t2s_model.eval()
|
||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||
logger.info("Number of parameter: %.2fM" % (total / 1e6))
|
||||
|
||||
|
||||
def get_bert_feature(text, word2ph):
|
||||
with torch.no_grad():
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(device)
|
||||
res = bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
repeat_feature = res[i].repeat(word2ph[i], 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
return phone_level_feature.T
|
||||
|
||||
|
||||
def clean_text_inf(text, language):
|
||||
phones, word2ph, norm_text = clean_text(text, language)
|
||||
phones = cleaned_text_to_sequence(phones)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
|
||||
def get_bert_inf(phones, word2ph, norm_text, language):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||
else:
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float16 if is_half else torch.float32,
|
||||
).to(device)
|
||||
|
||||
return bert
|
||||
|
||||
|
||||
def get_phones_and_bert(text, language):
|
||||
if language in {"en", "all_zh", "all_ja"}:
|
||||
language = language.replace("all_", "")
|
||||
if language == "en":
|
||||
LangSegment.setfilters(["en"])
|
||||
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
||||
else:
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language)
|
||||
if language == "zh":
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||
else:
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float16 if is_half else torch.float32,
|
||||
).to(device)
|
||||
elif language in {"zh", "ja", "auto"}:
|
||||
textlist = []
|
||||
langlist = []
|
||||
LangSegment.setfilters(["zh", "ja", "en", "ko"])
|
||||
if language == "auto":
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
if tmp["lang"] == "ko":
|
||||
langlist.append("zh")
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
|
||||
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
|
||||
return phones, bert.to(torch.float16 if is_half else torch.float32), norm_text
|
||||
|
||||
|
||||
class DictToAttrRecursive:
|
||||
def __init__(self, input_dict):
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
setattr(self, key, DictToAttrRecursive(value))
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
audio = load_audio(filename, int(hps.data.sampling_rate))
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
hps.data.filter_length,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
return spec
|
||||
|
||||
|
||||
def pack_audio(audio_bytes, data, rate):
|
||||
if media_type == "ogg":
|
||||
audio_bytes = pack_ogg(audio_bytes, data, rate)
|
||||
elif media_type == "aac":
|
||||
audio_bytes = pack_aac(audio_bytes, data, rate)
|
||||
else:
|
||||
audio_bytes = pack_raw(audio_bytes, data, rate)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
def pack_ogg(audio_bytes, data, rate):
|
||||
with sf.SoundFile(audio_bytes, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||
audio_file.write(data)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
def pack_raw(audio_bytes, data, rate):
|
||||
audio_bytes.write(data.tobytes())
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
def pack_wav(audio_bytes, rate):
|
||||
data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int16)
|
||||
wav_bytes = BytesIO()
|
||||
sf.write(wav_bytes, data, rate, format="wav")
|
||||
|
||||
return wav_bytes
|
||||
|
||||
|
||||
def pack_aac(audio_bytes, data, rate):
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-f",
|
||||
"s16le",
|
||||
"-ar",
|
||||
str(rate),
|
||||
"-ac",
|
||||
"1",
|
||||
"-i",
|
||||
"pipe:0",
|
||||
"-c:a",
|
||||
"aac",
|
||||
"-b:a",
|
||||
"192k",
|
||||
"-vn",
|
||||
"-f",
|
||||
"adts",
|
||||
"pipe:1",
|
||||
],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
audio_bytes.write(out)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
def read_clean_buffer(audio_bytes):
|
||||
audio_chunk = audio_bytes.getvalue()
|
||||
audio_bytes.truncate(0)
|
||||
audio_bytes.seek(0)
|
||||
|
||||
return audio_bytes, audio_chunk
|
||||
|
||||
|
||||
def cut_text(text, punc):
|
||||
text = re.escape(text)
|
||||
punc_list = [",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"]
|
||||
if len(punc_list) > 0:
|
||||
punds = r"[" + "".join(punc_list) + r"]"
|
||||
text = text.strip("\n")
|
||||
items = re.split(f"({punds})", text)
|
||||
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
|
||||
if len(items) % 2 == 1:
|
||||
mergeitems.append(items[-1])
|
||||
text = "\n".join(mergeitems)
|
||||
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def only_punc(text):
|
||||
return not any(t.isalnum() or t.isalpha() for t in text)
|
||||
|
||||
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
t0 = ttime()
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
prompt_language, text = prompt_language, text.strip("\n")
|
||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half else np.float32)
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
if is_half:
|
||||
wav16k = wav16k.half().to(device)
|
||||
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||
else:
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
codes = vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
t1 = ttime()
|
||||
prompt_language = dict_language[prompt_language.lower()]
|
||||
text_language = dict_language[text_language.lower()]
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language)
|
||||
texts = text.split("\n")
|
||||
audio_bytes = BytesIO()
|
||||
|
||||
for text in texts:
|
||||
if only_punc(text):
|
||||
continue
|
||||
|
||||
audio_opt = []
|
||||
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
# import intel_extension_for_pytorch as ipex
|
||||
# ipex.optimize(t2s_model.model)
|
||||
# from torch import profiler
|
||||
t2 = ttime()
|
||||
with torch.no_grad():
|
||||
# with profiler.profile(record_shapes=True) as prof:
|
||||
# with profiler.record_function("model_inference"):
|
||||
with (
|
||||
torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True)
|
||||
if use_bf16
|
||||
else contextlib.nullcontext()
|
||||
):
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_len,
|
||||
prompt,
|
||||
bert,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=config["inference"]["top_k"],
|
||||
early_stop_num=hz * max_sec,
|
||||
)
|
||||
# print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
|
||||
t3 = ttime()
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
refer = get_spepc(hps, ref_wav_path)
|
||||
if is_half:
|
||||
refer = refer.half().to(device)
|
||||
else:
|
||||
refer = refer.to(device)
|
||||
audio = (
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()[0, 0]
|
||||
)
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav)
|
||||
t4 = ttime()
|
||||
audio_bytes = pack_audio(
|
||||
audio_bytes, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16), hps.data.sampling_rate
|
||||
)
|
||||
logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
if stream_mode == "normal":
|
||||
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
|
||||
yield audio_chunk
|
||||
|
||||
if not stream_mode == "normal":
|
||||
if media_type == "wav":
|
||||
audio_bytes = pack_wav(audio_bytes, hps.data.sampling_rate)
|
||||
yield audio_bytes.getvalue()
|
||||
|
||||
|
||||
def handle_control(command):
|
||||
if command == "restart":
|
||||
os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
|
||||
elif command == "exit":
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
||||
|
||||
|
||||
def handle_change(path, text, language):
|
||||
if is_empty(path, text, language):
|
||||
return JSONResponse(
|
||||
{"code": 400, "message": 'missing any of the following parameters: "path", "text", "language"'},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if path != "" or path is not None:
|
||||
default_refer.path = path
|
||||
if text != "" or text is not None:
|
||||
default_refer.text = text
|
||||
if language != "" or language is not None:
|
||||
default_refer.language = language
|
||||
|
||||
logger.info(f"current default reference audio path: {default_refer.path}")
|
||||
logger.info(f"current default reference audio text: {default_refer.text}")
|
||||
logger.info(f"current default reference audio language: {default_refer.language}")
|
||||
logger.info(f"is_ready: {default_refer.is_ready()}")
|
||||
|
||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||
|
||||
|
||||
def text_stream_generator(result):
|
||||
"""Embed the unicode byte values to base64 and yield the text stream with data prefix.
|
||||
|
||||
Accepts a generator of bytes
|
||||
Returns a generator of string
|
||||
"""
|
||||
for bytes in result:
|
||||
data = base64.b64encode(bytes)
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc):
|
||||
if (
|
||||
refer_wav_path == ""
|
||||
or refer_wav_path is None
|
||||
or prompt_text == ""
|
||||
or prompt_text is None
|
||||
or prompt_language == ""
|
||||
or prompt_language is None
|
||||
):
|
||||
refer_wav_path, prompt_text, prompt_language = (
|
||||
default_refer.path,
|
||||
default_refer.text,
|
||||
default_refer.language,
|
||||
)
|
||||
if not default_refer.is_ready():
|
||||
return JSONResponse({"code": 400, "message": "unspecified refer audio!"}, status_code=400)
|
||||
|
||||
if cut_punc is None:
|
||||
text = cut_text(text, default_cut_punc)
|
||||
else:
|
||||
text = cut_text(text, cut_punc)
|
||||
|
||||
if not return_text_stream:
|
||||
return StreamingResponse(
|
||||
get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language),
|
||||
media_type="audio/" + media_type,
|
||||
)
|
||||
else:
|
||||
result = get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
|
||||
return StreamingResponse(text_stream_generator(result), media_type="text/event-stream")
|
||||
|
||||
|
||||
# --------------------------------
|
||||
# Initialization part
|
||||
# --------------------------------
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
|
||||
dict_language = {
|
||||
"中文": "all_zh",
|
||||
"英文": "en",
|
||||
"日文": "all_ja",
|
||||
"中英混合": "zh",
|
||||
"日英混合": "ja",
|
||||
"多语种混合": "auto",
|
||||
"all_zh": "all_zh",
|
||||
"en": "en",
|
||||
"all_ja": "all_ja",
|
||||
"zh": "zh",
|
||||
"ja": "ja",
|
||||
"auto": "auto",
|
||||
}
|
||||
|
||||
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
g_config = global_config.Config()
|
||||
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||
|
||||
parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS model path")
|
||||
parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT model path")
|
||||
parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="default reference audio path")
|
||||
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="default reference audio text")
|
||||
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="default reference audio language")
|
||||
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
|
||||
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
|
||||
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
||||
parser.add_argument(
|
||||
"-fp", "--full_precision", action="store_true", default=False, help="overwrite config.is_half, use fp32"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-hp", "--half_precision", action="store_true", default=False, help="overwrite config.is_half, use fp16"
|
||||
)
|
||||
# Here add an argument for specifying torch.bfloat16 inference on Xeon CPU
|
||||
parser.add_argument("-bf16", "--bf16", action="store_true", default=False, help="use bfloat16")
|
||||
parser.add_argument(
|
||||
"-sm", "--stream_mode", type=str, default="close", help="streaming response, close / normal / keepalive"
|
||||
)
|
||||
parser.add_argument("-mt", "--media_type", type=str, default="wav", help="media type, wav / ogg / aac")
|
||||
parser.add_argument("-cp", "--cut_punc", type=str, default="", help="text splitter, among ,.;?!、,。?!;:…")
|
||||
parser.add_argument(
|
||||
"-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="overwrite config.cnhubert_path"
|
||||
)
|
||||
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="overwrite config.bert_path")
|
||||
# Here add an argument to decide whether to return text/event-stream base64 encoded bytes to frontend
|
||||
# rather than audio bytes
|
||||
parser.add_argument(
|
||||
"-rts",
|
||||
"--return_text_stream",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="whether to return text/event-stream base64 encoded bytes to frontend",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
sovits_path = args.sovits_path
|
||||
gpt_path = args.gpt_path
|
||||
device = args.device
|
||||
port = args.port
|
||||
host = args.bind_addr
|
||||
cnhubert_base_path = args.hubert_path
|
||||
bert_path = args.bert_path
|
||||
default_cut_punc = args.cut_punc
|
||||
return_text_stream = args.return_text_stream
|
||||
|
||||
# Set default reference configuration
|
||||
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
|
||||
|
||||
# Check model paths
|
||||
if sovits_path == "":
|
||||
sovits_path = g_config.pretrained_sovits_path
|
||||
logger.warn(f"Unspecified SOVITS model path, fallback to current path: {sovits_path}")
|
||||
if gpt_path == "":
|
||||
gpt_path = g_config.pretrained_gpt_path
|
||||
logger.warn(f"Unspecified GPT model path, fallback to current path: {gpt_path}")
|
||||
|
||||
if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
|
||||
default_refer.path, default_refer.text, default_refer.language = "", "", ""
|
||||
logger.info("Unspecified default refer audio")
|
||||
else:
|
||||
logger.info(f"default refer audio path: {default_refer.path}")
|
||||
logger.info(f"default refer audio text: {default_refer.text}")
|
||||
logger.info(f"default refer audio language: {default_refer.language}")
|
||||
|
||||
# deal with half precision
|
||||
if device == "cuda":
|
||||
is_half = g_config.is_half
|
||||
use_bf16 = False
|
||||
if args.full_precision:
|
||||
is_half = False
|
||||
if args.half_precision:
|
||||
is_half = True
|
||||
if args.full_precision and args.half_precision:
|
||||
is_half = g_config.is_half # fallback to fp32
|
||||
logger.info(f"fp16 half: {is_half}")
|
||||
else:
|
||||
is_half = False
|
||||
use_bf16 = g_config.use_bf16
|
||||
if args.full_precision:
|
||||
use_bf16 = False
|
||||
elif args.bf16:
|
||||
use_bf16 = True
|
||||
|
||||
logger.info(f"bf16 half: {use_bf16}")
|
||||
|
||||
# stream response mode
|
||||
if args.stream_mode.lower() in ["normal", "n"]:
|
||||
stream_mode = "normal"
|
||||
logger.info("stream response mode enabled")
|
||||
else:
|
||||
stream_mode = "close"
|
||||
|
||||
# media type
|
||||
if args.media_type.lower() in ["aac", "ogg"]:
|
||||
media_type = args.media_type.lower()
|
||||
elif stream_mode == "close":
|
||||
media_type = "wav"
|
||||
else:
|
||||
media_type = "ogg"
|
||||
logger.info(f"media type: {media_type}")
|
||||
|
||||
# Initialize the model
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||
ssl_model = cnhubert.get_model()
|
||||
if is_half:
|
||||
bert_model = bert_model.half().to(device)
|
||||
ssl_model = ssl_model.half().to(device)
|
||||
else:
|
||||
bert_model = bert_model.to(device)
|
||||
ssl_model = ssl_model.to(device)
|
||||
change_sovits_weights(sovits_path)
|
||||
change_gpt_weights(gpt_path)
|
||||
|
||||
|
||||
# --------------------------------
|
||||
# APIs
|
||||
# --------------------------------
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
|
||||
)
|
||||
|
||||
|
||||
@app.post("/set_model")
|
||||
async def set_model(request: Request):
|
||||
json_post_raw = await request.json()
|
||||
global gpt_path
|
||||
gpt_path = json_post_raw.get("gpt_model_path")
|
||||
global sovits_path
|
||||
sovits_path = json_post_raw.get("sovits_model_path")
|
||||
logger.info("gptpath" + gpt_path + ";vitspath" + sovits_path)
|
||||
change_sovits_weights(sovits_path)
|
||||
change_gpt_weights(gpt_path)
|
||||
return "ok"
|
||||
|
||||
|
||||
@app.post("/control")
|
||||
async def control_req(request: Request):
|
||||
json_post_raw = await request.json()
|
||||
return handle_control(json_post_raw.get("command"))
|
||||
|
||||
|
||||
@app.get("/control")
|
||||
async def control(command: str = None):
|
||||
return handle_control(command)
|
||||
|
||||
|
||||
@app.post("/change_refer")
|
||||
async def change_refer_req(request: Request):
|
||||
json_post_raw = await request.json()
|
||||
return handle_change(
|
||||
json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language")
|
||||
)
|
||||
|
||||
|
||||
@app.get("/change_refer")
|
||||
async def change_refer(refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None):
|
||||
return handle_change(refer_wav_path, prompt_text, prompt_language)
|
||||
|
||||
|
||||
@app.post("/v1/audio/speech")
|
||||
async def tts_endpoint_req(request: Request):
|
||||
json_post_raw = await request.json()
|
||||
return handle(
|
||||
json_post_raw.get("refer_wav_path"),
|
||||
json_post_raw.get("prompt_text"),
|
||||
json_post_raw.get("prompt_language"),
|
||||
json_post_raw.get("text"),
|
||||
json_post_raw.get("text_language"),
|
||||
json_post_raw.get("cut_punc"),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/audio/speech")
|
||||
async def tts_endpoint(
|
||||
refer_wav_path: str = None,
|
||||
prompt_text: str = None,
|
||||
prompt_language: str = None,
|
||||
text: str = None,
|
||||
text_language: str = None,
|
||||
cut_punc: str = None,
|
||||
):
|
||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc)
|
||||
|
||||
|
||||
@app.post("/upload_as_default")
|
||||
async def upload_audio(
|
||||
default_refer_file: UploadFile = File(...),
|
||||
default_refer_text: str = Form(...),
|
||||
default_refer_language: str = Form(...),
|
||||
):
|
||||
if not default_refer_file or not default_refer_file or not default_refer_language:
|
||||
return JSONResponse(
|
||||
{"code": 400, "message": "reference audio, text and language must be provided!"}, status_code=400
|
||||
)
|
||||
name = default_refer_file.filename
|
||||
|
||||
if name.endswith(".mp3") or name.endswith(".wav"):
|
||||
# temp file location
|
||||
tmp_file_location = f"/tmp/{name}"
|
||||
with open(tmp_file_location, "wb+") as f:
|
||||
f.write(default_refer_file.file.read())
|
||||
logger.info(f"reference audio saved at {tmp_file_location}!")
|
||||
return handle_change(path=tmp_file_location, text=default_refer_text, language=default_refer_language)
|
||||
else:
|
||||
return JSONResponse({"code": 400, "message": "audio name invalid!"}, status_code=400)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host=host, port=port, workers=1)
|
||||
@@ -1,38 +0,0 @@
|
||||
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# SCRIPT USAGE NOTICE: By downloading and using any script file included
|
||||
# with the associated software package (such as files with .bat, .cmd, or
|
||||
# .JS extensions, Docker files, or any other type of file that, when executed,
|
||||
# automatically downloads and/or installs files onto your system) (the “Script File”),
|
||||
# it is your obligation to review the Script File to understand what files (e.g.,
|
||||
# other software, AI models, AI Datasets) the Script File will download to your system
|
||||
# (“Downloaded Files”). Furthermore, by downloading and using the Downloaded Files,
|
||||
# even if they are installed through a silent install, you agree to any and all
|
||||
# terms and conditions associated with such files, including but not limited to,
|
||||
# license terms, notices, or disclaimers.
|
||||
|
||||
FROM langchain/langchain:latest
|
||||
|
||||
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
|
||||
libgl1-mesa-glx \
|
||||
libjemalloc-dev
|
||||
|
||||
# RUN useradd -m -s /bin/bash user && \
|
||||
# mkdir -p /home/user && \
|
||||
# chown -R user /home/user/
|
||||
|
||||
# USER user
|
||||
|
||||
COPY requirements.txt /tmp/requirements.txt
|
||||
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir -r /tmp/requirements.txt
|
||||
|
||||
ENV PYTHONPATH=$PYTHONPATH:/ws:/home/user:/home/user/qna-app/app
|
||||
|
||||
WORKDIR /home/user/qna-app
|
||||
COPY qna-app /home/user/qna-app
|
||||
|
||||
ENTRYPOINT ["/usr/bin/sleep", "infinity"]
|
||||
@@ -1,32 +0,0 @@
|
||||
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
services:
|
||||
redis-vector-db:
|
||||
image: redis/redis-stack:7.2.0-v9
|
||||
container_name: redis-vector-db
|
||||
ports:
|
||||
- "6379:6379"
|
||||
- "8001:8001"
|
||||
qna-rag-redis-server:
|
||||
build:
|
||||
args:
|
||||
https_proxy: ${https_proxy}
|
||||
dockerfile: Dockerfile
|
||||
image: intel/gen-ai-examples:qna-rag-redis-server
|
||||
container_name: qna-rag-redis-server
|
||||
environment:
|
||||
- https_proxy=${https_proxy}
|
||||
- HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN}
|
||||
- "REDIS_PORT=6379"
|
||||
- "EMBED_MODEL=BAAI/bge-base-en-v1.5"
|
||||
- "REDIS_SCHEMA=schema_dim_768.yml"
|
||||
ulimits:
|
||||
memlock:
|
||||
soft: -1 # Set memlock to unlimited (no soft or hard limit)
|
||||
hard: -1
|
||||
volumes:
|
||||
- ../redis:/ws
|
||||
- ../test:/test
|
||||
network_mode: "host"
|
||||
@@ -1,25 +0,0 @@
|
||||
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
RUN pip install --no-cache-dir poetry==1.6.1
|
||||
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
WORKDIR /code
|
||||
|
||||
COPY ./pyproject.toml ./README.md ./poetry.lock* ./
|
||||
|
||||
COPY ./package[s] ./packages
|
||||
|
||||
RUN poetry install --no-interaction --no-ansi --no-root
|
||||
|
||||
COPY ./app ./app
|
||||
|
||||
RUN poetry install --no-interaction --no-ansi
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
CMD ["uvicorn", "app.server:app", "--host", "0.0.0.0", "--port", "8080"]
|
||||
@@ -1,79 +0,0 @@
|
||||
# my-app
|
||||
|
||||
## Installation
|
||||
|
||||
Install the LangChain CLI if you haven't yet
|
||||
|
||||
```bash
|
||||
pip install -U langchain-cli
|
||||
```
|
||||
|
||||
## Adding packages
|
||||
|
||||
```bash
|
||||
# adding packages from
|
||||
# https://github.com/langchain-ai/langchain/tree/master/templates
|
||||
langchain app add $PROJECT_NAME
|
||||
|
||||
# adding custom GitHub repo packages
|
||||
langchain app add --repo $OWNER/$REPO
|
||||
# or with whole git string (supports other git providers):
|
||||
# langchain app add git+https://github.com/hwchase17/chain-of-verification
|
||||
|
||||
# with a custom api mount point (defaults to `/{package_name}`)
|
||||
langchain app add $PROJECT_NAME --api_path=/my/custom/path/rag
|
||||
```
|
||||
|
||||
Note: you remove packages by their api path
|
||||
|
||||
```bash
|
||||
langchain app remove my/custom/path/rag
|
||||
```
|
||||
|
||||
## Setup LangSmith (Optional)
|
||||
|
||||
LangSmith will help us trace, monitor and debug LangChain applications.
|
||||
LangSmith is currently in private beta, you can sign up [here](https://smith.langchain.com/).
|
||||
If you don't have access, you can skip this section
|
||||
|
||||
```shell
|
||||
export LANGCHAIN_TRACING_V2=true
|
||||
export LANGCHAIN_API_KEY=<your-api-key>
|
||||
export LANGCHAIN_PROJECT=<your-project> # if not specified, defaults to "default"
|
||||
```
|
||||
|
||||
## Launch LangServe
|
||||
|
||||
```bash
|
||||
langchain serve
|
||||
```
|
||||
|
||||
## Running in Docker
|
||||
|
||||
This project folder includes a Dockerfile that allows you to easily build and host your LangServe app.
|
||||
|
||||
### Building the Image
|
||||
|
||||
To build the image, you simply:
|
||||
|
||||
```shell
|
||||
docker build . -t my-langserve-app
|
||||
```
|
||||
|
||||
If you tag your image with something other than `my-langserve-app`,
|
||||
note it for use in the next step.
|
||||
|
||||
### Running the Image Locally
|
||||
|
||||
To run the image, you'll need to include any environment variables
|
||||
necessary for your application.
|
||||
|
||||
In the below example, we inject the `OPENAI_API_KEY` environment
|
||||
variable with the value set in my local environment
|
||||
(`$OPENAI_API_KEY`)
|
||||
|
||||
We also expose port 8080 with the `-p 8080:8080` option.
|
||||
|
||||
```shell
|
||||
docker run -e OPENAI_API_KEY=$OPENAI_API_KEY -p 8080:8080 my-langserve-app
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
@@ -1,90 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
|
||||
unsafe_categories = """O1: Violence and Hate.
|
||||
Should not
|
||||
- Help people plan or engage in violence.
|
||||
- Encourage people to engage in violence.
|
||||
- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
|
||||
- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
|
||||
Can
|
||||
- Provide information on violence and discrimination.
|
||||
- Discuss the topics of hate and violence, and can discuss historical events involving violence.
|
||||
O2: Sexual Content.
|
||||
Should not
|
||||
- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
|
||||
Can
|
||||
- Discuss topics related to sexuality and sexual education.
|
||||
- Provide general information regarding sexual activities.
|
||||
- Provide guidance on topics related to sexual health.
|
||||
O3: Criminal Planning.
|
||||
Should not
|
||||
- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
|
||||
- Encourage people to engage in criminal activities or condone criminal behavior.
|
||||
Can
|
||||
- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
|
||||
- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.
|
||||
O4: Guns and Illegal Weapons.
|
||||
Should not
|
||||
- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
|
||||
- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
|
||||
Can
|
||||
- Discuss firearms and the arguments for and against firearm ownership.
|
||||
O5: Regulated or Controlled Substances.
|
||||
Should not
|
||||
- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
|
||||
- Assist or encourage people to create such substances when it is illegal to do so.
|
||||
Can
|
||||
- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
|
||||
- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.
|
||||
O6: Self-Harm.
|
||||
Should not
|
||||
- Encourage people to harm themselves.
|
||||
- Romanticize or condone self-harm or suicide.
|
||||
- Provide information regarding the methods of suicide or self-harm.
|
||||
- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
|
||||
Can
|
||||
- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
|
||||
Should
|
||||
- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves."""
|
||||
|
||||
unsafe_dict = {
|
||||
"O1": "Violence and Hate",
|
||||
"O2": "Sexual Content",
|
||||
"O3": "Criminal Planning",
|
||||
"O4": "Guns and Illegal Weapons",
|
||||
"O5": "Regulated or Controlled Substances",
|
||||
"O6": "Self-Harm",
|
||||
}
|
||||
|
||||
|
||||
def moderation_prompt_for_chat(chat):
|
||||
# For simplicity, we assume the chat dictionary correctly alternates "user" and "assistant" roles
|
||||
# Feel free to include error checking if needed
|
||||
conversation = [turn["content"] for turn in chat]
|
||||
is_agent = len(conversation) % 2 == 0
|
||||
role = "Agent" if len(conversation) % 2 == 0 else "User"
|
||||
prompt = f"""[INST] Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories.
|
||||
|
||||
<BEGIN UNSAFE CONTENT CATEGORIES>
|
||||
{unsafe_categories}
|
||||
<END UNSAFE CONTENT CATEGORIES>
|
||||
|
||||
<BEGIN CONVERSATION>
|
||||
|
||||
"""
|
||||
|
||||
# Alternate User/Agent turns, inserting 2 newlines between each
|
||||
for i, m in enumerate(conversation):
|
||||
role = "User" if i % 2 == 0 else "Agent"
|
||||
prompt += f"{role}: {m}\n\n"
|
||||
prompt += f"""<END CONVERSATION>
|
||||
|
||||
Provide your safety assessment for {role} in the above conversation:
|
||||
- First line must read 'safe' or 'unsafe'.
|
||||
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
|
||||
return prompt
|
||||
@@ -1,54 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
# ========= Raw Q&A template prompt =========
|
||||
template = """### System:\n\n
|
||||
You are an assistant chatbot. You answer questions. \
|
||||
If you don't know the answer, just say that you don't know. \
|
||||
Use three sentences maximum and keep the answer concise.\
|
||||
### User:\n{question}\n### Assistant:\n"""
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
|
||||
# ========= contextualize prompt =========
|
||||
contextualize_q_system_prompt = """Given a chat history and the latest user question \
|
||||
which might reference context in the chat history, formulate a standalone question \
|
||||
which can be understood without the chat history. Do NOT answer the question, \
|
||||
just reformulate it if needed and otherwise return it as is."""
|
||||
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", contextualize_q_system_prompt),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ========= Q&A with history prompt =========
|
||||
# qa_system_prompt = """You are an assistant for question-answering tasks. \
|
||||
# Use the following pieces of retrieved context to answer the question. \
|
||||
# If you don't know the answer, just say that you don't know. \
|
||||
# Use three sentences maximum and keep the answer concise.\
|
||||
|
||||
# {context}"""
|
||||
# qa_prompt = ChatPromptTemplate.from_messages(
|
||||
# [
|
||||
# ("system", qa_system_prompt),
|
||||
# MessagesPlaceholder(variable_name="chat_history"),
|
||||
# ("human", "{question}"),
|
||||
# ]
|
||||
# )
|
||||
template = """### System:\n\n
|
||||
You are an assistant chatbot. You answer questions. \
|
||||
Use the following pieces of retrieved context to answer the question. \
|
||||
If you don't know the answer, just say that you don't know. \
|
||||
Use three sentences maximum and keep the answer concise.\
|
||||
{context}
|
||||
### User:\n{question}\n### Assistant:\n"""
|
||||
qa_prompt = ChatPromptTemplate.from_template(template)
|
||||
@@ -1,322 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, FastAPI, File, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
||||
from guardrails import moderation_prompt_for_chat, unsafe_dict
|
||||
from langchain.globals import set_debug, set_verbose
|
||||
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceHubEmbeddings
|
||||
from langchain_community.llms import HuggingFaceEndpoint
|
||||
from langchain_community.vectorstores import Redis
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langserve import add_routes
|
||||
from prompts import contextualize_q_prompt, prompt, qa_prompt
|
||||
from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from utils import (
|
||||
create_kb_folder,
|
||||
create_retriever_from_files,
|
||||
create_retriever_from_links,
|
||||
get_current_beijing_time,
|
||||
post_process_text,
|
||||
reload_retriever,
|
||||
)
|
||||
|
||||
set_verbose(True)
|
||||
set_debug(True)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
|
||||
)
|
||||
|
||||
|
||||
class RAGAPIRouter(APIRouter):
|
||||
|
||||
def __init__(self, upload_dir, entrypoint, safety_guard_endpoint, tei_endpoint=None) -> None:
|
||||
super().__init__()
|
||||
self.upload_dir = upload_dir
|
||||
self.entrypoint = entrypoint
|
||||
self.safety_guard_endpoint = safety_guard_endpoint
|
||||
print(
|
||||
f"[rag - router] Initializing API Router, params:\n \
|
||||
upload_dir={upload_dir}, entrypoint={entrypoint}"
|
||||
)
|
||||
|
||||
# Define LLM
|
||||
self.llm = HuggingFaceEndpoint(
|
||||
endpoint_url=entrypoint,
|
||||
max_new_tokens=1024,
|
||||
top_k=10,
|
||||
top_p=0.95,
|
||||
typical_p=0.95,
|
||||
temperature=0.01,
|
||||
repetition_penalty=1.03,
|
||||
streaming=True,
|
||||
)
|
||||
# for NeuralChatEndpoint:
|
||||
"""
|
||||
self.llm = NeuralChatEndpoint(
|
||||
endpoint_url=entrypoint,
|
||||
max_new_tokens=1024,
|
||||
top_k=10,
|
||||
top_p=0.95,
|
||||
typical_p=0.95,
|
||||
temperature=0.01,
|
||||
repetition_penalty=1.03,
|
||||
streaming=True,
|
||||
)
|
||||
"""
|
||||
if self.safety_guard_endpoint:
|
||||
self.llm_guard = HuggingFaceEndpoint(
|
||||
endpoint_url=safety_guard_endpoint,
|
||||
max_new_tokens=100,
|
||||
top_k=1,
|
||||
top_p=0.95,
|
||||
typical_p=0.95,
|
||||
temperature=0.01,
|
||||
repetition_penalty=1.03,
|
||||
)
|
||||
print("[rag - router] LLM initialized.")
|
||||
|
||||
# Define LLM Chain
|
||||
if tei_endpoint:
|
||||
# create embeddings using TEI endpoint service
|
||||
self.embeddings = HuggingFaceHubEmbeddings(model=tei_endpoint)
|
||||
else:
|
||||
# create embeddings using local embedding model
|
||||
self.embeddings = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL)
|
||||
|
||||
try:
|
||||
rds = Redis.from_existing_index(
|
||||
self.embeddings,
|
||||
index_name=INDEX_NAME,
|
||||
redis_url=REDIS_URL,
|
||||
schema=INDEX_SCHEMA,
|
||||
)
|
||||
retriever = rds.as_retriever(search_type="mmr")
|
||||
except Exception as e:
|
||||
print(
|
||||
"[rag - chat] Initializing Redis RAG failure, will skip RAG and fallback to normal chat in the chain!"
|
||||
)
|
||||
retriever = None
|
||||
# Define contextualize chain
|
||||
# self.contextualize_q_chain = contextualize_q_prompt | self.llm | StrOutputParser()
|
||||
self.contextualize_q_chain = prompt | self.llm | StrOutputParser()
|
||||
|
||||
# Define LLM chain
|
||||
if retriever:
|
||||
self.llm_chain = (
|
||||
RunnablePassthrough.assign(context=self.contextualized_question | retriever) | qa_prompt | self.llm
|
||||
)
|
||||
else:
|
||||
self.llm_chain = RunnablePassthrough.assign(context=self.contextualized_question) | prompt | self.llm
|
||||
print("[rag - router] LLM chain initialized.")
|
||||
|
||||
# Define chat history
|
||||
self.chat_history = []
|
||||
|
||||
def contextualized_question(self, input: dict):
|
||||
if input.get("chat_history"):
|
||||
return self.contextualize_q_chain
|
||||
else:
|
||||
return input["question"]
|
||||
|
||||
def handle_rag_chat(self, query: str):
|
||||
response = self.llm_chain.invoke({"question": query, "chat_history": self.chat_history})
|
||||
# response = self.llm_chain.invoke({"question": query})
|
||||
result = response.split("</s>")[0]
|
||||
self.chat_history.extend([HumanMessage(content=query), response])
|
||||
# output guardrails
|
||||
if self.safety_guard_endpoint:
|
||||
response_output_guard = self.llm_guard(
|
||||
moderation_prompt_for_chat("Agent", f"User: {query}\n Agent: {response}")
|
||||
)
|
||||
if "unsafe" in response_output_guard:
|
||||
policy_violation_level = response_output_guard.split("\n")[1].strip()
|
||||
policy_violations = unsafe_dict[policy_violation_level]
|
||||
print(f"Violated policies: {policy_violations}")
|
||||
return policy_violations + " are found in the output"
|
||||
else:
|
||||
return result.lstrip()
|
||||
return result.lstrip()
|
||||
|
||||
|
||||
upload_dir = os.getenv("RAG_UPLOAD_DIR", "./upload_dir")
|
||||
tgi_llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
|
||||
safety_guard_endpoint = os.getenv("SAFETY_GUARD_ENDPOINT")
|
||||
tei_embedding_endpoint = os.getenv("TEI_ENDPOINT")
|
||||
router = RAGAPIRouter(upload_dir, tgi_llm_endpoint, safety_guard_endpoint, tei_embedding_endpoint)
|
||||
|
||||
|
||||
@router.post("/v1/rag/chat")
|
||||
async def rag_chat(request: Request):
|
||||
params = await request.json()
|
||||
print(f"[rag - chat] POST request: /v1/rag/chat, params:{params}")
|
||||
query = params["query"]
|
||||
kb_id = params.get("knowledge_base_id", "default")
|
||||
print(f"[rag - chat] history: {router.chat_history}")
|
||||
|
||||
# prompt guardrails
|
||||
if router.safety_guard_endpoint:
|
||||
response_input_guard = router.llm_guard(moderation_prompt_for_chat("User", query))
|
||||
if "unsafe" in response_input_guard:
|
||||
policy_violation_level = response_input_guard.split("\n")[1].strip()
|
||||
policy_violations = unsafe_dict[policy_violation_level]
|
||||
print(f"Violated policies: {policy_violations}")
|
||||
return f"Violated policies: {policy_violations}, please check your input."
|
||||
|
||||
if kb_id == "default":
|
||||
print("[rag - chat] use default knowledge base")
|
||||
new_index_name = INDEX_NAME
|
||||
elif kb_id.startswith("kb"):
|
||||
new_index_name = INDEX_NAME + kb_id
|
||||
print(f"[rag - chat] use knowledge base {kb_id}, index name is {new_index_name}")
|
||||
else:
|
||||
return JSONResponse(status_code=400, content={"message": "Wrong knowledge base id."})
|
||||
|
||||
try:
|
||||
retriever = reload_retriever(router.embeddings, new_index_name)
|
||||
router.llm_chain = (
|
||||
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
|
||||
)
|
||||
except Exception as e:
|
||||
print("[rag - chat] Initializing Redis RAG failure, will skip RAG and fallback to normal chat in the chain!")
|
||||
return router.handle_rag_chat(query=query)
|
||||
|
||||
|
||||
@router.post("/v1/rag/chat_stream")
|
||||
async def rag_chat_stream(request: Request):
|
||||
params = await request.json()
|
||||
print(f"[rag - chat_stream] POST request: /v1/rag/chat_stream, params:{params}")
|
||||
query = params["query"]
|
||||
kb_id = params.get("knowledge_base_id", "default")
|
||||
print(f"[rag - chat_stream] history: {router.chat_history}")
|
||||
|
||||
# prompt guardrails
|
||||
if router.safety_guard_endpoint:
|
||||
response_input_guard = router.llm_guard(moderation_prompt_for_chat("User", query))
|
||||
if "unsafe" in response_input_guard:
|
||||
policy_violation_level = response_input_guard.split("\n")[1].strip()
|
||||
policy_violations = unsafe_dict[policy_violation_level]
|
||||
print(f"Violated policies: {policy_violations}")
|
||||
|
||||
def generate_content():
|
||||
content = f"Violated policies: {policy_violations}, please check your input."
|
||||
yield f"data: {content}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(generate_content(), media_type="text/event-stream")
|
||||
|
||||
if kb_id == "default":
|
||||
print("[rag - chat] use default knowledge base")
|
||||
new_index_name = INDEX_NAME
|
||||
elif kb_id.startswith("kb"):
|
||||
new_index_name = INDEX_NAME + kb_id
|
||||
print(f"[rag - chat] use knowledge base {kb_id}, index name is {new_index_name}")
|
||||
else:
|
||||
return JSONResponse(status_code=400, content={"message": "Wrong knowledge base id."})
|
||||
|
||||
try:
|
||||
retriever = reload_retriever(router.embeddings, new_index_name)
|
||||
router.llm_chain = (
|
||||
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
|
||||
)
|
||||
except Exception as e:
|
||||
print("[rag - chat] Initializing Redis RAG failure, will skip RAG and fallback to normal chat in the chain!")
|
||||
|
||||
def stream_generator():
|
||||
chat_response = ""
|
||||
for text in router.llm_chain.stream({"question": query, "chat_history": router.chat_history}):
|
||||
# for text in router.llm_chain.stream({"question": query}):
|
||||
chat_response += text
|
||||
processed_text = post_process_text(text)
|
||||
if text is not None:
|
||||
yield processed_text
|
||||
chat_response = chat_response.split("</s>")[0]
|
||||
print(f"[rag - chat_stream] stream response: {chat_response}")
|
||||
router.chat_history.extend([HumanMessage(content=query), chat_response])
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/rag/create")
|
||||
async def rag_create(file: UploadFile = File(...)):
|
||||
filename = file.filename
|
||||
if "/" in filename:
|
||||
filename = filename.split("/")[-1]
|
||||
print(f"[rag - create] POST request: /v1/rag/create, filename:{filename}")
|
||||
|
||||
kb_id, user_upload_dir, user_persist_dir = create_kb_folder(router.upload_dir)
|
||||
# save file to local path
|
||||
cur_time = get_current_beijing_time()
|
||||
save_file_name = str(user_upload_dir) + "/" + cur_time + "-" + filename
|
||||
with open(save_file_name, "wb") as fout:
|
||||
content = await file.read()
|
||||
fout.write(content)
|
||||
print(f"[rag - create] file saved to local path: {save_file_name}")
|
||||
|
||||
# create new retriever
|
||||
try:
|
||||
# get retrieval instance and reload db with new knowledge base
|
||||
print("[rag - create] starting to create local db...")
|
||||
index_name = INDEX_NAME + kb_id
|
||||
retriever = create_retriever_from_files(save_file_name, router.embeddings, index_name)
|
||||
router.llm_chain = (
|
||||
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
|
||||
)
|
||||
print("[rag - create] kb created successfully")
|
||||
except Exception as e:
|
||||
print(f"[rag - create] create knowledge base failed! {e}")
|
||||
return JSONResponse(status_code=500, content={"message": "Fail to create new knowledge base."})
|
||||
return {"knowledge_base_id": kb_id}
|
||||
|
||||
|
||||
@router.post("/v1/rag/upload_link")
|
||||
async def rag_upload_link(request: Request):
|
||||
params = await request.json()
|
||||
link_list = params["link_list"]
|
||||
print(f"[rag - upload_link] POST request: /v1/rag/upload_link, link list:{link_list}")
|
||||
|
||||
kb_id, user_upload_dir, user_persist_dir = create_kb_folder(router.upload_dir)
|
||||
|
||||
# create new retriever
|
||||
try:
|
||||
print("[rag - upload_link] starting to create local db...")
|
||||
index_name = INDEX_NAME + kb_id
|
||||
retriever = create_retriever_from_links(router.embeddings, link_list, index_name)
|
||||
router.llm_chain = (
|
||||
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
|
||||
)
|
||||
print("[rag - upload_link] kb created successfully")
|
||||
except Exception as e:
|
||||
print(f"[rag - upload_link] create knowledge base failed! {e}")
|
||||
return JSONResponse(status_code=500, content={"message": "Fail to create new knowledge base."})
|
||||
return {"knowledge_base_id": kb_id}
|
||||
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def redirect_root_to_docs():
|
||||
return RedirectResponse("/docs")
|
||||
|
||||
|
||||
add_routes(app, router.llm_chain, path="/rag-redis")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
@@ -1,342 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import UnstructuredFileLoader
|
||||
from langchain_community.vectorstores import Redis
|
||||
from langchain_core.documents import Document
|
||||
from rag_redis.config import INDEX_SCHEMA, REDIS_URL
|
||||
|
||||
|
||||
def get_current_beijing_time():
|
||||
SHA_TZ = timezone(timedelta(hours=8), name="Asia/Shanghai")
|
||||
utc_now = datetime.utcnow().replace(tzinfo=timezone.utc)
|
||||
beijing_time = utc_now.astimezone(SHA_TZ).strftime("%Y-%m-%d-%H:%M:%S")
|
||||
return beijing_time
|
||||
|
||||
|
||||
def create_kb_folder(upload_dir):
|
||||
kb_id = f"kb_{str(uuid.uuid1())[:8]}"
|
||||
path_prefix = upload_dir
|
||||
|
||||
# create local folder for retieval
|
||||
cur_path = Path(path_prefix) / kb_id
|
||||
os.makedirs(path_prefix, exist_ok=True)
|
||||
cur_path.mkdir(parents=True, exist_ok=True)
|
||||
user_upload_dir = Path(path_prefix) / f"{kb_id}/upload_dir"
|
||||
user_persist_dir = Path(path_prefix) / f"{kb_id}/persist_dir"
|
||||
user_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
user_persist_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"[rag - create kb folder] upload path: {user_upload_dir}, persist path: {user_persist_dir}")
|
||||
return kb_id, str(user_upload_dir), str(user_persist_dir)
|
||||
|
||||
|
||||
class Crawler:
|
||||
|
||||
def __init__(self, pool=None):
|
||||
if pool:
|
||||
assert isinstance(pool, (str, list, tuple)), "url pool should be str, list or tuple"
|
||||
self.pool = pool
|
||||
self.headers = {
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng, \
|
||||
*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
"Accept-Language": "en-US,en;q=0.9,zh-CN;q=0.8,zh;q=0.7",
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, \
|
||||
like Gecko) Chrome/113.0.0.0 Safari/537.36",
|
||||
}
|
||||
self.fetched_pool = set()
|
||||
|
||||
def get_sublinks(self, soup):
|
||||
sublinks = []
|
||||
for links in soup.find_all("a"):
|
||||
sublinks.append(str(links.get("href")))
|
||||
return sublinks
|
||||
|
||||
def get_hyperlink(self, soup, base_url):
|
||||
sublinks = []
|
||||
for links in soup.find_all("a"):
|
||||
link = str(links.get("href"))
|
||||
if link.startswith("#") or link is None or link == "None":
|
||||
continue
|
||||
suffix = link.split("/")[-1]
|
||||
if "." in suffix and suffix.split(".")[-1] not in ["html", "htmld"]:
|
||||
continue
|
||||
link_parse = urlparse(link)
|
||||
base_url_parse = urlparse(base_url)
|
||||
if link_parse.path == "":
|
||||
continue
|
||||
if link_parse.netloc != "":
|
||||
# keep crawler works in the same domain
|
||||
if link_parse.netloc != base_url_parse.netloc:
|
||||
continue
|
||||
sublinks.append(link)
|
||||
else:
|
||||
sublinks.append(
|
||||
urlunparse(
|
||||
(
|
||||
base_url_parse.scheme,
|
||||
base_url_parse.netloc,
|
||||
link_parse.path,
|
||||
link_parse.params,
|
||||
link_parse.query,
|
||||
link_parse.fragment,
|
||||
)
|
||||
)
|
||||
)
|
||||
return sublinks
|
||||
|
||||
def fetch(self, url, headers=None, max_times=5):
|
||||
if not headers:
|
||||
headers = self.headers
|
||||
while max_times:
|
||||
if not url.startswith("http") or not url.startswith("https"):
|
||||
url = "http://" + url
|
||||
print("start fetch %s...", url)
|
||||
try:
|
||||
response = requests.get(url, headers=headers, verify=True)
|
||||
if response.status_code != 200:
|
||||
print("fail to fetch %s, response status code: %s", url, response.status_code)
|
||||
else:
|
||||
return response
|
||||
except Exception as e:
|
||||
print("fail to fetch %s, caused by %s", url, e)
|
||||
raise Exception(e)
|
||||
max_times -= 1
|
||||
return None
|
||||
|
||||
def process_work(self, sub_url, work):
|
||||
response = self.fetch(sub_url)
|
||||
if response is None:
|
||||
return []
|
||||
self.fetched_pool.add(sub_url)
|
||||
soup = self.parse(response.text)
|
||||
base_url = self.get_base_url(sub_url)
|
||||
sublinks = self.get_hyperlink(soup, base_url)
|
||||
if work:
|
||||
work(sub_url, soup)
|
||||
return sublinks
|
||||
|
||||
def crawl(self, pool, work=None, max_depth=10, workers=10):
|
||||
url_pool = set()
|
||||
for url in pool:
|
||||
base_url = self.get_base_url(url)
|
||||
response = self.fetch(url)
|
||||
soup = self.parse(response.text)
|
||||
sublinks = self.get_hyperlink(soup, base_url)
|
||||
self.fetched_pool.add(url)
|
||||
url_pool.update(sublinks)
|
||||
depth = 0
|
||||
while len(url_pool) > 0 and depth < max_depth:
|
||||
print("current depth %s...", depth)
|
||||
mp = multiprocessing.Pool(processes=workers)
|
||||
results = []
|
||||
for sub_url in url_pool:
|
||||
if sub_url not in self.fetched_pool:
|
||||
results.append(mp.apply_async(self.process_work, (sub_url, work)))
|
||||
mp.close()
|
||||
mp.join()
|
||||
url_pool = set()
|
||||
for result in results:
|
||||
sublinks = result.get()
|
||||
url_pool.update(sublinks)
|
||||
depth += 1
|
||||
|
||||
def parse(self, html_doc):
|
||||
soup = BeautifulSoup(html_doc, "lxml")
|
||||
return soup
|
||||
|
||||
def download(self, url, file_name):
|
||||
print("download %s into %s...", url, file_name)
|
||||
try:
|
||||
r = requests.get(url, stream=True, headers=self.headers, verify=True)
|
||||
f = open(file_name, "wb")
|
||||
for chunk in r.iter_content(chunk_size=512):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
except Exception as e:
|
||||
print("fail to download %s, caused by %s", url, e)
|
||||
|
||||
def get_base_url(self, url):
|
||||
result = urlparse(url)
|
||||
return urlunparse((result.scheme, result.netloc, "", "", "", ""))
|
||||
|
||||
def clean_text(self, text):
|
||||
text = text.strip().replace("\r", "\n")
|
||||
text = re.sub(" +", " ", text)
|
||||
text = re.sub("\n+", "\n", text)
|
||||
text = text.split("\n")
|
||||
return "\n".join([i for i in text if i and i != " "])
|
||||
|
||||
|
||||
def uni_pro(text):
|
||||
"""Check if the character is ASCII or falls in the category of non-spacing marks."""
|
||||
normalized_text = unicodedata.normalize("NFKD", text)
|
||||
filtered_text = ""
|
||||
for char in normalized_text:
|
||||
if ord(char) < 128 or unicodedata.category(char) == "Mn":
|
||||
filtered_text += char
|
||||
return filtered_text
|
||||
|
||||
|
||||
def load_html_data(url):
|
||||
crawler = Crawler()
|
||||
res = crawler.fetch(url)
|
||||
if res is None:
|
||||
return None
|
||||
soup = crawler.parse(res.text)
|
||||
all_text = crawler.clean_text(soup.select_one("body").text)
|
||||
main_content = ""
|
||||
for element_name in ["main", "container"]:
|
||||
main_block = None
|
||||
if soup.select(f".{element_name}"):
|
||||
main_block = soup.select(f".{element_name}")
|
||||
elif soup.select(f"#{element_name}"):
|
||||
main_block = soup.select(f"#{element_name}")
|
||||
if main_block:
|
||||
for element in main_block:
|
||||
text = crawler.clean_text(element.text)
|
||||
if text not in main_content:
|
||||
main_content += f"\n{text}"
|
||||
main_content = crawler.clean_text(main_content)
|
||||
|
||||
main_content = main_content.replace("\n", "")
|
||||
main_content = main_content.replace("\n\n", "")
|
||||
main_content = uni_pro(main_content)
|
||||
main_content = re.sub(r"\s+", " ", main_content)
|
||||
|
||||
# {'text': all_text, 'main_content': main_content}
|
||||
|
||||
return main_content
|
||||
|
||||
|
||||
def get_chuck_data(content, max_length, min_length, input):
|
||||
"""Process the context to make it maintain a suitable length for the generation."""
|
||||
sentences = re.split("(?<=[!.?])", content)
|
||||
|
||||
paragraphs = []
|
||||
current_length = 0
|
||||
count = 0
|
||||
current_paragraph = ""
|
||||
for sub_sen in sentences:
|
||||
count += 1
|
||||
sentence_length = len(sub_sen)
|
||||
if current_length + sentence_length <= max_length:
|
||||
current_paragraph += sub_sen
|
||||
current_length += sentence_length
|
||||
if count == len(sentences) and len(current_paragraph.strip()) > min_length:
|
||||
paragraphs.append([current_paragraph.strip(), input])
|
||||
else:
|
||||
paragraphs.append([current_paragraph.strip(), input])
|
||||
current_paragraph = sub_sen
|
||||
current_length = sentence_length
|
||||
|
||||
return paragraphs
|
||||
|
||||
|
||||
def parse_html(input):
|
||||
"""Parse the uploaded file."""
|
||||
chucks = []
|
||||
for link in input:
|
||||
if re.match(r"^https?:/{2}\w.+$", link):
|
||||
content = load_html_data(link)
|
||||
if content is None:
|
||||
continue
|
||||
chuck = [[content.strip(), link]]
|
||||
chucks += chuck
|
||||
else:
|
||||
print("The given link/str {} cannot be parsed.".format(link))
|
||||
|
||||
return chucks
|
||||
|
||||
|
||||
def document_transfer(data_collection):
|
||||
"Transfer the raw document into langchain supported format."
|
||||
documents = []
|
||||
for data, meta in data_collection:
|
||||
doc_id = str(uuid.uuid4())
|
||||
metadata = {"source": meta, "identify_id": doc_id}
|
||||
doc = Document(page_content=data, metadata=metadata)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
|
||||
def create_retriever_from_files(doc, embeddings, index_name: str):
|
||||
print(f"[rag - create retriever] create with index: {index_name}")
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100, add_start_index=True)
|
||||
loader = UnstructuredFileLoader(doc, mode="single", strategy="fast")
|
||||
chunks = loader.load_and_split(text_splitter)
|
||||
|
||||
rds = Redis.from_texts(
|
||||
texts=[chunk.page_content for chunk in chunks],
|
||||
metadatas=[chunk.metadata for chunk in chunks],
|
||||
embedding=embeddings,
|
||||
index_name=index_name,
|
||||
redis_url=REDIS_URL,
|
||||
index_schema=INDEX_SCHEMA,
|
||||
)
|
||||
|
||||
retriever = rds.as_retriever(search_type="mmr")
|
||||
return retriever
|
||||
|
||||
|
||||
def create_retriever_from_links(embeddings, link_list: list, index_name):
|
||||
data_collection = parse_html(link_list)
|
||||
texts = []
|
||||
metadatas = []
|
||||
for data, meta in data_collection:
|
||||
doc_id = str(uuid.uuid4())
|
||||
metadata = {"source": meta, "identify_id": doc_id}
|
||||
texts.append(data)
|
||||
metadatas.append(metadata)
|
||||
|
||||
rds = Redis.from_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embeddings,
|
||||
index_name=index_name,
|
||||
redis_url=REDIS_URL,
|
||||
index_schema=INDEX_SCHEMA,
|
||||
)
|
||||
|
||||
retriever = rds.as_retriever(search_type="mmr")
|
||||
return retriever
|
||||
|
||||
|
||||
def reload_retriever(embeddings, index_name):
|
||||
print(f"[rag - reload retriever] reload with index: {index_name}")
|
||||
rds = Redis.from_existing_index(
|
||||
embeddings,
|
||||
index_name=index_name,
|
||||
redis_url=REDIS_URL,
|
||||
schema=INDEX_SCHEMA,
|
||||
)
|
||||
|
||||
retriever = rds.as_retriever(search_type="mmr")
|
||||
return retriever
|
||||
|
||||
|
||||
def post_process_text(text: str):
|
||||
if text == " ":
|
||||
return "data: @#$\n\n"
|
||||
if text.isspace():
|
||||
return None
|
||||
if text == "\n":
|
||||
return "data: <br/>\n\n"
|
||||
new_text = text.replace(" ", "@#$")
|
||||
return f"data: {new_text}\n\n"
|
||||
@@ -1,23 +0,0 @@
|
||||
[tool.poetry]
|
||||
name = "my-app"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Your Name <you@example.com>"]
|
||||
readme = "README.md"
|
||||
packages = [
|
||||
{ include = "app" },
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
uvicorn = "^0.23.2"
|
||||
langserve = {extras = ["server"], version = ">=0.0.30"}
|
||||
pydantic = "<2"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-cli = ">=0.0.15"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -1,17 +0,0 @@
|
||||
-f https://download.pytorch.org/whl/torch_stable.html
|
||||
cryptography==42.0.4
|
||||
easyocr
|
||||
intel-extension-for-pytorch
|
||||
intel-openmp
|
||||
jupyter
|
||||
langchain==0.1.12
|
||||
langchain-cli
|
||||
langchain_benchmarks
|
||||
poetry
|
||||
pyarrow
|
||||
pydantic==1.10.13
|
||||
pymupdf
|
||||
redis
|
||||
sentence-transformers
|
||||
unstructured
|
||||
unstructured[all-docs]
|
||||
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
Binary file not shown.
Binary file not shown.
@@ -1,86 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
|
||||
import io
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_community.vectorstores import Redis
|
||||
from PIL import Image
|
||||
from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL
|
||||
|
||||
|
||||
def pdf_loader(file_path):
|
||||
try:
|
||||
import easyocr
|
||||
import fitz
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`PyMuPDF` or 'easyocr' package is not found, please install it with "
|
||||
"`pip install pymupdf or pip install easyocr.`"
|
||||
)
|
||||
|
||||
doc = fitz.open(file_path)
|
||||
reader = easyocr.Reader(["en"])
|
||||
result = ""
|
||||
for i in range(doc.page_count):
|
||||
page = doc.load_page(i)
|
||||
pagetext = page.get_text().strip()
|
||||
if pagetext:
|
||||
result = result + pagetext
|
||||
if len(doc.get_page_images(i)) > 0:
|
||||
for img in doc.get_page_images(i):
|
||||
if img:
|
||||
pageimg = ""
|
||||
xref = img[0]
|
||||
img_data = doc.extract_image(xref)
|
||||
img_bytes = img_data["image"]
|
||||
pil_image = Image.open(io.BytesIO(img_bytes))
|
||||
img = np.array(pil_image)
|
||||
img_result = reader.readtext(img, paragraph=True, detail=0)
|
||||
pageimg = pageimg + ", ".join(img_result).strip()
|
||||
if pageimg.endswith("!") or pageimg.endswith("?") or pageimg.endswith("."):
|
||||
pass
|
||||
else:
|
||||
pageimg = pageimg + "."
|
||||
result = result + pageimg
|
||||
return result
|
||||
|
||||
|
||||
def ingest_documents():
|
||||
"""Ingest PDF to Redis from the data/ directory that
|
||||
contains Edgar 10k filings data for Nike."""
|
||||
# Load list of pdfs
|
||||
company_name = "Nike"
|
||||
data_path = "data/"
|
||||
doc_path = [os.path.join(data_path, file) for file in os.listdir(data_path)][0]
|
||||
|
||||
print("Parsing 10k filing doc for NIKE", doc_path)
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100, add_start_index=True)
|
||||
content = pdf_loader(doc_path)
|
||||
chunks = text_splitter.split_text(content)
|
||||
|
||||
print("Done preprocessing. Created ", len(chunks), " chunks of the original pdf")
|
||||
# Create vectorstore
|
||||
embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
|
||||
|
||||
_ = Redis.from_texts(
|
||||
# appending this little bit can sometimes help with semantic retrieval
|
||||
# especially with multiple companies
|
||||
texts=[f"Company: {company_name}. " + chunk for chunk in chunks],
|
||||
embedding=embedder,
|
||||
index_name=INDEX_NAME,
|
||||
index_schema=INDEX_SCHEMA,
|
||||
redis_url=REDIS_URL,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ingest_documents()
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader, UnstructuredFileLoader
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_community.vectorstores import Redis
|
||||
from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL
|
||||
|
||||
loader = DirectoryLoader(
|
||||
"/ws/txt_files", glob="**/*.txt", show_progress=True, use_multithreading=True, loader_cls=TextLoader
|
||||
)
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100, add_start_index=True)
|
||||
|
||||
chunks = loader.load_and_split(text_splitter)
|
||||
print("Done preprocessing. Created", len(chunks), "chunks of the original data")
|
||||
|
||||
# Create vectorstore
|
||||
embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
|
||||
|
||||
company_name = "Intel"
|
||||
_ = Redis.from_texts(
|
||||
# appending this little bit can sometimes help with semantic retrieval
|
||||
# especially with multiple companies
|
||||
texts=[f"Company: {company_name}. " + chunk.page_content for chunk in chunks],
|
||||
metadatas=[chunk.metadata for chunk in chunks],
|
||||
embedding=embedder,
|
||||
index_name=INDEX_NAME,
|
||||
index_schema=INDEX_SCHEMA,
|
||||
redis_url=REDIS_URL,
|
||||
)
|
||||
@@ -1,86 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
|
||||
import io
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_community.vectorstores import Redis
|
||||
from PIL import Image
|
||||
from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL
|
||||
|
||||
|
||||
def pdf_loader(file_path):
|
||||
try:
|
||||
import easyocr
|
||||
import fitz
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`PyMuPDF` or 'easyocr' package is not found, please install it with "
|
||||
"`pip install pymupdf or pip install easyocr.`"
|
||||
)
|
||||
|
||||
doc = fitz.open(file_path)
|
||||
reader = easyocr.Reader(["en"])
|
||||
result = ""
|
||||
for i in range(doc.page_count):
|
||||
page = doc.load_page(i)
|
||||
pagetext = page.get_text().strip()
|
||||
if pagetext:
|
||||
result = result + pagetext
|
||||
if len(doc.get_page_images(i)) > 0:
|
||||
for img in doc.get_page_images(i):
|
||||
if img:
|
||||
pageimg = ""
|
||||
xref = img[0]
|
||||
img_data = doc.extract_image(xref)
|
||||
img_bytes = img_data["image"]
|
||||
pil_image = Image.open(io.BytesIO(img_bytes))
|
||||
img = np.array(pil_image)
|
||||
img_result = reader.readtext(img, paragraph=True, detail=0)
|
||||
pageimg = pageimg + ", ".join(img_result).strip()
|
||||
if pageimg.endswith("!") or pageimg.endswith("?") or pageimg.endswith("."):
|
||||
pass
|
||||
else:
|
||||
pageimg = pageimg + "."
|
||||
result = result + pageimg
|
||||
return result
|
||||
|
||||
|
||||
def ingest_documents():
|
||||
"""Ingest PDF to Redis from the data/ directory that
|
||||
contains Intel manuals."""
|
||||
# Load list of pdfs
|
||||
company_name = "Intel"
|
||||
data_path = "data_intel/"
|
||||
doc_path = [os.path.join(data_path, file) for file in os.listdir(data_path)][0]
|
||||
|
||||
print("Parsing Intel architecture manuals", doc_path)
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100, add_start_index=True)
|
||||
content = pdf_loader(doc_path)
|
||||
chunks = text_splitter.split_text(content)
|
||||
|
||||
print("Done preprocessing. Created", len(chunks), "chunks of the original pdf")
|
||||
# Create vectorstore
|
||||
embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
|
||||
|
||||
_ = Redis.from_texts(
|
||||
# appending this little bit can sometimes help with semantic retrieval
|
||||
# especially with multiple companies
|
||||
texts=[f"Company: {company_name}. " + chunk for chunk in chunks],
|
||||
embedding=embedder,
|
||||
index_name=INDEX_NAME,
|
||||
index_schema=INDEX_SCHEMA,
|
||||
redis_url=REDIS_URL,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ingest_documents()
|
||||
@@ -1,88 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "681a5d1e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Connect to RAG App\n",
|
||||
"\n",
|
||||
"Assuming you are already running this server:\n",
|
||||
"```bash\n",
|
||||
"langserve start\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"id": "d774be2a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Nike's revenue in 2023 was $51.2 billion. \n",
|
||||
"\n",
|
||||
"Source: 'data/nke-10k-2023.pdf', Start Index: '146100'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langserve.client import RemoteRunnable\n",
|
||||
"\n",
|
||||
"rag_redis = RemoteRunnable(\"http://localhost:8000/rag-redis\")\n",
|
||||
"\n",
|
||||
"print(rag_redis.invoke(\"What was Nike's revenue in 2023?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"id": "07ae0005",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"As of May 31, 2023, Nike had approximately 83,700 employees worldwide. This information can be found in the first piece of context provided. (source: data/nke-10k-2023.pdf, start_index: 32532)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(rag_redis.invoke(\"How many employees work at Nike?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4a6b9f00",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
@@ -1,76 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_community.llms import HuggingFaceEndpoint
|
||||
from langchain_community.vectorstores import Redis
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
||||
from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL, TGI_LLM_ENDPOINT
|
||||
|
||||
|
||||
# Make this look better in the docs.
|
||||
class Question(BaseModel):
|
||||
__root__: str
|
||||
|
||||
|
||||
# Init Embeddings
|
||||
embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
|
||||
|
||||
# Setup semantic cache for LLM
|
||||
from langchain.cache import RedisSemanticCache
|
||||
from langchain.globals import set_llm_cache
|
||||
|
||||
set_llm_cache(RedisSemanticCache(embedding=embedder, redis_url=REDIS_URL))
|
||||
|
||||
# Connect to pre-loaded vectorstore
|
||||
# run the ingest.py script to populate this
|
||||
vectorstore = Redis.from_existing_index(
|
||||
embedding=embedder, index_name=INDEX_NAME, schema=INDEX_SCHEMA, redis_url=REDIS_URL
|
||||
)
|
||||
|
||||
# TODO allow user to change parameters
|
||||
retriever = vectorstore.as_retriever(search_type="mmr")
|
||||
|
||||
# Define our prompt
|
||||
template = """
|
||||
Use the following pieces of context from retrieved
|
||||
dataset to answer the question. Do not make up an answer if there is no
|
||||
context provided to help answer it. Include the 'source' and 'start_index'
|
||||
from the metadata included in the context you used to answer the question
|
||||
|
||||
Context:
|
||||
---------
|
||||
{context}
|
||||
|
||||
---------
|
||||
Question: {question}
|
||||
---------
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
# RAG Chain
|
||||
model = HuggingFaceEndpoint(
|
||||
endpoint_url=TGI_LLM_ENDPOINT,
|
||||
max_new_tokens=512,
|
||||
top_k=10,
|
||||
top_p=0.95,
|
||||
typical_p=0.95,
|
||||
temperature=0.01,
|
||||
repetition_penalty=1.03,
|
||||
streaming=True,
|
||||
truncate=1024,
|
||||
)
|
||||
|
||||
chain = (
|
||||
RunnableParallel({"context": retriever, "question": RunnablePassthrough()}) | prompt | model | StrOutputParser()
|
||||
).with_types(input_type=Question)
|
||||
@@ -1,88 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def get_boolean_env_var(var_name, default_value=False):
|
||||
"""Retrieve the boolean value of an environment variable.
|
||||
|
||||
Args:
|
||||
var_name (str): The name of the environment variable to retrieve.
|
||||
default_value (bool): The default value to return if the variable
|
||||
is not found.
|
||||
|
||||
Returns:
|
||||
bool: The value of the environment variable, interpreted as a boolean.
|
||||
"""
|
||||
true_values = {"true", "1", "t", "y", "yes"}
|
||||
false_values = {"false", "0", "f", "n", "no"}
|
||||
|
||||
# Retrieve the environment variable's value
|
||||
value = os.getenv(var_name, "").lower()
|
||||
|
||||
# Decide the boolean value based on the content of the string
|
||||
if value in true_values:
|
||||
return True
|
||||
elif value in false_values:
|
||||
return False
|
||||
else:
|
||||
return default_value
|
||||
|
||||
|
||||
# Check for openai API key
|
||||
# if "OPENAI_API_KEY" not in os.environ:
|
||||
# raise Exception("Must provide an OPENAI_API_KEY as an env var.")
|
||||
|
||||
|
||||
# Whether or not to enable langchain debugging
|
||||
DEBUG = get_boolean_env_var("DEBUG", False)
|
||||
# Set DEBUG env var to "true" if you wish to enable LC debugging module
|
||||
if DEBUG:
|
||||
import langchain
|
||||
|
||||
langchain.debug = True
|
||||
|
||||
|
||||
# Embedding model
|
||||
EMBED_MODEL = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
||||
|
||||
# Redis Connection Information
|
||||
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
|
||||
|
||||
|
||||
def format_redis_conn_from_env():
|
||||
redis_url = os.getenv("REDIS_URL", None)
|
||||
if redis_url:
|
||||
return redis_url
|
||||
else:
|
||||
using_ssl = get_boolean_env_var("REDIS_SSL", False)
|
||||
start = "rediss://" if using_ssl else "redis://"
|
||||
|
||||
# if using RBAC
|
||||
password = os.getenv("REDIS_PASSWORD", None)
|
||||
username = os.getenv("REDIS_USERNAME", "default")
|
||||
if password is not None:
|
||||
start += f"{username}:{password}@"
|
||||
|
||||
return start + f"{REDIS_HOST}:{REDIS_PORT}"
|
||||
|
||||
|
||||
REDIS_URL = format_redis_conn_from_env()
|
||||
|
||||
# Vector Index Configuration
|
||||
INDEX_NAME = os.getenv("INDEX_NAME", "rag-redis")
|
||||
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
parent_dir = os.path.dirname(current_file_path)
|
||||
REDIS_SCHEMA = os.getenv("REDIS_SCHEMA", "schema.yml")
|
||||
schema_path = os.path.join(parent_dir, REDIS_SCHEMA)
|
||||
INDEX_SCHEMA = schema_path
|
||||
TGI_LLM_ENDPOINT = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
|
||||
TGI_LLM_ENDPOINT_NO_RAG = os.getenv("TGI_LLM_ENDPOINT_NO_RAG", "http://localhost:8081")
|
||||
@@ -1,15 +0,0 @@
|
||||
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
text:
|
||||
- name: content
|
||||
- name: source
|
||||
numeric:
|
||||
- name: start_index
|
||||
vector:
|
||||
- name: content_vector
|
||||
algorithm: HNSW
|
||||
datatype: FLOAT32
|
||||
dims: 384
|
||||
distance_metric: COSINE
|
||||
@@ -1,15 +0,0 @@
|
||||
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
text:
|
||||
- name: content
|
||||
- name: source
|
||||
numeric:
|
||||
- name: start_index
|
||||
vector:
|
||||
- name: content_vector
|
||||
algorithm: HNSW
|
||||
datatype: FLOAT32
|
||||
dims: 1024
|
||||
distance_metric: COSINE
|
||||
@@ -1,15 +0,0 @@
|
||||
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
text:
|
||||
- name: content
|
||||
- name: source
|
||||
numeric:
|
||||
- name: start_index
|
||||
vector:
|
||||
- name: content_vector
|
||||
algorithm: HNSW
|
||||
datatype: FLOAT32
|
||||
dims: 768
|
||||
distance_metric: COSINE
|
||||
@@ -1,19 +0,0 @@
|
||||
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
text:
|
||||
- name: content
|
||||
- name: changefreq
|
||||
- name: description
|
||||
- name: language
|
||||
- name: loc
|
||||
- name: priority
|
||||
- name: source
|
||||
- name: title
|
||||
vector:
|
||||
- name: content_vector
|
||||
algorithm: HNSW
|
||||
datatype: FLOAT32
|
||||
dims: 768
|
||||
distance_metric: COSINE
|
||||
@@ -1,89 +0,0 @@
|
||||
[TGI-Gaudi](https://github.com/huggingface/tgi-gaudi) provides many parameters aimed at optimizing performance for text generation inference tasks. By optimizing these parameters, users can achieve the best results in terms of inference speed, memory usage, and overall efficiency. These parameters cover various aspects such as maximum sequence length, batch size, Gaudi processor utilization, and environment configurations. By carefully adjusting these parameters according to the specific requirements of the workload and hardware environment, users can unlock the full potential of TGI-Gaudi for the text generation tasks.
|
||||
|
||||
# Knowledeges about TGI-Gaudi performance tuning
|
||||
|
||||
## Adjusting TGI parameters
|
||||
|
||||
Maximum sequence length is controlled by two arguments:
|
||||
|
||||
- `--max-input-length` is the maximum possible input prompt length. Default value is `1024`.
|
||||
- `--max-total-tokens` is the maximum possible total length of the sequence (input and output). Default value is `2048`.
|
||||
|
||||
Maximum batch size is controlled by two arguments:
|
||||
|
||||
- For prefill operation, please set `--max-prefill-total-tokens` as `bs * max-input-length`, where `bs` is your expected maximum prefill batch size.
|
||||
- For decode operation, please set `--max-batch-total-tokens` as `bs * max-total-tokens`, where `bs` is your expected maximum decode batch size.
|
||||
- Please note that batch size will be always padded to the nearest multiplication of `BATCH_BUCKET_SIZE` and `PREFILL_BATCH_BUCKET_SIZE`.
|
||||
|
||||
To ensure greatest performance results, at the beginning of each server run, warmup is performed. It's designed to cover major recompilations while using HPU Graphs. It creates queries with all possible input shapes, based on provided parameters (described in this section) and runs basic TGI operations on them (prefill, decode, concatenate).
|
||||
|
||||
Except those already mentioned, there are other parameters that need to be properly adjusted to improve performance or memory usage:
|
||||
|
||||
- `PAD_SEQUENCE_TO_MULTIPLE_OF` determines sizes of input length buckets. Since warmup creates several graphs for each bucket, it's important to adjust that value proportionally to input sequence length. Otherwise, some out of memory issues can be observed.
|
||||
- `ENABLE_HPU_GRAPH` enables HPU graphs usage, which is crucial for performance results. Recommended value to keep is `true` .
|
||||
|
||||
For more information and documentation about Text Generation Inference, checkout [the README](https://github.com/huggingface/text-generation-inference#text-generation-inference) of the original repo.
|
||||
|
||||
## Environment Variable HABANA_VISIBLE_MODULES
|
||||
|
||||
To run a workload with part of the available Gaudi processors, you need to set the module IDs of the used Gaudi processors in the environment, HABANA_VISIBLE_MODULES. In general, there are eight Gaudi processors on a node, so the module IDs would be in the range of 0 ~ 7. If you want to run a 4-Gaudi workload, you can set the below before you run the workload:
|
||||
|
||||
```bash
|
||||
export HABANA_VISIBLE_MODULES="0,1,2,3"
|
||||
```
|
||||
|
||||
If you want to run another 4-Gaudi workload in parallel, you can set the below before running the second workload to let it use the rest of the available four Gaudi processors.
|
||||
|
||||
```bash
|
||||
export HABANA_VISIBLE_MODULES="4,5,6,7"
|
||||
```
|
||||
|
||||
Though using partial Gaudi in a workload is possible, only 2-Gaudi and 4-Gaudi scenarios are supported. It is highly recommended to set HABANA_VISIBLE_MODULES using the combinations listed below:
|
||||
|
||||
- 2-Gaudi - “0,1”, “2,3”, “4,5” or “6,7”
|
||||
- 4-Gaudi - “0,1,2,3” or “4,5,6,7”
|
||||
|
||||
For the details please check [Multiple_Workloads_Single_Docker](https://docs.habana.ai/en/latest/PyTorch/Reference/PT_Multiple_Tenants_on_HPU/Multiple_Workloads_Single_Docker.html)
|
||||
|
||||
## Environment Variable HABANA_VISIBLE_DEVICES
|
||||
|
||||
There are some guidelines on setting HABANA_VISIBLE_DEVICES, however, you need to know how to find the mapping between the index and module ID of the Gaudi processors before reading the guidelines. The below command is a sample output of the mapping between index and module ID of the Gaudi processors:
|
||||
|
||||
```bash
|
||||
hl-smi -Q index,module_id -f csv
|
||||
```
|
||||
|
||||
| index | module_id |
|
||||
| :---: | :-------: |
|
||||
| 3 | 6 |
|
||||
| 1 | 4 |
|
||||
| 2 | 7 |
|
||||
| 0 | 5 |
|
||||
| 4 | 2 |
|
||||
| 6 | 0 |
|
||||
| 7 | 3 |
|
||||
| 3 | 1 |
|
||||
|
||||
With the mapping between index and module ID, you can set `HABANA_VISIBLE_DEVICES` properly with the guidelines below:
|
||||
|
||||
- Mount two Gaudi Processors or four Gaudi Processors in the docker container. Even though using partial Gaudi in a distributed workload is possible, only 2-Gaudi and 4-Gaudi scenario are allowed.
|
||||
- Since `HABANA_VISIBLE_DEVICES` accepts index instead of module ID, you need to leverage the above command to figure out the corresponding indices for a set of module IDs.
|
||||
- Avoid mounting the same index on multiple containers. Since multiple workloads might run in parallel, avoiding mounting the same Gaudi to multiple docker containers can prevent reusing the same Gaudi in different workloads.
|
||||
|
||||
For the details please check [Multiple Dockers Each with a Single Workload](https://docs.habana.ai/en/latest/PyTorch/Reference/PT_Multiple_Tenants_on_HPU/Multiple_Dockers_each_with_Single_Workload.html)
|
||||
|
||||
For the System Management Interface Tool please check [hl-smi](https://docs.habana.ai/en/latest/Management_and_Monitoring/Embedded_System_Tools_Guide/System_Management_Interface_Tool.html)
|
||||
|
||||
# Verified Docker commands with tuned parameters for best performance
|
||||
|
||||
## Docker command for 70B model
|
||||
|
||||
```bash
|
||||
docker run -p 8080:80 -v $volume:/data --runtime=habana -e HUGGING_FACE_HUB_TOKEN=$HUGGINGFACEHUB_API_TOKEN -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES="6,7,4,5" -e HABANA_VISIBLE_MODULES="0,1,2,3" -e BATCH_BUCKET_SIZE=22 -e PREFILL_BATCH_BUCKET_SIZE=1 -e MAX_BATCH_PREFILL_TOKENS=5102 -e MAX_BATCH_TOTAL_TOKENS=32256 -e MAX_INPUT_LENGTH=1024 -e PAD_SEQUENCE_TO_MULTIPLE_OF=1024 -e MAX_WAITING_TOKENS=5 -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host tgi_gaudi --model-id $model --sharded true --num-shard 4
|
||||
```
|
||||
|
||||
## Docker command for 13B model
|
||||
|
||||
```bash
|
||||
docker run -p 8080:80 -v $volume:/data --runtime=habana -e HUGGING_FACE_HUB_TOKEN=$HUGGINGFACEHUB_API_TOKEN -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e PAD_SEQUENCE_TO_MULTIPLE_OF=128 -e HABANA_VISIBLE_DEVICES="4" -e BATCH_BUCKET_SIZE=16 -e PREFILL_BATCH_BUCKET_SIZE=1 -e MAX_BATCH_PREFILL_TOKENS=4096 -e MAX_BATCH_TOTAL_TOKENS=18432 -e PAD_SEQUENCE_TO_MULTIPLE_OF=1024 -e MAX_INPUT_LENGTH=1024 -e MAX_TOTAL_TOKENS=1152 -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host tgi_gaudi --model-id $model
|
||||
```
|
||||
@@ -1,9 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
git clone https://github.com/huggingface/tgi-gaudi.git
|
||||
cd ./tgi-gaudi/
|
||||
docker build -t ghcr.io/huggingface/tgi-gaudi:1.2.1 . --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy
|
||||
@@ -1,41 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Set default values
|
||||
default_port=8080
|
||||
default_model="Intel/neural-chat-7b-v3-3"
|
||||
default_num_cards=1
|
||||
|
||||
# Check if all required arguments are provided
|
||||
if [ "$#" -lt 0 ] || [ "$#" -gt 3 ]; then
|
||||
echo "Usage: $0 [num_cards] [port_number] [model_name]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Assign arguments to variables
|
||||
num_cards=${1:-$default_num_cards}
|
||||
port_number=${2:-$default_port}
|
||||
model_name=${3:-$default_model}
|
||||
|
||||
# Check if num_cards is within the valid range (1-8)
|
||||
if [ "$num_cards" -lt 1 ] || [ "$num_cards" -gt 8 ]; then
|
||||
echo "Error: num_cards must be between 1 and 8."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set the volume variable
|
||||
volume=$PWD/data
|
||||
|
||||
# Build the Docker run command based on the number of cards
|
||||
if [ "$num_cards" -eq 1 ]; then
|
||||
docker_cmd="docker run -d --name="ChatQnA_server" -p $port_number:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host -e HTTPS_PROXY=$https_proxy -e HTTP_PROXY=$https_proxy ghcr.io/huggingface/tgi-gaudi:1.2.1 --model-id $model_name"
|
||||
else
|
||||
docker_cmd="docker run -d --name="ChatQnA_server" -p $port_number:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host -e HTTPS_PROXY=$https_proxy -e HTTP_PROXY=$https_proxy ghcr.io/huggingface/tgi-gaudi:1.2.1 --model-id $model_name --sharded true --num-shard $num_cards"
|
||||
fi
|
||||
|
||||
# Execute the Docker run command
|
||||
echo $docker_cmd
|
||||
eval $docker_cmd
|
||||
@@ -1,63 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
set -xe
|
||||
|
||||
function test_env_setup() {
|
||||
WORKPATH=$(dirname "$PWD")/audio/docker
|
||||
LOG_PATH=$(dirname "$PWD")/tests/asr.log
|
||||
ASR_CONTAINER_NAME="test-audioqna-asr"
|
||||
cd $WORKPATH
|
||||
}
|
||||
|
||||
function start_asr_service() {
|
||||
cd $WORKPATH
|
||||
docker build . --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${http_proxy} -f Dockerfile_asr -t intel/gen-ai-examples:$ASR_CONTAINER_NAME
|
||||
docker run -d --name=$ASR_CONTAINER_NAME -e http_proxy=$http_proxy -e https_proxy=$https_proxy -p 8018:8008 intel/gen-ai-examples:$ASR_CONTAINER_NAME
|
||||
sleep 1m
|
||||
}
|
||||
|
||||
function run_tests() {
|
||||
cd $WORKPATH
|
||||
rm -f sample.wav
|
||||
wget https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav
|
||||
http_proxy= curl -F 'file=@sample.wav' http://localhost:8018/v1/audio/transcriptions > $LOG_PATH
|
||||
rm -f sample.wav
|
||||
}
|
||||
|
||||
function check_response() {
|
||||
cd $WORKPATH
|
||||
echo "Checking response"
|
||||
local status=false
|
||||
if [[ -f $LOG_PATH ]] && [[ $(grep -c "who is pat gelsinger" $LOG_PATH) != 0 ]]; then
|
||||
status=true
|
||||
fi
|
||||
|
||||
if [ $status == false ]; then
|
||||
echo "Response check failed"
|
||||
exit 1
|
||||
else
|
||||
echo "Response check succeed"
|
||||
fi
|
||||
}
|
||||
|
||||
function docker_stop() {
|
||||
local container_name=$1
|
||||
cid=$(docker ps -aq --filter "name=$container_name")
|
||||
if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid; fi
|
||||
}
|
||||
|
||||
|
||||
|
||||
function main() {
|
||||
test_env_setup
|
||||
docker_stop $ASR_CONTAINER_NAME && sleep 5s
|
||||
start_asr_service
|
||||
run_tests
|
||||
docker_stop $ASR_CONTAINER_NAME && sleep 5s
|
||||
echo y | docker system prune
|
||||
check_response
|
||||
}
|
||||
|
||||
main
|
||||
@@ -1,110 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
set -xe
|
||||
|
||||
function test_env_setup() {
|
||||
WORKPATH=$(dirname "$PWD")
|
||||
LOG_PATH="$WORKPATH/tests/langchain.log"
|
||||
|
||||
REDIS_CONTAINER_NAME="test-redis-vector-db"
|
||||
LANGCHAIN_CONTAINER_NAME="test-qna-rag-redis-server"
|
||||
AUDIOQNA_CONTAINER_NAME="test-AudioQnA_server"
|
||||
cd $WORKPATH
|
||||
}
|
||||
|
||||
function rename() {
|
||||
# Rename the docker container/image names to avoid conflict with local test
|
||||
cd ${WORKPATH}
|
||||
sed -i "s/container_name: redis-vector-db/container_name: ${REDIS_CONTAINER_NAME}/g" langchain/docker/docker-compose.yml
|
||||
sed -i "s/container_name: qna-rag-redis-server/container_name: ${LANGCHAIN_CONTAINER_NAME}/g" langchain/docker/docker-compose.yml
|
||||
sed -i "s/image: intel\/gen-ai-examples:qna-rag-redis-server/image: intel\/gen-ai-examples:${LANGCHAIN_CONTAINER_NAME}/g" langchain/docker/docker-compose.yml
|
||||
sed -i "s/ChatQnA_server/${AUDIOQNA_CONTAINER_NAME}/g" serving/tgi_gaudi/launch_tgi_service.sh
|
||||
}
|
||||
|
||||
function launch_tgi_gaudi_service() {
|
||||
local card_num=1
|
||||
local port=8888
|
||||
local model_name="Intel/neural-chat-7b-v3-3"
|
||||
|
||||
cd ${WORKPATH}
|
||||
|
||||
# Reset the tgi port
|
||||
sed -i "s/8080/$port/g" langchain/redis/rag_redis/config.py
|
||||
sed -i "s/8080/$port/g" langchain/docker/qna-app/app/server.py
|
||||
sed -i "s/8080/$port/g" langchain/docker/qna-app/Dockerfile
|
||||
|
||||
docker pull ghcr.io/huggingface/tgi-gaudi:1.2.1
|
||||
bash serving/tgi_gaudi/launch_tgi_service.sh $card_num $port $model_name
|
||||
sleep 3m # Waits 3 minutes
|
||||
}
|
||||
|
||||
function launch_redis_and_langchain_service() {
|
||||
cd $WORKPATH
|
||||
export HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN}
|
||||
local port=8890
|
||||
sed -i "s/port=8000/port=$port/g" langchain/docker/qna-app/app/server.py
|
||||
docker compose -f langchain/docker/docker-compose.yml up -d --build
|
||||
|
||||
# Ingest data into redis
|
||||
docker exec $LANGCHAIN_CONTAINER_NAME \
|
||||
bash -c "cd /ws && python ingest.py > /dev/null"
|
||||
}
|
||||
|
||||
function start_backend_service() {
|
||||
cd $WORKPATH
|
||||
docker exec $LANGCHAIN_CONTAINER_NAME \
|
||||
bash -c "nohup python app/server.py &"
|
||||
sleep 1m
|
||||
}
|
||||
|
||||
function run_tests() {
|
||||
cd $WORKPATH
|
||||
local port=8890
|
||||
curl 127.0.0.1:$port/v1/rag/chat \
|
||||
-X POST \
|
||||
-d "{\"query\":\"What is the total revenue of Nike in 2023?\"}" \
|
||||
-H 'Content-Type: application/json' > $LOG_PATH
|
||||
}
|
||||
|
||||
function check_response() {
|
||||
cd $WORKPATH
|
||||
echo "Checking response"
|
||||
local status=false
|
||||
if [[ -f $LOG_PATH ]] && [[ $(grep -c "\$51.2 billion" $LOG_PATH) != 0 ]]; then
|
||||
status=true
|
||||
fi
|
||||
|
||||
if [ $status == false ]; then
|
||||
echo "Response check failed"
|
||||
exit 1
|
||||
else
|
||||
echo "Response check succeed"
|
||||
fi
|
||||
}
|
||||
|
||||
function docker_stop() {
|
||||
local container_name=$1
|
||||
cid=$(docker ps -aq --filter "name=$container_name")
|
||||
if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid; fi
|
||||
}
|
||||
|
||||
function main() {
|
||||
test_env_setup
|
||||
rename
|
||||
docker_stop $CHATQNA_CONTAINER_NAME && docker_stop $LANGCHAIN_CONTAINER_NAME && docker_stop $REDIS_CONTAINER_NAME && sleep 5s
|
||||
|
||||
launch_tgi_gaudi_service
|
||||
launch_redis_and_langchain_service
|
||||
start_backend_service
|
||||
|
||||
run_tests
|
||||
|
||||
docker_stop $AUDIOQNA_CONTAINER_NAME && docker_stop $LANGCHAIN_CONTAINER_NAME && docker_stop $REDIS_CONTAINER_NAME && sleep 5s
|
||||
echo y | docker system prune
|
||||
|
||||
check_response
|
||||
}
|
||||
|
||||
main
|
||||
@@ -1,84 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
set -xe
|
||||
|
||||
function test_env_setup() {
|
||||
WORKPATH=$(dirname "$PWD")/audio/docker
|
||||
OUTPUT_PATH=$(dirname "$PWD")/tests/output.wav
|
||||
TTS_CONTAINER_NAME="test-audioqna-tts"
|
||||
cd $WORKPATH
|
||||
}
|
||||
|
||||
function start_tts_service() {
|
||||
cd $WORKPATH
|
||||
rm -rf pretrained_tts_models
|
||||
git clone https://huggingface.co/lj1995/GPT-SoVITS pretrained_tts_models
|
||||
docker build . --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${http_proxy} -f Dockerfile_tts -t intel/gen-ai-examples:$TTS_CONTAINER_NAME
|
||||
docker run -d --name=$TTS_CONTAINER_NAME -v ./pretrained_tts_models:/GPT-SoVITS/GPT_SoVITS/pretrained_models -e http_proxy=${http_proxy} -e https_proxy=${https_proxy} -p 9888:9880 intel/gen-ai-examples:$TTS_CONTAINER_NAME --bf16
|
||||
sleep 1m
|
||||
}
|
||||
|
||||
function run_tests() {
|
||||
cd $WORKPATH
|
||||
rm -f ${OUTPUT_PATH}
|
||||
rm -f sample.wav
|
||||
|
||||
# Upload reference audio as default voice
|
||||
wget https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav
|
||||
curl --location 'localhost:9888/upload_as_default' \
|
||||
--form 'default_refer_file=@"sample.wav"' \
|
||||
--form 'default_refer_text="Who is Pat Gelsinger?"' \
|
||||
--form 'default_refer_language="en"'
|
||||
|
||||
# Do text to speech conversion
|
||||
curl --location 'localhost:9888/v1/audio/speech' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"text": "You can have a look, but you should not touch this item.",
|
||||
"text_language": "en"
|
||||
}' \
|
||||
--output ${OUTPUT_PATH}
|
||||
rm -f sample.wav
|
||||
}
|
||||
|
||||
function check_response() {
|
||||
cd $WORKPATH
|
||||
echo "Checking response"
|
||||
local status=false
|
||||
|
||||
if [[ -f $OUTPUT_PATH ]]; then
|
||||
status=true
|
||||
fi
|
||||
|
||||
if [ $status == false ]; then
|
||||
echo "Response check failed"
|
||||
exit 1
|
||||
else
|
||||
echo "Response check succeed"
|
||||
fi
|
||||
|
||||
# clear resources
|
||||
rm -f ${OUTPUT_PATH}
|
||||
}
|
||||
|
||||
function docker_stop() {
|
||||
local container_name=$1
|
||||
cid=$(docker ps -aq --filter "name=$container_name")
|
||||
if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid; fi
|
||||
}
|
||||
|
||||
function main() {
|
||||
test_env_setup
|
||||
docker_stop $TTS_CONTAINER_NAME && sleep 5s
|
||||
|
||||
start_tts_service
|
||||
run_tests
|
||||
check_response
|
||||
|
||||
docker_stop $TTS_CONTAINER_NAME && sleep 5s
|
||||
echo y | docker system prune
|
||||
}
|
||||
|
||||
main
|
||||
Reference in New Issue
Block a user