add rag test

This commit is contained in:
2026-06-12 10:57:45 +07:00
parent 5c5e3333a5
commit 3ebf6f450d
6 changed files with 1251 additions and 0 deletions
+14
View File
@@ -0,0 +1,14 @@
# LLM API Configuration
# Get your API key from:
# - OpenAI: https://platform.openai.com/api-keys
# - Together.ai: https://api.together.xyz/settings/api-keys
# - Groq: https://console.groq.com/keys
# - Or any OpenAI-compatible API
OPENAI_API_KEY=your_api_key_here
# Optional: Custom base URL for OpenAI-compatible APIs
# LLM_BASE_URL=https://api.openai.com/v1
# Optional: Model name (default: gpt-4o-mini)
# LLM_MODEL=gpt-4o-mini
+10
View File
@@ -0,0 +1,10 @@
# LLM API Configuration
# Get your API key from https://platform.openai.com/api-keys
# Or use any OpenAI-compatible API (e.g., Together.ai, Groq, etc.)
OPENAI_API_KEY=your_api_key_here
# Optional: Custom base URL for OpenAI-compatible APIs
# LLM_BASE_URL=https://api.openai.com/v1
# Optional: Model name (default: gpt-4o-mini)
# LLM_MODEL=gpt-4o-mini
+320
View File
@@ -0,0 +1,320 @@
#!/usr/bin/env python3
"""
OrangePi RAG Application
A Vietnamese-language RAG system for querying Orange Pi blog articles.
"""
import os
import json
import argparse
import sys
from pathlib import Path
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
import requests
from dotenv import load_dotenv
from tqdm import tqdm
@dataclass
class Chunk:
chunk_id: str
article_id: str
content: str
section: Optional[str]
title: str
url: str
language: str
metadata: Dict[str, Any]
embedding: Optional[np.ndarray] = None
class VectorStore:
def __init__(self, dim: int):
self.dim = dim
self.index = faiss.IndexFlatIP(dim)
self.chunks: List[Chunk] = []
def add(self, chunks: List[Chunk], embeddings: np.ndarray):
faiss.normalize_L2(embeddings)
self.index.add(embeddings.astype(np.float32))
self.chunks.extend(chunks)
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Chunk]:
faiss.normalize_L2(query_embedding.reshape(1, -1))
scores, indices = self.index.search(query_embedding.reshape(1, -1).astype(np.float32), top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx >= 0 and idx < len(self.chunks):
chunk = self.chunks[idx]
chunk.metadata["similarity_score"] = float(score)
results.append(chunk)
return results
def save(self, path: Path):
faiss.write_index(self.index, str(path / "faiss.index"))
with open(path / "chunks.jsonl", "w", encoding="utf-8") as f:
for chunk in self.chunks:
data = {
"chunk_id": chunk.chunk_id,
"article_id": chunk.article_id,
"content": chunk.content,
"section": chunk.section,
"title": chunk.title,
"url": chunk.url,
"language": chunk.language,
"metadata": chunk.metadata,
}
f.write(json.dumps(data, ensure_ascii=False) + "\n")
@classmethod
def load(cls, path: Path, dim: int) -> "VectorStore":
store = cls(dim)
store.index = faiss.read_index(str(path / "faiss.index"))
with open(path / "chunks.jsonl", "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
store.chunks.append(Chunk(
chunk_id=data["chunk_id"],
article_id=data["article_id"],
content=data["content"],
section=data.get("section"),
title=data["title"],
url=data["url"],
language=data["language"],
metadata=data["metadata"],
))
return store
class Embedder:
def __init__(self, model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
self.model = SentenceTransformer(model_name)
self.dim = self.model.get_embedding_dimension()
def embed(self, texts: List[str]) -> np.ndarray:
return self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
def embed_query(self, query: str) -> np.ndarray:
return self.model.encode([query], show_progress_bar=False, convert_to_numpy=True)[0]
class LLMClient:
def __init__(
self,
api_key: str,
base_url: str = "https://api.openai.com/v1",
model: str = "gpt-4o-mini",
):
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.model = model
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
def generate(self, prompt: str, temperature: float = 0.1, max_tokens: int = 1000) -> str:
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": self._system_prompt()},
{"role": "user", "content": prompt},
],
"temperature": temperature,
"max_tokens": max_tokens,
}
resp = requests.post(
f"{self.base_url}/chat/completions",
headers=self.headers,
json=payload,
timeout=60,
)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]["content"].strip()
def _system_prompt(self) -> str:
return """Bạn là một trợ lý AI chuyên về Orange Pi, sử dụng dữ liệu từ blog orangepi.vn (nhà phân phối chính thức Orange Pi tại Việt Nam).
NHIỆM VỤ:
- Trả lời câu hỏi của người dùng CHỈ DỰA TRÊN THÔNG TIN ĐƯỢC CUNG CẤP trong phần "NGUYÊN LIỆU".
- Nếu thông tin không có trong nguyên liệu, hãy trả lời: "Không có thông tin trong dữ liệu."
- KHÔNG được bịa đặt, suy diễn hoặc sử dụng kiến thức bên ngoài.
- Trả lời bằng tiếng Việt, ngắn gọn, chính xác.
- Trích dẫn nguồn (title + URL) khi có thể.
NGUYÊN LIỆU:
{context}"""
class RAGPipeline:
def __init__(
self,
data_dir: Path,
index_dir: Path,
embedder: Embedder,
llm: LLMClient,
top_k: int = 5,
):
self.data_dir = data_dir
self.index_dir = index_dir
self.embedder = embedder
self.llm = llm
self.top_k = top_k
self.store: Optional[VectorStore] = None
def build_index(self):
print("Loading chunks...")
chunks = self._load_chunks()
print(f"Loaded {len(chunks)} chunks")
print("Generating embeddings...")
texts = [c.content for c in chunks]
embeddings = self.embedder.embed(texts)
print("Building FAISS index...")
self.store = VectorStore(self.embedder.dim)
self.store.add(chunks, embeddings)
print("Saving index...")
self.index_dir.mkdir(parents=True, exist_ok=True)
self.store.save(self.index_dir)
print(f"Index saved to {self.index_dir}")
def load_index(self):
print("Loading index...")
self.store = VectorStore.load(self.index_dir, self.embedder.dim)
print(f"Loaded {len(self.store.chunks)} chunks")
def _load_chunks(self) -> List[Chunk]:
chunks = []
chunks_path = self.data_dir / "chunks.jsonl"
with open(chunks_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
chunks.append(Chunk(
chunk_id=data["chunk_id"],
article_id=data["article_id"],
content=data["content"],
section=data.get("section"),
title=data["title"],
url=data["url"],
language=data["language"],
metadata=data["metadata"],
))
return chunks
def query(self, question: str) -> str:
if self.store is None:
raise RuntimeError("Index not loaded. Call build_index() or load_index() first.")
query_emb = self.embedder.embed_query(question)
results = self.store.search(query_emb, top_k=self.top_k)
if not results:
return "Không có thông tin trong dữ liệu."
context_parts = []
for i, chunk in enumerate(results, 1):
source = f"[{i}] {chunk.title} ({chunk.url})"
if chunk.section:
source += f" - Section: {chunk.section}"
context_parts.append(f"{source}\n{chunk.content}")
context = "\n\n---\n\n".join(context_parts)
prompt = f"NGUYÊN LIỆU:\n{context}\n\nCÂU HỎI: {question}"
return self.llm.generate(prompt)
def main():
# Force UTF-8 output on Windows
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8')
if hasattr(sys.stderr, 'reconfigure'):
sys.stderr.reconfigure(encoding='utf-8')
parser = argparse.ArgumentParser(description="OrangePi RAG Application")
parser.add_argument("--data-dir", type=Path, default=Path("."), help="Directory with chunks.jsonl")
parser.add_argument("--index-dir", type=Path, default=Path("./rag_index"), help="FAISS index directory")
parser.add_argument("--build", action="store_true", help="Build index from chunks")
parser.add_argument("--query", type=str, help="Query to answer")
parser.add_argument("--interactive", action="store_true", help="Interactive chat mode")
parser.add_argument("--retrieve-only", action="store_true", help="Test retrieval without LLM")
parser.add_argument("--top-k", type=int, default=5, help="Number of chunks to retrieve")
parser.add_argument("--embed-model", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
parser.add_argument("--llm-model", default="gpt-4o-mini")
parser.add_argument("--llm-base-url", default="https://api.openai.com/v1")
args = parser.parse_args()
load_dotenv()
embedder = Embedder(args.embed_model)
if args.build:
# Build doesn't need LLM
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, None, args.top_k)
pipeline.build_index()
return 0
if args.retrieve_only:
# Retrieval test without LLM
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, None, args.top_k)
pipeline.load_index()
query_emb = embedder.embed_query(args.query)
results = pipeline.store.search(query_emb, top_k=args.top_k)
for i, chunk in enumerate(results, 1):
print(f"\n--- Result {i} (score: {chunk.metadata.get('similarity_score', 0):.4f}) ---")
print(f"Title: {chunk.title}")
print(f"URL: {chunk.url}")
if chunk.section:
print(f"Section: {chunk.section}")
print(f"Content: {chunk.content[:500]}...".encode('utf-8', errors='replace').decode('utf-8'))
return 0
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY")
if not api_key:
print("ERROR: Set OPENAI_API_KEY or LLM_API_KEY in environment or .env file")
print(" Or use --retrieve-only to test retrieval without LLM")
return 1
llm = LLMClient(api_key, args.llm_base_url, args.llm_model)
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, llm, args.top_k)
if args.build:
pipeline.build_index()
return 0
pipeline.load_index()
if args.query:
print(pipeline.query(args.query))
return 0
if args.interactive:
print("OrangePi RAG - Interactive mode (Ctrl+C to exit)")
print("=" * 50)
while True:
try:
question = input("\n❓ Câu hỏi: ").strip()
if not question:
continue
answer = pipeline.query(question)
print(f"\n🤖 Trả lời: {answer}")
except KeyboardInterrupt:
print("\nTạm biệt!")
break
except EOFError:
break
return 0
parser.print_help()
return 1
if __name__ == "__main__":
raise SystemExit(main())
File diff suppressed because one or more lines are too long
Binary file not shown.
+6
View File
@@ -0,0 +1,6 @@
sentence-transformers>=2.2.0
faiss-cpu>=1.7.4
numpy>=1.24.0
requests>=2.31.0
python-dotenv>=1.0.0
tqdm>=4.65.0