import os import sys import uuid import time from pathlib import Path # Suppress verbose logs os.environ["GRPC_VERBOSITY"] = "ERROR" os.environ["GLOG_minloglevel"] = "3" import litert_lm from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse from pydantic import BaseModel from contextlib import asynccontextmanager # ── Config ──────────────────────────────────────────────────────────────────── MODELS_DIR = Path(__file__).parent / "models" TEMPLATE_DIR = Path(__file__).parent / "templates" AVAILABLE_MODELS = { "gemma-4-E2B-it": { "file": "gemma-4-E2B-it.litertlm", "repo": "litert-community/gemma-4-E2B-it-litert-lm", "desc": "Gemma 4 Edge 2B — nhỏ hơn, nhanh hơn", }, "gemma-4-E4B-it": { "file": "gemma-4-E4B-it.litertlm", "repo": "litert-community/gemma-4-E4B-it-litert-lm", "desc": "Gemma 4 Edge 4B — thông minh hơn, chậm hơn", }, } # ── CLI: chọn model khi khởi động ──────────────────────────────────────────── def download_model(repo: str, local_dir: Path) -> bool: """Tải model từ Hugging Face về local.""" try: import subprocess print(f"\n Đang tải model từ {repo}...") print(f" Vui lòng đợi, quá trình này có thể mất vài phút...\n") cmd = [ "huggingface-cli", "download", repo, "--include", "*.litertlm", "--local-dir", str(local_dir) ] result = subprocess.run(cmd, check=True, capture_output=False, text=True) print(f"\n ✓ Tải model thành công!") return True except subprocess.CalledProcessError as e: print(f"\n ✗ Lỗi khi tải model: {e}") return False except FileNotFoundError: print(f"\n ✗ Không tìm thấy huggingface-cli.") print(f" Cài đặt bằng lệnh: pip install huggingface-hub") return False except Exception as e: print(f"\n ✗ Lỗi không xác định: {e}") return False def select_model() -> Path: print("\n" + "="*52) print(" LiteRT-LM Server — Chọn model") print("="*52) for i, (key, info) in enumerate(AVAILABLE_MODELS.items(), 1): model_path = MODELS_DIR / info["file"] status = "✓ có sẵn" if model_path.exists() else "✗ chưa tải" print(f" [{i}] {key}") print(f" {info['desc']}") print(f" {status}") print() while True: try: choice = input("Chọn model (1/2): ").strip() idx = int(choice) - 1 if 0 <= idx < len(AVAILABLE_MODELS): key = list(AVAILABLE_MODELS.keys())[idx] info = AVAILABLE_MODELS[key] model_path = MODELS_DIR / info["file"] if not model_path.exists(): print(f"\n Model chưa có trong thư mục models/") print(f" Lệnh tải thủ công:\n") print(f" huggingface-cli download {info['repo']} \\") print(f" --include '*.litertlm' \\") print(f" --local-dir models/\n") download_choice = input(" Bạn muốn tải model ngay bây giờ? (y/n): ").strip().lower() if download_choice == "y": if download_model(info['repo'], MODELS_DIR): # Kiểm tra lại xem file đã tồn tại chưa if model_path.exists(): print(f"\n Đã chọn: {key}") print(f" Path: {model_path}\n") return model_path else: print(f"\n ✗ Không tìm thấy file model sau khi tải.") retry = input(" Chọn model khác? (y/n): ").strip().lower() if retry == "y": continue else: sys.exit(0) else: retry = input("\n Chọn model khác? (y/n): ").strip().lower() if retry == "y": continue else: sys.exit(0) else: retry = input(" Chọn model khác? (y/n): ").strip().lower() if retry == "y": continue else: sys.exit(0) print(f"\n Đã chọn: {key}") print(f" Path: {model_path}\n") return model_path else: print(" Vui lòng nhập 1 hoặc 2.") except (ValueError, KeyboardInterrupt): print("\n Thoát.") sys.exit(0) # Chọn model trước khi FastAPI khởi động MODELS_DIR.mkdir(exist_ok=True) MODEL_PATH = select_model() # ── Models ─────────────────────────────────────────────────────────────────── class PromptRequest(BaseModel): prompt: str # ── State ──────────────────────────────────────────────────────────────────── ml_models = {} sessions: dict = {} # session_id -> conversation object # ── Helpers ─────────────────────────────────────────────────────────────────── def count_tokens(engine, text: str) -> int: try: return len(engine.tokenize(text)) except Exception: return max(1, len(text) // 4) # ── Lifespan ───────────────────────────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): print(f" Loading model: {MODEL_PATH.name} ...") engine = litert_lm.Engine(str(MODEL_PATH), backend=litert_lm.Backend.CPU) ml_models["engine"] = engine ml_models["model_name"] = MODEL_PATH.stem print(f" Model ready: {MODEL_PATH.name}\n") yield sessions.clear() del ml_models["engine"] # ── App ─────────────────────────────────────────────────────────────────────── app = FastAPI(title="LiteRT-LM API", lifespan=lifespan) # ── REST: info ──────────────────────────────────────────────────────────────── @app.get("/info") async def info(): """Return current loaded model info.""" return { "model": ml_models.get("model_name", "unknown"), "sessions": len(sessions), } # ── REST: stateless single-turn ─────────────────────────────────────────────── @app.post("/generate") async def generate_text(request: PromptRequest): """Single-turn generation. No memory between calls.""" engine = ml_models.get("engine") if not engine: raise HTTPException(status_code=503, detail="Model engine not initialized") try: conversation = engine.create_conversation() t0 = time.perf_counter() result = conversation.send_message(request.prompt) elapsed = time.perf_counter() - t0 text = result["content"][0]["text"] num_tokens = count_tokens(engine, text) tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0 return { "response": text, "tokens": num_tokens, "elapsed_s": round(elapsed, 2), "tokens_per_sec": tps, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ── REST: multi-turn chat sessions ──────────────────────────────────────────── @app.post("/chat/new") async def new_session(): """Create a new chat session. Returns session_id.""" engine = ml_models.get("engine") if not engine: raise HTTPException(status_code=503, detail="Model engine not initialized") session_id = str(uuid.uuid4()) sessions[session_id] = engine.create_conversation() return {"session_id": session_id} @app.post("/chat/{session_id}") async def chat(session_id: str, request: PromptRequest): """Send a message in an existing session (retains conversation history).""" if session_id not in sessions: raise HTTPException( status_code=404, detail="Session not found. Create one via POST /chat/new", ) try: engine = ml_models.get("engine") t0 = time.perf_counter() result = sessions[session_id].send_message(request.prompt) elapsed = time.perf_counter() - t0 text = result["content"][0]["text"] num_tokens = count_tokens(engine, text) tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0 return { "session_id": session_id, "response": text, "tokens": num_tokens, "elapsed_s": round(elapsed, 2), "tokens_per_sec": tps, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.delete("/chat/{session_id}") async def clear_session(session_id: str): """Delete a session and free its memory.""" if session_id not in sessions: raise HTTPException(status_code=404, detail="Session not found") del sessions[session_id] return {"status": "cleared", "session_id": session_id} @app.get("/chat/sessions/list") async def list_sessions(): """List all active session IDs.""" return {"sessions": list(sessions.keys()), "count": len(sessions)} # ── WebUI ───────────────────────────────────────────────────────────────────── @app.get("/", response_class=HTMLResponse) async def web_ui(): html = (TEMPLATE_DIR / "index.html").read_text(encoding="utf-8") return HTMLResponse(content=html) # ── Run ─────────────────────────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)