move examples gateway (#992)

Co-authored-by: root <root@idc708073.jf.intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sihan Chen <39623753+Spycsh@users.noreply.github.com>
This commit is contained in:
lkk
2024-12-06 14:40:25 +08:00
committed by GitHub
parent f5c08d4fbb
commit bde285dfce
17 changed files with 1236 additions and 113 deletions

View File

@@ -4,9 +4,11 @@
import asyncio import asyncio
import os import os
from comps import AudioQnAGateway, MicroService, ServiceOrchestrator, ServiceType from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse
from comps.cores.proto.docarray import LLMParams
from fastapi import Request
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0") ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0")
ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 9099)) ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 9099))
@@ -16,7 +18,7 @@ TTS_SERVICE_HOST_IP = os.getenv("TTS_SERVICE_HOST_IP", "0.0.0.0")
TTS_SERVICE_PORT = int(os.getenv("TTS_SERVICE_PORT", 9088)) TTS_SERVICE_PORT = int(os.getenv("TTS_SERVICE_PORT", 9088))
class AudioQnAService: class AudioQnAService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -50,9 +52,43 @@ class AudioQnAService:
self.megaservice.add(asr).add(llm).add(tts) self.megaservice.add(asr).add(llm).add(tts)
self.megaservice.flow_to(asr, llm) self.megaservice.flow_to(asr, llm)
self.megaservice.flow_to(llm, tts) self.megaservice.flow_to(llm, tts)
self.gateway = AudioQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
chat_request = AudioChatCompletionRequest.parse_obj(data)
parameters = LLMParams(
# relatively lower max_tokens for audio conversation
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=False, # TODO add streaming LLM output as input to TTS
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters
)
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["byte_str"]
return response
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.AUDIO_QNA),
input_datatype=AudioChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
audioqna = AudioQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) audioqna = AudioQnAService(port=MEGA_SERVICE_PORT)
audioqna.add_remote_service() audioqna.add_remote_service()
audioqna.start()

View File

@@ -5,9 +5,11 @@ import asyncio
import base64 import base64
import os import os
from comps import AudioQnAGateway, MicroService, ServiceOrchestrator, ServiceType from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse
from comps.cores.proto.docarray import LLMParams
from fastapi import Request
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
WHISPER_SERVER_HOST_IP = os.getenv("WHISPER_SERVER_HOST_IP", "0.0.0.0") WHISPER_SERVER_HOST_IP = os.getenv("WHISPER_SERVER_HOST_IP", "0.0.0.0")
@@ -52,7 +54,7 @@ def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_di
return data return data
class AudioQnAService: class AudioQnAService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -90,9 +92,43 @@ class AudioQnAService:
self.megaservice.add(asr).add(llm).add(tts) self.megaservice.add(asr).add(llm).add(tts)
self.megaservice.flow_to(asr, llm) self.megaservice.flow_to(asr, llm)
self.megaservice.flow_to(llm, tts) self.megaservice.flow_to(llm, tts)
self.gateway = AudioQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
chat_request = AudioChatCompletionRequest.parse_obj(data)
parameters = LLMParams(
# relatively lower max_tokens for audio conversation
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=False, # TODO add streaming LLM output as input to TTS
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters
)
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["byte_str"]
return response
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.AUDIO_QNA),
input_datatype=AudioChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
audioqna = AudioQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) audioqna = AudioQnAService(port=MEGA_SERVICE_PORT)
audioqna.add_remote_service() audioqna.add_remote_service()
audioqna.start()

View File

@@ -5,9 +5,11 @@ import asyncio
import os import os
import sys import sys
from comps import AvatarChatbotGateway, MicroService, ServiceOrchestrator, ServiceType from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse
from comps.cores.proto.docarray import LLMParams
from fastapi import Request
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0") ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0")
ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 9099)) ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 9099))
@@ -27,7 +29,7 @@ def check_env_vars(env_var_list):
print("All environment variables are set.") print("All environment variables are set.")
class AvatarChatbotService: class AvatarChatbotService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -70,7 +72,39 @@ class AvatarChatbotService:
self.megaservice.flow_to(asr, llm) self.megaservice.flow_to(asr, llm)
self.megaservice.flow_to(llm, tts) self.megaservice.flow_to(llm, tts)
self.megaservice.flow_to(tts, animation) self.megaservice.flow_to(tts, animation)
self.gateway = AvatarChatbotGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
chat_request = AudioChatCompletionRequest.model_validate(data)
parameters = LLMParams(
# relatively lower max_tokens for audio conversation
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=False, # TODO add streaming LLM output as input to TTS
)
# print(parameters)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters
)
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["video_path"]
return response
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.AVATAR_CHATBOT),
input_datatype=AudioChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
@@ -89,5 +123,6 @@ if __name__ == "__main__":
] ]
) )
avatarchatbot = AvatarChatbotService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) avatarchatbot = AvatarChatbotService(port=MEGA_SERVICE_PORT)
avatarchatbot.add_remote_service() avatarchatbot.add_remote_service()
avatarchatbot.start()

