Files
GenAIExamples/CodeTrans/langchain/docker/codetrans-app/server.py
Letong Han 0ac6fd4fef Add code translation example (#81)
* add example of Code Translation

Signed-off-by: LetongHan <letong.han@intel.com>
2024-04-17 10:03:06 +08:00

130 lines
4.5 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from langchain_community.llms import HuggingFaceEndpoint
from prompts import codetrans_prompt_template
from starlette.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
TGI_ENDPOINT = os.getenv("TGI_ENDPOINT", "http://localhost:8080")
SERVICE_PORT = os.getenv("SERVER_PORT", 8000)
class CodeTranslationAPIRouter(APIRouter):
"""The router for CodeTranslation example."""
def __init__(self, entrypoint: str, prompt_template: str) -> None:
super().__init__()
self.entrypoint = entrypoint
# setup TGI endpoint
self.llm = HuggingFaceEndpoint(
endpoint_url=entrypoint,
max_new_tokens=1024,
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
streaming=True,
)
self.prompt_template = prompt_template
def handle_code_translation(self, language_from: str, language_to: str, source_code: str):
prompt = self.prompt_template.format(
language_from=language_from, language_to=language_to, source_code=source_code
)
print(f"[codetrans - nonstream] prompt:{prompt}")
try:
response = self.llm(prompt)
except Exception as e:
print(f"[codetrans - nonstream] Error occurred: {e}")
raise Exception(f"[codetrans - nonstream] {e}")
print(f"[codetrans - nonstream] response:\n{response}")
return response
async def handle_code_translation_stream(self, language_from: str, language_to: str, source_code: str):
prompt = self.prompt_template.format(
language_from=language_from, language_to=language_to, source_code=source_code
)
print(f"[codetrans - stream] prompt:{prompt}")
async def stream_generator():
async for chunk in self.llm.astream_log(prompt):
print(f"[codetrans - stream] data: {chunk}")
yield f"data: {chunk}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_generator(), media_type="text/event-stream")
router = CodeTranslationAPIRouter(entrypoint=TGI_ENDPOINT, prompt_template=codetrans_prompt_template)
@router.post("/v1/code_translation")
async def code_translation(request: Request):
params = await request.json()
print(f"[codetrans - nonstream] POST request: /v1/code_translation, params:{params}")
language_from = params["language_from"]
language_to = params["language_to"]
source_code = params["source_code"]
try:
return router.handle_code_translation(
language_from=language_from, language_to=language_to, source_code=source_code
)
except Exception as e:
print(f"[codetrans - nonstream] Error occurred: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/v1/code_translation_stream")
async def code_translation_stream(request: Request):
params = await request.json()
print(f"[codetrans - stream] POST request: /v1/code_translation_stream, params:{params}")
language_from = params["language_from"]
language_to = params["language_to"]
source_code = params["source_code"]
try:
return await router.handle_code_translation_stream(
language_from=language_from, language_to=language_to, source_code=source_code
)
except Exception as e:
print(f"[codetrans - stream] Error occurred: {e}")
raise HTTPException(status_code=500, detail=str(e))
app.include_router(router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(SERVICE_PORT))