Files
GenAIExamples/CodeGen/codegen/codegen-app/server.py
lvliang-intel b91a010fd8 Enhance the CodeGen example for the VSCode plugin's public release (#18)
* update codegen readme and code

Signed-off-by: lvliang-intel <liang1.lv@intel.com>

* update readme

Signed-off-by: lvliang-intel <liang1.lv@intel.com>

* update readme

Signed-off-by: lvliang-intel <liang1.lv@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* clean the server code

Signed-off-by: lvliang-intel <liang1.lv@intel.com>

* refine document

Signed-off-by: lvliang-intel <liang1.lv@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update readme

Signed-off-by: lvliang-intel <liang1.lv@intel.com>

---------

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-03-28 23:12:26 +08:00

167 lines
5.8 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
from fastapi import APIRouter, FastAPI
from fastapi.responses import RedirectResponse, StreamingResponse
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.llms import HuggingFaceEndpoint
from langchain_core.pydantic_v1 import BaseModel
from openai_protocol import ChatCompletionRequest, ChatCompletionResponse
from starlette.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
)
def filter_code_format(code):
language_prefixes = {
"go": "```go",
"c": "```c",
"cpp": "```cpp",
"java": "```java",
"python": "```python",
"typescript": "```typescript",
}
suffix = "\n```"
# Find the first occurrence of a language prefix
first_prefix_pos = len(code)
for prefix in language_prefixes.values():
pos = code.find(prefix)
if pos != -1 and pos < first_prefix_pos:
first_prefix_pos = pos + len(prefix) + 1
# Find the first occurrence of the suffix after the first language prefix
first_suffix_pos = code.find(suffix, first_prefix_pos + 1)
# Extract the code block
if first_prefix_pos != -1 and first_suffix_pos != -1:
return code[first_prefix_pos:first_suffix_pos]
elif first_prefix_pos != -1:
return code[first_prefix_pos:]
return code
class CodeGenAPIRouter(APIRouter):
def __init__(self, entrypoint) -> None:
super().__init__()
self.entrypoint = entrypoint
print(f"[codegen - router] Initializing API Router, entrypoint={entrypoint}")
# Define LLM
callbacks = [StreamingStdOutCallbackHandler()]
self.llm = HuggingFaceEndpoint(
endpoint_url=entrypoint,
max_new_tokens=1024,
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
streaming=True,
callbacks=callbacks,
)
print("[codegen - router] LLM initialized.")
def handle_chat_completion_request(self, request: ChatCompletionRequest):
try:
print(f"Predicting chat completion using prompt '{request.prompt}'")
if request.stream:
async def stream_generator():
for chunk in self.llm.stream(request.prompt):
yield f"data: {chunk.encode()}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
result = self.llm(request.prompt)
response = filter_code_format(result)
except Exception as e:
print(f"An error occurred: {e}")
else:
print("Chat completion finished.")
return ChatCompletionResponse(response=response)
tgi_endpoint = os.getenv("TGI_ENDPOINT", "http://localhost:8080")
router = CodeGenAPIRouter(tgi_endpoint)
def check_completion_request(request: BaseModel) -> Optional[str]:
if request.temperature is not None and request.temperature < 0:
return f"Param Error: {request.temperature} is less than the minimum of 0 --- 'temperature'"
if request.temperature is not None and request.temperature > 2:
return f"Param Error: {request.temperature} is greater than the maximum of 2 --- 'temperature'"
if request.top_p is not None and request.top_p < 0:
return f"Param Error: {request.top_p} is less than the minimum of 0 --- 'top_p'"
if request.top_p is not None and request.top_p > 1:
return f"Param Error: {request.top_p} is greater than the maximum of 1 --- 'top_p'"
if request.top_k is not None and (not isinstance(request.top_k, int)):
return f"Param Error: {request.top_k} is not valid under any of the given schemas --- 'top_k'"
if request.top_k is not None and request.top_k < 1:
return f"Param Error: {request.top_k} is greater than the minimum of 1 --- 'top_k'"
if request.max_new_tokens is not None and (not isinstance(request.max_new_tokens, int)):
return f"Param Error: {request.max_new_tokens} is not valid under any of the given schemas --- 'max_new_tokens'"
return None
# router /v1/code_generation only supports non-streaming mode.
@router.post("/v1/code_generation")
async def code_generation_endpoint(chat_request: ChatCompletionRequest):
ret = check_completion_request(chat_request)
if ret is not None:
raise RuntimeError("Invalid parameter.")
return router.handle_chat_completion_request(chat_request)
# router /v1/code_chat supports both non-streaming and streaming mode.
@router.post("/v1/code_chat")
async def code_chat_endpoint(chat_request: ChatCompletionRequest):
ret = check_completion_request(chat_request)
if ret is not None:
raise RuntimeError("Invalid parameter.")
return router.handle_chat_completion_request(chat_request)
app.include_router(router)
@app.get("/")
async def redirect_root_to_docs():
return RedirectResponse("/docs")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)