View File

@@ -6,7 +6,17 @@ import json
import os import os
import re import re
from comps import ChatQnAGateway, MicroService, ServiceOrchestrator, ServiceType from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms
from fastapi import Request
from fastapi.responses import StreamingResponse
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
@@ -35,7 +45,6 @@ If you don't know the answer to a question, please don't share false information
return template.format(context=context_str, question=question) return template.format(context=context_str, question=question)
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
GUARDRAIL_SERVICE_HOST_IP = os.getenv("GUARDRAIL_SERVICE_HOST_IP", "0.0.0.0") GUARDRAIL_SERVICE_HOST_IP = os.getenv("GUARDRAIL_SERVICE_HOST_IP", "0.0.0.0")
GUARDRAIL_SERVICE_PORT = int(os.getenv("GUARDRAIL_SERVICE_PORT", 80)) GUARDRAIL_SERVICE_PORT = int(os.getenv("GUARDRAIL_SERVICE_PORT", 80))
@@ -178,13 +187,14 @@ def align_generator(self, gen, **kwargs):
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
class ChatQnAService: class ChatQnAService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
ServiceOrchestrator.align_inputs = align_inputs ServiceOrchestrator.align_inputs = align_inputs
ServiceOrchestrator.align_outputs = align_outputs ServiceOrchestrator.align_outputs = align_outputs
ServiceOrchestrator.align_generator = align_generator ServiceOrchestrator.align_generator = align_generator
self.megaservice = ServiceOrchestrator() self.megaservice = ServiceOrchestrator()
def add_remote_service(self): def add_remote_service(self):
@@ -228,7 +238,6 @@ class ChatQnAService:
self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(embedding, retriever)
self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(retriever, rerank)
self.megaservice.flow_to(rerank, llm) self.megaservice.flow_to(rerank, llm)
self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
def add_remote_service_without_rerank(self): def add_remote_service_without_rerank(self):
@@ -261,7 +270,6 @@ class ChatQnAService:
self.megaservice.add(embedding).add(retriever).add(llm) self.megaservice.add(embedding).add(retriever).add(llm)
self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(embedding, retriever)
self.megaservice.flow_to(retriever, llm) self.megaservice.flow_to(retriever, llm)
self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
def add_remote_service_with_guardrails(self): def add_remote_service_with_guardrails(self):
guardrail_in = MicroService( guardrail_in = MicroService(
@@ -319,7 +327,66 @@ class ChatQnAService:
self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(retriever, rerank)
self.megaservice.flow_to(rerank, llm) self.megaservice.flow_to(rerank, llm)
# self.megaservice.flow_to(llm, guardrail_out) # self.megaservice.flow_to(llm, guardrail_out)
self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
retriever_parameters = RetrieverParms(
search_type=chat_request.search_type if chat_request.search_type else "similarity",
k=chat_request.k if chat_request.k else 4,
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
)
reranker_parameters = RerankerParms(
top_n=chat_request.top_n if chat_request.top_n else 1,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt},
llm_parameters=parameters,
retriever_parameters=retriever_parameters,
reranker_parameters=reranker_parameters,
)
for node, response in result_dict.items():
if isinstance(response, StreamingResponse):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.CHAT_QNA),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
@@ -329,10 +396,12 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
chatqna = ChatQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) chatqna = ChatQnAService(port=MEGA_SERVICE_PORT)
if args.without_rerank: if args.without_rerank:
chatqna.add_remote_service_without_rerank() chatqna.add_remote_service_without_rerank()
elif args.with_guardrails: elif args.with_guardrails:
chatqna.add_remote_service_with_guardrails() chatqna.add_remote_service_with_guardrails()
else: else:
chatqna.add_remote_service() chatqna.add_remote_service()
chatqna.start()

View File

