Add opensearch integration for OPEA (#1024)

* Add opensearch integration for OPEA

Signed-off-by: Cameron Morin <cammorin@amazon.com>

* Update docker compose yaml workflows files

Signed-off-by: Cameron Morin <cammorin@amazon.com>

* Fix empty files

Signed-off-by: Cameron Morin <cammorin@amazon.com>

* Address PR comments

Signed-off-by: Cameron Morin <cammorin@amazon.com>

---------

Signed-off-by: Cameron Morin <cammorin@amazon.com>
This commit is contained in:
Cameron Morin
2024-12-25 19:09:59 -08:00
committed by GitHub
parent 45d0002057
commit 8d6b4b0ac7
22 changed files with 2002 additions and 0 deletions

View File

@@ -0,0 +1,253 @@
# Dataprep Microservice with OpenSearch
For dataprep microservice for text input, we provide here the `Langchain` framework.
## 🚀1. Start Microservice with PythonOption 1
### 1.1 Install Requirements
- option 1: Install Single-process version (for processing up to 10 files)
```bash
apt update
apt install default-jre tesseract-ocr libtesseract-dev poppler-utils -y
# for langchain
cd langchain
pip install -r requirements.txt
```
### 1.2 Start OpenSearch Stack Server
Please refer to this [readme](../../vectorstores/opensearch/README.md).
### 1.3 Setup Environment Variables
```bash
export your_ip=$(hostname -I | awk '{print $1}')
export OPENSEARCH_URL="http://${your_ip}:9200"
export INDEX_NAME=${your_index_name}
export PYTHONPATH=${path_to_comps}
```
### 1.4 Start Embedding Service
First, you need to start a TEI service.
```bash
your_port=6006
model="BAAI/bge-base-en-v1.5"
docker run -p $your_port:80 -v ./data:/data --name tei_server -e http_proxy=$http_proxy -e https_proxy=$https_proxy --pull always ghcr.io/huggingface/text-embeddings-inference:cpu-1.5 --model-id $model
```
Then you need to test your TEI service using the following commands:
```bash
curl localhost:$your_port/embed \
-X POST \
-d '{"inputs":"What is Deep Learning?"}' \
-H 'Content-Type: application/json'
```
After checking that it works, set up environment variables.
```bash
export TEI_ENDPOINT="http://localhost:$your_port"
```
### 1.4 Start Document Preparation Microservice for OpenSearch with Python Script
Start document preparation microservice for OpenSearch with below command.
- option 1: Start single-process version (for processing up to 10 files)
```bash
cd langchain
python prepare_doc_opensearch.py
```
## 🚀2. Start Microservice with Docker (Option 2)
### 2.1 Start OpenSearch Stack Server
Please refer to this [readme](../../vectorstores/opensearch/README.md).
### 2.2 Setup Environment Variables
```bash
export EMBEDDING_MODEL_ID="BAAI/bge-base-en-v1.5"
export TEI_ENDPOINT="http://${your_ip}:6006"
export OPENSEARCH_URL="http://${your_ip}:9200"
export INDEX_NAME=${your_index_name}
export HUGGINGFACEHUB_API_TOKEN=${your_hf_api_token}
```
### 2.3 Build Docker Image
- Build docker image with langchain
- option 1: Start single-process version (for processing up to 10 files)
```bash
cd ../../
docker build -t opea/dataprep-opensearch:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/dataprep/opensearch/langchain/Dockerfile .
```
### 2.4 Run Docker with CLI (Option A)
- option 1: Start single-process version (for processing up to 10 files)
```bash
docker run -d --name="dataprep-opensearch-server" -p 6007:6007 --runtime=runc --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e OPENSEARCH_URL=$OPENSEARCH_URL -e INDEX_NAME=$INDEX_NAME -e TEI_ENDPOINT=$TEI_ENDPOINT -e HUGGINGFACEHUB_API_TOKEN=$HUGGINGFACEHUB_API_TOKEN opea/dataprep-opensearch:latest
```
### 2.5 Run with Docker Compose (Option B - deprecated, will move to genAIExample in future)
```bash
# for langchain
cd comps/dataprep/opensearch/langchain
# common command
docker compose -f docker-compose-dataprep-opensearch.yaml up -d
```
## 🚀3. Status Microservice
```bash
docker container logs -f dataprep-opensearch-server
```
## 🚀4. Consume Microservice
### 4.1 Consume Upload API
Once document preparation microservice for OpenSearch is started, user can use below command to invoke the microservice to convert the document to embedding and save to the database.
Make sure the file path after `files=@` is correct.
- Single file upload
```bash
curl -X POST \
-H "Content-Type: multipart/form-data" \
-F "files=@./file1.txt" \
http://localhost:6007/v1/dataprep
```
You can specify chunk_size and chunk_size by the following commands.
```bash
curl -X POST \
-H "Content-Type: multipart/form-data" \
-F "files=@./file1.txt" \
-F "chunk_size=1500" \
-F "chunk_overlap=100" \
http://localhost:6007/v1/dataprep
```
We support table extraction from pdf documents. You can specify process_table and table_strategy by the following commands. "table_strategy" refers to the strategies to understand tables for table retrieval. As the setting progresses from "fast" to "hq" to "llm," the focus shifts towards deeper table understanding at the expense of processing speed. The default strategy is "fast".
Note: If you specify "table_strategy=llm", You should first start TGI Service, please refer to 1.2.1, 1.3.1 in https://github.com/opea-project/GenAIComps/tree/main/comps/llms/README.md, and then `export TGI_LLM_ENDPOINT="http://${your_ip}:8008"`.
```bash
curl -X POST \
-H "Content-Type: multipart/form-data" \
-F "files=@./your_file.pdf" \
-F "process_table=true" \
-F "table_strategy=hq" \
http://localhost:6007/v1/dataprep
```
- Multiple file upload
```bash
curl -X POST \
-H "Content-Type: multipart/form-data" \
-F "files=@./file1.txt" \
-F "files=@./file2.txt" \
-F "files=@./file3.txt" \
http://localhost:6007/v1/dataprep
```
- Links upload (not supported for llama_index now)
```bash
curl -X POST \
-F 'link_list=["https://www.ces.tech/"]' \
http://localhost:6007/v1/dataprep
```
or
```python
import requests
import json
proxies = {"http": ""}
url = "http://localhost:6007/v1/dataprep"
urls = [
"https://towardsdatascience.com/no-gpu-no-party-fine-tune-bert-for-sentiment-analysis-with-vertex-ai-custom-jobs-d8fc410e908b?source=rss----7f60cf5620c9---4"
]
payload = {"link_list": json.dumps(urls)}
try:
resp = requests.post(url=url, data=payload, proxies=proxies)
print(resp.text)
resp.raise_for_status() # Raise an exception for unsuccessful HTTP status codes
print("Request successful!")
except requests.exceptions.RequestException as e:
print("An error occurred:", e)
```
### 4.2 Consume get_file API
To get uploaded file structures, use the following command:
```bash
curl -X POST \
-H "Content-Type: application/json" \
http://localhost:6007/v1/dataprep/get_file
```
Then you will get the response JSON like this:
```json
[
{
"name": "uploaded_file_1.txt",
"id": "uploaded_file_1.txt",
"type": "File",
"parent": ""
},
{
"name": "uploaded_file_2.txt",
"id": "uploaded_file_2.txt",
"type": "File",
"parent": ""
}
]
```
### 4.3 Consume delete_file API
To delete uploaded file/link, use the following command.
The `file_path` here should be the `id` get from `/v1/dataprep/get_file` API.
```bash
# delete link
curl -X POST \
-H "Content-Type: application/json" \
-d '{"file_path": "https://www.ces.tech/.txt"}' \
http://localhost:6007/v1/dataprep/delete_file
# delete file
curl -X POST \
-H "Content-Type: application/json" \
-d '{"file_path": "uploaded_file_1.txt"}' \
http://localhost:6007/v1/dataprep/delete_file
# delete all files and links
curl -X POST \
-H "Content-Type: application/json" \
-d '{"file_path": "all"}' \
http://localhost:6007/v1/dataprep/delete_file
```

View File

@@ -0,0 +1,42 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
FROM python:3.11-slim
ENV LANG=C.UTF-8
ARG ARCH="cpu"
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
build-essential \
default-jre \
libgl1-mesa-glx \
libjemalloc-dev \
libreoffice \
poppler-utils \
tesseract-ocr
RUN useradd -m -s /bin/bash user && \
mkdir -p /home/user && \
chown -R user /home/user/
USER user
COPY comps /home/user/comps
RUN pip install --no-cache-dir --upgrade pip setuptools && \
if [ ${ARCH} = "cpu" ]; then pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu; fi && \
pip install --no-cache-dir -r /home/user/comps/dataprep/opensearch/langchain/requirements.txt
ENV PYTHONPATH=$PYTHONPATH:/home/user
USER root
RUN mkdir -p /home/user/comps/dataprep/opensearch/langchain/uploaded_files && chown -R user /home/user/comps/dataprep/opensearch/langchain/uploaded_files
USER user
WORKDIR /home/user/comps/dataprep/opensearch/langchain
ENTRYPOINT ["python", "prepare_doc_opensearch.py"]

View File

@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

View File

@@ -0,0 +1,60 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
# Embedding model
EMBED_MODEL = os.getenv("EMBED_MODEL", "BAAI/bge-base-en-v1.5")
# OpenSearch Connection Information
OPENSEARCH_HOST = os.getenv("OPENSEARCH_HOST", "localhost")
OPENSEARCH_PORT = int(os.getenv("OPENSEARCH_PORT", 9200))
OPENSEARCH_INITIAL_ADMIN_PASSWORD = os.getenv("OPENSEARCH_INITIAL_ADMIN_PASSWORD", "")
def get_boolean_env_var(var_name, default_value=False):
"""Retrieve the boolean value of an environment variable.
Args:
var_name (str): The name of the environment variable to retrieve.
default_value (bool): The default value to return if the variable
is not found.
Returns:
bool: The value of the environment variable, interpreted as a boolean.
"""
true_values = {"true", "1", "t", "y", "yes"}
false_values = {"false", "0", "f", "n", "no"}
# Retrieve the environment variable's value
value = os.getenv(var_name, "").lower()
# Decide the boolean value based on the content of the string
if value in true_values:
return True
elif value in false_values:
return False
else:
return default_value
def format_opensearch_conn_from_env():
opensearch_url = os.getenv("OPENSEARCH_URL", None)
if opensearch_url:
return opensearch_url
else:
using_ssl = get_boolean_env_var("OPENSEARCH_SSL", False)
start = "https://" if using_ssl else "http://"
return start + f"{OPENSEARCH_HOST}:{OPENSEARCH_PORT}"
OPENSEARCH_URL = format_opensearch_conn_from_env()
# Vector Index Configuration
INDEX_NAME = os.getenv("INDEX_NAME", "rag-opensearch")
KEY_INDEX_NAME = os.getenv("KEY_INDEX_NAME", "file-keys")
TIMEOUT_SECONDS = int(os.getenv("TIMEOUT_SECONDS", 600))
SEARCH_BATCH_SIZE = int(os.getenv("SEARCH_BATCH_SIZE", 10))

View File

@@ -0,0 +1,65 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
version: "3"
services:
opensearch-vector-db:
image: opensearchproject/opensearch:latest
container_name: opensearch-vector-db
environment:
- cluster.name=opensearch-cluster
- node.name=opensearch-vector-db
- discovery.seed_hosts=opensearch-vector-db
- cluster.initial_master_nodes=opensearch-vector-db
- bootstrap.memory_lock=true # along with the memlock settings below, disables swapping
- "OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx512m" # minimum and maximum Java heap size, recommend setting both to 50% of system RAM
- OPENSEARCH_INITIAL_ADMIN_PASSWORD=${OPENSEARCH_INITIAL_ADMIN_PASSWORD} # Sets the demo admin user password when using demo configuration, required for OpenSearch 2.12 and later
ulimits:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536 # maximum number of open files for the OpenSearch user, set to at least 65536 on modern systems
hard: 65536
ports:
- 9200:9200
- 9600:9600 # required for Performance Analyzer
networks:
- opensearch-net
security_opt:
- no-new-privileges:true
tei-embedding-service:
image: ghcr.io/huggingface/text-embeddings-inference:cpu-1.5
container_name: tei-embedding-server
ports:
- "6060:80"
volumes:
- "./data:/data"
shm_size: 1g
environment:
no_proxy: ${no_proxy}
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
command: --model-id ${EMBEDDING_MODEL_ID} --auto-truncate
dataprep-opensearch:
image: opea/dataprep-opensearch:latest
container_name: dataprep-opensearch-server
ports:
- 6007:6007
ipc: host
environment:
no_proxy: ${no_proxy}
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
OPENSEARCH_URL: ${OPENSEARCH_URL}
INDEX_NAME: ${INDEX_NAME}
TEI_ENDPOINT: ${TEI_ENDPOINT}
HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN}
restart: unless-stopped
security_opt:
- no-new-privileges:true
networks:
default:
driver: bridge
opensearch-net:

