import os import sys import uuid import time import socket import argparse from pathlib import Path # Suppress verbose logs os.environ["GRPC_VERBOSITY"] = "ERROR" os.environ["GLOG_minloglevel"] = "3" import litert_lm from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse from pydantic import BaseModel from contextlib import asynccontextmanager # ── Config ──────────────────────────────────────────────────────────────────── MODELS_DIR = Path(__file__).parent / "models" TEMPLATE_DIR = Path(__file__).parent / "templates" AVAILABLE_MODELS = { "gemma-4-E2B-it": { "file": "gemma-4-E2B-it.litertlm", "repo": "litert-community/gemma-4-E2B-it-litert-lm", "desc": "Gemma 4 Edge 2B — nhỏ hơn, nhanh hơn", }, "gemma-4-E4B-it": { "file": "gemma-4-E4B-it.litertlm", "repo": "litert-community/gemma-4-E4B-it-litert-lm", "desc": "Gemma 4 Edge 4B — thông minh hơn, chậm hơn", }, } # ── CLI: chọn model khi khởi động ──────────────────────────────────────────── def is_port_available(port: int) -> bool: """Kiểm tra xem port có khả dụng không.""" try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("0.0.0.0", port)) return True except OSError: return False def select_port(default_port: int = 8000) -> int: """Chọn port khả dụng.""" if is_port_available(default_port): return default_port print(f"\n ⚠️ Port {default_port} đã bị chiếm!") while True: try: choice = input(f" Nhập port khác (hoặc Enter để dùng port tự động): ").strip() if not choice: # Tìm port tự động for port in range(8001, 9000): if is_port_available(port): print(f" ✓ Sử dụng port tự động: {port}") return port print(" ✗ Không tìm thấy port khả dụng trong khoảng 8001-8999") continue port = int(choice) if port < 1024 or port > 65535: print(" ✗ Port phải trong khoảng 1024-65535") continue if is_port_available(port): print(f" ✓ Sử dụng port: {port}") return port else: print(f" ✗ Port {port} đã bị chiếm, vui lòng chọn port khác") except ValueError: print(" ✗ Vui lòng nhập số port hợp lệ") except KeyboardInterrupt: print("\n Thoát.") sys.exit(0) def download_model(repo: str, local_dir: Path) -> bool: """Tải model từ Hugging Face về local.""" try: import subprocess print(f"\n Đang tải model từ {repo}...") print(f" Vui lòng đợi, quá trình này có thể mất vài phút...\n") cmd = [ "hf", "download", repo, "--include", "*.litertlm", "--local-dir", str(local_dir) ] result = subprocess.run(cmd, check=True, capture_output=False, text=True) print(f"\n ✓ Tải model thành công!") return True except subprocess.CalledProcessError as e: print(f"\n ✗ Lỗi khi tải model: {e}") return False except FileNotFoundError: print(f"\n ✗ Không tìm thấy lệnh 'hf'.") print(f" Cài đặt bằng lệnh: pip install -U huggingface-hub") return False except Exception as e: print(f"\n ✗ Lỗi không xác định: {e}") return False def select_model(custom_path: str = None) -> Path: """Chọn model để sử dụng.""" # Nếu có custom path, kiểm tra và sử dụng luôn if custom_path: custom_model = Path(custom_path) if custom_model.exists() and custom_model.suffix == ".litertlm": print(f"\n ✓ Sử dụng model từ đường dẫn: {custom_model}") return custom_model else: print(f"\n ✗ Không tìm thấy model tại: {custom_path}") print(f" Vui lòng kiểm tra lại đường dẫn.\n") sys.exit(1) print("\n" + "="*52) print(" LiteRT-LM Server — Chọn model") print("="*52) for i, (key, info) in enumerate(AVAILABLE_MODELS.items(), 1): model_path = MODELS_DIR / info["file"] status = "✓ có sẵn" if model_path.exists() else "✗ chưa tải" print(f" [{i}] {key}") print(f" {info['desc']}") print(f" {status}") print() print(f" [3] Sử dụng model từ đường dẫn khác") print() while True: try: choice = input("Chọn model (1/2/3): ").strip() # Tùy chọn 3: Đường dẫn tùy chỉnh if choice == "3": custom_path = input("\n Nhập đường dẫn đầy đủ tới file .litertlm: ").strip() custom_model = Path(custom_path) if custom_model.exists() and custom_model.suffix == ".litertlm": print(f"\n Đã chọn: {custom_model.name}") print(f" Path: {custom_model}\n") return custom_model else: print(f"\n ✗ Không tìm thấy file model hợp lệ tại: {custom_path}") retry = input(" Thử lại? (y/n): ").strip().lower() if retry == "y": continue else: sys.exit(0) idx = int(choice) - 1 if 0 <= idx < len(AVAILABLE_MODELS): key = list(AVAILABLE_MODELS.keys())[idx] info = AVAILABLE_MODELS[key] model_path = MODELS_DIR / info["file"] if not model_path.exists(): print(f"\n Model chưa có trong thư mục models/") print(f" Lệnh tải thủ công:\n") print(f" hf download {info['repo']} \\") print(f" --include '*.litertlm' \\") print(f" --local-dir models/\n") download_choice = input(" Bạn muốn tải model ngay bây giờ? (y/n): ").strip().lower() if download_choice == "y": if download_model(info['repo'], MODELS_DIR): # Kiểm tra lại xem file đã tồn tại chưa if model_path.exists(): print(f"\n Đã chọn: {key}") print(f" Path: {model_path}\n") return model_path else: print(f"\n ✗ Không tìm thấy file model sau khi tải.") retry = input(" Chọn model khác? (y/n): ").strip().lower() if retry == "y": continue else: sys.exit(0) else: retry = input("\n Chọn model khác? (y/n): ").strip().lower() if retry == "y": continue else: sys.exit(0) else: retry = input(" Chọn model khác? (y/n): ").strip().lower() if retry == "y": continue else: sys.exit(0) print(f"\n Đã chọn: {key}") print(f" Path: {model_path}\n") return model_path else: print(" Vui lòng nhập 1, 2 hoặc 3.") except (ValueError, KeyboardInterrupt): print("\n Thoát.") sys.exit(0) # Parse command line arguments def parse_args(): parser = argparse.ArgumentParser( description="LiteRT-LM Server - Local AI inference server", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Ví dụ: python server.py python server.py --port 8080 python server.py --model /path/to/model.litertlm python server.py --port 8080 --model /path/to/model.litertlm """ ) parser.add_argument( "--port", "-p", type=int, default=8000, help="Port để chạy server (mặc định: 8000)" ) parser.add_argument( "--model", "-m", type=str, default=None, help="Đường dẫn đầy đủ tới file model .litertlm" ) return parser.parse_args() # Parse arguments và chọn model trước khi FastAPI khởi động args = parse_args() MODELS_DIR.mkdir(exist_ok=True) MODEL_PATH = select_model(args.model) SERVER_PORT = select_port(args.port) # ── Models ─────────────────────────────────────────────────────────────────── class PromptRequest(BaseModel): prompt: str # ── State ──────────────────────────────────────────────────────────────────── ml_models = {} sessions: dict = {} # session_id -> conversation object # ── Helpers ─────────────────────────────────────────────────────────────────── def count_tokens(engine, text: str) -> int: try: return len(engine.tokenize(text)) except Exception: return max(1, len(text) // 4) # ── Lifespan ───────────────────────────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): print(f" Loading model: {MODEL_PATH.name} ...") engine = litert_lm.Engine(str(MODEL_PATH), backend=litert_lm.Backend.CPU) ml_models["engine"] = engine ml_models["model_name"] = MODEL_PATH.stem print(f" Model ready: {MODEL_PATH.name}\n") yield sessions.clear() del ml_models["engine"] # ── App ─────────────────────────────────────────────────────────────────────── app = FastAPI(title="LiteRT-LM API", lifespan=lifespan) # ── REST: info ──────────────────────────────────────────────────────────────── @app.get("/info") async def info(): """Return current loaded model info.""" return { "model": ml_models.get("model_name", "unknown"), "sessions": len(sessions), } # ── REST: stateless single-turn ─────────────────────────────────────────────── @app.post("/generate") async def generate_text(request: PromptRequest): """Single-turn generation. No memory between calls.""" engine = ml_models.get("engine") if not engine: raise HTTPException(status_code=503, detail="Model engine not initialized") try: conversation = engine.create_conversation() t0 = time.perf_counter() result = conversation.send_message(request.prompt) elapsed = time.perf_counter() - t0 text = result["content"][0]["text"] num_tokens = count_tokens(engine, text) tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0 return { "response": text, "tokens": num_tokens, "elapsed_s": round(elapsed, 2), "tokens_per_sec": tps, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ── REST: multi-turn chat sessions ──────────────────────────────────────────── @app.post("/chat/new") async def new_session(): """Create a new chat session. Returns session_id.""" engine = ml_models.get("engine") if not engine: raise HTTPException(status_code=503, detail="Model engine not initialized") session_id = str(uuid.uuid4()) sessions[session_id] = engine.create_conversation() return {"session_id": session_id} @app.post("/chat/{session_id}") async def chat(session_id: str, request: PromptRequest): """Send a message in an existing session (retains conversation history).""" if session_id not in sessions: raise HTTPException( status_code=404, detail="Session not found. Create one via POST /chat/new", ) try: engine = ml_models.get("engine") t0 = time.perf_counter() result = sessions[session_id].send_message(request.prompt) elapsed = time.perf_counter() - t0 text = result["content"][0]["text"] num_tokens = count_tokens(engine, text) tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0 return { "session_id": session_id, "response": text, "tokens": num_tokens, "elapsed_s": round(elapsed, 2), "tokens_per_sec": tps, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.delete("/chat/{session_id}") async def clear_session(session_id: str): """Delete a session and free its memory.""" if session_id not in sessions: raise HTTPException(status_code=404, detail="Session not found") del sessions[session_id] return {"status": "cleared", "session_id": session_id} @app.get("/chat/sessions/list") async def list_sessions(): """List all active session IDs.""" return {"sessions": list(sessions.keys()), "count": len(sessions)} # ── WebUI ───────────────────────────────────────────────────────────────────── @app.get("/", response_class=HTMLResponse) async def web_ui(): html = (TEMPLATE_DIR / "index.html").read_text(encoding="utf-8") return HTMLResponse(content=html) # ── Run ─────────────────────────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn print(f"\n{'='*52}") print(f" 🚀 Server đang khởi động...") print(f" 📍 URL: http://localhost:{SERVER_PORT}") print(f" 📦 Model: {MODEL_PATH.name}") print(f"{'='*52}\n") uvicorn.run(app, host="0.0.0.0", port=SERVER_PORT)