@@ -3,7 +3,17 @@
import os import os
from comps import ChatQnAGateway, MicroService, ServiceOrchestrator, ServiceType from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms
from fastapi import Request
from fastapi.responses import StreamingResponse
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
@@ -17,7 +27,7 @@ LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0")
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))
class ChatQnAService: class ChatQnAService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -60,9 +70,69 @@ class ChatQnAService:
self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(embedding, retriever)
self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(retriever, rerank)
self.megaservice.flow_to(rerank, llm) self.megaservice.flow_to(rerank, llm)
self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
retriever_parameters = RetrieverParms(
search_type=chat_request.search_type if chat_request.search_type else "similarity",
k=chat_request.k if chat_request.k else 4,
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
)
reranker_parameters = RerankerParms(
top_n=chat_request.top_n if chat_request.top_n else 1,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt},
llm_parameters=parameters,
retriever_parameters=retriever_parameters,
reranker_parameters=reranker_parameters,
)
for node, response in result_dict.items():
if isinstance(response, StreamingResponse):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.CHAT_QNA),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
chatqna = ChatQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) chatqna = ChatQnAService(port=MEGA_SERVICE_PORT)
chatqna.add_remote_service() chatqna.add_remote_service()
chatqna.start()

View File

@@ -4,15 +4,24 @@
import asyncio import asyncio
import os import os
from comps import CodeGenGateway, MicroService, ServiceOrchestrator, ServiceType from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams
from fastapi import Request
from fastapi.responses import StreamingResponse
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7778)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7778))
LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0")
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))
class CodeGenService: class CodeGenService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -28,9 +37,58 @@ class CodeGenService:
service_type=ServiceType.LLM, service_type=ServiceType.LLM,
) )
self.megaservice.add(llm) self.megaservice.add(llm)
self.gateway = CodeGenGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"query": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="codegen", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.CODE_GEN),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
chatqna = CodeGenService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) chatqna = CodeGenService(port=MEGA_SERVICE_PORT)
chatqna.add_remote_service() chatqna.add_remote_service()
chatqna.start()

View File

@@ -4,15 +4,23 @@
import asyncio import asyncio
import os import os
from comps import CodeTransGateway, MicroService, ServiceOrchestrator from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from fastapi import Request
from fastapi.responses import StreamingResponse
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7777)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7777))
LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0")
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))
class CodeTransService: class CodeTransService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -27,9 +35,59 @@ class CodeTransService:
use_remote_service=True, use_remote_service=True,
) )
self.megaservice.add(llm) self.megaservice.add(llm)
self.gateway = CodeTransGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
language_from = data["language_from"]
language_to = data["language_to"]
source_code = data["source_code"]
prompt_template = """
### System: Please translate the following {language_from} codes into {language_to} codes.
### Original codes:
'''{language_from}
{source_code}
'''
### Translated codes:
"""
prompt = prompt_template.format(language_from=language_from, language_to=language_to, source_code=source_code)
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"query": prompt})
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="codetrans", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.CODE_TRANS),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
service_ochestrator = CodeTransService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) service_ochestrator = CodeTransService(port=MEGA_SERVICE_PORT)
service_ochestrator.add_remote_service() service_ochestrator.add_remote_service()
service_ochestrator.start()

View File

@@ -3,10 +3,14 @@
import asyncio import asyncio
import os import os
from typing import Union
from comps import MicroService, RetrievalToolGateway, ServiceOrchestrator, ServiceType from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest
from comps.cores.proto.docarray import LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
from fastapi import Request
from fastapi.responses import StreamingResponse
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = os.getenv("MEGA_SERVICE_PORT", 8889) MEGA_SERVICE_PORT = os.getenv("MEGA_SERVICE_PORT", 8889)
EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0")
EMBEDDING_SERVICE_PORT = os.getenv("EMBEDDING_SERVICE_PORT", 6000) EMBEDDING_SERVICE_PORT = os.getenv("EMBEDDING_SERVICE_PORT", 6000)
@@ -16,7 +20,7 @@ RERANK_SERVICE_HOST_IP = os.getenv("RERANK_SERVICE_HOST_IP", "0.0.0.0")
RERANK_SERVICE_PORT = os.getenv("RERANK_SERVICE_PORT", 8000) RERANK_SERVICE_PORT = os.getenv("RERANK_SERVICE_PORT", 8000)
class RetrievalToolService: class RetrievalToolService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -51,9 +55,77 @@ class RetrievalToolService:
self.megaservice.add(embedding).add(retriever).add(rerank) self.megaservice.add(embedding).add(retriever).add(rerank)
self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(embedding, retriever)
self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(retriever, rerank)
self.gateway = RetrievalToolGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
def parser_input(data, TypeClass, key):
chat_request = None
try:
chat_request = TypeClass.parse_obj(data)
query = getattr(chat_request, key)
except:
query = None
return query, chat_request
data = await request.json()
query = None
for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
query, chat_request = parser_input(data, TypeClass, key)
if query is not None:
break
if query is None:
raise ValueError(f"Unknown request type: {data}")
if chat_request is None:
raise ValueError(f"Unknown request type: {data}")
if isinstance(chat_request, ChatCompletionRequest):
retriever_parameters = RetrieverParms(
search_type=chat_request.search_type if chat_request.search_type else "similarity",
k=chat_request.k if chat_request.k else 4,
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
)
reranker_parameters = RerankerParms(
top_n=chat_request.top_n if chat_request.top_n else 1,
)
initial_inputs = {
"messages": query,
"input": query, # has to be input due to embedding expects either input or text
"search_type": chat_request.search_type if chat_request.search_type else "similarity",
"k": chat_request.k if chat_request.k else 4,
"distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None,
"fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20,
"lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
"score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2,
"top_n": chat_request.top_n if chat_request.top_n else 1,
}
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs=initial_inputs,
retriever_parameters=retriever_parameters,
reranker_parameters=reranker_parameters,
)
else:
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query})
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]
return response
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.RETRIEVALTOOL),
input_datatype=Union[TextDoc, EmbeddingRequest, ChatCompletionRequest],
output_datatype=Union[RerankedDoc, LLMParamsDoc],
)
if __name__ == "__main__": if __name__ == "__main__":
chatqna = RetrievalToolService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) chatqna = RetrievalToolService(port=MEGA_SERVICE_PORT)
chatqna.add_remote_service() chatqna.add_remote_service()
chatqna.start()