View File

@@ -0,0 +1,471 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import json
import os
from pathlib import Path
from typing import List, Optional, Union
from config import (
EMBED_MODEL,
INDEX_NAME,
KEY_INDEX_NAME,
OPENSEARCH_INITIAL_ADMIN_PASSWORD,
OPENSEARCH_URL,
SEARCH_BATCH_SIZE,
)
from fastapi import Body, File, Form, HTTPException, UploadFile
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import OpenSearchVectorSearch
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain_text_splitters import HTMLHeaderTextSplitter
# from pyspark import SparkConf, SparkContext
from opensearchpy import OpenSearch, helpers
from comps import CustomLogger, DocPath, opea_microservices, register_microservice
from comps.dataprep.utils import (
create_upload_folder,
document_loader,
encode_filename,
format_search_results,
get_separators,
get_tables_result,
parse_html,
remove_folder_with_ignore,
save_content_to_local_disk,
)
logger = CustomLogger("prepare_doc_opensearch")
logflag = os.getenv("LOGFLAG", False)
upload_folder = "./uploaded_files/"
tei_embedding_endpoint = os.getenv("TEI_ENDPOINT")
if tei_embedding_endpoint:
# create embeddings using TEI endpoint service
embeddings = HuggingFaceEndpointEmbeddings(model=tei_embedding_endpoint)
else:
# create embeddings using local embedding model
embeddings = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL)
auth = ("admin", OPENSEARCH_INITIAL_ADMIN_PASSWORD)
opensearch_client = OpenSearchVectorSearch(
opensearch_url=OPENSEARCH_URL,
index_name=INDEX_NAME,
embedding_function=embeddings,
http_auth=auth,
use_ssl=True,
verify_certs=False,
ssl_assert_hostname=False,
ssl_show_warn=False,
)
def check_index_existence(client, index_name):
if logflag:
logger.info(f"[ check index existence ] checking {client}")
try:
exists = client.index_exists(index_name)
exists = False if exists is None else exists
if exists:
if logflag:
logger.info(f"[ check index existence ] index of client exists: {client}")
else:
if logflag:
logger.info("[ check index existence ] index does not exist")
return exists
except Exception as e:
if logflag:
logger.info(f"[ check index existence ] error checking index for client: {e}")
return False
def create_index(client, index_name: str = KEY_INDEX_NAME):
if logflag:
logger.info(f"[ create index ] creating index {index_name}")
try:
index_body = {
"mappings": {
"properties": {
"file_name": {"type": "text"},
"key_ids": {"type": "text"},
}
}
}
# Create the index
client.client.indices.create(index_name, body=index_body)
if logflag:
logger.info(f"[ create index ] index {index_name} successfully created")
return True
except Exception as e:
if logflag:
logger.info(f"[ create index ] fail to create index {index_name}: {e}")
return False
def store_by_id(client, key, value):
if logflag:
logger.info(f"[ store by id ] storing ids of {key}")
try:
client.client.index(
index=KEY_INDEX_NAME, body={"file_name": f"file:${key}", "key_ids:": value}, id="file:" + key, refresh=True
)
if logflag:
logger.info(f"[ store by id ] store document success. id: file:{key}")
except Exception as e:
if logflag:
logger.info(f"[ store by id ] fail to store document file:{key}: {e}")
return False
return True
def search_by_id(client, doc_id):
if logflag:
logger.info(f"[ search by id ] searching docs of {doc_id}")
try:
result = client.client.get(index=KEY_INDEX_NAME, id=doc_id)
if result["found"]:
if logflag:
logger.info(f"[ search by id ] search success of {doc_id}: {result}")
return result
return None
except Exception as e:
if logflag:
logger.info(f"[ search by id ] fail to search docs of {doc_id}: {e}")
return None
def drop_index(client, index_name):
if logflag:
logger.info(f"[ drop index ] dropping index {index_name}")
try:
client.client.indices.delete(index=index_name)
if logflag:
logger.info(f"[ drop index ] index {index_name} deleted")
except Exception as e:
if logflag:
logger.info(f"[ drop index ] index {index_name} delete failed: {e}")
return False
return True
def delete_by_id(client, doc_id):
try:
response = client.client.delete(index=KEY_INDEX_NAME, id=doc_id)
if response["result"] == "deleted":
if logflag:
logger.info(f"[ delete by id ] delete id success: {doc_id}")
return True
else:
if logflag:
logger.info(f"[ delete by id ] delete id failed: {doc_id}")
return False
except Exception as e:
if logflag:
logger.info(f"[ delete by id ] fail to delete ids {doc_id}: {e}")
return False
def ingest_chunks_to_opensearch(file_name: str, chunks: List):
if logflag:
logger.info(f"[ ingest chunks ] file name: {file_name}")
# Batch size
batch_size = 32
num_chunks = len(chunks)
file_ids = []
for i in range(0, num_chunks, batch_size):
if logflag:
logger.info(f"[ ingest chunks ] Current batch: {i}")
batch_chunks = chunks[i : i + batch_size]
keys = opensearch_client.add_texts(texts=batch_chunks, metadatas=[{"source": file_name} for _ in batch_chunks])
if logflag:
logger.info(f"[ ingest chunks ] keys: {keys}")
file_ids.extend(keys)
if logflag:
logger.info(f"[ ingest chunks ] Processed batch {i//batch_size + 1}/{(num_chunks-1)//batch_size + 1}")
# store file_ids into index file-keys
if not check_index_existence(opensearch_client, KEY_INDEX_NAME):
assert create_index(opensearch_client)
try:
assert store_by_id(opensearch_client, key=file_name, value="#".join(file_ids))
except Exception as e:
if logflag:
logger.info(f"[ ingest chunks ] {e}. Fail to store chunks of file {file_name}.")
raise HTTPException(status_code=500, detail=f"Fail to store chunks of file {file_name}.")
return True
def ingest_data_to_opensearch(doc_path: DocPath):
"""Ingest document to OpenSearch."""
path = doc_path.path
if logflag:
logger.info(f"[ ingest data ] Parsing document {path}.")
if path.endswith(".html"):
headers_to_split_on = [
("h1", "Header 1"),
("h2", "Header 2"),
("h3", "Header 3"),
]
text_splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
else:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=doc_path.chunk_size,
chunk_overlap=doc_path.chunk_overlap,
add_start_index=True,
separators=get_separators(),
)
content = document_loader(path)
if logflag:
logger.info("[ ingest data ] file content loaded")
structured_types = [".xlsx", ".csv", ".json", "jsonl"]
_, ext = os.path.splitext(path)
if ext in structured_types:
chunks = content
else:
chunks = text_splitter.split_text(content)
### Specially processing for the table content in PDFs
if doc_path.process_table and path.endswith(".pdf"):
table_chunks = get_tables_result(path, doc_path.table_strategy)
chunks = chunks + table_chunks
if logflag:
logger.info(f"[ ingest data ] Done preprocessing. Created {len(chunks)} chunks of the given file.")
file_name = doc_path.path.split("/")[-1]
return ingest_chunks_to_opensearch(file_name, chunks)
def search_all_documents(index_name, offset, search_batch_size):
try:
response = opensearch_client.client.search(
index=index_name,
body={
"query": {"match_all": {}},
"from": offset, # Starting position
"size": search_batch_size, # Number of results to return
},
)
# Get total number of matching documents
total_hits = response["hits"]["total"]["value"]
# Get the documents from the current batch
documents = response["hits"]["hits"]
return {"total_hits": total_hits, "documents": documents}
except Exception as e:
print(f"Error performing search: {e}")
return None
@register_microservice(name="opea_service@prepare_doc_opensearch", endpoint="/v1/dataprep", host="0.0.0.0", port=6007)
async def ingest_documents(
files: Optional[Union[UploadFile, List[UploadFile]]] = File(None),
link_list: Optional[str] = Form(None),
chunk_size: int = Form(1500),
chunk_overlap: int = Form(100),
process_table: bool = Form(False),
table_strategy: str = Form("fast"),
):
if logflag:
logger.info(f"[ upload ] files:{files}")
logger.info(f"[ upload ] link_list:{link_list}")
if files:
if not isinstance(files, list):
files = [files]
uploaded_files = []
for file in files:
encode_file = encode_filename(file.filename)
doc_id = "file:" + encode_file
if logflag:
logger.info(f"[ upload ] processing file {doc_id}")
# check whether the file already exists
key_ids = None
try:
document = search_by_id(opensearch_client, doc_id)
if document:
if logflag:
logger.info(f"[ upload ] File {file.filename} already exists.")
key_ids = document["_id"]
except Exception as e:
logger.info(f"[ upload ] File {file.filename} does not exist.")
if key_ids:
raise HTTPException(
status_code=400, detail=f"Uploaded file {file.filename} already exists. Please change file name."
)
save_path = upload_folder + encode_file
await save_content_to_local_disk(save_path, file)
ingest_data_to_opensearch(
DocPath(
path=save_path,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
process_table=process_table,
table_strategy=table_strategy,
)
)
uploaded_files.append(save_path)
if logflag:
logger.info(f"[ upload ] Successfully saved file {save_path}")
result = {"status": 200, "message": "Data preparation succeeded"}
if logflag:
logger.info(result)
return result
if link_list:
link_list = json.loads(link_list) # Parse JSON string to list
if not isinstance(link_list, list):
raise HTTPException(status_code=400, detail=f"Link_list {link_list} should be a list.")
for link in link_list:
encoded_link = encode_filename(link)
doc_id = "file:" + encoded_link + ".txt"
if logflag:
logger.info(f"[ upload ] processing link {doc_id}")
# check whether the link file already exists
key_ids = None
try:
document = search_by_id(opensearch_client, doc_id)
if document:
if logflag:
logger.info(f"[ upload ] Link {link} already exists.")
key_ids = document["_id"]
except Exception as e:
logger.info(f"[ upload ] Link {link} does not exist. Keep storing.")
if key_ids:
raise HTTPException(
status_code=400, detail=f"Uploaded link {link} already exists. Please change another link."
)
save_path = upload_folder + encoded_link + ".txt"
content = parse_html([link])[0][0]
await save_content_to_local_disk(save_path, content)
ingest_data_to_opensearch(
DocPath(
path=save_path,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
process_table=process_table,
table_strategy=table_strategy,
)
)
if logflag:
logger.info(f"[ upload ] Successfully saved link list {link_list}")
return {"status": 200, "message": "Data preparation succeeded"}
raise HTTPException(status_code=400, detail="Must provide either a file or a string list.")
@register_microservice(
name="opea_service@prepare_doc_opensearch", endpoint="/v1/dataprep/get_file", host="0.0.0.0", port=6007
)
async def rag_get_file_structure():
if logflag:
logger.info("[ get ] start to get file structure")
offset = 0
file_list = []
# check index existence
res = check_index_existence(opensearch_client, KEY_INDEX_NAME)
if not res:
if logflag:
logger.info(f"[ get ] index {KEY_INDEX_NAME} does not exist")
return file_list
while True:
response = search_all_documents(KEY_INDEX_NAME, offset, SEARCH_BATCH_SIZE)
# no doc retrieved
if len(response) < 2:
break
def format_opensearch_results(response, file_list):
for document in response["documents"]:
file_id = document["_id"]
file_list.append({"name": file_id, "id": file_id, "type": "File", "parent": ""})
file_list = format_opensearch_results(response, file_list)
offset += SEARCH_BATCH_SIZE
# last batch
if (len(response) - 1) // 2 < SEARCH_BATCH_SIZE:
break
if logflag:
logger.info(f"[get] final file_list: {file_list}")
return file_list
@register_microservice(
name="opea_service@prepare_doc_opensearch", endpoint="/v1/dataprep/delete_file", host="0.0.0.0", port=6007
)
async def delete_single_file(file_path: str = Body(..., embed=True)):
"""Delete file according to `file_path`.
`file_path`:
- specific file path (e.g. /path/to/file.txt)
- "all": delete all files uploaded
"""
# delete all uploaded files
if file_path == "all":
if logflag:
logger.info("[ delete ] delete all files")
# drop index KEY_INDEX_NAME
if check_index_existence(opensearch_client, KEY_INDEX_NAME):
try:
assert drop_index(index_name=KEY_INDEX_NAME)
except Exception as e:
if logflag:
logger.info(f"[ delete ] {e}. Fail to drop index {KEY_INDEX_NAME}.")
raise HTTPException(status_code=500, detail=f"Fail to drop index {KEY_INDEX_NAME}.")
else:
logger.info(f"[ delete ] Index {KEY_INDEX_NAME} does not exits.")
# drop index INDEX_NAME
if check_index_existence(opensearch_client, INDEX_NAME):
try:
assert drop_index(index_name=INDEX_NAME)
except Exception as e:
if logflag:
logger.info(f"[ delete ] {e}. Fail to drop index {INDEX_NAME}.")
raise HTTPException(status_code=500, detail=f"Fail to drop index {INDEX_NAME}.")
else:
if logflag:
logger.info(f"[ delete ] Index {INDEX_NAME} does not exits.")
# delete files on local disk
try:
remove_folder_with_ignore(upload_folder)
except Exception as e:
if logflag:
logger.info(f"[ delete ] {e}. Fail to delete {upload_folder}.")
raise HTTPException(status_code=500, detail=f"Fail to delete {upload_folder}.")
if logflag:
logger.info("[ delete ] successfully delete all files.")
create_upload_folder(upload_folder)
if logflag:
logger.info({"status": True})
return {"status": True}
else:
raise HTTPException(status_code=404, detail="Single file deletion is not implemented yet")
if __name__ == "__main__":
create_upload_folder(upload_folder)
opea_microservices["opea_service@prepare_doc_opensearch"].start()

