update
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import time
|
||||
from pathlib import Path
|
||||
@@ -15,9 +16,71 @@ from contextlib import asynccontextmanager
|
||||
|
||||
# ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
MODEL_PATH = "gemma-4-E4B-it.litertlm"
|
||||
MODELS_DIR = Path(__file__).parent / "models"
|
||||
TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
|
||||
AVAILABLE_MODELS = {
|
||||
"gemma-4-E2B-it": {
|
||||
"file": "gemma-4-E2B-it.litertlm",
|
||||
"repo": "google/gemma-4-E2B-it",
|
||||
"desc": "Gemma 4 Edge 2B — nhỏ hơn, nhanh hơn",
|
||||
},
|
||||
"gemma-4-E4B-it": {
|
||||
"file": "gemma-4-E4B-it.litertlm",
|
||||
"repo": "google/gemma-4-E4B-it",
|
||||
"desc": "Gemma 4 Edge 4B — thông minh hơn, chậm hơn",
|
||||
},
|
||||
}
|
||||
|
||||
# ── CLI: chọn model khi khởi động ────────────────────────────────────────────
|
||||
|
||||
def select_model() -> Path:
|
||||
print("\n" + "="*52)
|
||||
print(" LiteRT-LM Server — Chọn model")
|
||||
print("="*52)
|
||||
|
||||
for i, (key, info) in enumerate(AVAILABLE_MODELS.items(), 1):
|
||||
model_path = MODELS_DIR / info["file"]
|
||||
status = "✓ có sẵn" if model_path.exists() else "✗ chưa tải"
|
||||
print(f" [{i}] {key}")
|
||||
print(f" {info['desc']}")
|
||||
print(f" {status}")
|
||||
print()
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = input("Chọn model (1/2): ").strip()
|
||||
idx = int(choice) - 1
|
||||
if 0 <= idx < len(AVAILABLE_MODELS):
|
||||
key = list(AVAILABLE_MODELS.keys())[idx]
|
||||
info = AVAILABLE_MODELS[key]
|
||||
model_path = MODELS_DIR / info["file"]
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"\n Model chưa có trong thư mục models/")
|
||||
print(f" Tải về bằng lệnh:\n")
|
||||
print(f" huggingface-cli download {info['repo']} \\")
|
||||
print(f" --include '*.litertlm' \\")
|
||||
print(f" --local-dir models/\n")
|
||||
retry = input(" Chọn model khác? (y/n): ").strip().lower()
|
||||
if retry == "y":
|
||||
continue
|
||||
else:
|
||||
sys.exit(0)
|
||||
|
||||
print(f"\n Đã chọn: {key}")
|
||||
print(f" Path: {model_path}\n")
|
||||
return model_path
|
||||
else:
|
||||
print(" Vui lòng nhập 1 hoặc 2.")
|
||||
except (ValueError, KeyboardInterrupt):
|
||||
print("\n Thoát.")
|
||||
sys.exit(0)
|
||||
|
||||
# Chọn model trước khi FastAPI khởi động
|
||||
MODELS_DIR.mkdir(exist_ok=True)
|
||||
MODEL_PATH = select_model()
|
||||
|
||||
# ── Models ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
@@ -28,12 +91,23 @@ class PromptRequest(BaseModel):
|
||||
ml_models = {}
|
||||
sessions: dict = {} # session_id -> conversation object
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def count_tokens(engine, text: str) -> int:
|
||||
try:
|
||||
return len(engine.tokenize(text))
|
||||
except Exception:
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
# ── Lifespan ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
engine = litert_lm.Engine(MODEL_PATH, backend=litert_lm.Backend.CPU)
|
||||
print(f" Loading model: {MODEL_PATH.name} ...")
|
||||
engine = litert_lm.Engine(str(MODEL_PATH), backend=litert_lm.Backend.CPU)
|
||||
ml_models["engine"] = engine
|
||||
ml_models["model_name"] = MODEL_PATH.stem
|
||||
print(f" Model ready: {MODEL_PATH.name}\n")
|
||||
yield
|
||||
sessions.clear()
|
||||
del ml_models["engine"]
|
||||
@@ -42,6 +116,16 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
app = FastAPI(title="LiteRT-LM API", lifespan=lifespan)
|
||||
|
||||
# ── REST: info ────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/info")
|
||||
async def info():
|
||||
"""Return current loaded model info."""
|
||||
return {
|
||||
"model": ml_models.get("model_name", "unknown"),
|
||||
"sessions": len(sessions),
|
||||
}
|
||||
|
||||
# ── REST: stateless single-turn ───────────────────────────────────────────────
|
||||
|
||||
@app.post("/generate")
|
||||
@@ -56,7 +140,7 @@ async def generate_text(request: PromptRequest):
|
||||
result = conversation.send_message(request.prompt)
|
||||
elapsed = time.perf_counter() - t0
|
||||
text = result["content"][0]["text"]
|
||||
num_tokens = len(engine.tokenize(text))
|
||||
num_tokens = count_tokens(engine, text)
|
||||
tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0
|
||||
return {
|
||||
"response": text,
|
||||
@@ -93,7 +177,7 @@ async def chat(session_id: str, request: PromptRequest):
|
||||
result = sessions[session_id].send_message(request.prompt)
|
||||
elapsed = time.perf_counter() - t0
|
||||
text = result["content"][0]["text"]
|
||||
num_tokens = len(engine.tokenize(text)) if engine else 0
|
||||
num_tokens = count_tokens(engine, text)
|
||||
tps = round(num_tokens / elapsed, 2) if elapsed > 0 else 0
|
||||
return {
|
||||
"session_id": session_id,
|
||||
|
||||
Reference in New Issue
Block a user