View File

@@ -3,10 +3,21 @@
import asyncio import asyncio
import os import os
from typing import List
from comps import DocSumGateway, MicroService, ServiceOrchestrator, ServiceType from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.mega.gateway import read_text_from_file
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams
from fastapi import File, Request, UploadFile
from fastapi.responses import StreamingResponse
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
DATA_SERVICE_HOST_IP = os.getenv("DATA_SERVICE_HOST_IP", "0.0.0.0") DATA_SERVICE_HOST_IP = os.getenv("DATA_SERVICE_HOST_IP", "0.0.0.0")
@@ -16,7 +27,7 @@ LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0")
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))
class DocSumService: class DocSumService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -41,12 +52,114 @@ class DocSumService:
use_remote_service=True, use_remote_service=True,
service_type=ServiceType.LLM, service_type=ServiceType.LLM,
) )
self.megaservice.add(llm)
self.megaservice.add(data).add(llm) async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)):
self.megaservice.flow_to(data, llm)
self.gateway = DocSumGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) if "application/json" in request.headers.get("content-type"):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.model_validate(data)
prompt = self._handle_message(chat_request.messages)
initial_inputs_data = {data["type"]: prompt}
elif "multipart/form-data" in request.headers.get("content-type"):
data = await request.form()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.model_validate(data)
data_type = data.get("type")
file_summaries = []
if files:
for file in files:
file_path = f"/tmp/{file.filename}"
if data_type is not None and data_type in ["audio", "video"]:
raise ValueError(
"Audio and Video file uploads are not supported in docsum with curl request, please use the UI."
)
else:
import aiofiles
async with aiofiles.open(file_path, "wb") as f:
await f.write(await file.read())
docs = read_text_from_file(file, file_path)
os.remove(file_path)
if isinstance(docs, list):
file_summaries.extend(docs)
else:
file_summaries.append(docs)
if file_summaries:
prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries)
else:
prompt = self._handle_message(chat_request.messages)
data_type = data.get("type")
if data_type is not None:
initial_inputs_data = {}
initial_inputs_data[data_type] = prompt
else:
initial_inputs_data = {"query": prompt}
else:
raise ValueError(f"Unknown request type: {request.headers.get('content-type')}")
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
model=chat_request.model if chat_request.model else None,
language=chat_request.language if chat_request.language else "auto",
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs=initial_inputs_data, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="docsum", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.DOC_SUMMARY),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
docsum = DocSumService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) docsum = DocSumService(port=MEGA_SERVICE_PORT)
docsum.add_remote_service() docsum.add_remote_service()
docsum.start()

View File