View File

@@ -0,0 +1,30 @@
beautifulsoup4
cairosvg
docarray[full]
docx2txt
easyocr
fastapi
huggingface_hub
langchain
langchain-community
langchain-text-splitters
langchain_huggingface
markdown
numpy
opensearch-py
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-sdk
pandas
Pillow
prometheus-fastapi-instrumentator
pymupdf
pyspark
pytesseract
python-bidi
python-docx
python-pptx
sentence_transformers
shortuuid
unstructured[all-docs]
uvicorn

View File

@@ -0,0 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
FROM python:3.11-slim
ARG ARCH="cpu"
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
libgl1-mesa-glx \
libjemalloc-dev
RUN useradd -m -s /bin/bash user && \
mkdir -p /home/user && \
chown -R user /home/user/
COPY comps /home/user/comps
USER user
RUN pip install --no-cache-dir --upgrade pip && \
if [ ${ARCH} = "cpu" ]; then pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu; fi && \
pip install --no-cache-dir -r /home/user/comps/retrievers/opensearch/langchain/requirements.txt
ENV PYTHONPATH=$PYTHONPATH:/home/user
WORKDIR /home/user/comps/retrievers/opensearch/langchain
ENTRYPOINT ["python", "retriever_opensearch.py"]

View File

@@ -0,0 +1,144 @@
# Retriever Microservice
This retriever microservice is a highly efficient search service designed for handling and retrieving embedding vectors. It operates by receiving an embedding vector as input and conducting a similarity search against vectors stored in a VectorDB database. Users must specify the VectorDB's URL and the index name, and the service searches within that index to find documents with the highest similarity to the input vector.
The service primarily utilizes similarity measures in vector space to rapidly retrieve contentually similar documents. The vector-based retrieval approach is particularly suited for handling large datasets, offering fast and accurate search results that significantly enhance the efficiency and quality of information retrieval.
Overall, this microservice provides robust backend support for applications requiring efficient similarity searches, playing a vital role in scenarios such as recommendation systems, information retrieval, or any other context where precise measurement of document similarity is crucial.
## 🚀1. Start Microservice with Python (Option 1)
To start the retriever microservice, you must first install the required python packages.
### 1.1 Install Requirements
```bash
pip install -r requirements.txt
```
### 1.2 Start TEI Service
```bash
model=BAAI/bge-base-en-v1.5
volume=$PWD/data
docker run -d -p 6060:80 -v $volume:/data -e http_proxy=$http_proxy -e https_proxy=$https_proxy --pull always ghcr.io/huggingface/text-embeddings-inference:cpu-1.5 --model-id $model
```
### 1.3 Verify the TEI Service
Health check the embedding service with:
```bash
curl 127.0.0.1:6060/embed \
-X POST \
-d '{"inputs":"What is Deep Learning?"}' \
-H 'Content-Type: application/json'
```
### 1.4 Setup VectorDB Service
You need to setup your own VectorDB service (OpenSearch in this example), and ingest your knowledge documents into the vector database.
As for OpenSearch, you could start a docker container referencing the instructions found in the OpenSearch vectorstores [README.md](../../../vectorstores/opensearch/README.md)
### 1.5 Start Retriever Service
```bash
export TEI_EMBEDDING_ENDPOINT="http://${your_ip}:6060"
python retriever_opensearch.py
```
## 🚀2. Start Microservice with Docker (Option 2)
### 2.1 Setup Environment Variables
```bash
export RETRIEVE_MODEL_ID="BAAI/bge-base-en-v1.5"
export OPENSEARCH_URL="http://${your_ip}:9200"
export INDEX_NAME=${your_index_name}
export TEI_EMBEDDING_ENDPOINT="http://${your_ip}:6060"
export HUGGINGFACEHUB_API_TOKEN=${your_hf_token}
export OPENSEARCH_INITIAL_ADMIN_PASSWORD=${your_opensearch_initial_admin_password}
```
### 2.2 Build Docker Image
```bash
cd ../../../../
docker build -t opea/retriever-opensearch-server:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/retrievers/opensearch/langchain/Dockerfile .
```
To start a docker container, you have two options:
- A. Run Docker with CLI
- B. Run Docker with Docker Compose
You can choose one as needed.
### 2.3 Run Docker with CLI (Option A)
```bash
docker run -d --name="retriever-opensearch-server" -p 7000:7000 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e OPENSEARCH_URL=$OPENSEARCH_URL -e INDEX_NAME=$INDEX_NAME -e TEI_EMBEDDING_ENDPOINT=$TEI_EMBEDDING_ENDPOINT -e HUGGINGFACEHUB_API_TOKEN=$HUGGINGFACEHUB_API_TOKEN opea/retriever-opensearch:latest
```
### 2.4 Run Docker with Docker Compose (Option B)
```bash
docker compose -f docker_compose_retriever.yaml up -d
```
## 🚀3. Consume Retriever Service
### 3.1 Check Service Status
```bash
curl http://localhost:7000/v1/health_check \
-X GET \
-H 'Content-Type: application/json'
```
### 3.2 Consume Embedding Service
To consume the Retriever Microservice, you can generate a mock embedding vector of length 768 with Python.
```bash
export your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)")
curl http://${your_ip}:7000/v1/retrieval \
-X POST \
-d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding}}" \
-H 'Content-Type: application/json'
```
You can set the parameters for the retriever.
```bash
export your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)")
curl http://localhost:7000/v1/retrieval \
-X POST \
-d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"similarity\", \"k\":4}" \
-H 'Content-Type: application/json'
```
```bash
export your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)")
curl http://localhost:7000/v1/retrieval \
-X POST \
-d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"similarity_distance_threshold\", \"k\":4, \"distance_threshold\":1.0}" \
-H 'Content-Type: application/json'
```
```bash
export your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)")
curl http://localhost:7000/v1/retrieval \
-X POST \
-d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"similarity_score_threshold\", \"k\":4, \"score_threshold\":0.2}" \
-H 'Content-Type: application/json'
```
```bash
export your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)")
curl http://localhost:7000/v1/retrieval \
-X POST \
-d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"mmr\", \"k\":4, \"fetch_k\":20, \"lambda_mult\":0.5}" \
-H 'Content-Type: application/json'
```

