#!/usr/bin/env python3 """Web interface for the RAG application with session management and chat history.""" import os import json import uuid import sqlite3 import datetime as dt from pathlib import Path from typing import Optional from flask import Flask, render_template, request, jsonify, Response, stream_with_context from dotenv import load_dotenv from rag_app import Embedder, LLMClient, RAGPipeline, VectorStore, Chunk # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- load_dotenv() DATA_DIR = Path(os.environ.get("RAG_DATA_DIR", ".")) INDEX_DIR = Path(os.environ.get("RAG_INDEX_DIR", "./rag_index")) DB_DIR = Path(os.environ.get("RAG_DB_DIR", ".")) EMBED_MODEL = os.environ.get("RAG_EMBED_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini") LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "https://api.openai.com/v1") LLM_API_KEY = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY", "") TOP_K = int(os.environ.get("RAG_TOP_K", "5")) MAX_HISTORY = int(os.environ.get("RAG_MAX_HISTORY", "10")) app = Flask(__name__) app.secret_key = os.urandom(24) # --------------------------------------------------------------------------- # Database # --------------------------------------------------------------------------- DB_PATH = DB_DIR / "rag_chat.db" def get_db() -> sqlite3.Connection: conn = sqlite3.connect(str(DB_PATH)) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") return conn def init_db(): conn = get_db() conn.executescript(""" CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, title TEXT NOT NULL, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT NOT NULL, role TEXT NOT NULL, content TEXT NOT NULL, sources TEXT, created_at TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id); """) conn.close() # --------------------------------------------------------------------------- # RAG pipeline (singleton) # --------------------------------------------------------------------------- _pipeline: Optional[RAGPipeline] = None def get_pipeline() -> RAGPipeline: global _pipeline if _pipeline is None: embedder = Embedder(EMBED_MODEL) llm = LLMClient(LLM_API_KEY, LLM_BASE_URL, LLM_MODEL) if LLM_API_KEY else None _pipeline = RAGPipeline(DATA_DIR, INDEX_DIR, embedder, llm, TOP_K) _pipeline.load_index() return _pipeline # --------------------------------------------------------------------------- # System prompt builder # --------------------------------------------------------------------------- def build_system_prompt(session_history: list[dict]) -> str: history_text = "" if session_history: lines = [] for msg in session_history[-MAX_HISTORY:]: role = "Người dùng" if msg["role"] == "user" else "Trợ lý" lines.append(f"{role}: {msg['content'][:200]}") history_text = "\n".join(lines) return f"""Bạn là một trợ lý AI thông minh, hỗ trợ trả lời câu hỏi dựa trên dữ liệu được cung cấp. NHIỆM VỤ: - Trả lời câu hỏi của người dùng CHỈ DỰA TRÊN THÔNG TIN trong phần "NGUYÊN LIỆU". - Nếu thông tin không có trong nguyên liệu, hãy trả lời: "Không có thông tin trong dữ liệu." - KHÔNG được bịa đặt, suy diễn hoặc sử dụng kiến thức bên ngoài nguyên liệu. - Trả lời ngắn gọn, chính xác, bằng ngôn ngữ của người dùng (tiếng Việt hoặc tiếng Anh). - Trích nguồn (tên bài viết + URL) khi có thể. - Nếu câu hỏi không liên quan đến nội dung dữ liệu (ví dụ: hỏi về thời tiết, nấu ăn, etc.), hãy trả lời: "Câu hỏi này nằm ngoài phạm vi dữ liệu. Vui lòng hỏi về các chủ đề liên quan đến nội dung blog." ĐƯỜNG DẪN NGỮ CẢNH: {history_text if history_text else "(Đây là câu hỏi đầu tiên trong phiên)"} """ # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @app.route("/") def index(): return render_template("index.html") @app.route("/api/sessions", methods=["GET"]) def list_sessions(): conn = get_db() rows = conn.execute( "SELECT id, title, created_at, updated_at FROM sessions ORDER BY updated_at DESC" ).fetchall() conn.close() return jsonify([dict(r) for r in rows]) @app.route("/api/sessions", methods=["POST"]) def create_session(): data = request.json or {} session_id = str(uuid.uuid4())[:8] title = data.get("title", f"Phiên {dt.datetime.now().strftime('%H:%M %d/%m')}") now = dt.datetime.now(dt.timezone.utc).isoformat() conn = get_db() conn.execute( "INSERT INTO sessions (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)", (session_id, title, now, now), ) conn.commit() conn.close() return jsonify({"id": session_id, "title": title, "created_at": now, "updated_at": now}) @app.route("/api/sessions/", methods=["DELETE"]) def delete_session(session_id): conn = get_db() conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) conn.commit() conn.close() return jsonify({"ok": True}) @app.route("/api/sessions//messages", methods=["GET"]) def get_messages(session_id): conn = get_db() rows = conn.execute( "SELECT id, role, content, sources, created_at FROM messages WHERE session_id = ? ORDER BY id", (session_id,), ).fetchall() conn.close() result = [] for r in rows: msg = dict(r) if msg["sources"]: msg["sources"] = json.loads(msg["sources"]) result.append(msg) return jsonify(result) @app.route("/api/sessions//messages", methods=["POST"]) def send_message(session_id): data = request.json or {} question = (data.get("content") or "").strip() if not question: return jsonify({"error": "Empty message"}), 400 pipeline = get_pipeline() # Get session history for context conn = get_db() history_rows = conn.execute( "SELECT role, content FROM messages WHERE session_id = ? ORDER BY id DESC LIMIT ?", (session_id, MAX_HISTORY), ).fetchall() session_history = [dict(r) for r in reversed(history_rows)] # Save user message now = dt.datetime.now(dt.timezone.utc).isoformat() conn.execute( "INSERT INTO messages (session_id, role, content, created_at) VALUES (?, ?, ?, ?)", (session_id, "user", question, now), ) conn.commit() # Retrieve relevant chunks query_emb = pipeline.embedder.embed_query(question) results = pipeline.store.search(query_emb, top_k=pipeline.top_k) if not results: answer = "Không có thông tin trong dữ liệu." sources = [] else: # Build context context_parts = [] sources = [] for i, chunk in enumerate(results, 1): source = f"[{i}] {chunk.title} ({chunk.url})" if chunk.section: source += f" - Section: {chunk.section}" context_parts.append(f"{source}\n{chunk.content}") sources.append({ "title": chunk.title, "url": chunk.url, "section": chunk.section, "score": round(chunk.metadata.get("similarity_score", 0), 4), }) context = "\n\n---\n\n".join(context_parts) system_prompt = build_system_prompt(session_history) prompt = f"NGUYÊN LIỆU:\n{context}\n\nCÂU HỎI: {question}" # Generate answer try: payload = { "model": pipeline.llm.model, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ], "temperature": 0.1, "max_tokens": 1000, } resp = pipeline.llm._post(payload) answer = resp["choices"][0]["message"]["content"].strip() except Exception as e: answer = f"Lỗi khi tạo câu trả lời: {e}" # Save assistant message conn.execute( "INSERT INTO messages (session_id, role, content, sources, created_at) VALUES (?, ?, ?, ?, ?)", (session_id, "assistant", answer, json.dumps(sources, ensure_ascii=False), now), ) # Update session timestamp conn.execute("UPDATE sessions SET updated_at = ? WHERE id = ?", (now, session_id)) conn.commit() conn.close() return jsonify({ "role": "assistant", "content": answer, "sources": sources, "created_at": now, }) @app.route("/api/sessions//clear", methods=["POST"]) def clear_messages(session_id): conn = get_db() conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) conn.commit() conn.close() return jsonify({"ok": True}) @app.route("/api/stats", methods=["GET"]) def get_stats(): conn = get_db() session_count = conn.execute("SELECT COUNT(*) FROM sessions").fetchone()[0] message_count = conn.execute("SELECT COUNT(*) FROM messages").fetchone()[0] conn.close() pipeline = get_pipeline() chunk_count = len(pipeline.store.chunks) if pipeline.store else 0 return jsonify({ "sessions": session_count, "messages": message_count, "chunks_indexed": chunk_count, "llm_model": LLM_MODEL, "data_dir": str(DATA_DIR), "index_dir": str(INDEX_DIR), }) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="RAG Web Interface") parser.add_argument("--host", default="0.0.0.0", help="Host to bind") parser.add_argument("--port", type=int, default=5000, help="Port to bind") parser.add_argument("--debug", action="store_true", help="Debug mode") parser.add_argument("--data-dir", type=Path, default=DATA_DIR) parser.add_argument("--index-dir", type=Path, default=INDEX_DIR) args = parser.parse_args() DATA_DIR = args.data_dir INDEX_DIR = args.index_dir init_db() print(f"Starting RAG Web Interface on http://{args.host}:{args.port}") print(f"Data dir: {DATA_DIR}") print(f"Index dir: {INDEX_DIR}") app.run(host=args.host, port=args.port, debug=args.debug)