add files
This commit is contained in:
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Suppress verbose logs
|
||||
os.environ["GRPC_VERBOSITY"] = "ERROR"
|
||||
os.environ["GLOG_minloglevel"] = "3"
|
||||
|
||||
import litert_lm
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
# ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
MODEL_PATH = "gemma-4-E2B-it.litertlm"
|
||||
TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
|
||||
# ── Models ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
# ── State ────────────────────────────────────────────────────────────────────
|
||||
|
||||
ml_models = {}
|
||||
sessions: dict = {} # session_id -> conversation object
|
||||
|
||||
# ── Lifespan ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
engine = litert_lm.Engine(MODEL_PATH, backend=litert_lm.Backend.CPU)
|
||||
ml_models["engine"] = engine
|
||||
yield
|
||||
sessions.clear()
|
||||
del ml_models["engine"]
|
||||
|
||||
# ── App ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
app = FastAPI(title="LiteRT-LM API", lifespan=lifespan)
|
||||
|
||||
# ── REST: stateless single-turn ───────────────────────────────────────────────
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_text(request: PromptRequest):
|
||||
"""Single-turn generation. No memory between calls."""
|
||||
engine = ml_models.get("engine")
|
||||
if not engine:
|
||||
raise HTTPException(status_code=503, detail="Model engine not initialized")
|
||||
try:
|
||||
conversation = engine.create_conversation()
|
||||
t0 = time.perf_counter()
|
||||
result = conversation.send_message(request.prompt)
|
||||
elapsed = time.perf_counter() - t0
|
||||
text = result["content"][0]["text"]
|
||||
num_tokens = len(engine.tokenize(text))
|
||||
tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0
|
||||
return {
|
||||
"response": text,
|
||||
"tokens": num_tokens,
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
"tokens_per_sec": tps,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# ── REST: multi-turn chat sessions ────────────────────────────────────────────
|
||||
|
||||
@app.post("/chat/new")
|
||||
async def new_session():
|
||||
"""Create a new chat session. Returns session_id."""
|
||||
engine = ml_models.get("engine")
|
||||
if not engine:
|
||||
raise HTTPException(status_code=503, detail="Model engine not initialized")
|
||||
session_id = str(uuid.uuid4())
|
||||
sessions[session_id] = engine.create_conversation()
|
||||
return {"session_id": session_id}
|
||||
|
||||
@app.post("/chat/{session_id}")
|
||||
async def chat(session_id: str, request: PromptRequest):
|
||||
"""Send a message in an existing session (retains conversation history)."""
|
||||
if session_id not in sessions:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Session not found. Create one via POST /chat/new",
|
||||
)
|
||||
try:
|
||||
engine = ml_models.get("engine")
|
||||
t0 = time.perf_counter()
|
||||
result = sessions[session_id].send_message(request.prompt)
|
||||
elapsed = time.perf_counter() - t0
|
||||
text = result["content"][0]["text"]
|
||||
num_tokens = len(engine.tokenize(text)) if engine else 0
|
||||
tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"response": text,
|
||||
"tokens": num_tokens,
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
"tokens_per_sec": tps,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.delete("/chat/{session_id}")
|
||||
async def clear_session(session_id: str):
|
||||
"""Delete a session and free its memory."""
|
||||
if session_id not in sessions:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
del sessions[session_id]
|
||||
return {"status": "cleared", "session_id": session_id}
|
||||
|
||||
@app.get("/chat/sessions/list")
|
||||
async def list_sessions():
|
||||
"""List all active session IDs."""
|
||||
return {"sessions": list(sessions.keys()), "count": len(sessions)}
|
||||
|
||||
# ── WebUI ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def web_ui():
|
||||
html = (TEMPLATE_DIR / "index.html").read_text(encoding="utf-8")
|
||||
return HTMLResponse(content=html)
|
||||
|
||||
# ── Run ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user