diff --git a/comps/llms/text-generation/vllm/langchain/llm.py b/comps/llms/text-generation/vllm/langchain/llm.py index fdb245320..950c6d531 100644 --- a/comps/llms/text-generation/vllm/langchain/llm.py +++ b/comps/llms/text-generation/vllm/langchain/llm.py @@ -26,6 +26,7 @@ logflag = os.getenv("LOGFLAG", False) llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8008") model_name = os.getenv("LLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct") +llm = VLLMOpenAI(openai_api_key="EMPTY", openai_api_base=llm_endpoint + "/v1", model_name=model_name) @opea_telemetry @@ -56,6 +57,13 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]) if not isinstance(input, SearchedDoc) and input.chat_template: prompt_template = PromptTemplate.from_template(input.chat_template) input_variables = prompt_template.input_variables + parameters = { + "max_tokens": input.max_tokens, + "top_p": input.top_p, + "temperature": input.temperature, + "frequency_penalty": input.frequency_penalty, + "presence_penalty": input.presence_penalty, + } if isinstance(input, SearchedDoc): if logflag: @@ -76,23 +84,11 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]) if logflag: logger.info(f"[ SearchedDoc ] final input: {new_input}") - llm = VLLMOpenAI( - openai_api_key="EMPTY", - openai_api_base=llm_endpoint + "/v1", - max_tokens=new_input.max_tokens, - model_name=model_name, - top_p=new_input.top_p, - temperature=new_input.temperature, - frequency_penalty=new_input.frequency_penalty, - presence_penalty=new_input.presence_penalty, - streaming=new_input.streaming, - ) - if new_input.streaming: - def stream_generator(): + async def stream_generator(): chat_response = "" - for text in llm.stream(new_input.query): + async for text in llm.astream(new_input.query, **parameters): chat_response += text chunk_repr = repr(text.encode("utf-8")) if logflag: @@ -105,7 +101,7 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]) return StreamingResponse(stream_generator(), media_type="text/event-stream") else: - response = llm.invoke(new_input.query) + response = llm.invoke(new_input.query, **parameters) if logflag: logger.info(response) @@ -131,23 +127,11 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]) # use rag default template prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents) - llm = VLLMOpenAI( - openai_api_key="EMPTY", - openai_api_base=llm_endpoint + "/v1", - max_tokens=input.max_tokens, - model_name=model_name, - top_p=input.top_p, - temperature=input.temperature, - frequency_penalty=input.frequency_penalty, - presence_penalty=input.presence_penalty, - streaming=input.streaming, - ) - if input.streaming: - def stream_generator(): + async def stream_generator(): chat_response = "" - for text in llm.stream(input.query): + async for text in llm.astream(input.query, **parameters): chat_response += text chunk_repr = repr(text.encode("utf-8")) if logflag: @@ -160,7 +144,7 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]) return StreamingResponse(stream_generator(), media_type="text/event-stream") else: - response = llm.invoke(input.query) + response = llm.invoke(input.query, **parameters) if logflag: logger.info(response)