Files
GenAIExamples/comps/llms/text-generation/vllm/langchain/llm.py
Yao Qing 2159f9ad00 Fix vllm microservice performance issue. (#731)
* Fix vllm microservice performance issue.

Signed-off-by: Yao, Qing <qing.yao@intel.com>

* Refine llm generate parameters

Signed-off-by: Yao, Qing <qing.yao@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Yao, Qing <qing.yao@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-09-25 21:37:38 +08:00

156 lines
5.2 KiB
Python

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Union
from fastapi.responses import StreamingResponse
from langchain_community.llms import VLLMOpenAI
from langchain_core.prompts import PromptTemplate
from template import ChatTemplate
from comps import (
CustomLogger,
GeneratedDoc,
LLMParamsDoc,
SearchedDoc,
ServiceType,
opea_microservices,
opea_telemetry,
register_microservice,
)
from comps.cores.proto.api_protocol import ChatCompletionRequest
logger = CustomLogger("llm_vllm")
logflag = os.getenv("LOGFLAG", False)
llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8008")
model_name = os.getenv("LLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
llm = VLLMOpenAI(openai_api_key="EMPTY", openai_api_base=llm_endpoint + "/v1", model_name=model_name)
@opea_telemetry
def post_process_text(text: str):
if text == " ":
return "data: @#$\n\n"
if text == "\n":
return "data: <br/>\n\n"
if text.isspace():
return None
new_text = text.replace(" ", "@#$")
return f"data: {new_text}\n\n"
@register_microservice(
name="opea_service@llm_vllm",
service_type=ServiceType.LLM,
endpoint="/v1/chat/completions",
host="0.0.0.0",
port=9000,
)
def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
if logflag:
logger.info(input)
prompt_template = None
if not isinstance(input, SearchedDoc) and input.chat_template:
prompt_template = PromptTemplate.from_template(input.chat_template)
input_variables = prompt_template.input_variables
parameters = {
"max_tokens": input.max_tokens,
"top_p": input.top_p,
"temperature": input.temperature,
"frequency_penalty": input.frequency_penalty,
"presence_penalty": input.presence_penalty,
}
if isinstance(input, SearchedDoc):
if logflag:
logger.info("[ SearchedDoc ] input from retriever microservice")
prompt = input.initial_query
if input.retrieved_docs:
docs = [doc.text for doc in input.retrieved_docs]
if logflag:
logger.info(f"[ SearchedDoc ] combined retrieved docs: {docs}")
prompt = ChatTemplate.generate_rag_prompt(input.initial_query, docs)
# use default llm parameter for inference
new_input = LLMParamsDoc(query=prompt)
if logflag:
logger.info(f"[ SearchedDoc ] final input: {new_input}")
if new_input.streaming:
async def stream_generator():
chat_response = ""
async for text in llm.astream(new_input.query, **parameters):
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
logger.info(f"[ SearchedDoc ] chunk: {chunk_repr}")
yield f"data: {chunk_repr}\n\n"
if logflag:
logger.info(f"[ SearchedDoc ] stream response: {chat_response}")
yield "data: [DONE]\n\n"
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = llm.invoke(new_input.query, **parameters)
if logflag:
logger.info(response)
return GeneratedDoc(text=response, prompt=new_input.query)
elif isinstance(input, LLMParamsDoc):
if logflag:
logger.info("[ LLMParamsDoc ] input from rerank microservice")
prompt = input.query
if prompt_template:
if sorted(input_variables) == ["context", "question"]:
prompt = prompt_template.format(question=input.query, context="\n".join(input.documents))
elif input_variables == ["question"]:
prompt = prompt_template.format(question=input.query)
else:
logger.info(
f"[ LLMParamsDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']"
)
else:
if input.documents:
# use rag default template
prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents)
if input.streaming:
async def stream_generator():
chat_response = ""
async for text in llm.astream(input.query, **parameters):
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
logger.info(f"[ LLMParamsDoc ] chunk: {chunk_repr}")
yield f"data: {chunk_repr}\n\n"
if logflag:
logger.info(f"[ LLMParamsDoc ] stream response: {chat_response}")
yield "data: [DONE]\n\n"
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = llm.invoke(input.query, **parameters)
if logflag:
logger.info(response)
return GeneratedDoc(text=response, prompt=input.query)
if __name__ == "__main__":
opea_microservices["opea_service@llm_vllm"].start()