318 lines
11 KiB
Python
318 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""Web interface for the RAG application with session management and chat history."""
|
|
|
|
import os
|
|
import json
|
|
import uuid
|
|
import sqlite3
|
|
import datetime as dt
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from flask import Flask, render_template, request, jsonify, Response, stream_with_context
|
|
from dotenv import load_dotenv
|
|
|
|
from rag_app import Embedder, LLMClient, RAGPipeline, VectorStore, Chunk
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config
|
|
# ---------------------------------------------------------------------------
|
|
|
|
load_dotenv()
|
|
|
|
DATA_DIR = Path(os.environ.get("RAG_DATA_DIR", "."))
|
|
INDEX_DIR = Path(os.environ.get("RAG_INDEX_DIR", "./rag_index"))
|
|
DB_DIR = Path(os.environ.get("RAG_DB_DIR", "."))
|
|
EMBED_MODEL = os.environ.get("RAG_EMBED_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
|
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
|
|
LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "https://api.openai.com/v1")
|
|
LLM_API_KEY = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY", "")
|
|
TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
|
|
MAX_HISTORY = int(os.environ.get("RAG_MAX_HISTORY", "10"))
|
|
|
|
app = Flask(__name__)
|
|
app.secret_key = os.urandom(24)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Database
|
|
# ---------------------------------------------------------------------------
|
|
|
|
DB_PATH = DB_DIR / "rag_chat.db"
|
|
|
|
|
|
def get_db() -> sqlite3.Connection:
|
|
conn = sqlite3.connect(str(DB_PATH))
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
return conn
|
|
|
|
|
|
def init_db():
|
|
conn = get_db()
|
|
conn.executescript("""
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
id TEXT PRIMARY KEY,
|
|
title TEXT NOT NULL,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
);
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id TEXT NOT NULL,
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
sources TEXT,
|
|
created_at TEXT NOT NULL,
|
|
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id);
|
|
""")
|
|
conn.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RAG pipeline (singleton)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_pipeline: Optional[RAGPipeline] = None
|
|
|
|
|
|
def get_pipeline() -> RAGPipeline:
|
|
global _pipeline
|
|
if _pipeline is None:
|
|
embedder = Embedder(EMBED_MODEL)
|
|
llm = LLMClient(LLM_API_KEY, LLM_BASE_URL, LLM_MODEL) if LLM_API_KEY else None
|
|
_pipeline = RAGPipeline(DATA_DIR, INDEX_DIR, embedder, llm, TOP_K)
|
|
_pipeline.load_index()
|
|
return _pipeline
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# System prompt builder
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def build_system_prompt(session_history: list[dict]) -> str:
|
|
history_text = ""
|
|
if session_history:
|
|
lines = []
|
|
for msg in session_history[-MAX_HISTORY:]:
|
|
role = "Người dùng" if msg["role"] == "user" else "Trợ lý"
|
|
lines.append(f"{role}: {msg['content'][:200]}")
|
|
history_text = "\n".join(lines)
|
|
|
|
return f"""Bạn là một trợ lý AI thông minh, hỗ trợ trả lời câu hỏi dựa trên dữ liệu được cung cấp.
|
|
|
|
NHIỆM VỤ:
|
|
- Trả lời câu hỏi của người dùng CHỈ DỰA TRÊN THÔNG TIN trong phần "NGUYÊN LIỆU".
|
|
- Nếu thông tin không có trong nguyên liệu, hãy trả lời: "Không có thông tin trong dữ liệu."
|
|
- KHÔNG được bịa đặt, suy diễn hoặc sử dụng kiến thức bên ngoài nguyên liệu.
|
|
- Trả lời ngắn gọn, chính xác, bằng ngôn ngữ của người dùng (tiếng Việt hoặc tiếng Anh).
|
|
- Trích nguồn (tên bài viết + URL) khi có thể.
|
|
- Nếu câu hỏi không liên quan đến nội dung dữ liệu (ví dụ: hỏi về thời tiết, nấu ăn, etc.), hãy trả lời: "Câu hỏi này nằm ngoài phạm vi dữ liệu. Vui lòng hỏi về các chủ đề liên quan đến nội dung blog."
|
|
|
|
ĐƯỜNG DẪN NGỮ CẢNH:
|
|
{history_text if history_text else "(Đây là câu hỏi đầu tiên trong phiên)"}
|
|
"""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Routes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.route("/")
|
|
def index():
|
|
return render_template("index.html")
|
|
|
|
|
|
@app.route("/api/sessions", methods=["GET"])
|
|
def list_sessions():
|
|
conn = get_db()
|
|
rows = conn.execute(
|
|
"SELECT id, title, created_at, updated_at FROM sessions ORDER BY updated_at DESC"
|
|
).fetchall()
|
|
conn.close()
|
|
return jsonify([dict(r) for r in rows])
|
|
|
|
|
|
@app.route("/api/sessions", methods=["POST"])
|
|
def create_session():
|
|
data = request.json or {}
|
|
session_id = str(uuid.uuid4())[:8]
|
|
title = data.get("title", f"Phiên {dt.datetime.now().strftime('%H:%M %d/%m')}")
|
|
now = dt.datetime.now(dt.timezone.utc).isoformat()
|
|
conn = get_db()
|
|
conn.execute(
|
|
"INSERT INTO sessions (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
|
|
(session_id, title, now, now),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
return jsonify({"id": session_id, "title": title, "created_at": now, "updated_at": now})
|
|
|
|
|
|
@app.route("/api/sessions/<session_id>", methods=["DELETE"])
|
|
def delete_session(session_id):
|
|
conn = get_db()
|
|
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
|
conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
|
conn.commit()
|
|
conn.close()
|
|
return jsonify({"ok": True})
|
|
|
|
|
|
@app.route("/api/sessions/<session_id>/messages", methods=["GET"])
|
|
def get_messages(session_id):
|
|
conn = get_db()
|
|
rows = conn.execute(
|
|
"SELECT id, role, content, sources, created_at FROM messages WHERE session_id = ? ORDER BY id",
|
|
(session_id,),
|
|
).fetchall()
|
|
conn.close()
|
|
result = []
|
|
for r in rows:
|
|
msg = dict(r)
|
|
if msg["sources"]:
|
|
msg["sources"] = json.loads(msg["sources"])
|
|
result.append(msg)
|
|
return jsonify(result)
|
|
|
|
|
|
@app.route("/api/sessions/<session_id>/messages", methods=["POST"])
|
|
def send_message(session_id):
|
|
data = request.json or {}
|
|
question = (data.get("content") or "").strip()
|
|
if not question:
|
|
return jsonify({"error": "Empty message"}), 400
|
|
|
|
pipeline = get_pipeline()
|
|
|
|
# Get session history for context
|
|
conn = get_db()
|
|
history_rows = conn.execute(
|
|
"SELECT role, content FROM messages WHERE session_id = ? ORDER BY id DESC LIMIT ?",
|
|
(session_id, MAX_HISTORY),
|
|
).fetchall()
|
|
session_history = [dict(r) for r in reversed(history_rows)]
|
|
|
|
# Save user message
|
|
now = dt.datetime.now(dt.timezone.utc).isoformat()
|
|
conn.execute(
|
|
"INSERT INTO messages (session_id, role, content, created_at) VALUES (?, ?, ?, ?)",
|
|
(session_id, "user", question, now),
|
|
)
|
|
conn.commit()
|
|
|
|
# Retrieve relevant chunks
|
|
query_emb = pipeline.embedder.embed_query(question)
|
|
results = pipeline.store.search(query_emb, top_k=pipeline.top_k)
|
|
|
|
if not results:
|
|
answer = "Không có thông tin trong dữ liệu."
|
|
sources = []
|
|
else:
|
|
# Build context
|
|
context_parts = []
|
|
sources = []
|
|
for i, chunk in enumerate(results, 1):
|
|
source = f"[{i}] {chunk.title} ({chunk.url})"
|
|
if chunk.section:
|
|
source += f" - Section: {chunk.section}"
|
|
context_parts.append(f"{source}\n{chunk.content}")
|
|
sources.append({
|
|
"title": chunk.title,
|
|
"url": chunk.url,
|
|
"section": chunk.section,
|
|
"score": round(chunk.metadata.get("similarity_score", 0), 4),
|
|
})
|
|
|
|
context = "\n\n---\n\n".join(context_parts)
|
|
system_prompt = build_system_prompt(session_history)
|
|
prompt = f"NGUYÊN LIỆU:\n{context}\n\nCÂU HỎI: {question}"
|
|
|
|
# Generate answer
|
|
try:
|
|
payload = {
|
|
"model": pipeline.llm.model,
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
"temperature": 0.1,
|
|
"max_tokens": 1000,
|
|
}
|
|
resp = pipeline.llm._post(payload)
|
|
answer = resp["choices"][0]["message"]["content"].strip()
|
|
except Exception as e:
|
|
answer = f"Lỗi khi tạo câu trả lời: {e}"
|
|
|
|
# Save assistant message
|
|
conn.execute(
|
|
"INSERT INTO messages (session_id, role, content, sources, created_at) VALUES (?, ?, ?, ?, ?)",
|
|
(session_id, "assistant", answer, json.dumps(sources, ensure_ascii=False), now),
|
|
)
|
|
# Update session timestamp
|
|
conn.execute("UPDATE sessions SET updated_at = ? WHERE id = ?", (now, session_id))
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
return jsonify({
|
|
"role": "assistant",
|
|
"content": answer,
|
|
"sources": sources,
|
|
"created_at": now,
|
|
})
|
|
|
|
|
|
@app.route("/api/sessions/<session_id>/clear", methods=["POST"])
|
|
def clear_messages(session_id):
|
|
conn = get_db()
|
|
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
|
conn.commit()
|
|
conn.close()
|
|
return jsonify({"ok": True})
|
|
|
|
|
|
@app.route("/api/stats", methods=["GET"])
|
|
def get_stats():
|
|
conn = get_db()
|
|
session_count = conn.execute("SELECT COUNT(*) FROM sessions").fetchone()[0]
|
|
message_count = conn.execute("SELECT COUNT(*) FROM messages").fetchone()[0]
|
|
conn.close()
|
|
|
|
pipeline = get_pipeline()
|
|
chunk_count = len(pipeline.store.chunks) if pipeline.store else 0
|
|
|
|
return jsonify({
|
|
"sessions": session_count,
|
|
"messages": message_count,
|
|
"chunks_indexed": chunk_count,
|
|
"llm_model": LLM_MODEL,
|
|
"data_dir": str(DATA_DIR),
|
|
"index_dir": str(INDEX_DIR),
|
|
})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="RAG Web Interface")
|
|
parser.add_argument("--host", default="0.0.0.0", help="Host to bind")
|
|
parser.add_argument("--port", type=int, default=5000, help="Port to bind")
|
|
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
|
parser.add_argument("--data-dir", type=Path, default=DATA_DIR)
|
|
parser.add_argument("--index-dir", type=Path, default=INDEX_DIR)
|
|
args = parser.parse_args()
|
|
|
|
DATA_DIR = args.data_dir
|
|
INDEX_DIR = args.index_dir
|
|
|
|
init_db()
|
|
print(f"Starting RAG Web Interface on http://{args.host}:{args.port}")
|
|
print(f"Data dir: {DATA_DIR}")
|
|
print(f"Index dir: {INDEX_DIR}")
|
|
app.run(host=args.host, port=args.port, debug=args.debug)
|