@@ -5,7 +5,6 @@ import os
from comps import MicroService, ServiceOrchestrator, ServiceType from comps import MicroService, ServiceOrchestrator, ServiceType
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "127.0.0.1")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 16011)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 16011))
PIPELINE_SERVICE_HOST_IP = os.getenv("PIPELINE_SERVICE_HOST_IP", "127.0.0.1") PIPELINE_SERVICE_HOST_IP = os.getenv("PIPELINE_SERVICE_HOST_IP", "127.0.0.1")
PIPELINE_SERVICE_PORT = int(os.getenv("PIPELINE_SERVICE_PORT", 16010)) PIPELINE_SERVICE_PORT = int(os.getenv("PIPELINE_SERVICE_PORT", 16010))
@@ -23,11 +22,22 @@ from fastapi import Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
class EdgeCraftRagGateway(Gateway): class EdgeCraftRagService(Gateway):
def __init__(self, megaservice, host="0.0.0.0", port=16011): def __init__(self, host="0.0.0.0", port=16010):
super().__init__( self.host = host
megaservice, host, port, str(MegaServiceEndpoint.CHAT_QNA), ChatCompletionRequest, ChatCompletionResponse self.port = port
self.megaservice = ServiceOrchestrator()
def add_remote_service(self):
edgecraftrag = MicroService(
name="pipeline",
host=PIPELINE_SERVICE_HOST_IP,
port=PIPELINE_SERVICE_PORT,
endpoint="/v1/chatqna",
use_remote_service=True,
service_type=ServiceType.LLM,
) )
self.megaservice.add(edgecraftrag)
async def handle_request(self, request: Request): async def handle_request(self, request: Request):
input = await request.json() input = await request.json()
@@ -61,26 +71,18 @@ class EdgeCraftRagGateway(Gateway):
) )
return ChatCompletionResponse(model="edgecraftrag", choices=choices, usage=usage) return ChatCompletionResponse(model="edgecraftrag", choices=choices, usage=usage)
def start(self):
class EdgeCraftRagService: super().__init__(
def __init__(self, host="0.0.0.0", port=16010): megaservice=self.megaservice,
self.host = host host=self.host,
self.port = port port=self.port,
self.megaservice = ServiceOrchestrator() endpoint=str(MegaServiceEndpoint.CHAT_QNA),
input_datatype=ChatCompletionRequest,
def add_remote_service(self): output_datatype=ChatCompletionResponse,
edgecraftrag = MicroService(
name="pipeline",
host=PIPELINE_SERVICE_HOST_IP,
port=PIPELINE_SERVICE_PORT,
endpoint="/v1/chatqna",
use_remote_service=True,
service_type=ServiceType.LLM,
) )
self.megaservice.add(edgecraftrag)
self.gateway = EdgeCraftRagGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
if __name__ == "__main__": if __name__ == "__main__":
edgecraftrag = EdgeCraftRagService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) edgecraftrag = EdgeCraftRagService(port=MEGA_SERVICE_PORT)
edgecraftrag.add_remote_service() edgecraftrag.add_remote_service()
edgecraftrag.start()

View File

@@ -3,16 +3,27 @@
import asyncio import asyncio
import os import os
from typing import List
from comps import FaqGenGateway, MicroService, ServiceOrchestrator, ServiceType from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.mega.gateway import read_text_from_file
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams
from fastapi import File, Request, UploadFile
from fastapi.responses import StreamingResponse
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0")
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))
class FaqGenService: class FaqGenService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -28,9 +39,79 @@ class FaqGenService:
service_type=ServiceType.LLM, service_type=ServiceType.LLM,
) )
self.megaservice.add(llm) self.megaservice.add(llm)
self.gateway = FaqGenGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)):
data = await request.form()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
file_summaries = []
if files:
for file in files:
file_path = f"/tmp/{file.filename}"
import aiofiles
async with aiofiles.open(file_path, "wb") as f:
await f.write(await file.read())
docs = read_text_from_file(file, file_path)
os.remove(file_path)
if isinstance(docs, list):
file_summaries.extend(docs)
else:
file_summaries.append(docs)
if file_summaries:
prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries)
else:
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
model=chat_request.model if chat_request.model else None,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"query": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="faqgen", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.FAQ_GEN),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
faqgen = FaqGenService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) faqgen = FaqGenService(port=MEGA_SERVICE_PORT)
faqgen.add_remote_service() faqgen.add_remote_service()
faqgen.start()

View File

