Files
litert-lm-orangepi/server.py
T
2026-04-20 20:39:43 +07:00

267 lines
11 KiB
Python

import os
import sys
import uuid
import time
from pathlib import Path
# Suppress verbose logs
os.environ["GRPC_VERBOSITY"] = "ERROR"
os.environ["GLOG_minloglevel"] = "3"
import litert_lm
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from contextlib import asynccontextmanager
# ── Config ────────────────────────────────────────────────────────────────────
MODELS_DIR = Path(__file__).parent / "models"
TEMPLATE_DIR = Path(__file__).parent / "templates"
AVAILABLE_MODELS = {
"gemma-4-E2B-it": {
"file": "gemma-4-E2B-it.litertlm",
"repo": "litert-community/gemma-4-E2B-it-litert-lm",
"desc": "Gemma 4 Edge 2B — nhỏ hơn, nhanh hơn",
},
"gemma-4-E4B-it": {
"file": "gemma-4-E4B-it.litertlm",
"repo": "litert-community/gemma-4-E4B-it-litert-lm",
"desc": "Gemma 4 Edge 4B — thông minh hơn, chậm hơn",
},
}
# ── CLI: chọn model khi khởi động ────────────────────────────────────────────
def download_model(repo: str, local_dir: Path) -> bool:
"""Tải model từ Hugging Face về local."""
try:
import subprocess
print(f"\n Đang tải model từ {repo}...")
print(f" Vui lòng đợi, quá trình này có thể mất vài phút...\n")
cmd = [
"huggingface-cli", "download", repo,
"--include", "*.litertlm",
"--local-dir", str(local_dir)
]
result = subprocess.run(cmd, check=True, capture_output=False, text=True)
print(f"\n ✓ Tải model thành công!")
return True
except subprocess.CalledProcessError as e:
print(f"\n ✗ Lỗi khi tải model: {e}")
return False
except FileNotFoundError:
print(f"\n ✗ Không tìm thấy huggingface-cli.")
print(f" Cài đặt bằng lệnh: pip install huggingface-hub")
return False
except Exception as e:
print(f"\n ✗ Lỗi không xác định: {e}")
return False
def select_model() -> Path:
print("\n" + "="*52)
print(" LiteRT-LM Server — Chọn model")
print("="*52)
for i, (key, info) in enumerate(AVAILABLE_MODELS.items(), 1):
model_path = MODELS_DIR / info["file"]
status = "✓ có sẵn" if model_path.exists() else "✗ chưa tải"
print(f" [{i}] {key}")
print(f" {info['desc']}")
print(f" {status}")
print()
while True:
try:
choice = input("Chọn model (1/2): ").strip()
idx = int(choice) - 1
if 0 <= idx < len(AVAILABLE_MODELS):
key = list(AVAILABLE_MODELS.keys())[idx]
info = AVAILABLE_MODELS[key]
model_path = MODELS_DIR / info["file"]
if not model_path.exists():
print(f"\n Model chưa có trong thư mục models/")
print(f" Lệnh tải thủ công:\n")
print(f" huggingface-cli download {info['repo']} \\")
print(f" --include '*.litertlm' \\")
print(f" --local-dir models/\n")
download_choice = input(" Bạn muốn tải model ngay bây giờ? (y/n): ").strip().lower()
if download_choice == "y":
if download_model(info['repo'], MODELS_DIR):
# Kiểm tra lại xem file đã tồn tại chưa
if model_path.exists():
print(f"\n Đã chọn: {key}")
print(f" Path: {model_path}\n")
return model_path
else:
print(f"\n ✗ Không tìm thấy file model sau khi tải.")
retry = input(" Chọn model khác? (y/n): ").strip().lower()
if retry == "y":
continue
else:
sys.exit(0)
else:
retry = input("\n Chọn model khác? (y/n): ").strip().lower()
if retry == "y":
continue
else:
sys.exit(0)
else:
retry = input(" Chọn model khác? (y/n): ").strip().lower()
if retry == "y":
continue
else:
sys.exit(0)
print(f"\n Đã chọn: {key}")
print(f" Path: {model_path}\n")
return model_path
else:
print(" Vui lòng nhập 1 hoặc 2.")
except (ValueError, KeyboardInterrupt):
print("\n Thoát.")
sys.exit(0)
# Chọn model trước khi FastAPI khởi động
MODELS_DIR.mkdir(exist_ok=True)
MODEL_PATH = select_model()
# ── Models ───────────────────────────────────────────────────────────────────
class PromptRequest(BaseModel):
prompt: str
# ── State ────────────────────────────────────────────────────────────────────
ml_models = {}
sessions: dict = {} # session_id -> conversation object
# ── Helpers ───────────────────────────────────────────────────────────────────
def count_tokens(engine, text: str) -> int:
try:
return len(engine.tokenize(text))
except Exception:
return max(1, len(text) // 4)
# ── Lifespan ─────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
print(f" Loading model: {MODEL_PATH.name} ...")
engine = litert_lm.Engine(str(MODEL_PATH), backend=litert_lm.Backend.CPU)
ml_models["engine"] = engine
ml_models["model_name"] = MODEL_PATH.stem
print(f" Model ready: {MODEL_PATH.name}\n")
yield
sessions.clear()
del ml_models["engine"]
# ── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(title="LiteRT-LM API", lifespan=lifespan)
# ── REST: info ────────────────────────────────────────────────────────────────
@app.get("/info")
async def info():
"""Return current loaded model info."""
return {
"model": ml_models.get("model_name", "unknown"),
"sessions": len(sessions),
}
# ── REST: stateless single-turn ───────────────────────────────────────────────
@app.post("/generate")
async def generate_text(request: PromptRequest):
"""Single-turn generation. No memory between calls."""
engine = ml_models.get("engine")
if not engine:
raise HTTPException(status_code=503, detail="Model engine not initialized")
try:
conversation = engine.create_conversation()
t0 = time.perf_counter()
result = conversation.send_message(request.prompt)
elapsed = time.perf_counter() - t0
text = result["content"][0]["text"]
num_tokens = count_tokens(engine, text)
tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0
return {
"response": text,
"tokens": num_tokens,
"elapsed_s": round(elapsed, 2),
"tokens_per_sec": tps,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ── REST: multi-turn chat sessions ────────────────────────────────────────────
@app.post("/chat/new")
async def new_session():
"""Create a new chat session. Returns session_id."""
engine = ml_models.get("engine")
if not engine:
raise HTTPException(status_code=503, detail="Model engine not initialized")
session_id = str(uuid.uuid4())
sessions[session_id] = engine.create_conversation()
return {"session_id": session_id}
@app.post("/chat/{session_id}")
async def chat(session_id: str, request: PromptRequest):
"""Send a message in an existing session (retains conversation history)."""
if session_id not in sessions:
raise HTTPException(
status_code=404,
detail="Session not found. Create one via POST /chat/new",
)
try:
engine = ml_models.get("engine")
t0 = time.perf_counter()
result = sessions[session_id].send_message(request.prompt)
elapsed = time.perf_counter() - t0
text = result["content"][0]["text"]
num_tokens = count_tokens(engine, text)
tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0
return {
"session_id": session_id,
"response": text,
"tokens": num_tokens,
"elapsed_s": round(elapsed, 2),
"tokens_per_sec": tps,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/chat/{session_id}")
async def clear_session(session_id: str):
"""Delete a session and free its memory."""
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
del sessions[session_id]
return {"status": "cleared", "session_id": session_id}
@app.get("/chat/sessions/list")
async def list_sessions():
"""List all active session IDs."""
return {"sessions": list(sessions.keys()), "count": len(sessions)}
# ── WebUI ─────────────────────────────────────────────────────────────────────
@app.get("/", response_class=HTMLResponse)
async def web_ui():
html = (TEMPLATE_DIR / "index.html").read_text(encoding="utf-8")
return HTMLResponse(content=html)
# ── Run ───────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)