View File

@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

View File

@@ -0,0 +1,36 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
version: "3.8"
services:
tei_xeon_service:
image: ghcr.io/huggingface/text-embeddings-inference:cpu-1.2
container_name: tei-xeon_server
ports:
- "6060:80"
volumes:
- "./data:/data"
shm_size: 1g
command: --model-id ${RETRIEVE_MODEL_ID}
retriever:
image: opea/retriever-opensearch-server
container_name: retriever-opensearch-server
ports:
- "7000:7000"
ipc: host
environment:
no_proxy: ${no_proxy}
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
OPENSEARCH_URL: ${OPENSEARCH_URL}
INDEX_NAME: ${INDEX_NAME}
TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN}
restart: unless-stopped
security_opt:
- no-new-privileges:true
networks:
default:
driver: bridge

View File

@@ -0,0 +1,70 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
def get_boolean_env_var(var_name, default_value=False):
"""Retrieve the boolean value of an environment variable.
Args:
var_name (str): The name of the environment variable to retrieve.
default_value (bool): The default value to return if the variable
is not found.
Returns:
bool: The value of the environment variable, interpreted as a boolean.
"""
true_values = {"true", "1", "t", "y", "yes"}
false_values = {"false", "0", "f", "n", "no"}
# Retrieve the environment variable's value
value = os.getenv(var_name, "").lower()
# Decide the boolean value based on the content of the string
if value in true_values:
return True
elif value in false_values:
return False
else:
return default_value
# Whether or not to enable langchain debugging
DEBUG = get_boolean_env_var("DEBUG", False)
# Set DEBUG env var to "true" if you wish to enable LC debugging module
if DEBUG:
import langchain
langchain.debug = True
# Embedding model
EMBED_MODEL = os.getenv("EMBED_MODEL", "BAAI/bge-base-en-v1.5")
# OpenSearch Connection Information
OPENSEARCH_HOST = os.getenv("OPENSEARCH_HOST", "localhost")
OPENSEARCH_PORT = int(os.getenv("OPENSEARCH_PORT", 9200))
OPENSEARCH_INITIAL_ADMIN_PASSWORD = os.getenv("OPENSEARCH_INITIAL_ADMIN_PASSWORD", "")
def format_opensearch_conn_from_env():
opensearch_url = os.getenv("OPENSEARCH_URL", None)
if opensearch_url:
return opensearch_url
else:
using_ssl = get_boolean_env_var("OPENSEARCH_SSL", False)
start = "https://" if using_ssl else "http://"
return start + f"{OPENSEARCH_HOST}:{OPENSEARCH_PORT}"
OPENSEARCH_URL = format_opensearch_conn_from_env()
# Vector Index Configuration
INDEX_NAME = os.getenv("INDEX_NAME", "rag-opensearch")
current_file_path = os.path.abspath(__file__)
parent_dir = os.path.dirname(current_file_path)

