Files
GenAIExamples/AudioQnA/langchain/redis/rag_redis/chain.py
chen, suyue 509d5c66ca unify license copyright (#234)
Signed-off-by: chensuyue <suyue.chen@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-05-31 17:19:30 +08:00

77 lines
2.1 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.vectorstores import Redis
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL, TGI_LLM_ENDPOINT
# Make this look better in the docs.
class Question(BaseModel):
__root__: str
# Init Embeddings
embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
# Setup semantic cache for LLM
from langchain.cache import RedisSemanticCache
from langchain.globals import set_llm_cache
set_llm_cache(RedisSemanticCache(embedding=embedder, redis_url=REDIS_URL))
# Connect to pre-loaded vectorstore
# run the ingest.py script to populate this
vectorstore = Redis.from_existing_index(
embedding=embedder, index_name=INDEX_NAME, schema=INDEX_SCHEMA, redis_url=REDIS_URL
)
# TODO allow user to change parameters
retriever = vectorstore.as_retriever(search_type="mmr")
# Define our prompt
template = """
Use the following pieces of context from retrieved
dataset to answer the question. Do not make up an answer if there is no
context provided to help answer it. Include the 'source' and 'start_index'
from the metadata included in the context you used to answer the question
Context:
---------
{context}
---------
Question: {question}
---------
Answer:
"""
prompt = ChatPromptTemplate.from_template(template)
# RAG Chain
model = HuggingFaceEndpoint(
endpoint_url=TGI_LLM_ENDPOINT,
max_new_tokens=512,
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
streaming=True,
truncate=1024,
)
chain = (
RunnableParallel({"context": retriever, "question": RunnablePassthrough()}) | prompt | model | StrOutputParser()
).with_types(input_type=Question)