RAG 与知识库
2534 字约 8 分钟
2026-05-20
RAG(Retrieval-Augmented Generation,检索增强生成)解决大模型两大核心痛点:
- 知识截止:模型不知道训练数据之后发生的事
- 幻觉问题:模型会自信地"编造"事实
RAG 思路简单而有效:先检索再生成——从外部知识库找到相关内容,将其作为上下文提供给模型。
1. RAG 工作原理
1.1 整体流程
离线阶段(构建知识库):
文档 → 分块 → 向量化(Embedding)→ 存入向量数据库
在线阶段(回答问题):
用户问题 → 向量化 → 相似度检索 → 取 Top-K 文档
→ 文档 + 问题 → 大模型 → 答案1.2 为什么用向量?
传统关键词搜索(BM25):查"苹果手机",找不到包含"iPhone"的文档。
向量搜索:将文本映射到高维空间,语义相似的文本对应相近的向量,"苹果手机"和"iPhone"的向量很接近,可以互相检索到。
2. 文本向量化(Embedding)
2.1 什么是 Embedding
Embedding 是将文本映射到连续高维向量空间的过程:
"机器学习是AI的基础" → [0.23, -0.45, 0.11, ...] (1024维向量)
"深度学习是ML的子集" → [0.21, -0.43, 0.09, ...] (相近的向量)
"今天天气很好" → [-0.32, 0.67, -0.54, ...] (差距很大)2.2 主流 Embedding 模型
| 模型 | 维度 | 特点 |
|---|---|---|
| text-embedding-3-small(OpenAI) | 1536 | 便宜,效果不错 |
| text-embedding-3-large(OpenAI) | 3072 | 效果更好,贵一些 |
| voyage-3(Voyage AI) | 1024 | Anthropic 推荐,效果优秀 |
| bge-m3(BAAI,开源) | 1024 | 多语言,中文效果好,免费 |
| nomic-embed-text(本地) | 768 | 完全本地运行 |
# OpenAI Embedding
from openai import OpenAI
client = OpenAI()
def embed_openai(texts: list[str]) -> list[list[float]]:
response = client.embeddings.create(
model="text-embedding-3-small",
input=texts
)
return [item.embedding for item in response.data]
# Voyage AI Embedding(Anthropic 推荐)
import voyageai
vo = voyageai.Client()
def embed_voyage(texts: list[str]) -> list[list[float]]:
result = vo.embed(texts, model="voyage-3", input_type="document")
return result.embeddings
# 本地运行(sentence-transformers)
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('BAAI/bge-m3') # 多语言,中文效果好
def embed_local(texts: list[str]) -> list[list[float]]:
return model.encode(texts, normalize_embeddings=True).tolist()2.3 相似度计算
import numpy as np
def cosine_similarity(a: list, b: list) -> float:
a, b = np.array(a), np.array(b)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
# 已归一化的向量可以直接用点积
def dot_product_similarity(a: list, b: list) -> float:
return float(np.dot(np.array(a), np.array(b)))3. 向量数据库
3.1 主流选择对比
| 数据库 | 部署方式 | 适用场景 | 特点 |
|---|---|---|---|
| Chroma | 本地/自托管 | 开发、小项目 | Python 原生,无需配置 |
| FAISS | 本地库 | 大规模纯向量检索 | Facebook 出品,极快 |
| Weaviate | 自托管/云 | 中大型项目 | 丰富功能,GraphQL API |
| Pinecone | 纯云服务 | 快速上线生产 | 无需运维,按量付费 |
| Qdrant | 自托管/云 | 生产环境 | Rust 实现,高性能 |
| pgvector | PostgreSQL 插件 | 已有 PG 基础设施 | SQL 友好,简化架构 |
3.2 Chroma 详细使用
import chromadb
from chromadb.utils import embedding_functions
# 内存模式(开发测试)
client = chromadb.Client()
# 持久化模式(生产)
client = chromadb.PersistentClient(path="./chroma_db")
# 使用自定义 Embedding 函数
ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.environ["OPENAI_API_KEY"],
model_name="text-embedding-3-small"
)
# 创建集合
collection = client.get_or_create_collection(
name="my_knowledge_base",
embedding_function=ef,
metadata={"hnsw:space": "cosine"} # 距离度量
)
# 批量添加文档
def add_documents(texts: list[str], metadatas: list[dict], ids: list[str]):
collection.add(
documents=texts,
metadatas=metadatas,
ids=ids
)
# 检索
def search(query: str, n_results: int = 5, filters: dict = None):
results = collection.query(
query_texts=[query],
n_results=n_results,
where=filters, # 元数据过滤:{"source": "report_2024"}
include=["documents", "metadatas", "distances"]
)
return [
{
"text": doc,
"metadata": meta,
"score": 1 - dist # 转换为相似度分数
}
for doc, meta, dist in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0]
)
]3.3 Qdrant(生产推荐)
from qdrant_client import QdrantClient
from qdrant_client.models import (
VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchValue
)
# 连接 Qdrant
qdrant = QdrantClient(url="http://localhost:6333")
# 创建集合
qdrant.create_collection(
collection_name="knowledge",
vectors_config=VectorParams(size=1024, distance=Distance.COSINE)
)
# 添加向量
def upsert_documents(texts, embeddings, metadatas, ids):
points = [
PointStruct(id=id_, vector=emb, payload=meta)
for id_, emb, meta in zip(ids, embeddings, metadatas)
]
qdrant.upsert(collection_name="knowledge", points=points)
# 检索(带过滤条件)
def search_with_filter(query_embedding, category=None, n=5):
filter_ = None
if category:
filter_ = Filter(
must=[FieldCondition(key="category", match=MatchValue(value=category))]
)
results = qdrant.search(
collection_name="knowledge",
query_vector=query_embedding,
limit=n,
query_filter=filter_,
with_payload=True
)
return results4. 文档预处理
4.1 文档切块(Chunking)
将长文档切成小块,每块独立表达完整意思。切块质量直接影响 RAG 效果。
策略1:固定大小切块
def chunk_by_size(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
"""按字符数切块,带重叠(避免切断重要信息)"""
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
if chunk.strip():
chunks.append(chunk)
start = end - overlap
return chunks策略2:按语义切块(推荐)
def chunk_by_structure(text: str, max_chunk_size: int = 800) -> list[str]:
"""优先按段落/句子切块,保留语义完整性"""
import re
# 按段落切分
paragraphs = text.split('\n\n')
chunks, current_chunk = [], ""
for para in paragraphs:
para = para.strip()
if not para:
continue
if len(current_chunk) + len(para) < max_chunk_size:
current_chunk += ("\n\n" if current_chunk else "") + para
else:
if current_chunk:
chunks.append(current_chunk)
# 如果单段落太长,按句子切
if len(para) > max_chunk_size:
sentences = re.split(r'(?<=[。!?.!?])\s*', para)
sub_chunk = ""
for sent in sentences:
if len(sub_chunk) + len(sent) < max_chunk_size:
sub_chunk += sent
else:
if sub_chunk:
chunks.append(sub_chunk)
sub_chunk = sent
if sub_chunk:
chunks.append(sub_chunk)
else:
current_chunk = para
if current_chunk:
chunks.append(current_chunk)
return chunksLangChain RecursiveCharacterTextSplitter(工程推荐):
from langchain.text_splitter import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=100,
length_function=len,
separators=["\n\n", "\n", "。", "!", "?", ".", "!", "?", " ", ""]
)
chunks = splitter.split_text(text)4.2 文档解析
# PDF 解析
import pdfplumber
def parse_pdf(file_path: str) -> list[dict]:
docs = []
with pdfplumber.open(file_path) as pdf:
for i, page in enumerate(pdf.pages):
text = page.extract_text()
if text and text.strip():
docs.append({
"text": text,
"metadata": {
"source": file_path,
"page": i + 1,
"total_pages": len(pdf.pages)
}
})
return docs
# Word 文档解析
from docx import Document
def parse_docx(file_path: str) -> list[dict]:
doc = Document(file_path)
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
text = "\n\n".join(paragraphs)
return [{"text": text, "metadata": {"source": file_path}}]
# 网页解析
import requests
from bs4 import BeautifulSoup
def parse_webpage(url: str) -> dict:
response = requests.get(url, timeout=10)
soup = BeautifulSoup(response.content, "html.parser")
# 移除导航、广告等噪声
for tag in soup(["nav", "footer", "aside", "script", "style", "header"]):
tag.decompose()
text = soup.get_text(separator="\n", strip=True)
return {"text": text, "metadata": {"source": url, "title": soup.title.string}}5. 完整 RAG 系统
import anthropic
from typing import Optional
class RAGSystem:
def __init__(self, collection_name: str = "knowledge"):
self.llm = anthropic.Anthropic()
# 使用 Chroma 作为向量库
import chromadb
from chromadb.utils import embedding_functions
self.chroma = chromadb.PersistentClient(path=f"./rag_{collection_name}")
ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.environ["OPENAI_API_KEY"],
model_name="text-embedding-3-small"
)
self.collection = self.chroma.get_or_create_collection(
name=collection_name, embedding_function=ef
)
self.splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
def ingest_document(self, text: str, source: str, extra_metadata: dict = {}):
"""将文档加入知识库"""
chunks = self.splitter.split_text(text)
self.collection.add(
documents=chunks,
metadatas=[{"source": source, "chunk_index": i, **extra_metadata}
for i, _ in enumerate(chunks)],
ids=[f"{source}_{i}" for i in range(len(chunks))]
)
print(f"已索引 {len(chunks)} 个块,来源:{source}")
return len(chunks)
def retrieve(self, query: str, n_results: int = 5,
source_filter: Optional[str] = None) -> list[dict]:
"""从知识库检索相关内容"""
where = {"source": source_filter} if source_filter else None
results = self.collection.query(
query_texts=[query],
n_results=n_results,
where=where,
include=["documents", "metadatas", "distances"]
)
return [
{
"text": doc,
"source": meta.get("source", "unknown"),
"score": round(1 - dist, 4)
}
for doc, meta, dist in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0]
)
if 1 - dist > 0.3 # 过滤低相关度结果
]
def answer(self, question: str, n_results: int = 5,
source_filter: Optional[str] = None) -> dict:
"""检索 + 生成完整问答"""
# 1. 检索相关文档
retrieved = self.retrieve(question, n_results, source_filter)
if not retrieved:
return {
"answer": "抱歉,我在知识库中没有找到相关信息。",
"sources": [],
"retrieved_docs": []
}
# 2. 构建上下文
context_parts = []
for i, doc in enumerate(retrieved, 1):
context_parts.append(
f"[参考资料 {i}](来源:{doc['source']},相关度:{doc['score']})\n"
f"{doc['text']}"
)
context = "\n\n---\n\n".join(context_parts)
# 3. 构建 Prompt
system = """你是一个专业的问答助手。请严格基于提供的参考资料回答问题。
规则:
- 只使用参考资料中的信息
- 如果资料中没有相关信息,明确告知用户
- 不要编造或推断资料中没有的内容
- 引用来源时使用 [参考资料 N] 格式
- 如果信息来自多个资料,都要引用"""
user_message = f"""参考资料:
{context}
---
问题:{question}"""
# 4. 调用 LLM
response = self.llm.messages.create(
model="claude-sonnet-4-6",
max_tokens=2048,
system=system,
messages=[{"role": "user", "content": user_message}]
)
return {
"answer": response.content[0].text,
"sources": list(set(doc["source"] for doc in retrieved)),
"retrieved_docs": retrieved,
"tokens_used": {
"input": response.usage.input_tokens,
"output": response.usage.output_tokens
}
}
# 使用示例
rag = RAGSystem("company_docs")
# 导入文档
with open("annual_report.txt") as f:
rag.ingest_document(f.read(), source="年报2024", extra_metadata={"year": 2024})
# 提问
result = rag.answer("公司2024年的营收增长情况如何?")
print(result["answer"])
print(f"\n引用来源:{result['sources']}")6. 高级 RAG 技术
6.1 HyDE(假设文档扩展)
先让 LLM 生成一个假设答案,用假设答案去检索(而不是直接用问题检索):
def hyde_retrieve(question: str, n_results: int = 5) -> list[dict]:
# 生成假设文档
hypothetical_doc = client.messages.create(
model="claude-haiku-4-5", # 用便宜模型生成
max_tokens=300,
messages=[{
"role": "user",
"content": f"写一段可能包含以下问题答案的文字(不需要准确,只需相关):\n{question}"
}]
).content[0].text
# 用假设文档检索(而不是问题本身)
return retrieve(hypothetical_doc, n_results)6.2 重排序(Re-ranking)
检索后用专门的重排序模型精细排序,提升精度:
from sentence_transformers import CrossEncoder
reranker = CrossEncoder('BAAI/bge-reranker-v2-m3')
def rerank(query: str, candidates: list[dict], top_k: int = 3) -> list[dict]:
"""使用交叉编码器重新排序候选文档"""
pairs = [(query, doc["text"]) for doc in candidates]
scores = reranker.predict(pairs)
# 按重排序分数排序
sorted_docs = sorted(
zip(candidates, scores),
key=lambda x: x[1], reverse=True
)
return [doc for doc, score in sorted_docs[:top_k]]6.3 混合检索(Hybrid Search)
结合向量检索(语义)和 BM25 关键词检索(精确匹配),互补不足:
from rank_bm25 import BM25Okapi
class HybridRetriever:
def __init__(self, alpha: float = 0.5):
"""alpha: 向量检索权重(0=纯关键词,1=纯向量)"""
self.alpha = alpha
self.corpus = []
self.bm25 = None
def fit(self, texts: list[str]):
self.corpus = texts
tokenized = [text.split() for text in texts]
self.bm25 = BM25Okapi(tokenized)
def retrieve(self, query: str, vector_scores: list[float], n: int = 5) -> list[int]:
# BM25 分数
bm25_scores = self.bm25.get_scores(query.split())
bm25_scores_normalized = bm25_scores / (bm25_scores.max() + 1e-6)
# 向量分数(已归一化)
vector_scores_normalized = np.array(vector_scores)
# 融合分数(RRF 或线性融合)
combined = self.alpha * vector_scores_normalized + (1 - self.alpha) * bm25_scores_normalized
return np.argsort(combined)[::-1][:n].tolist()6.4 Contextual Compression
检索到的文档可能很长,只有部分内容相关。让 LLM 压缩成相关部分:
def compress_document(query: str, document: str) -> str:
"""提取文档中与问题相关的部分"""
response = client.messages.create(
model="claude-haiku-4-5",
max_tokens=500,
messages=[{
"role": "user",
"content": f"""从以下文档中提取与问题直接相关的句子或段落。
如果整个文档都不相关,输出 "NOT_RELEVANT"。
只输出相关部分,不要其他内容。
问题:{query}
文档:{document}"""
}]
).content[0].text
if "NOT_RELEVANT" in response:
return None
return response7. RAG 评估
# RAGAS 框架:专门评估 RAG 系统的工具
from ragas import evaluate
from ragas.metrics import (
faithfulness, # 忠实度:答案是否基于检索文档
answer_relevancy, # 答案相关性:答案是否回答了问题
context_recall, # 上下文召回率:相关信息是否都检索到了
context_precision, # 上下文精确率:检索到的是否都相关
)
from datasets import Dataset
# 准备评估数据
eval_data = {
"question": ["..."], # 问题
"answer": ["..."], # RAG 系统的答案
"contexts": [["...", "..."]], # 检索到的上下文
"ground_truth": ["..."], # 标准答案
}
result = evaluate(
Dataset.from_dict(eval_data),
metrics=[faithfulness, answer_relevancy, context_recall, context_precision]
)
print(result)