View File

@@ -0,0 +1,16 @@
docarray[full]
easyocr
fastapi
langchain_community
langchain_huggingface
numpy
opensearch-py
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-sdk
prometheus-fastapi-instrumentator
pydantic
pymupdf
sentence_transformers
shortuuid
uvicorn

View File

@@ -0,0 +1,162 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
import time
from typing import Callable, List, Union
import numpy as np
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import OpenSearchVectorSearch
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from opensearch_config import EMBED_MODEL, INDEX_NAME, OPENSEARCH_INITIAL_ADMIN_PASSWORD, OPENSEARCH_URL
from pydantic import conlist
from comps import (
CustomLogger,
EmbedDoc,
SearchedDoc,
ServiceType,
TextDoc,
opea_microservices,
register_microservice,
register_statistics,
statistics_dict,
)
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
RetrievalRequest,
RetrievalResponse,
RetrievalResponseData,
)
logger = CustomLogger("retriever_opensearch")
logflag = os.getenv("LOGFLAG", False)
tei_embedding_endpoint = os.getenv("TEI_EMBEDDING_ENDPOINT", None)
async def search_all_embeddings_vectors(
embeddings: Union[conlist(float, min_length=0), List[conlist(float, min_length=0)]], func: Callable, *args, **kwargs
):
try:
if not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
if not np.issubdtype(embeddings.dtype, np.floating):
raise ValueError("All embeddings values must be floating point numbers")
if embeddings.ndim == 1:
return await func(embedding=embeddings, *args, **kwargs)
elif embeddings.ndim == 2:
responses = []
for emb in embeddings:
response = await func(embedding=emb, *args, **kwargs)
responses.extend(response)
return responses
else:
raise ValueError("Embeddings must be one or two dimensional")
except Exception as e:
raise ValueError(f"Embedding data is not valid: {e}")
@register_microservice(
name="opea_service@retriever_opensearch",
service_type=ServiceType.RETRIEVER,
endpoint="/v1/retrieval",
host="0.0.0.0",
port=7000,
)
@register_statistics(names=["opea_service@retriever_opensearch"])
async def retrieve(
input: Union[EmbedDoc, RetrievalRequest, ChatCompletionRequest]
) -> Union[SearchedDoc, RetrievalResponse, ChatCompletionRequest]:
if logflag:
logger.info(input)
start = time.time()
# Check if the index exists and has documents
doc_count = 0
index_exists = vector_db.client.indices.exists(index=INDEX_NAME)
if index_exists:
doc_count = vector_db.client.count(index=INDEX_NAME)["count"]
if (not index_exists) or doc_count == 0:
search_res = []
else:
if isinstance(input, EmbedDoc):
query = input.text
else:
# for RetrievalRequest, ChatCompletionRequest
query = input.input
# if the OpenSearch index has data, perform the search
if input.search_type == "similarity":
search_res = await search_all_embeddings_vectors(
embeddings=input.embedding,
func=vector_db.asimilarity_search_by_vector,
k=input.k,
)
elif input.search_type == "similarity_distance_threshold":
if input.distance_threshold is None:
raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever")
search_res = await search_all_embeddings_vectors(
embeddings=input.embedding,
func=vector_db.asimilarity_search_by_vector,
k=input.k,
distance_threshold=input.distance_threshold,
)
elif input.search_type == "similarity_score_threshold":
doc_and_similarities = await vector_db.asimilarity_search_with_relevance_scores(
query=input.text, k=input.k, score_threshold=input.score_threshold
)
search_res = [doc for doc, _ in doc_and_similarities]
elif input.search_type == "mmr":
search_res = await vector_db.amax_marginal_relevance_search(
query=input.text, k=input.k, fetch_k=input.fetch_k, lambda_mult=input.lambda_mult
)
else:
raise ValueError(f"{input.search_type} not valid")
# return different response format
retrieved_docs = []
if isinstance(input, EmbedDoc):
for r in search_res:
retrieved_docs.append(TextDoc(text=r.page_content))
result = SearchedDoc(retrieved_docs=retrieved_docs, initial_query=input.text)
else:
for r in search_res:
retrieved_docs.append(RetrievalResponseData(text=r.page_content, metadata=r.metadata))
if isinstance(input, RetrievalRequest):
result = RetrievalResponse(retrieved_docs=retrieved_docs)
elif isinstance(input, ChatCompletionRequest):
input.retrieved_docs = retrieved_docs
input.documents = [doc.text for doc in retrieved_docs]
result = input
statistics_dict["opea_service@retriever_opensearch"].append_latency(time.time() - start, None)
if logflag:
logger.info(result)
return result
if __name__ == "__main__":
# Create vectorstore
if tei_embedding_endpoint:
# create embeddings using TEI endpoint service
embeddings = HuggingFaceEndpointEmbeddings(model=tei_embedding_endpoint)
else:
# create embeddings using local embedding model
embeddings = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL)
auth = ("admin", OPENSEARCH_INITIAL_ADMIN_PASSWORD)
vector_db = OpenSearchVectorSearch(
opensearch_url=OPENSEARCH_URL,
index_name=INDEX_NAME,
embedding_function=embeddings,
http_auth=auth,
use_ssl=True,
verify_certs=False,
ssl_assert_hostname=False,
ssl_show_warn=False,
)
opea_microservices["opea_service@retriever_opensearch"].start()

