379 lines
15 KiB
Python
379 lines
15 KiB
Python
import os
|
|
import sys
|
|
import uuid
|
|
import time
|
|
import socket
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
# Suppress verbose logs
|
|
os.environ["GRPC_VERBOSITY"] = "ERROR"
|
|
os.environ["GLOG_minloglevel"] = "3"
|
|
|
|
import litert_lm
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import HTMLResponse
|
|
from pydantic import BaseModel
|
|
from contextlib import asynccontextmanager
|
|
|
|
# ── Config ────────────────────────────────────────────────────────────────────
|
|
|
|
MODELS_DIR = Path(__file__).parent / "models"
|
|
TEMPLATE_DIR = Path(__file__).parent / "templates"
|
|
|
|
AVAILABLE_MODELS = {
|
|
"gemma-4-E2B-it": {
|
|
"file": "gemma-4-E2B-it.litertlm",
|
|
"repo": "litert-community/gemma-4-E2B-it-litert-lm",
|
|
"desc": "Gemma 4 Edge 2B — nhỏ hơn, nhanh hơn",
|
|
},
|
|
"gemma-4-E4B-it": {
|
|
"file": "gemma-4-E4B-it.litertlm",
|
|
"repo": "litert-community/gemma-4-E4B-it-litert-lm",
|
|
"desc": "Gemma 4 Edge 4B — thông minh hơn, chậm hơn",
|
|
},
|
|
}
|
|
|
|
# ── CLI: chọn model khi khởi động ────────────────────────────────────────────
|
|
|
|
def is_port_available(port: int) -> bool:
|
|
"""Kiểm tra xem port có khả dụng không."""
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("0.0.0.0", port))
|
|
return True
|
|
except OSError:
|
|
return False
|
|
|
|
def select_port(default_port: int = 8000) -> int:
|
|
"""Chọn port khả dụng."""
|
|
if is_port_available(default_port):
|
|
return default_port
|
|
|
|
print(f"\n ⚠️ Port {default_port} đã bị chiếm!")
|
|
while True:
|
|
try:
|
|
choice = input(f" Nhập port khác (hoặc Enter để dùng port tự động): ").strip()
|
|
if not choice:
|
|
# Tìm port tự động
|
|
for port in range(8001, 9000):
|
|
if is_port_available(port):
|
|
print(f" ✓ Sử dụng port tự động: {port}")
|
|
return port
|
|
print(" ✗ Không tìm thấy port khả dụng trong khoảng 8001-8999")
|
|
continue
|
|
|
|
port = int(choice)
|
|
if port < 1024 or port > 65535:
|
|
print(" ✗ Port phải trong khoảng 1024-65535")
|
|
continue
|
|
|
|
if is_port_available(port):
|
|
print(f" ✓ Sử dụng port: {port}")
|
|
return port
|
|
else:
|
|
print(f" ✗ Port {port} đã bị chiếm, vui lòng chọn port khác")
|
|
except ValueError:
|
|
print(" ✗ Vui lòng nhập số port hợp lệ")
|
|
except KeyboardInterrupt:
|
|
print("\n Thoát.")
|
|
sys.exit(0)
|
|
|
|
def download_model(repo: str, local_dir: Path) -> bool:
|
|
"""Tải model từ Hugging Face về local."""
|
|
try:
|
|
import subprocess
|
|
print(f"\n Đang tải model từ {repo}...")
|
|
print(f" Vui lòng đợi, quá trình này có thể mất vài phút...\n")
|
|
|
|
cmd = [
|
|
"hf", "download", repo,
|
|
"--include", "*.litertlm",
|
|
"--local-dir", str(local_dir)
|
|
]
|
|
|
|
result = subprocess.run(cmd, check=True, capture_output=False, text=True)
|
|
print(f"\n ✓ Tải model thành công!")
|
|
return True
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"\n ✗ Lỗi khi tải model: {e}")
|
|
return False
|
|
except FileNotFoundError:
|
|
print(f"\n ✗ Không tìm thấy lệnh 'hf'.")
|
|
print(f" Cài đặt bằng lệnh: pip install -U huggingface-hub")
|
|
return False
|
|
except Exception as e:
|
|
print(f"\n ✗ Lỗi không xác định: {e}")
|
|
return False
|
|
|
|
def select_model(custom_path: str = None) -> Path:
|
|
"""Chọn model để sử dụng."""
|
|
|
|
# Nếu có custom path, kiểm tra và sử dụng luôn
|
|
if custom_path:
|
|
custom_model = Path(custom_path)
|
|
if custom_model.exists() and custom_model.suffix == ".litertlm":
|
|
print(f"\n ✓ Sử dụng model từ đường dẫn: {custom_model}")
|
|
return custom_model
|
|
else:
|
|
print(f"\n ✗ Không tìm thấy model tại: {custom_path}")
|
|
print(f" Vui lòng kiểm tra lại đường dẫn.\n")
|
|
sys.exit(1)
|
|
|
|
print("\n" + "="*52)
|
|
print(" LiteRT-LM Server — Chọn model")
|
|
print("="*52)
|
|
|
|
for i, (key, info) in enumerate(AVAILABLE_MODELS.items(), 1):
|
|
model_path = MODELS_DIR / info["file"]
|
|
status = "✓ có sẵn" if model_path.exists() else "✗ chưa tải"
|
|
print(f" [{i}] {key}")
|
|
print(f" {info['desc']}")
|
|
print(f" {status}")
|
|
print()
|
|
|
|
print(f" [3] Sử dụng model từ đường dẫn khác")
|
|
print()
|
|
|
|
while True:
|
|
try:
|
|
choice = input("Chọn model (1/2/3): ").strip()
|
|
|
|
# Tùy chọn 3: Đường dẫn tùy chỉnh
|
|
if choice == "3":
|
|
custom_path = input("\n Nhập đường dẫn đầy đủ tới file .litertlm: ").strip()
|
|
custom_model = Path(custom_path)
|
|
if custom_model.exists() and custom_model.suffix == ".litertlm":
|
|
print(f"\n Đã chọn: {custom_model.name}")
|
|
print(f" Path: {custom_model}\n")
|
|
return custom_model
|
|
else:
|
|
print(f"\n ✗ Không tìm thấy file model hợp lệ tại: {custom_path}")
|
|
retry = input(" Thử lại? (y/n): ").strip().lower()
|
|
if retry == "y":
|
|
continue
|
|
else:
|
|
sys.exit(0)
|
|
|
|
idx = int(choice) - 1
|
|
if 0 <= idx < len(AVAILABLE_MODELS):
|
|
key = list(AVAILABLE_MODELS.keys())[idx]
|
|
info = AVAILABLE_MODELS[key]
|
|
model_path = MODELS_DIR / info["file"]
|
|
|
|
if not model_path.exists():
|
|
print(f"\n Model chưa có trong thư mục models/")
|
|
print(f" Lệnh tải thủ công:\n")
|
|
print(f" hf download {info['repo']} \\")
|
|
print(f" --include '*.litertlm' \\")
|
|
print(f" --local-dir models/\n")
|
|
|
|
download_choice = input(" Bạn muốn tải model ngay bây giờ? (y/n): ").strip().lower()
|
|
if download_choice == "y":
|
|
if download_model(info['repo'], MODELS_DIR):
|
|
# Kiểm tra lại xem file đã tồn tại chưa
|
|
if model_path.exists():
|
|
print(f"\n Đã chọn: {key}")
|
|
print(f" Path: {model_path}\n")
|
|
return model_path
|
|
else:
|
|
print(f"\n ✗ Không tìm thấy file model sau khi tải.")
|
|
retry = input(" Chọn model khác? (y/n): ").strip().lower()
|
|
if retry == "y":
|
|
continue
|
|
else:
|
|
sys.exit(0)
|
|
else:
|
|
retry = input("\n Chọn model khác? (y/n): ").strip().lower()
|
|
if retry == "y":
|
|
continue
|
|
else:
|
|
sys.exit(0)
|
|
else:
|
|
retry = input(" Chọn model khác? (y/n): ").strip().lower()
|
|
if retry == "y":
|
|
continue
|
|
else:
|
|
sys.exit(0)
|
|
|
|
print(f"\n Đã chọn: {key}")
|
|
print(f" Path: {model_path}\n")
|
|
return model_path
|
|
else:
|
|
print(" Vui lòng nhập 1, 2 hoặc 3.")
|
|
except (ValueError, KeyboardInterrupt):
|
|
print("\n Thoát.")
|
|
sys.exit(0)
|
|
|
|
# Parse command line arguments
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="LiteRT-LM Server - Local AI inference server",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Ví dụ:
|
|
python server.py
|
|
python server.py --port 8080
|
|
python server.py --model /path/to/model.litertlm
|
|
python server.py --port 8080 --model /path/to/model.litertlm
|
|
"""
|
|
)
|
|
parser.add_argument(
|
|
"--port", "-p",
|
|
type=int,
|
|
default=8000,
|
|
help="Port để chạy server (mặc định: 8000)"
|
|
)
|
|
parser.add_argument(
|
|
"--model", "-m",
|
|
type=str,
|
|
default=None,
|
|
help="Đường dẫn đầy đủ tới file model .litertlm"
|
|
)
|
|
return parser.parse_args()
|
|
|
|
# Parse arguments và chọn model trước khi FastAPI khởi động
|
|
args = parse_args()
|
|
MODELS_DIR.mkdir(exist_ok=True)
|
|
MODEL_PATH = select_model(args.model)
|
|
SERVER_PORT = select_port(args.port)
|
|
|
|
# ── Models ───────────────────────────────────────────────────────────────────
|
|
|
|
class PromptRequest(BaseModel):
|
|
prompt: str
|
|
|
|
# ── State ────────────────────────────────────────────────────────────────────
|
|
|
|
ml_models = {}
|
|
sessions: dict = {} # session_id -> conversation object
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
def count_tokens(engine, text: str) -> int:
|
|
try:
|
|
return len(engine.tokenize(text))
|
|
except Exception:
|
|
return max(1, len(text) // 4)
|
|
|
|
# ── Lifespan ─────────────────────────────────────────────────────────────────
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
print(f" Loading model: {MODEL_PATH.name} ...")
|
|
engine = litert_lm.Engine(str(MODEL_PATH), backend=litert_lm.Backend.CPU)
|
|
ml_models["engine"] = engine
|
|
ml_models["model_name"] = MODEL_PATH.stem
|
|
print(f" Model ready: {MODEL_PATH.name}\n")
|
|
yield
|
|
sessions.clear()
|
|
del ml_models["engine"]
|
|
|
|
# ── App ───────────────────────────────────────────────────────────────────────
|
|
|
|
app = FastAPI(title="LiteRT-LM API", lifespan=lifespan)
|
|
|
|
# ── REST: info ────────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/info")
|
|
async def info():
|
|
"""Return current loaded model info."""
|
|
return {
|
|
"model": ml_models.get("model_name", "unknown"),
|
|
"sessions": len(sessions),
|
|
}
|
|
|
|
# ── REST: stateless single-turn ───────────────────────────────────────────────
|
|
|
|
@app.post("/generate")
|
|
async def generate_text(request: PromptRequest):
|
|
"""Single-turn generation. No memory between calls."""
|
|
engine = ml_models.get("engine")
|
|
if not engine:
|
|
raise HTTPException(status_code=503, detail="Model engine not initialized")
|
|
try:
|
|
conversation = engine.create_conversation()
|
|
t0 = time.perf_counter()
|
|
result = conversation.send_message(request.prompt)
|
|
elapsed = time.perf_counter() - t0
|
|
text = result["content"][0]["text"]
|
|
num_tokens = count_tokens(engine, text)
|
|
tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0
|
|
return {
|
|
"response": text,
|
|
"tokens": num_tokens,
|
|
"elapsed_s": round(elapsed, 2),
|
|
"tokens_per_sec": tps,
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# ── REST: multi-turn chat sessions ────────────────────────────────────────────
|
|
|
|
@app.post("/chat/new")
|
|
async def new_session():
|
|
"""Create a new chat session. Returns session_id."""
|
|
engine = ml_models.get("engine")
|
|
if not engine:
|
|
raise HTTPException(status_code=503, detail="Model engine not initialized")
|
|
session_id = str(uuid.uuid4())
|
|
sessions[session_id] = engine.create_conversation()
|
|
return {"session_id": session_id}
|
|
|
|
@app.post("/chat/{session_id}")
|
|
async def chat(session_id: str, request: PromptRequest):
|
|
"""Send a message in an existing session (retains conversation history)."""
|
|
if session_id not in sessions:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Session not found. Create one via POST /chat/new",
|
|
)
|
|
try:
|
|
engine = ml_models.get("engine")
|
|
t0 = time.perf_counter()
|
|
result = sessions[session_id].send_message(request.prompt)
|
|
elapsed = time.perf_counter() - t0
|
|
text = result["content"][0]["text"]
|
|
num_tokens = count_tokens(engine, text)
|
|
tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0
|
|
return {
|
|
"session_id": session_id,
|
|
"response": text,
|
|
"tokens": num_tokens,
|
|
"elapsed_s": round(elapsed, 2),
|
|
"tokens_per_sec": tps,
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.delete("/chat/{session_id}")
|
|
async def clear_session(session_id: str):
|
|
"""Delete a session and free its memory."""
|
|
if session_id not in sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
del sessions[session_id]
|
|
return {"status": "cleared", "session_id": session_id}
|
|
|
|
@app.get("/chat/sessions/list")
|
|
async def list_sessions():
|
|
"""List all active session IDs."""
|
|
return {"sessions": list(sessions.keys()), "count": len(sessions)}
|
|
|
|
# ── WebUI ─────────────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def web_ui():
|
|
html = (TEMPLATE_DIR / "index.html").read_text(encoding="utf-8")
|
|
return HTMLResponse(content=html)
|
|
|
|
# ── Run ───────────────────────────────────────────────────────────────────────
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
print(f"\n{'='*52}")
|
|
print(f" 🚀 Server đang khởi động...")
|
|
print(f" 📍 URL: http://localhost:{SERVER_PORT}")
|
|
print(f" 📦 Model: {MODEL_PATH.name}")
|
|
print(f"{'='*52}\n")
|
|
uvicorn.run(app, host="0.0.0.0", port=SERVER_PORT)
|