Fix vllm microservice performance issue. (#731)

* Fix vllm microservice performance issue.

Signed-off-by: Yao, Qing <qing.yao@intel.com>

* Refine llm generate parameters

Signed-off-by: Yao, Qing <qing.yao@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Yao, Qing <qing.yao@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yao Qing
2024-09-25 21:37:38 +08:00
committed by GitHub
parent f8f02e2e1d
commit 2159f9ad00

View File

@@ -26,6 +26,7 @@ logflag = os.getenv("LOGFLAG", False)
llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8008")
model_name = os.getenv("LLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
llm = VLLMOpenAI(openai_api_key="EMPTY", openai_api_base=llm_endpoint + "/v1", model_name=model_name)
@opea_telemetry
@@ -56,6 +57,13 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc])
if not isinstance(input, SearchedDoc) and input.chat_template:
prompt_template = PromptTemplate.from_template(input.chat_template)
input_variables = prompt_template.input_variables
parameters = {
"max_tokens": input.max_tokens,
"top_p": input.top_p,
"temperature": input.temperature,
"frequency_penalty": input.frequency_penalty,
"presence_penalty": input.presence_penalty,
}
if isinstance(input, SearchedDoc):
if logflag:
@@ -76,23 +84,11 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc])
if logflag:
logger.info(f"[ SearchedDoc ] final input: {new_input}")
llm = VLLMOpenAI(
openai_api_key="EMPTY",
openai_api_base=llm_endpoint + "/v1",
max_tokens=new_input.max_tokens,
model_name=model_name,
top_p=new_input.top_p,
temperature=new_input.temperature,
frequency_penalty=new_input.frequency_penalty,
presence_penalty=new_input.presence_penalty,
streaming=new_input.streaming,
)
if new_input.streaming:
def stream_generator():
async def stream_generator():
chat_response = ""
for text in llm.stream(new_input.query):
async for text in llm.astream(new_input.query, **parameters):
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
@@ -105,7 +101,7 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc])
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = llm.invoke(new_input.query)
response = llm.invoke(new_input.query, **parameters)
if logflag:
logger.info(response)
@@ -131,23 +127,11 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc])
# use rag default template
prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents)
llm = VLLMOpenAI(
openai_api_key="EMPTY",
openai_api_base=llm_endpoint + "/v1",
max_tokens=input.max_tokens,
model_name=model_name,
top_p=input.top_p,
temperature=input.temperature,
frequency_penalty=input.frequency_penalty,
presence_penalty=input.presence_penalty,
streaming=input.streaming,
)
if input.streaming:
def stream_generator():
async def stream_generator():
chat_response = ""
for text in llm.stream(input.query):
async for text in llm.astream(input.query, **parameters):
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
@@ -160,7 +144,7 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc])
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = llm.invoke(input.query)
response = llm.invoke(input.query, **parameters)
if logflag:
logger.info(response)