import os import uuid import time from pathlib import Path # Suppress verbose logs os.environ["GRPC_VERBOSITY"] = "ERROR" os.environ["GLOG_minloglevel"] = "3" import litert_lm from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse from pydantic import BaseModel from contextlib import asynccontextmanager # ── Config ──────────────────────────────────────────────────────────────────── MODEL_PATH = "gemma-4-E2B-it.litertlm" TEMPLATE_DIR = Path(__file__).parent / "templates" # ── Models ─────────────────────────────────────────────────────────────────── class PromptRequest(BaseModel): prompt: str # ── State ──────────────────────────────────────────────────────────────────── ml_models = {} sessions: dict = {} # session_id -> conversation object # ── Lifespan ───────────────────────────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): engine = litert_lm.Engine(MODEL_PATH, backend=litert_lm.Backend.CPU) ml_models["engine"] = engine yield sessions.clear() del ml_models["engine"] # ── App ─────────────────────────────────────────────────────────────────────── app = FastAPI(title="LiteRT-LM API", lifespan=lifespan) # ── REST: stateless single-turn ─────────────────────────────────────────────── @app.post("/generate") async def generate_text(request: PromptRequest): """Single-turn generation. No memory between calls.""" engine = ml_models.get("engine") if not engine: raise HTTPException(status_code=503, detail="Model engine not initialized") try: conversation = engine.create_conversation() t0 = time.perf_counter() result = conversation.send_message(request.prompt) elapsed = time.perf_counter() - t0 text = result["content"][0]["text"] num_tokens = len(engine.tokenize(text)) tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0 return { "response": text, "tokens": num_tokens, "elapsed_s": round(elapsed, 2), "tokens_per_sec": tps, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ── REST: multi-turn chat sessions ──────────────────────────────────────────── @app.post("/chat/new") async def new_session(): """Create a new chat session. Returns session_id.""" engine = ml_models.get("engine") if not engine: raise HTTPException(status_code=503, detail="Model engine not initialized") session_id = str(uuid.uuid4()) sessions[session_id] = engine.create_conversation() return {"session_id": session_id} @app.post("/chat/{session_id}") async def chat(session_id: str, request: PromptRequest): """Send a message in an existing session (retains conversation history).""" if session_id not in sessions: raise HTTPException( status_code=404, detail="Session not found. Create one via POST /chat/new", ) try: engine = ml_models.get("engine") t0 = time.perf_counter() result = sessions[session_id].send_message(request.prompt) elapsed = time.perf_counter() - t0 text = result["content"][0]["text"] num_tokens = len(engine.tokenize(text)) if engine else 0 tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0 return { "session_id": session_id, "response": text, "tokens": num_tokens, "elapsed_s": round(elapsed, 2), "tokens_per_sec": tps, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.delete("/chat/{session_id}") async def clear_session(session_id: str): """Delete a session and free its memory.""" if session_id not in sessions: raise HTTPException(status_code=404, detail="Session not found") del sessions[session_id] return {"status": "cleared", "session_id": session_id} @app.get("/chat/sessions/list") async def list_sessions(): """List all active session IDs.""" return {"sessions": list(sessions.keys()), "count": len(sessions)} # ── WebUI ───────────────────────────────────────────────────────────────────── @app.get("/", response_class=HTMLResponse) async def web_ui(): html = (TEMPLATE_DIR / "index.html").read_text(encoding="utf-8") return HTMLResponse(content=html) # ── Run ─────────────────────────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)