add rag test
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
# LLM API Configuration
|
||||
# Get your API key from:
|
||||
# - OpenAI: https://platform.openai.com/api-keys
|
||||
# - Together.ai: https://api.together.xyz/settings/api-keys
|
||||
# - Groq: https://console.groq.com/keys
|
||||
# - Or any OpenAI-compatible API
|
||||
|
||||
OPENAI_API_KEY=your_api_key_here
|
||||
|
||||
# Optional: Custom base URL for OpenAI-compatible APIs
|
||||
# LLM_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# Optional: Model name (default: gpt-4o-mini)
|
||||
# LLM_MODEL=gpt-4o-mini
|
||||
@@ -0,0 +1,10 @@
|
||||
# LLM API Configuration
|
||||
# Get your API key from https://platform.openai.com/api-keys
|
||||
# Or use any OpenAI-compatible API (e.g., Together.ai, Groq, etc.)
|
||||
OPENAI_API_KEY=your_api_key_here
|
||||
|
||||
# Optional: Custom base URL for OpenAI-compatible APIs
|
||||
# LLM_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# Optional: Model name (default: gpt-4o-mini)
|
||||
# LLM_MODEL=gpt-4o-mini
|
||||
+320
@@ -0,0 +1,320 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
OrangePi RAG Application
|
||||
A Vietnamese-language RAG system for querying Orange Pi blog articles.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import faiss
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
chunk_id: str
|
||||
article_id: str
|
||||
content: str
|
||||
section: Optional[str]
|
||||
title: str
|
||||
url: str
|
||||
language: str
|
||||
metadata: Dict[str, Any]
|
||||
embedding: Optional[np.ndarray] = None
|
||||
|
||||
|
||||
class VectorStore:
|
||||
def __init__(self, dim: int):
|
||||
self.dim = dim
|
||||
self.index = faiss.IndexFlatIP(dim)
|
||||
self.chunks: List[Chunk] = []
|
||||
|
||||
def add(self, chunks: List[Chunk], embeddings: np.ndarray):
|
||||
faiss.normalize_L2(embeddings)
|
||||
self.index.add(embeddings.astype(np.float32))
|
||||
self.chunks.extend(chunks)
|
||||
|
||||
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Chunk]:
|
||||
faiss.normalize_L2(query_embedding.reshape(1, -1))
|
||||
scores, indices = self.index.search(query_embedding.reshape(1, -1).astype(np.float32), top_k)
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx >= 0 and idx < len(self.chunks):
|
||||
chunk = self.chunks[idx]
|
||||
chunk.metadata["similarity_score"] = float(score)
|
||||
results.append(chunk)
|
||||
return results
|
||||
|
||||
def save(self, path: Path):
|
||||
faiss.write_index(self.index, str(path / "faiss.index"))
|
||||
with open(path / "chunks.jsonl", "w", encoding="utf-8") as f:
|
||||
for chunk in self.chunks:
|
||||
data = {
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"article_id": chunk.article_id,
|
||||
"content": chunk.content,
|
||||
"section": chunk.section,
|
||||
"title": chunk.title,
|
||||
"url": chunk.url,
|
||||
"language": chunk.language,
|
||||
"metadata": chunk.metadata,
|
||||
}
|
||||
f.write(json.dumps(data, ensure_ascii=False) + "\n")
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path, dim: int) -> "VectorStore":
|
||||
store = cls(dim)
|
||||
store.index = faiss.read_index(str(path / "faiss.index"))
|
||||
with open(path / "chunks.jsonl", "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
store.chunks.append(Chunk(
|
||||
chunk_id=data["chunk_id"],
|
||||
article_id=data["article_id"],
|
||||
content=data["content"],
|
||||
section=data.get("section"),
|
||||
title=data["title"],
|
||||
url=data["url"],
|
||||
language=data["language"],
|
||||
metadata=data["metadata"],
|
||||
))
|
||||
return store
|
||||
|
||||
|
||||
class Embedder:
|
||||
def __init__(self, model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
|
||||
self.model = SentenceTransformer(model_name)
|
||||
self.dim = self.model.get_embedding_dimension()
|
||||
|
||||
def embed(self, texts: List[str]) -> np.ndarray:
|
||||
return self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
|
||||
|
||||
def embed_query(self, query: str) -> np.ndarray:
|
||||
return self.model.encode([query], show_progress_bar=False, convert_to_numpy=True)[0]
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.openai.com/v1",
|
||||
model: str = "gpt-4o-mini",
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def generate(self, prompt: str, temperature: float = 0.1, max_tokens: int = 1000) -> str:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": self._system_prompt()},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
timeout=60,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["choices"][0]["message"]["content"].strip()
|
||||
|
||||
def _system_prompt(self) -> str:
|
||||
return """Bạn là một trợ lý AI chuyên về Orange Pi, sử dụng dữ liệu từ blog orangepi.vn (nhà phân phối chính thức Orange Pi tại Việt Nam).
|
||||
|
||||
NHIỆM VỤ:
|
||||
- Trả lời câu hỏi của người dùng CHỈ DỰA TRÊN THÔNG TIN ĐƯỢC CUNG CẤP trong phần "NGUYÊN LIỆU".
|
||||
- Nếu thông tin không có trong nguyên liệu, hãy trả lời: "Không có thông tin trong dữ liệu."
|
||||
- KHÔNG được bịa đặt, suy diễn hoặc sử dụng kiến thức bên ngoài.
|
||||
- Trả lời bằng tiếng Việt, ngắn gọn, chính xác.
|
||||
- Trích dẫn nguồn (title + URL) khi có thể.
|
||||
|
||||
NGUYÊN LIỆU:
|
||||
{context}"""
|
||||
|
||||
|
||||
class RAGPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
index_dir: Path,
|
||||
embedder: Embedder,
|
||||
llm: LLMClient,
|
||||
top_k: int = 5,
|
||||
):
|
||||
self.data_dir = data_dir
|
||||
self.index_dir = index_dir
|
||||
self.embedder = embedder
|
||||
self.llm = llm
|
||||
self.top_k = top_k
|
||||
self.store: Optional[VectorStore] = None
|
||||
|
||||
def build_index(self):
|
||||
print("Loading chunks...")
|
||||
chunks = self._load_chunks()
|
||||
print(f"Loaded {len(chunks)} chunks")
|
||||
|
||||
print("Generating embeddings...")
|
||||
texts = [c.content for c in chunks]
|
||||
embeddings = self.embedder.embed(texts)
|
||||
|
||||
print("Building FAISS index...")
|
||||
self.store = VectorStore(self.embedder.dim)
|
||||
self.store.add(chunks, embeddings)
|
||||
|
||||
print("Saving index...")
|
||||
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.store.save(self.index_dir)
|
||||
print(f"Index saved to {self.index_dir}")
|
||||
|
||||
def load_index(self):
|
||||
print("Loading index...")
|
||||
self.store = VectorStore.load(self.index_dir, self.embedder.dim)
|
||||
print(f"Loaded {len(self.store.chunks)} chunks")
|
||||
|
||||
def _load_chunks(self) -> List[Chunk]:
|
||||
chunks = []
|
||||
chunks_path = self.data_dir / "chunks.jsonl"
|
||||
with open(chunks_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
chunks.append(Chunk(
|
||||
chunk_id=data["chunk_id"],
|
||||
article_id=data["article_id"],
|
||||
content=data["content"],
|
||||
section=data.get("section"),
|
||||
title=data["title"],
|
||||
url=data["url"],
|
||||
language=data["language"],
|
||||
metadata=data["metadata"],
|
||||
))
|
||||
return chunks
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
if self.store is None:
|
||||
raise RuntimeError("Index not loaded. Call build_index() or load_index() first.")
|
||||
|
||||
query_emb = self.embedder.embed_query(question)
|
||||
results = self.store.search(query_emb, top_k=self.top_k)
|
||||
|
||||
if not results:
|
||||
return "Không có thông tin trong dữ liệu."
|
||||
|
||||
context_parts = []
|
||||
for i, chunk in enumerate(results, 1):
|
||||
source = f"[{i}] {chunk.title} ({chunk.url})"
|
||||
if chunk.section:
|
||||
source += f" - Section: {chunk.section}"
|
||||
context_parts.append(f"{source}\n{chunk.content}")
|
||||
|
||||
context = "\n\n---\n\n".join(context_parts)
|
||||
prompt = f"NGUYÊN LIỆU:\n{context}\n\nCÂU HỎI: {question}"
|
||||
|
||||
return self.llm.generate(prompt)
|
||||
|
||||
|
||||
def main():
|
||||
# Force UTF-8 output on Windows
|
||||
if hasattr(sys.stdout, 'reconfigure'):
|
||||
sys.stdout.reconfigure(encoding='utf-8')
|
||||
if hasattr(sys.stderr, 'reconfigure'):
|
||||
sys.stderr.reconfigure(encoding='utf-8')
|
||||
|
||||
parser = argparse.ArgumentParser(description="OrangePi RAG Application")
|
||||
parser.add_argument("--data-dir", type=Path, default=Path("."), help="Directory with chunks.jsonl")
|
||||
parser.add_argument("--index-dir", type=Path, default=Path("./rag_index"), help="FAISS index directory")
|
||||
parser.add_argument("--build", action="store_true", help="Build index from chunks")
|
||||
parser.add_argument("--query", type=str, help="Query to answer")
|
||||
parser.add_argument("--interactive", action="store_true", help="Interactive chat mode")
|
||||
parser.add_argument("--retrieve-only", action="store_true", help="Test retrieval without LLM")
|
||||
parser.add_argument("--top-k", type=int, default=5, help="Number of chunks to retrieve")
|
||||
parser.add_argument("--embed-model", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
||||
parser.add_argument("--llm-model", default="gpt-4o-mini")
|
||||
parser.add_argument("--llm-base-url", default="https://api.openai.com/v1")
|
||||
args = parser.parse_args()
|
||||
|
||||
load_dotenv()
|
||||
|
||||
embedder = Embedder(args.embed_model)
|
||||
|
||||
if args.build:
|
||||
# Build doesn't need LLM
|
||||
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, None, args.top_k)
|
||||
pipeline.build_index()
|
||||
return 0
|
||||
|
||||
if args.retrieve_only:
|
||||
# Retrieval test without LLM
|
||||
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, None, args.top_k)
|
||||
pipeline.load_index()
|
||||
query_emb = embedder.embed_query(args.query)
|
||||
results = pipeline.store.search(query_emb, top_k=args.top_k)
|
||||
for i, chunk in enumerate(results, 1):
|
||||
print(f"\n--- Result {i} (score: {chunk.metadata.get('similarity_score', 0):.4f}) ---")
|
||||
print(f"Title: {chunk.title}")
|
||||
print(f"URL: {chunk.url}")
|
||||
if chunk.section:
|
||||
print(f"Section: {chunk.section}")
|
||||
print(f"Content: {chunk.content[:500]}...".encode('utf-8', errors='replace').decode('utf-8'))
|
||||
return 0
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY")
|
||||
if not api_key:
|
||||
print("ERROR: Set OPENAI_API_KEY or LLM_API_KEY in environment or .env file")
|
||||
print(" Or use --retrieve-only to test retrieval without LLM")
|
||||
return 1
|
||||
|
||||
llm = LLMClient(api_key, args.llm_base_url, args.llm_model)
|
||||
pipeline = RAGPipeline(args.data_dir, args.index_dir, embedder, llm, args.top_k)
|
||||
|
||||
if args.build:
|
||||
pipeline.build_index()
|
||||
return 0
|
||||
|
||||
pipeline.load_index()
|
||||
|
||||
if args.query:
|
||||
print(pipeline.query(args.query))
|
||||
return 0
|
||||
|
||||
if args.interactive:
|
||||
print("OrangePi RAG - Interactive mode (Ctrl+C to exit)")
|
||||
print("=" * 50)
|
||||
while True:
|
||||
try:
|
||||
question = input("\n❓ Câu hỏi: ").strip()
|
||||
if not question:
|
||||
continue
|
||||
answer = pipeline.query(question)
|
||||
print(f"\n🤖 Trả lời: {answer}")
|
||||
except KeyboardInterrupt:
|
||||
print("\nTạm biệt!")
|
||||
break
|
||||
except EOFError:
|
||||
break
|
||||
return 0
|
||||
|
||||
parser.print_help()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
File diff suppressed because one or more lines are too long
Binary file not shown.
@@ -0,0 +1,6 @@
|
||||
sentence-transformers>=2.2.0
|
||||
faiss-cpu>=1.7.4
|
||||
numpy>=1.24.0
|
||||
requests>=2.31.0
|
||||
python-dotenv>=1.0.0
|
||||
tqdm>=4.65.0
|
||||
Reference in New Issue
Block a user