Signed-off-by: BaoHuiling <huiling.bao@intel.com> Signed-off-by: chen, suyue <suyue.chen@intel.com>
208 lines
7.4 KiB
Python
208 lines
7.4 KiB
Python
# Copyright (C) 2024 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
from io import BytesIO
|
|
|
|
import requests
|
|
import streamlit as st
|
|
|
|
BACKEND_SERVICE_ENDPOINT = os.getenv("BACKEND_SERVICE_ENDPOINT", "http://localhost:8888/v1/videoqna")
|
|
BACKEND_HEALTH_CHECK_ENDPOINT = os.getenv("BACKEND_HEALTH_CHECK_ENDPOINT", "http://localhost:8888/v1/health_check")
|
|
|
|
|
|
def perform_health_check():
|
|
url = BACKEND_HEALTH_CHECK_ENDPOINT
|
|
response = requests.get(url, headers={"accept": "application/json"})
|
|
return response
|
|
|
|
|
|
def download_video(url):
|
|
"""Download video from URL and return as bytes."""
|
|
response = requests.get(url)
|
|
if response.status_code == 200:
|
|
return BytesIO(response.content)
|
|
else:
|
|
st.error(f"Failed to download video. Status code: {response.status_code}")
|
|
return None
|
|
|
|
|
|
def play_video(url, offset):
|
|
"""Play video from URL with specified offset."""
|
|
with st.spinner("Loading Video ..."):
|
|
video_bytes = download_video(url)
|
|
if video_bytes:
|
|
st.video(video_bytes, start_time=int(offset))
|
|
|
|
|
|
def clear_chat_history():
|
|
st.session_state.example_video = "Enter Text"
|
|
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
|
|
|
|
|
|
def handle_selectbox_change():
|
|
prompt = st.session_state.example_video
|
|
|
|
if prompt is not None:
|
|
st.session_state["prompt"] = prompt
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
|
|
|
def handle_chat_input():
|
|
print("st.session_state.custom_prompt update", st.session_state.custom_prompt)
|
|
prompt = st.session_state.custom_prompt
|
|
|
|
st.session_state["prompt"] = prompt
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
|
|
|
def handle_message(col):
|
|
params = None
|
|
full_response = ""
|
|
|
|
# Generate a new response if last message is not from assistant
|
|
if st.session_state.messages[-1]["role"] != "assistant":
|
|
# Handle user messages here
|
|
with st.chat_message("assistant"):
|
|
placeholder = st.empty()
|
|
start = time.time()
|
|
prompt = st.session_state["prompt"]
|
|
request_data = {"messages": prompt, "stream": "True"}
|
|
try:
|
|
response = requests.post(BACKEND_SERVICE_ENDPOINT, data=json.dumps(request_data), stream=True)
|
|
response.raise_for_status()
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
if chunk:
|
|
if params is None:
|
|
try:
|
|
chunk_str = chunk.decode("utf-8").replace("'", '"')
|
|
params = json.loads(chunk_str)
|
|
|
|
video_url = params["video_url"]
|
|
chunk_start = params["chunk_start"]
|
|
print("VIDEO NAME USED IN PLAYBACK: ", video_url)
|
|
|
|
video_name = video_url.split("/")[-1]
|
|
full_response += f"Most relevant retrieved video is **{video_name}** \n\n"
|
|
placeholder.markdown(full_response)
|
|
|
|
with col:
|
|
play_video(video_url, chunk_start)
|
|
|
|
except json.JSONDecodeError:
|
|
print("In the param decode error branch")
|
|
print(chunk.decode("utf-8"))
|
|
else:
|
|
new_text = chunk.decode("utf-8")
|
|
# print(new_text, end=" ", flush=True)
|
|
full_response += new_text
|
|
placeholder.markdown(full_response)
|
|
# Fake response
|
|
# video_url = "https://github.com/DAMO-NLP-SG/Video-LLaMA/raw/main/examples/silence_girl.mp4"
|
|
# chunk_start=0
|
|
# video_name = video_url.split('/')[-1]
|
|
# full_response += f"Most relevant retrieved video is **{video_name}** \n\n"
|
|
# placeholder.markdown(full_response)
|
|
# with col:
|
|
# play_video(video_url, chunk_start)
|
|
# for i in range(10):
|
|
# full_response += f"new_text {i} "
|
|
# time.sleep(1)
|
|
# placeholder.markdown(full_response)
|
|
|
|
except requests.HTTPError as http_err:
|
|
st.error(f"HTTP error occurred: {http_err}")
|
|
except requests.RequestException as req_err:
|
|
st.error(f"Error occurred: {req_err}")
|
|
except Exception as err:
|
|
st.error(f"An unexpected error occurred: {err}")
|
|
|
|
end = time.time()
|
|
full_response += f"\n\n🚀 Generated in {(end - start):.4f} seconds."
|
|
placeholder.markdown(full_response)
|
|
|
|
message = {"role": "assistant", "content": full_response}
|
|
|
|
st.session_state.messages.append(message)
|
|
|
|
|
|
def display_messages():
|
|
# Display chat messages
|
|
for message in st.session_state.messages:
|
|
with st.chat_message(message["role"]):
|
|
st.write(message["content"])
|
|
|
|
|
|
def main():
|
|
st.set_page_config(initial_sidebar_state="collapsed", layout="wide")
|
|
st.title("VideoQnA")
|
|
title_alignment = """
|
|
<style>
|
|
h1 {
|
|
text-align: center
|
|
}
|
|
|
|
video.stVideo {
|
|
width: 200px;
|
|
height: 500px;
|
|
}
|
|
</style>
|
|
"""
|
|
st.markdown(title_alignment, unsafe_allow_html=True)
|
|
st.sidebar.button("Clear Chat History", on_click=clear_chat_history)
|
|
|
|
placeholder = st.empty()
|
|
|
|
# check server health
|
|
if "health_check" not in st.session_state.keys():
|
|
with st.spinner("Checking health of the server..."):
|
|
time.sleep(1)
|
|
response = perform_health_check()
|
|
if response.status_code == 200:
|
|
placeholder.success("Server is healthy!", icon="✅")
|
|
time.sleep(1)
|
|
placeholder.empty() # Remove the message
|
|
st.session_state["health_check"] = True
|
|
else:
|
|
st.error(f"Server health check failed with status code {response.status_code}")
|
|
st.stop()
|
|
|
|
# Initialize conversation state
|
|
if "messages" not in st.session_state.keys():
|
|
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
|
|
if "prompt" not in st.session_state.keys():
|
|
st.session_state["prompt"] = ""
|
|
|
|
col1, col2 = st.columns([2, 1])
|
|
|
|
with col1:
|
|
st.selectbox(
|
|
"Example Prompts",
|
|
(
|
|
"Man wearing glasses",
|
|
"People reading item description",
|
|
"Man holding red shopping basket",
|
|
"Was there any person wearing a blue shirt seen today?",
|
|
"Was there any person wearing a blue shirt seen in the last 6 hours?",
|
|
"Was there any person wearing a blue shirt seen last Sunday?",
|
|
"Was a person wearing glasses seen in the last 30 minutes?",
|
|
"Was a person wearing glasses seen in the last 72 hours?",
|
|
),
|
|
key="example_video",
|
|
index=None,
|
|
placeholder="--- Options ---",
|
|
on_change=handle_selectbox_change,
|
|
)
|
|
|
|
st.chat_input(disabled=False, key="custom_prompt", on_submit=handle_chat_input)
|
|
|
|
with col1:
|
|
display_messages()
|
|
handle_message(col2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|