526b099dbc
Merge manual edit
335 lines
12 KiB
Python
335 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
OrangePi RAG Application
|
|
A Vietnamese-language RAG system for querying Orange Pi blog articles.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any, Optional
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import faiss
|
|
from sentence_transformers import SentenceTransformer
|
|
import requests
|
|
from dotenv import load_dotenv
|
|
from tqdm import tqdm
|
|
|
|
|
|
@dataclass
|
|
class Chunk:
|
|
chunk_id: str
|
|
article_id: str
|
|
content: str
|
|
section: Optional[str]
|
|
title: str
|
|
url: str
|
|
language: str
|
|
metadata: Dict[str, Any]
|
|
embedding: Optional[np.ndarray] = None
|
|
|
|
|
|
class VectorStore:
|
|
def __init__(self, dim: int):
|
|
self.dim = dim
|
|
self.index = faiss.IndexFlatIP(dim)
|
|
self.chunks: List[Chunk] = []
|
|
|
|
def add(self, chunks: List[Chunk], embeddings: np.ndarray):
|
|
faiss.normalize_L2(embeddings)
|
|
self.index.add(embeddings.astype(np.float32))
|
|
self.chunks.extend(chunks)
|
|
|
|
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Chunk]:
|
|
faiss.normalize_L2(query_embedding.reshape(1, -1))
|
|
scores, indices = self.index.search(query_embedding.reshape(1, -1).astype(np.float32), top_k)
|
|
results = []
|
|
for score, idx in zip(scores[0], indices[0]):
|
|
if idx >= 0 and idx < len(self.chunks):
|
|
chunk = self.chunks[idx]
|
|
chunk.metadata["similarity_score"] = float(score)
|
|
results.append(chunk)
|
|
return results
|
|
|
|
def save(self, path: Path):
|
|
faiss.write_index(self.index, str(path / "faiss.index"))
|
|
with open(path / "chunks.jsonl", "w", encoding="utf-8") as f:
|
|
for chunk in self.chunks:
|
|
data = {
|
|
"chunk_id": chunk.chunk_id,
|
|
"article_id": chunk.article_id,
|
|
"content": chunk.content,
|
|
"section": chunk.section,
|
|
"title": chunk.title,
|
|
"url": chunk.url,
|
|
"language": chunk.language,
|
|
"metadata": chunk.metadata,
|
|
}
|
|
f.write(json.dumps(data, ensure_ascii=False) + "\n")
|
|
|
|
@classmethod
|
|
def load(cls, path: Path, dim: int) -> "VectorStore":
|
|
store = cls(dim)
|
|
store.index = faiss.read_index(str(path / "faiss.index"))
|
|
with open(path / "chunks.jsonl", "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
data = json.loads(line)
|
|
store.chunks.append(Chunk(
|
|
chunk_id=data["chunk_id"],
|
|
article_id=data["article_id"],
|
|
content=data["content"],
|
|
section=data.get("section"),
|
|
title=data["title"],
|
|
url=data["url"],
|
|
language=data["language"],
|
|
metadata=data["metadata"],
|
|
))
|
|
return store
|
|
|
|
|
|
class Embedder:
|
|
def __init__(self, model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
|
|
self.model = SentenceTransformer(model_name)
|
|
self.dim = self.model.get_embedding_dimension()
|
|
|
|
def embed(self, texts: List[str]) -> np.ndarray:
|
|
return self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
|
|
|
|
def embed_query(self, query: str) -> np.ndarray:
|
|
return self.model.encode([query], show_progress_bar=False, convert_to_numpy=True)[0]
|
|
|
|
|
|
class LLMClient:
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
base_url: str = "https://api.openai.com/v1",
|
|
model: str = "gpt-4o-mini",
|
|
):
|
|
self.api_key = api_key
|
|
self.base_url = base_url.rstrip("/")
|
|
self.model = model
|
|
self.headers = {
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
def _post(self, payload: dict) -> dict:
|
|
resp = requests.post(
|
|
f"{self.base_url}/chat/completions",
|
|
headers=self.headers,
|
|
json=payload,
|
|
timeout=60,
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
def generate(self, prompt: str, temperature: float = 0.1, max_tokens: int = 1000) -> str:
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": [
|
|
{"role": "system", "content": self._system_prompt()},
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
return self._post(payload)["choices"][0]["message"]["content"].strip()
|
|
|
|
def _system_prompt(self) -> str:
|
|
return """Bạn là một trợ lý AI chuyên về Orange Pi, sử dụng dữ liệu từ blog orangepi.vn (nhà phân phối chính thức Orange Pi tại Việt Nam).
|
|
|
|
NHIỆM VỤ:
|
|
- Trả lời câu hỏi của người dùng CHỈ DỰA TRÊN THÔNG TIN ĐƯỢC CUNG CẤP trong phần "NGUYÊN LIỆU".
|
|
- Nếu thông tin không có trong nguyên liệu, hãy trả lời: "Không có thông tin trong dữ liệu."
|
|
- KHÔNG được bịa đặt, suy diễn hoặc sử dụng kiến thức bên ngoài.
|
|
- Trả lời bằng tiếng Việt, ngắn gọn, chính xác.
|
|
- Trích dẫn nguồn (title + URL) khi có thể.
|
|
|
|
NGUYÊN LIỆU:
|
|
{context}"""
|
|
|
|
|
|
class RAGPipeline:
|
|
def __init__(
|
|
self,
|
|
data_dir: Path,
|
|
index_dir: Path,
|
|
embedder: Embedder,
|
|
llm: LLMClient,
|
|
top_k: int = 5,
|
|
):
|
|
self.data_dir = data_dir
|
|
self.index_dir = index_dir
|
|
self.embedder = embedder
|
|
self.llm = llm
|
|
self.top_k = top_k
|
|
self.store: Optional[VectorStore] = None
|
|
|
|
def build_index(self):
|
|
print("Loading chunks...")
|
|
chunks = self._load_chunks()
|
|
print(f"Loaded {len(chunks)} chunks")
|
|
|
|
print("Generating embeddings...")
|
|
texts = [c.content for c in chunks]
|
|
embeddings = self.embedder.embed(texts)
|
|
|
|
print("Building FAISS index...")
|
|
self.store = VectorStore(self.embedder.dim)
|
|
self.store.add(chunks, embeddings)
|
|
|
|
print("Saving index...")
|
|
self.index_dir.mkdir(parents=True, exist_ok=True)
|
|
self.store.save(self.index_dir)
|
|
print(f"Index saved to {self.index_dir}")
|
|
|
|
def load_index(self):
|
|
print("Loading index...")
|
|
self.store = VectorStore.load(self.index_dir, self.embedder.dim)
|
|
print(f"Loaded {len(self.store.chunks)} chunks")
|
|
|
|
def _load_chunks(self) -> List[Chunk]:
|
|
chunks = []
|
|
chunks_path = self.data_dir / "chunks.jsonl"
|
|
with open(chunks_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
data = json.loads(line)
|
|
chunks.append(Chunk(
|
|
chunk_id=data["chunk_id"],
|
|
article_id=data["article_id"],
|
|
content=data["content"],
|
|
section=data.get("section"),
|
|
title=data["title"],
|
|
url=data["url"],
|
|
language=data["language"],
|
|
metadata=data["metadata"],
|
|
))
|
|
return chunks
|
|
|
|
def query(self, question: str) -> str:
|
|
if self.store is None:
|
|
raise RuntimeError("Index not loaded. Call build_index() or load_index() first.")
|
|
|
|
query_emb = self.embedder.embed_query(question)
|
|
results = self.store.search(query_emb, top_k=self.top_k)
|
|
|
|
if not results:
|
|
return "Không có thông tin trong dữ liệu."
|
|
|
|
context_parts = []
|
|
for i, chunk in enumerate(results, 1):
|
|
source = f"[{i}] {chunk.title} ({chunk.url})"
|
|
if chunk.section:
|
|
source += f" - Section: {chunk.section}"
|
|
context_parts.append(f"{source}\n{chunk.content}")
|
|
|
|
context = "\n\n---\n\n".join(context_parts)
|
|
prompt = f"NGUYÊN LIỆU:\n{context}\n\nCÂU HỎI: {question}"
|
|
|
|
return self.llm.generate(prompt)
|
|
|
|
|
|
def main():
|
|
# Force UTF-8 output on Windows
|
|
if hasattr(sys.stdout, 'reconfigure'):
|
|
sys.stdout.reconfigure(encoding='utf-8')
|
|
if hasattr(sys.stderr, 'reconfigure'):
|
|
sys.stderr.reconfigure(encoding='utf-8')
|
|
|
|
parser = argparse.ArgumentParser(description="OrangePi RAG Application")
|
|
parser.add_argument("--data-dir", type=Path, default=Path("."), help="Directory with chunks.jsonl")
|
|
parser.add_argument("--index-dir", type=Path, default=Path("./rag_index"), help="FAISS index directory")
|
|
parser.add_argument("--build", action="store_true", help="Build index from chunks")
|
|
parser.add_argument("--query", type=str, help="Query to answer")
|
|
parser.add_argument("--interactive", action="store_true", help="Interactive chat mode")
|
|
parser.add_argument("--retrieve-only", action="store_true", help="Test retrieval without LLM")
|
|
parser.add_argument("--top-k", type=int, default=5, help="Number of chunks to retrieve")
|
|
parser.add_argument("--embed-model", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
|
parser.add_argument("--llm-model", default="gpt-4o-mini")
|
|
parser.add_argument("--llm-base-url", default="https://api.openai.com/v1")
|
|
args = parser.parse_args()
|
|
|
|
load_dotenv()
|
|
|
|
embedder = Embedder(args.embed_model)
|
|
|
|
if args.build:
|
|
# Build doesn't need LLM
|
|
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, None, args.top_k)
|
|
pipeline.build_index()
|
|
return 0
|
|
|
|
if args.retrieve_only:
|
|
# Retrieval test without LLM
|
|
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, None, args.top_k)
|
|
pipeline.load_index()
|
|
query_emb = embedder.embed_query(args.query)
|
|
results = pipeline.store.search(query_emb, top_k=args.top_k)
|
|
for i, chunk in enumerate(results, 1):
|
|
print(f"\n--- Result {i} (score: {chunk.metadata.get('similarity_score', 0):.4f}) ---")
|
|
print(f"Title: {chunk.title}")
|
|
print(f"URL: {chunk.url}")
|
|
if chunk.section:
|
|
print(f"Section: {chunk.section}")
|
|
print(f"Content: {chunk.content[:500]}...".encode('utf-8', errors='replace').decode('utf-8'))
|
|
return 0
|
|
|
|
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY")
|
|
if not api_key:
|
|
print("ERROR: Set OPENAI_API_KEY or LLM_API_KEY in environment or .env file")
|
|
print(" Or use --retrieve-only to test retrieval without LLM")
|
|
return 1
|
|
|
|
llm_base_url = os.environ.get("LLM_BASE_URL") or args.llm_base_url
|
|
if not llm_base_url:
|
|
print("ERROR: Set LLM_BASE_URL or --llm_base_url in environment or .env file")
|
|
print(" Or use --retrieve-only to test retrieval without LLM")
|
|
return 1
|
|
|
|
llm_model = os.environ.get("LLM_MODEL") or args.llm_model
|
|
if not llm_model:
|
|
print("ERROR: Set LLM_MODEL or --llm_model in environment or .env file")
|
|
print(" Or use --retrieve-only to test retrieval without LLM")
|
|
return 1
|
|
|
|
llm = LLMClient(api_key, llm_base_url, llm_model)
|
|
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, llm, args.top_k)
|
|
|
|
if args.build:
|
|
pipeline.build_index()
|
|
return 0
|
|
|
|
pipeline.load_index()
|
|
|
|
if args.query:
|
|
print(pipeline.query(args.query))
|
|
return 0
|
|
|
|
if args.interactive:
|
|
print("OrangePi RAG - Interactive mode (Ctrl+C to exit)")
|
|
print("=" * 50)
|
|
while True:
|
|
try:
|
|
question = input("\n❓ Câu hỏi: ").strip()
|
|
if not question:
|
|
continue
|
|
answer = pipeline.query(question)
|
|
print(f"\n🤖 Trả lời: {answer}")
|
|
except KeyboardInterrupt:
|
|
print("\nTạm biệt!")
|
|
break
|
|
except EOFError:
|
|
break
|
|
return 0
|
|
|
|
parser.print_help()
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main()) |