Files
GenAIExamples/comps/llms/summarization/tgi/langchain/llm.py
2025-01-06 13:25:47 +08:00

246 lines
8.2 KiB
Python

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
from fastapi.responses import StreamingResponse
from langchain.chains.summarize import load_summarize_chain
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
from langchain_community.llms import HuggingFaceEndpoint
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
from comps import CustomLogger, DocSumLLMParams, GeneratedDoc, ServiceType, opea_microservices, register_microservice
from comps.cores.mega.utils import get_access_token
logger = CustomLogger("llm_docsum")
logflag = os.getenv("LOGFLAG", False)
# Environment variables
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", 2048))
MAX_TOTAL_TOKENS = int(os.getenv("MAX_TOTAL_TOKENS", 4096))
LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "Intel/neural-chat-7b-v3-3")
templ_en = """Write a concise summary of the following:
"{text}"
CONCISE SUMMARY:"""
templ_zh = """请简要概括以下内容:
"{text}"
概况:"""
templ_refine_en = """Your job is to produce a final summary.
We have provided an existing summary up to a certain point, then we will provide more context.
You need to refine the existing summary (only if needed) with new context and generate a final summary.
Existing Summary:
"{existing_answer}"
New Context:
"{text}"
Final Summary:
"""
templ_refine_zh = """\
你的任务是生成一个最终摘要。
我们已经处理好部分文本并生成初始摘要, 并提供了新的未处理文本
你需要根据新提供的文本,结合初始摘要,生成一个最终摘要。
初始摘要:
"{existing_answer}"
新的文本:
"{text}"
最终摘要:
"""
@register_microservice(
name="opea_service@llm_docsum",
service_type=ServiceType.LLM,
endpoint="/v1/chat/docsum",
host="0.0.0.0",
port=9000,
)
async def llm_generate(input: DocSumLLMParams):
if logflag:
logger.info(input)
### check summary type
summary_types = ["auto", "stuff", "truncate", "map_reduce", "refine"]
if input.summary_type not in summary_types:
raise NotImplementedError(f"Please specify the summary_type in {summary_types}")
if input.summary_type == "auto": ### Check input token length in auto mode
token_len = len(tokenizer.encode(input.query))
if token_len > MAX_INPUT_TOKENS + 50:
input.summary_type = "refine"
if logflag:
logger.info(
f"Input token length {token_len} exceed MAX_INPUT_TOKENS + 50 {MAX_INPUT_TOKENS+50}, auto switch to 'refine' mode."
)
else:
input.summary_type = "stuff"
if logflag:
logger.info(
f"Input token length {token_len} not exceed MAX_INPUT_TOKENS + 50 {MAX_INPUT_TOKENS+50}, auto switch to 'stuff' mode."
)
if input.language in ["en", "auto"]:
templ = templ_en
templ_refine = templ_refine_en
elif input.language in ["zh"]:
templ = templ_zh
templ_refine = templ_refine_zh
else:
raise NotImplementedError('Please specify the input language in "en", "zh", "auto"')
## Prompt
PROMPT = PromptTemplate.from_template(templ)
if input.summary_type == "refine":
PROMPT_REFINE = PromptTemplate.from_template(templ_refine)
if logflag:
logger.info("After prompting:")
logger.info(PROMPT)
if input.summary_type == "refine":
logger.info(PROMPT_REFINE)
## Split text
if input.summary_type == "stuff":
text_splitter = CharacterTextSplitter()
else:
if input.summary_type == "refine":
if MAX_TOTAL_TOKENS <= 2 * input.max_tokens + 128:
raise RuntimeError("In Refine mode, Please set MAX_TOTAL_TOKENS larger than (max_tokens * 2 + 128)")
max_input_tokens = min(
MAX_TOTAL_TOKENS - 2 * input.max_tokens - 128, MAX_INPUT_TOKENS
) # 128 is reserved token length for prompt
else:
if MAX_TOTAL_TOKENS <= input.max_tokens + 50:
raise RuntimeError("Please set MAX_TOTAL_TOKENS larger than max_tokens + 50)")
max_input_tokens = min(
MAX_TOTAL_TOKENS - input.max_tokens - 50, MAX_INPUT_TOKENS
) # 50 is reserved token length for prompt
chunk_size = min(input.chunk_size, max_input_tokens) if input.chunk_size > 0 else max_input_tokens
chunk_overlap = input.chunk_overlap if input.chunk_overlap > 0 else int(0.1 * chunk_size)
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer=tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
if logflag:
logger.info(f"set chunk size to: {chunk_size}")
logger.info(f"set chunk overlap to: {chunk_overlap}")
texts = text_splitter.split_text(input.query)
docs = [Document(page_content=t) for t in texts]
if logflag:
logger.info(f"Split input query into {len(docs)} chunks")
logger.info(f"The character length of the first chunk is {len(texts[0])}")
## Access auth
access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
server_kwargs = {}
if access_token:
server_kwargs["headers"] = {"Authorization": f"Bearer {access_token}"}
## LLM
if input.stream and input.summary_type == "map_reduce":
logger.info("Map Reduce mode don't support stream=True, set to stream=False")
input.stream = False
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
max_new_tokens=input.max_tokens,
top_k=input.top_k,
top_p=input.top_p,
typical_p=input.typical_p,
temperature=input.temperature,
repetition_penalty=input.repetition_penalty,
streaming=input.stream,
server_kwargs=server_kwargs,
)
## LLM chain
summary_type = input.summary_type
if summary_type == "stuff":
llm_chain = load_summarize_chain(llm=llm, prompt=PROMPT)
elif summary_type == "truncate":
docs = [docs[0]]
llm_chain = load_summarize_chain(llm=llm, prompt=PROMPT)
elif summary_type == "map_reduce":
llm_chain = load_summarize_chain(
llm=llm, map_prompt=PROMPT, combine_prompt=PROMPT, chain_type="map_reduce", return_intermediate_steps=True
)
elif summary_type == "refine":
llm_chain = load_summarize_chain(
llm=llm,
question_prompt=PROMPT,
refine_prompt=PROMPT_REFINE,
chain_type="refine",
return_intermediate_steps=True,
)
else:
raise NotImplementedError('Please specify the summary_type in "stuff", "truncate", "map_reduce", "refine"')
if input.stream:
async def stream_generator():
from langserve.serialization import WellKnownLCSerializer
_serializer = WellKnownLCSerializer()
async for chunk in llm_chain.astream_log(docs):
data = _serializer.dumps({"ops": chunk.ops}).decode("utf-8")
if logflag:
logger.info(data)
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = await llm_chain.ainvoke(docs)
if input.summary_type in ["map_reduce", "refine"]:
intermediate_steps = response["intermediate_steps"]
if logflag:
logger.info("intermediate_steps:")
logger.info(intermediate_steps)
output_text = response["output_text"]
if logflag:
logger.info("\n\noutput_text:")
logger.info(output_text)
return GeneratedDoc(text=output_text, prompt=input.query)
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
opea_microservices["opea_service@llm_docsum"].start()