@@ -6,7 +6,18 @@ import json
import os import os
import re import re
from comps import GraphragGateway, MicroService, ServiceOrchestrator, ServiceType from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
EmbeddingRequest,
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams, RetrieverParms, TextDoc
from fastapi import Request
from fastapi.responses import StreamingResponse
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
@@ -35,7 +46,6 @@ If you don't know the answer to a question, please don't share false information
return template.format(context=context_str, question=question) return template.format(context=context_str, question=question)
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
RETRIEVER_SERVICE_HOST_IP = os.getenv("RETRIEVER_SERVICE_HOST_IP", "0.0.0.0") RETRIEVER_SERVICE_HOST_IP = os.getenv("RETRIEVER_SERVICE_HOST_IP", "0.0.0.0")
RETRIEVER_SERVICE_PORT = int(os.getenv("RETRIEVER_SERVICE_PORT", 7000)) RETRIEVER_SERVICE_PORT = int(os.getenv("RETRIEVER_SERVICE_PORT", 7000))
@@ -117,7 +127,7 @@ def align_generator(self, gen, **kwargs):
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
class GraphRAGService: class GraphRAGService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -146,9 +156,84 @@ class GraphRAGService:
) )
self.megaservice.add(retriever).add(llm) self.megaservice.add(retriever).add(llm)
self.megaservice.flow_to(retriever, llm) self.megaservice.flow_to(retriever, llm)
self.gateway = GraphragGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
def parser_input(data, TypeClass, key):
chat_request = None
try:
chat_request = TypeClass.parse_obj(data)
query = getattr(chat_request, key)
except:
query = None
return query, chat_request
query = None
for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
query, chat_request = parser_input(data, TypeClass, key)
if query is not None:
break
if query is None:
raise ValueError(f"Unknown request type: {data}")
if chat_request is None:
raise ValueError(f"Unknown request type: {data}")
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
retriever_parameters = RetrieverParms(
search_type=chat_request.search_type if chat_request.search_type else "similarity",
k=chat_request.k if chat_request.k else 4,
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
)
initial_inputs = chat_request
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs=initial_inputs,
llm_parameters=parameters,
retriever_parameters=retriever_parameters,
)
for node, response in result_dict.items():
if isinstance(response, StreamingResponse):
return response
last_node = runtime_graph.all_leaves()[-1]
response_content = result_dict[last_node]["choices"][0]["message"]["content"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response_content),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.GRAPH_RAG),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
graphrag = GraphRAGService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) graphrag = GraphRAGService(port=MEGA_SERVICE_PORT)
graphrag.add_remote_service() graphrag.add_remote_service()
graphrag.start()

View File

