add rag test
This commit is contained in:
@@ -0,0 +1,14 @@
|
|||||||
|
# LLM API Configuration
|
||||||
|
# Get your API key from:
|
||||||
|
# - OpenAI: https://platform.openai.com/api-keys
|
||||||
|
# - Together.ai: https://api.together.xyz/settings/api-keys
|
||||||
|
# - Groq: https://console.groq.com/keys
|
||||||
|
# - Or any OpenAI-compatible API
|
||||||
|
|
||||||
|
OPENAI_API_KEY=your_api_key_here
|
||||||
|
|
||||||
|
# Optional: Custom base URL for OpenAI-compatible APIs
|
||||||
|
# LLM_BASE_URL=https://api.openai.com/v1
|
||||||
|
|
||||||
|
# Optional: Model name (default: gpt-4o-mini)
|
||||||
|
# LLM_MODEL=gpt-4o-mini
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
# LLM API Configuration
|
||||||
|
# Get your API key from https://platform.openai.com/api-keys
|
||||||
|
# Or use any OpenAI-compatible API (e.g., Together.ai, Groq, etc.)
|
||||||
|
OPENAI_API_KEY=your_api_key_here
|
||||||
|
|
||||||
|
# Optional: Custom base URL for OpenAI-compatible APIs
|
||||||
|
# LLM_BASE_URL=https://api.openai.com/v1
|
||||||
|
|
||||||
|
# Optional: Model name (default: gpt-4o-mini)
|
||||||
|
# LLM_MODEL=gpt-4o-mini
|
||||||
+320
@@ -0,0 +1,320 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
OrangePi RAG Application
|
||||||
|
A Vietnamese-language RAG system for querying Orange Pi blog articles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import faiss
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import requests
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Chunk:
|
||||||
|
chunk_id: str
|
||||||
|
article_id: str
|
||||||
|
content: str
|
||||||
|
section: Optional[str]
|
||||||
|
title: str
|
||||||
|
url: str
|
||||||
|
language: str
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
embedding: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStore:
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
self.dim = dim
|
||||||
|
self.index = faiss.IndexFlatIP(dim)
|
||||||
|
self.chunks: List[Chunk] = []
|
||||||
|
|
||||||
|
def add(self, chunks: List[Chunk], embeddings: np.ndarray):
|
||||||
|
faiss.normalize_L2(embeddings)
|
||||||
|
self.index.add(embeddings.astype(np.float32))
|
||||||
|
self.chunks.extend(chunks)
|
||||||
|
|
||||||
|
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Chunk]:
|
||||||
|
faiss.normalize_L2(query_embedding.reshape(1, -1))
|
||||||
|
scores, indices = self.index.search(query_embedding.reshape(1, -1).astype(np.float32), top_k)
|
||||||
|
results = []
|
||||||
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
|
if idx >= 0 and idx < len(self.chunks):
|
||||||
|
chunk = self.chunks[idx]
|
||||||
|
chunk.metadata["similarity_score"] = float(score)
|
||||||
|
results.append(chunk)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def save(self, path: Path):
|
||||||
|
faiss.write_index(self.index, str(path / "faiss.index"))
|
||||||
|
with open(path / "chunks.jsonl", "w", encoding="utf-8") as f:
|
||||||
|
for chunk in self.chunks:
|
||||||
|
data = {
|
||||||
|
"chunk_id": chunk.chunk_id,
|
||||||
|
"article_id": chunk.article_id,
|
||||||
|
"content": chunk.content,
|
||||||
|
"section": chunk.section,
|
||||||
|
"title": chunk.title,
|
||||||
|
"url": chunk.url,
|
||||||
|
"language": chunk.language,
|
||||||
|
"metadata": chunk.metadata,
|
||||||
|
}
|
||||||
|
f.write(json.dumps(data, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: Path, dim: int) -> "VectorStore":
|
||||||
|
store = cls(dim)
|
||||||
|
store.index = faiss.read_index(str(path / "faiss.index"))
|
||||||
|
with open(path / "chunks.jsonl", "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
store.chunks.append(Chunk(
|
||||||
|
chunk_id=data["chunk_id"],
|
||||||
|
article_id=data["article_id"],
|
||||||
|
content=data["content"],
|
||||||
|
section=data.get("section"),
|
||||||
|
title=data["title"],
|
||||||
|
url=data["url"],
|
||||||
|
language=data["language"],
|
||||||
|
metadata=data["metadata"],
|
||||||
|
))
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
class Embedder:
|
||||||
|
def __init__(self, model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
|
||||||
|
self.model = SentenceTransformer(model_name)
|
||||||
|
self.dim = self.model.get_embedding_dimension()
|
||||||
|
|
||||||
|
def embed(self, texts: List[str]) -> np.ndarray:
|
||||||
|
return self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
|
||||||
|
|
||||||
|
def embed_query(self, query: str) -> np.ndarray:
|
||||||
|
return self.model.encode([query], show_progress_bar=False, convert_to_numpy=True)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = "https://api.openai.com/v1",
|
||||||
|
model: str = "gpt-4o-mini",
|
||||||
|
):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.model = model
|
||||||
|
self.headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate(self, prompt: str, temperature: float = 0.1, max_tokens: int = 1000) -> str:
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": self._system_prompt()},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
}
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()["choices"][0]["message"]["content"].strip()
|
||||||
|
|
||||||
|
def _system_prompt(self) -> str:
|
||||||
|
return """Bạn là một trợ lý AI chuyên về Orange Pi, sử dụng dữ liệu từ blog orangepi.vn (nhà phân phối chính thức Orange Pi tại Việt Nam).
|
||||||
|
|
||||||
|
NHIỆM VỤ:
|
||||||
|
- Trả lời câu hỏi của người dùng CHỈ DỰA TRÊN THÔNG TIN ĐƯỢC CUNG CẤP trong phần "NGUYÊN LIỆU".
|
||||||
|
- Nếu thông tin không có trong nguyên liệu, hãy trả lời: "Không có thông tin trong dữ liệu."
|
||||||
|
- KHÔNG được bịa đặt, suy diễn hoặc sử dụng kiến thức bên ngoài.
|
||||||
|
- Trả lời bằng tiếng Việt, ngắn gọn, chính xác.
|
||||||
|
- Trích dẫn nguồn (title + URL) khi có thể.
|
||||||
|
|
||||||
|
NGUYÊN LIỆU:
|
||||||
|
{context}"""
|
||||||
|
|
||||||
|
|
||||||
|
class RAGPipeline:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_dir: Path,
|
||||||
|
index_dir: Path,
|
||||||
|
embedder: Embedder,
|
||||||
|
llm: LLMClient,
|
||||||
|
top_k: int = 5,
|
||||||
|
):
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.index_dir = index_dir
|
||||||
|
self.embedder = embedder
|
||||||
|
self.llm = llm
|
||||||
|
self.top_k = top_k
|
||||||
|
self.store: Optional[VectorStore] = None
|
||||||
|
|
||||||
|
def build_index(self):
|
||||||
|
print("Loading chunks...")
|
||||||
|
chunks = self._load_chunks()
|
||||||
|
print(f"Loaded {len(chunks)} chunks")
|
||||||
|
|
||||||
|
print("Generating embeddings...")
|
||||||
|
texts = [c.content for c in chunks]
|
||||||
|
embeddings = self.embedder.embed(texts)
|
||||||
|
|
||||||
|
print("Building FAISS index...")
|
||||||
|
self.store = VectorStore(self.embedder.dim)
|
||||||
|
self.store.add(chunks, embeddings)
|
||||||
|
|
||||||
|
print("Saving index...")
|
||||||
|
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.store.save(self.index_dir)
|
||||||
|
print(f"Index saved to {self.index_dir}")
|
||||||
|
|
||||||
|
def load_index(self):
|
||||||
|
print("Loading index...")
|
||||||
|
self.store = VectorStore.load(self.index_dir, self.embedder.dim)
|
||||||
|
print(f"Loaded {len(self.store.chunks)} chunks")
|
||||||
|
|
||||||
|
def _load_chunks(self) -> List[Chunk]:
|
||||||
|
chunks = []
|
||||||
|
chunks_path = self.data_dir / "chunks.jsonl"
|
||||||
|
with open(chunks_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
chunks.append(Chunk(
|
||||||
|
chunk_id=data["chunk_id"],
|
||||||
|
article_id=data["article_id"],
|
||||||
|
content=data["content"],
|
||||||
|
section=data.get("section"),
|
||||||
|
title=data["title"],
|
||||||
|
url=data["url"],
|
||||||
|
language=data["language"],
|
||||||
|
metadata=data["metadata"],
|
||||||
|
))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def query(self, question: str) -> str:
|
||||||
|
if self.store is None:
|
||||||
|
raise RuntimeError("Index not loaded. Call build_index() or load_index() first.")
|
||||||
|
|
||||||
|
query_emb = self.embedder.embed_query(question)
|
||||||
|
results = self.store.search(query_emb, top_k=self.top_k)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return "Không có thông tin trong dữ liệu."
|
||||||
|
|
||||||
|
context_parts = []
|
||||||
|
for i, chunk in enumerate(results, 1):
|
||||||
|
source = f"[{i}] {chunk.title} ({chunk.url})"
|
||||||
|
if chunk.section:
|
||||||
|
source += f" - Section: {chunk.section}"
|
||||||
|
context_parts.append(f"{source}\n{chunk.content}")
|
||||||
|
|
||||||
|
context = "\n\n---\n\n".join(context_parts)
|
||||||
|
prompt = f"NGUYÊN LIỆU:\n{context}\n\nCÂU HỎI: {question}"
|
||||||
|
|
||||||
|
return self.llm.generate(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Force UTF-8 output on Windows
|
||||||
|
if hasattr(sys.stdout, 'reconfigure'):
|
||||||
|
sys.stdout.reconfigure(encoding='utf-8')
|
||||||
|
if hasattr(sys.stderr, 'reconfigure'):
|
||||||
|
sys.stderr.reconfigure(encoding='utf-8')
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="OrangePi RAG Application")
|
||||||
|
parser.add_argument("--data-dir", type=Path, default=Path("."), help="Directory with chunks.jsonl")
|
||||||
|
parser.add_argument("--index-dir", type=Path, default=Path("./rag_index"), help="FAISS index directory")
|
||||||
|
parser.add_argument("--build", action="store_true", help="Build index from chunks")
|
||||||
|
parser.add_argument("--query", type=str, help="Query to answer")
|
||||||
|
parser.add_argument("--interactive", action="store_true", help="Interactive chat mode")
|
||||||
|
parser.add_argument("--retrieve-only", action="store_true", help="Test retrieval without LLM")
|
||||||
|
parser.add_argument("--top-k", type=int, default=5, help="Number of chunks to retrieve")
|
||||||
|
parser.add_argument("--embed-model", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
||||||
|
parser.add_argument("--llm-model", default="gpt-4o-mini")
|
||||||
|
parser.add_argument("--llm-base-url", default="https://api.openai.com/v1")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
embedder = Embedder(args.embed_model)
|
||||||
|
|
||||||
|
if args.build:
|
||||||
|
# Build doesn't need LLM
|
||||||
|
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, None, args.top_k)
|
||||||
|
pipeline.build_index()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if args.retrieve_only:
|
||||||
|
# Retrieval test without LLM
|
||||||
|
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, None, args.top_k)
|
||||||
|
pipeline.load_index()
|
||||||
|
query_emb = embedder.embed_query(args.query)
|
||||||
|
results = pipeline.store.search(query_emb, top_k=args.top_k)
|
||||||
|
for i, chunk in enumerate(results, 1):
|
||||||
|
print(f"\n--- Result {i} (score: {chunk.metadata.get('similarity_score', 0):.4f}) ---")
|
||||||
|
print(f"Title: {chunk.title}")
|
||||||
|
print(f"URL: {chunk.url}")
|
||||||
|
if chunk.section:
|
||||||
|
print(f"Section: {chunk.section}")
|
||||||
|
print(f"Content: {chunk.content[:500]}...".encode('utf-8', errors='replace').decode('utf-8'))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
print("ERROR: Set OPENAI_API_KEY or LLM_API_KEY in environment or .env file")
|
||||||
|
print(" Or use --retrieve-only to test retrieval without LLM")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
llm = LLMClient(api_key, args.llm_base_url, args.llm_model)
|
||||||
|
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, llm, args.top_k)
|
||||||
|
|
||||||
|
if args.build:
|
||||||
|
pipeline.build_index()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
pipeline.load_index()
|
||||||
|
|
||||||
|
if args.query:
|
||||||
|
print(pipeline.query(args.query))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if args.interactive:
|
||||||
|
print("OrangePi RAG - Interactive mode (Ctrl+C to exit)")
|
||||||
|
print("=" * 50)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
question = input("\n❓ Câu hỏi: ").strip()
|
||||||
|
if not question:
|
||||||
|
continue
|
||||||
|
answer = pipeline.query(question)
|
||||||
|
print(f"\n🤖 Trả lời: {answer}")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nTạm biệt!")
|
||||||
|
break
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
return 0
|
||||||
|
|
||||||
|
parser.print_help()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
File diff suppressed because one or more lines are too long
Binary file not shown.
@@ -0,0 +1,6 @@
|
|||||||
|
sentence-transformers>=2.2.0
|
||||||
|
faiss-cpu>=1.7.4
|
||||||
|
numpy>=1.24.0
|
||||||
|
requests>=2.31.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
tqdm>=4.65.0
|
||||||
Reference in New Issue
Block a user