diff --git a/app.py b/app.py new file mode 100644 index 0000000..53c5c6f --- /dev/null +++ b/app.py @@ -0,0 +1,37 @@ +import litert_lm +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from contextlib import asynccontextmanager + +MODEL_PATH = "gemma-4-E2B-it.litertlm" + +class PromptRequest(BaseModel): + prompt: str + +ml_models = {} + +@asynccontextmanager +async def lifespan(app: FastAPI): + engine = litert_lm.Engine(MODEL_PATH, backend=litert_lm.Backend.CPU) + ml_models["engine"] = engine + yield + del ml_models["engine"] + +app = FastAPI(lifespan=lifespan) + +@app.post("/generate") +async def generate_text(request: PromptRequest): + engine = ml_models.get("engine") + if not engine: + raise HTTPException(status_code=503, detail="Model engine not initialized") + try: + # Không dùng "with", tạo conversation trực tiếp + conversation = engine.create_conversation() + result = conversation.send_message(request.prompt) + return {"response": result["content"][0]["text"]} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/default.profraw b/default.profraw new file mode 100644 index 0000000..6e624b0 Binary files /dev/null and b/default.profraw differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f2ad05e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +fastapi +uvicorn +litert-lm-api-nightly diff --git a/server.py b/server.py new file mode 100644 index 0000000..a5da892 --- /dev/null +++ b/server.py @@ -0,0 +1,132 @@ +import os +import uuid +import time +from pathlib import Path + +# Suppress verbose logs +os.environ["GRPC_VERBOSITY"] = "ERROR" +os.environ["GLOG_minloglevel"] = "3" + +import litert_lm +from fastapi import FastAPI, HTTPException +from fastapi.responses import HTMLResponse +from pydantic import BaseModel +from contextlib import asynccontextmanager + +# ── Config ──────────────────────────────────────────────────────────────────── + +MODEL_PATH = "gemma-4-E2B-it.litertlm" +TEMPLATE_DIR = Path(__file__).parent / "templates" + +# ── Models ─────────────────────────────────────────────────────────────────── + +class PromptRequest(BaseModel): + prompt: str + +# ── State ──────────────────────────────────────────────────────────────────── + +ml_models = {} +sessions: dict = {} # session_id -> conversation object + +# ── Lifespan ───────────────────────────────────────────────────────────────── + +@asynccontextmanager +async def lifespan(app: FastAPI): + engine = litert_lm.Engine(MODEL_PATH, backend=litert_lm.Backend.CPU) + ml_models["engine"] = engine + yield + sessions.clear() + del ml_models["engine"] + +# ── App ─────────────────────────────────────────────────────────────────────── + +app = FastAPI(title="LiteRT-LM API", lifespan=lifespan) + +# ── REST: stateless single-turn ─────────────────────────────────────────────── + +@app.post("/generate") +async def generate_text(request: PromptRequest): + """Single-turn generation. No memory between calls.""" + engine = ml_models.get("engine") + if not engine: + raise HTTPException(status_code=503, detail="Model engine not initialized") + try: + conversation = engine.create_conversation() + t0 = time.perf_counter() + result = conversation.send_message(request.prompt) + elapsed = time.perf_counter() - t0 + text = result["content"][0]["text"] + num_tokens = len(engine.tokenize(text)) + tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0 + return { + "response": text, + "tokens": num_tokens, + "elapsed_s": round(elapsed, 2), + "tokens_per_sec": tps, + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +# ── REST: multi-turn chat sessions ──────────────────────────────────────────── + +@app.post("/chat/new") +async def new_session(): + """Create a new chat session. Returns session_id.""" + engine = ml_models.get("engine") + if not engine: + raise HTTPException(status_code=503, detail="Model engine not initialized") + session_id = str(uuid.uuid4()) + sessions[session_id] = engine.create_conversation() + return {"session_id": session_id} + +@app.post("/chat/{session_id}") +async def chat(session_id: str, request: PromptRequest): + """Send a message in an existing session (retains conversation history).""" + if session_id not in sessions: + raise HTTPException( + status_code=404, + detail="Session not found. Create one via POST /chat/new", + ) + try: + engine = ml_models.get("engine") + t0 = time.perf_counter() + result = sessions[session_id].send_message(request.prompt) + elapsed = time.perf_counter() - t0 + text = result["content"][0]["text"] + num_tokens = len(engine.tokenize(text)) if engine else 0 + tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0 + return { + "session_id": session_id, + "response": text, + "tokens": num_tokens, + "elapsed_s": round(elapsed, 2), + "tokens_per_sec": tps, + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.delete("/chat/{session_id}") +async def clear_session(session_id: str): + """Delete a session and free its memory.""" + if session_id not in sessions: + raise HTTPException(status_code=404, detail="Session not found") + del sessions[session_id] + return {"status": "cleared", "session_id": session_id} + +@app.get("/chat/sessions/list") +async def list_sessions(): + """List all active session IDs.""" + return {"sessions": list(sessions.keys()), "count": len(sessions)} + +# ── WebUI ───────────────────────────────────────────────────────────────────── + +@app.get("/", response_class=HTMLResponse) +async def web_ui(): + html = (TEMPLATE_DIR / "index.html").read_text(encoding="utf-8") + return HTMLResponse(content=html) + +# ── Run ─────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/templates/index.html b/templates/index.html new file mode 100644 index 0000000..2d7c950 --- /dev/null +++ b/templates/index.html @@ -0,0 +1,502 @@ + + +
+ + +Start a conversation.
Gemma 4 remembers context within a session.