View File

@@ -0,0 +1,35 @@
# Start Opensearch server
## Prerequisites
1. Install docker
1. Install docker compose (if not already installed)
1. `sudo curl -L https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m) -o /usr/local/bin/docker-compose`
2. `sudo chmod +x /usr/local/bin/docker-compose`
## Instructions
### 1. Set admin password as environment variable
OpenSearch version 2.12 and later require a custom admin password to be set. Following [these guidelines](https://opensearch.org/docs/latest/security/configuration/demo-configuration/#setting-up-a-custom-admin-password), set the admin password as an environment variable to be used by the `docker-compose-opensearch.yml` file like `export OPENSEARCH_INITIAL_ADMIN_PASSWORD=_some_admin_password` in the terminal before starting the docker containers.
### 2. Start the cluster
`docker-compose -f docker-compose-opensearch.yml up`
## Troubleshooting
### "java.nio.file.FileSystemNotFoundException: null" error
1. Make sure to grant read permissions to your local data volume folders
1. `sudo chown -R instance_user:instance_user ./opensearch-data1`
2. `sudo chown -R instance_user:instance_user ./opensearch-data2`
1. Replace `instance_user` with the login user (i.e. ec2-user, ssm-user, or your local user name)
2. Try increasing the virtual max memory map count
1. `sudo sysctl -w vm.max_map_count=262144`
### OpenSearch Dashboards container errors
1. Make sure to grant read permission to the `opensearch_dashboards.yml` file
1. `sudo chown -R instance_user:instance_user ./opensearch_dashboards.yml`
1. Replace `instance_user` with the login user (i.e. ec2-user, ssm-user, or your local user name)

View File

@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

View File

@@ -0,0 +1,81 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
version: '3'
services:
opensearch-node1:
image: opensearchproject/opensearch:latest
container_name: opensearch-node1
environment:
- cluster.name=opensearch-cluster
- node.name=opensearch-node1
- discovery.seed_hosts=opensearch-node1,opensearch-node2
- cluster.initial_master_nodes=opensearch-node1,opensearch-node2
- bootstrap.memory_lock=true # along with the memlock settings below, disables swapping
- "OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx512m" # minimum and maximum Java heap size, recommend setting both to 50% of system RAM
- OPENSEARCH_INITIAL_ADMIN_PASSWORD=${OPENSEARCH_INITIAL_ADMIN_PASSWORD} # Sets the demo admin user password when using demo configuration, required for OpenSearch 2.12 and later
ulimits:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536 # maximum number of open files for the OpenSearch user, set to at least 65536 on modern systems
hard: 65536
volumes:
- ./opensearch-data1:/var/lib/opensearch/data
ports:
- 9200:9200
- 9600:9600 # required for Performance Analyzer
networks:
- opensearch-net
security_opt:
- no-new-privileges:true
opensearch-node2:
image: opensearchproject/opensearch:latest
container_name: opensearch-node2
environment:
- cluster.name=opensearch-cluster
- node.name=opensearch-node2
- discovery.seed_hosts=opensearch-node1,opensearch-node2
- cluster.initial_master_nodes=opensearch-node1,opensearch-node2
- bootstrap.memory_lock=true
- "OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx512m"
- OPENSEARCH_INITIAL_ADMIN_PASSWORD=${OPENSEARCH_INITIAL_ADMIN_PASSWORD} # Sets the demo admin user password when using demo configuration, required for OpenSearch 2.12 and later
ulimits:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536
hard: 65536
volumes:
- ./opensearch-data2:/var/lib/opensearch/data
networks:
- opensearch-net
security_opt:
- no-new-privileges:true
opensearch-dashboards:
image: opensearchproject/opensearch-dashboards:latest
volumes:
- ./opensearch_dashboards.yml:/usr/share/opensearch-dashboards/config/opensearch_dashboards.yml
container_name: opensearch-dashboards
ports:
- 5601:5601
expose:
- "5601"
environment:
OPENSEARCH_HOSTS: '["https://opensearch-node1:9200","https://opensearch-node2:9200"]' # must be a string with no spaces when specified as an environment variable
networks:
- opensearch-net
security_opt:
- no-new-privileges:true
depends_on:
- opensearch-node1
- opensearch-node2
volumes:
opensearch-data1:
opensearch-data2:
networks:
opensearch-net:

View File

@@ -0,0 +1,210 @@
---
# Copyright OpenSearch Contributors
# SPDX-License-Identifier: Apache-2.0
# Description:
# Default configuration for OpenSearch Dashboards
# OpenSearch Dashboards is served by a back end server. This setting specifies the port to use.
# server.port: 5601
# Specifies the address to which the OpenSearch Dashboards server will bind. IP addresses and host names are both valid values.
# The default is 'localhost', which usually means remote machines will not be able to connect.
# To allow connections from remote users, set this parameter to a non-loopback address.
# server.host: "localhost"
# Enables you to specify a path to mount OpenSearch Dashboards at if you are running behind a proxy.
# Use the `server.rewriteBasePath` setting to tell OpenSearch Dashboards if it should remove the basePath
# from requests it receives, and to prevent a deprecation warning at startup.
# This setting cannot end in a slash.
# server.basePath: ""
# Specifies whether OpenSearch Dashboards should rewrite requests that are prefixed with
# `server.basePath` or require that they are rewritten by your reverse proxy.
# server.rewriteBasePath: false
# The maximum payload size in bytes for incoming server requests.
# server.maxPayloadBytes: 1048576
# The OpenSearch Dashboards server's name. This is used for display purposes.
# server.name: "your-hostname"
# The URLs of the OpenSearch instances to use for all your queries.
# opensearch.hosts: ["http://localhost:9200"]
# OpenSearch Dashboards uses an index in OpenSearch to store saved searches, visualizations and
# dashboards. OpenSearch Dashboards creates a new index if the index doesn't already exist.
# opensearchDashboards.index: ".opensearch_dashboards"
# The default application to load.
# opensearchDashboards.defaultAppId: "home"
# Setting for an optimized healthcheck that only uses the local OpenSearch node to do Dashboards healthcheck.
# This settings should be used for large clusters or for clusters with ingest heavy nodes.
# It allows Dashboards to only healthcheck using the local OpenSearch node rather than fan out requests across all nodes.
#
# It requires the user to create an OpenSearch node attribute with the same name as the value used in the setting
# This node attribute should assign all nodes of the same cluster an integer value that increments with each new cluster that is spun up
# e.g. in opensearch.yml file you would set the value to a setting using node.attr.cluster_id:
# Should only be enabled if there is a corresponding node attribute created in your OpenSearch config that matches the value here
# opensearch.optimizedHealthcheckId: "cluster_id"
# If your OpenSearch is protected with basic authentication, these settings provide
# the username and password that the OpenSearch Dashboards server uses to perform maintenance on the OpenSearch Dashboards
# index at startup. Your OpenSearch Dashboards users still need to authenticate with OpenSearch, which
# is proxied through the OpenSearch Dashboards server.
# opensearch.username: "opensearch_dashboards_system"
# opensearch.password: "pass"
# Enables SSL and paths to the PEM-format SSL certificate and SSL key files, respectively.
# These settings enable SSL for outgoing requests from the OpenSearch Dashboards server to the browser.
# server.ssl.enabled: false
# server.ssl.certificate: /path/to/your/server.crt
# server.ssl.key: /path/to/your/server.key
# Optional settings that provide the paths to the PEM-format SSL certificate and key files.
# These files are used to verify the identity of OpenSearch Dashboards to OpenSearch and are required when
# xpack.security.http.ssl.client_authentication in OpenSearch is set to required.
# opensearch.ssl.certificate: /path/to/your/client.crt
# opensearch.ssl.key: /path/to/your/client.key
# Optional setting that enables you to specify a path to the PEM file for the certificate
# authority for your OpenSearch instance.
# opensearch.ssl.certificateAuthorities: [ "/path/to/your/CA.pem" ]
# To disregard the validity of SSL certificates, change this setting's value to 'none'.
# opensearch.ssl.verificationMode: full
# Time in milliseconds to wait for OpenSearch to respond to pings. Defaults to the value of
# the opensearch.requestTimeout setting.
# opensearch.pingTimeout: 1500
# Time in milliseconds to wait for responses from the back end or OpenSearch. This value
# must be a positive integer.
# opensearch.requestTimeout: 30000
# List of OpenSearch Dashboards client-side headers to send to OpenSearch. To send *no* client-side
# headers, set this value to [] (an empty list).
# opensearch.requestHeadersWhitelist: [ authorization ]
# Header names and values that are sent to OpenSearch. Any custom headers cannot be overwritten
# by client-side headers, regardless of the opensearch.requestHeadersWhitelist configuration.
# opensearch.customHeaders: {}
# Time in milliseconds for OpenSearch to wait for responses from shards. Set to 0 to disable.
# opensearch.shardTimeout: 30000
# Logs queries sent to OpenSearch. Requires logging.verbose set to true.
# opensearch.logQueries: false
# Specifies the path where OpenSearch Dashboards creates the process ID file.
# pid.file: /var/run/opensearchDashboards.pid
# Enables you to specify a file where OpenSearch Dashboards stores log output.
# logging.dest: stdout
# Set the value of this setting to true to suppress all logging output.
# logging.silent: false
# Set the value of this setting to true to suppress all logging output other than error messages.
# logging.quiet: false
# Set the value of this setting to true to log all events, including system usage information
# and all requests.
# logging.verbose: false
# Set the interval in milliseconds to sample system and process performance
# metrics. Minimum is 100ms. Defaults to 5000.
# ops.interval: 5000
# Specifies locale to be used for all localizable strings, dates and number formats.
# Supported languages are the following: English - en , by default , Chinese - zh-CN .
# i18n.locale: "en"
# Set the allowlist to check input graphite Url. Allowlist is the default check list.
# vis_type_timeline.graphiteAllowedUrls: ['https://www.hostedgraphite.com/UID/ACCESS_KEY/graphite']
# Set the blocklist to check input graphite Url. Blocklist is an IP list.
# Below is an example for reference
# vis_type_timeline.graphiteBlockedIPs: [
# //Loopback
# '127.0.0.0/8',
# '::1/128',
# //Link-local Address for IPv6
# 'fe80::/10',
# //Private IP address for IPv4
# '10.0.0.0/8',
# '172.16.0.0/12',
# '192.168.0.0/16',
# //Unique local address (ULA)
# 'fc00::/7',
# //Reserved IP address
# '0.0.0.0/8',
# '100.64.0.0/10',
# '192.0.0.0/24',
# '192.0.2.0/24',
# '198.18.0.0/15',
# '192.88.99.0/24',
# '198.51.100.0/24',
# '203.0.113.0/24',
# '224.0.0.0/4',
# '240.0.0.0/4',
# '255.255.255.255/32',
# '::/128',
# '2001:db8::/32',
# 'ff00::/8',
# ]
# vis_type_timeline.graphiteBlockedIPs: []
# opensearchDashboards.branding:
# logo:
# defaultUrl: ""
# darkModeUrl: ""
# mark:
# defaultUrl: ""
# darkModeUrl: ""
# loadingLogo:
# defaultUrl: ""
# darkModeUrl: ""
# faviconUrl: ""
# applicationTitle: ""
# Set the value of this setting to true to capture region blocked warnings and errors
# for your map rendering services.
# map.showRegionBlockedWarning: false%
# Set the value of this setting to false to suppress search usage telemetry
# for reducing the load of OpenSearch cluster.
# data.search.usageTelemetry.enabled: false
# 2.4 renames 'wizard.enabled: false' to 'vis_builder.enabled: false'
# Set the value of this setting to false to disable VisBuilder
# functionality in Visualization.
# vis_builder.enabled: false
# 2.4 New Experimental Feature
# Set the value of this setting to true to enable the experimental multiple data source
# support feature. Use with caution.
# data_source.enabled: false
# Set the value of these settings to customize crypto materials to encryption saved credentials
# in data sources.
# data_source.encryption.wrappingKeyName: 'changeme'
# data_source.encryption.wrappingKeyNamespace: 'changeme'
# data_source.encryption.wrappingKey: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
# 2.6 New ML Commons Dashboards Experimental Feature
# Set the value of this setting to true to enable the experimental ml commons dashboards
ml_commons_dashboards.enabled: true
opensearch.hosts: ["https://localhost:9200"]
opensearch.ssl.verificationMode: none
opensearch.username: kibanaserver
opensearch.password: kibanaserver
opensearch.requestHeadersWhitelist: [authorization, securitytenant]
opensearch_security.multitenancy.enabled: true
opensearch_security.multitenancy.tenants.preferred: [Private, Global]
opensearch_security.readonly_mode.roles: [kibana_read_only]
# Use this setting if you are running opensearch-dashboards without https
opensearch_security.cookie.secure: false
server.host: '0.0.0.0'