@@ -1,11 +1,24 @@
# Copyright (C) 2024 Intel Corporation # Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import base64
import os import os
from io import BytesIO
from comps import MicroService, MultimodalQnAGateway, ServiceOrchestrator, ServiceType import requests
from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams
from fastapi import Request
from fastapi.responses import StreamingResponse
from PIL import Image
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
MM_EMBEDDING_SERVICE_HOST_IP = os.getenv("MM_EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") MM_EMBEDDING_SERVICE_HOST_IP = os.getenv("MM_EMBEDDING_SERVICE_HOST_IP", "0.0.0.0")
MM_EMBEDDING_PORT_MICROSERVICE = int(os.getenv("MM_EMBEDDING_PORT_MICROSERVICE", 6000)) MM_EMBEDDING_PORT_MICROSERVICE = int(os.getenv("MM_EMBEDDING_PORT_MICROSERVICE", 6000))
@@ -15,12 +28,12 @@ LVM_SERVICE_HOST_IP = os.getenv("LVM_SERVICE_HOST_IP", "0.0.0.0")
LVM_SERVICE_PORT = int(os.getenv("LVM_SERVICE_PORT", 9399)) LVM_SERVICE_PORT = int(os.getenv("LVM_SERVICE_PORT", 9399))
class MultimodalQnAService: class MultimodalQnAService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
self.mmrag_megaservice = ServiceOrchestrator()
self.lvm_megaservice = ServiceOrchestrator() self.lvm_megaservice = ServiceOrchestrator()
self.megaservice = ServiceOrchestrator()
def add_remote_service(self): def add_remote_service(self):
mm_embedding = MicroService( mm_embedding = MicroService(
@@ -50,21 +63,186 @@ class MultimodalQnAService:
) )
# for mmrag megaservice # for mmrag megaservice
self.mmrag_megaservice.add(mm_embedding).add(mm_retriever).add(lvm) self.megaservice.add(mm_embedding).add(mm_retriever).add(lvm)
self.mmrag_megaservice.flow_to(mm_embedding, mm_retriever) self.megaservice.flow_to(mm_embedding, mm_retriever)
self.mmrag_megaservice.flow_to(mm_retriever, lvm) self.megaservice.flow_to(mm_retriever, lvm)
# for lvm megaservice # for lvm megaservice
self.lvm_megaservice.add(lvm) self.lvm_megaservice.add(lvm)
self.gateway = MultimodalQnAGateway( # this overrides _handle_message method of Gateway
multimodal_rag_megaservice=self.mmrag_megaservice, def _handle_message(self, messages):
lvm_megaservice=self.lvm_megaservice, images = []
host="0.0.0.0", messages_dicts = []
if isinstance(messages, str):
prompt = messages
else:
messages_dict = {}
system_prompt = ""
prompt = ""
for message in messages:
msg_role = message["role"]
messages_dict = {}
if msg_role == "system":
system_prompt = message["content"]
elif msg_role == "user":
if type(message["content"]) == list:
text = ""
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
text += "\n".join(text_list)
image_list = [
item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url"
]
if image_list:
messages_dict[msg_role] = (text, image_list)
else:
messages_dict[msg_role] = text
else:
messages_dict[msg_role] = message["content"]
messages_dicts.append(messages_dict)
elif msg_role == "assistant":
messages_dict[msg_role] = message["content"]
messages_dicts.append(messages_dict)
else:
raise ValueError(f"Unknown role: {msg_role}")
if system_prompt:
prompt = system_prompt + "\n"
for messages_dict in messages_dicts:
for i, (role, message) in enumerate(messages_dict.items()):
if isinstance(message, tuple):
text, image_list = message
if i == 0:
# do not add role for the very first message.
# this will be added by llava_server
if text:
prompt += text + "\n"
else:
if text:
prompt += role.upper() + ": " + text + "\n"
else:
prompt += role.upper() + ":"
for img in image_list:
# URL
if img.startswith("http://") or img.startswith("https://"):
response = requests.get(img)
image = Image.open(BytesIO(response.content)).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Local Path
elif os.path.exists(img):
image = Image.open(img).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Bytes
else:
img_b64_str = img
images.append(img_b64_str)
else:
if i == 0:
# do not add role for the very first message.
# this will be added by llava_server
if message:
prompt += role.upper() + ": " + message + "\n"
else:
if message:
prompt += role.upper() + ": " + message + "\n"
else:
prompt += role.upper() + ":"
if images:
return prompt, images
else:
return prompt
async def handle_request(self, request: Request):
data = await request.json()
stream_opt = bool(data.get("stream", False))
if stream_opt:
print("[ MultimodalQnAService ] stream=True not used, this has not support streaming yet!")
stream_opt = False
chat_request = ChatCompletionRequest.model_validate(data)
# Multimodal RAG QnA With Videos has not yet accepts image as input during QnA.
prompt_and_image = self._handle_message(chat_request.messages)
if isinstance(prompt_and_image, tuple):
# print(f"This request include image, thus it is a follow-up query. Using lvm megaservice")
prompt, images = prompt_and_image
cur_megaservice = self.lvm_megaservice
initial_inputs = {"prompt": prompt, "image": images[0]}
else:
# print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice")
prompt = prompt_and_image
cur_megaservice = self.megaservice
initial_inputs = {"text": prompt}
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
result_dict, runtime_graph = await cur_megaservice.schedule(
initial_inputs=initial_inputs, llm_parameters=parameters
)
for node, response in result_dict.items():
# the last microservice in this megaservice is LVM.
# checking if LVM returns StreamingResponse
# Currently, LVM with LLAVA has not yet supported streaming.
# @TODO: Will need to test this once LVM with LLAVA supports streaming
if (
isinstance(response, StreamingResponse)
and node == runtime_graph.all_leaves()[-1]
and self.megaservice.services[node].service_type == ServiceType.LVM
):
return response
last_node = runtime_graph.all_leaves()[-1]
if "text" in result_dict[last_node].keys():
response = result_dict[last_node]["text"]
else:
# text in not response message
# something wrong, for example due to empty retrieval results
if "detail" in result_dict[last_node].keys():
response = result_dict[last_node]["detail"]
else:
response = "The server fail to generate answer to your query!"
if "metadata" in result_dict[last_node].keys():
# from retrieval results
metadata = result_dict[last_node]["metadata"]
else:
# follow-up question, no retrieval
metadata = None
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
metadata=metadata,
)
)
return ChatCompletionResponse(model="multimodalqna", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port, port=self.port,
endpoint=str(MegaServiceEndpoint.MULTIMODAL_QNA),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
) )
if __name__ == "__main__": if __name__ == "__main__":
mmragwithvideos = MultimodalQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) mmragwithvideos = MultimodalQnAService(port=MEGA_SERVICE_PORT)
mmragwithvideos.add_remote_service() mmragwithvideos.add_remote_service()
mmragwithvideos.start()

View File

@@ -3,9 +3,18 @@
import os import os
from comps import MicroService, SearchQnAGateway, ServiceOrchestrator, ServiceType from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams
from fastapi import Request
from fastapi.responses import StreamingResponse
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0")
EMBEDDING_SERVICE_PORT = int(os.getenv("EMBEDDING_SERVICE_PORT", 6000)) EMBEDDING_SERVICE_PORT = int(os.getenv("EMBEDDING_SERVICE_PORT", 6000))
@@ -17,7 +26,7 @@ LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0")
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))
class SearchQnAService: class SearchQnAService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -60,9 +69,58 @@ class SearchQnAService:
self.megaservice.flow_to(embedding, web_retriever) self.megaservice.flow_to(embedding, web_retriever)
self.megaservice.flow_to(web_retriever, rerank) self.megaservice.flow_to(web_retriever, rerank)
self.megaservice.flow_to(rerank, llm) self.megaservice.flow_to(rerank, llm)
self.gateway = SearchQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="searchqna", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.SEARCH_QNA),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
searchqna = SearchQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) searchqna = SearchQnAService(port=MEGA_SERVICE_PORT)
searchqna.add_remote_service() searchqna.add_remote_service()
searchqna.start()

View File

@@ -15,15 +15,23 @@
import asyncio import asyncio
import os import os
from comps import MicroService, ServiceOrchestrator, ServiceType, TranslationGateway from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from fastapi import Request
from fastapi.responses import StreamingResponse
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0")
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))
class TranslationService: class TranslationService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -39,9 +47,57 @@ class TranslationService:
service_type=ServiceType.LLM, service_type=ServiceType.LLM,
) )
self.megaservice.add(llm) self.megaservice.add(llm)
self.gateway = TranslationGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
language_from = data["language_from"]
language_to = data["language_to"]
source_language = data["source_language"]
prompt_template = """
Translate this from {language_from} to {language_to}:
{language_from}:
{source_language}
{language_to}:
"""
prompt = prompt_template.format(
language_from=language_from, language_to=language_to, source_language=source_language
)
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"query": prompt})
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="translation", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.TRANSLATION),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
translation = TranslationService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) translation = TranslationService(port=MEGA_SERVICE_PORT)
translation.add_remote_service() translation.add_remote_service()
translation.start()

View File

@@ -3,9 +3,18 @@
import os import os
from comps import MicroService, ServiceOrchestrator, ServiceType, VideoQnAGateway from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams
from fastapi import Request
from fastapi.responses import StreamingResponse
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0")
EMBEDDING_SERVICE_PORT = int(os.getenv("EMBEDDING_SERVICE_PORT", 6000)) EMBEDDING_SERVICE_PORT = int(os.getenv("EMBEDDING_SERVICE_PORT", 6000))
@@ -17,7 +26,7 @@ LVM_SERVICE_HOST_IP = os.getenv("LVM_SERVICE_HOST_IP", "0.0.0.0")
LVM_SERVICE_PORT = int(os.getenv("LVM_SERVICE_PORT", 9000)) LVM_SERVICE_PORT = int(os.getenv("LVM_SERVICE_PORT", 9000))
class VideoQnAService: class VideoQnAService(Gateway):
def __init__(self, host="0.0.0.0", port=8888): def __init__(self, host="0.0.0.0", port=8888):
self.host = host self.host = host
self.port = port self.port = port
@@ -60,9 +69,58 @@ class VideoQnAService:
self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(embedding, retriever)
self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(retriever, rerank)
self.megaservice.flow_to(rerank, lvm) self.megaservice.flow_to(rerank, lvm)
self.gateway = VideoQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", False)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LVM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LVM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="videoqna", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.VIDEO_RAG_QNA),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
videoqna = VideoQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) videoqna = VideoQnAService(port=MEGA_SERVICE_PORT)
videoqna.add_remote_service() videoqna.add_remote_service()
videoqna.start()

View File

@@ -3,15 +3,24 @@
import os import os
from comps import MicroService, ServiceOrchestrator, ServiceType, VisualQnAGateway from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams
from fastapi import Request
from fastapi.responses import StreamingResponse
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
LVM_SERVICE_HOST_IP = os.getenv("LVM_SERVICE_HOST_IP", "0.0.0.0") LVM_SERVICE_HOST_IP = os.getenv("LVM_SERVICE_HOST_IP", "0.0.0.0")
LVM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9399)) LVM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9399))
class VisualQnAService: class VisualQnAService(Gateway):
def __init__(self, host="0.0.0.0", port=8000): def __init__(self, host="0.0.0.0", port=8000):
self.host = host self.host = host
self.port = port self.port = port
@@ -27,9 +36,58 @@ class VisualQnAService:
service_type=ServiceType.LVM, service_type=ServiceType.LVM,
) )
self.megaservice.add(llm) self.megaservice.add(llm)
self.gateway = VisualQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", False)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt, images = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"prompt": prompt, "image": images[0]}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LVM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LVM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage)
def start(self):
super().__init__(
megaservice=self.megaservice,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.VISUAL_QNA),
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
if __name__ == "__main__": if __name__ == "__main__":
visualqna = VisualQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) visualqna = VisualQnAService(port=MEGA_SERVICE_PORT)
visualqna.add_remote_service() visualqna.add_remote_service()
visualqna.start()