5.1 KiB
5.1 KiB
| layout | title | parent | nav_order |
|---|---|---|---|
| default | RAG | Design Pattern | 4 |
RAG (Retrieval Augmented Generation)
For certain LLM tasks like answering questions, providing context is essential. Most common way to retrive text-based context is through embedding:
- Given texts, you first chunk them.
- Next, you embed each chunk.
- Then you store the chunks in vector databases.
- Finally, given a query, you embed the query and find the closest chunk in the vector databases.
RAG (Retrieval Augmented Generation)
For certain LLM tasks like answering questions, providing relevant context is essential. One common architecture is a two-stage RAG pipeline:
- Offline stage: Preprocess and index documents ("building the index").
- Online stage: Given a question, generate answers by retrieving the most relevant context from the index.
Stage 1: Offline Indexing
We create three Nodes:
ChunkDocs– chunks raw text.EmbedDocs– embeds each chunk.StoreIndex– stores embeddings into a vector database.
class ChunkDocs(BatchNode):
def prep(self, shared):
# A list of file paths in shared["files"]. We process each file.
return shared["files"]
def exec(self, filepath):
# read file content. In real usage, do error handling.
with open(filepath, "r", encoding="utf-8") as f:
text = f.read()
# chunk by 100 chars each
chunks = []
size = 100
for i in range(0, len(text), size):
chunks.append(text[i : i + size])
return chunks
def post(self, shared, prep_res, exec_res_list):
# exec_res_list is a list of chunk-lists, one per file.
# flatten them all into a single list of chunks.
all_chunks = []
for chunk_list in exec_res_list:
all_chunks.extend(chunk_list)
shared["all_chunks"] = all_chunks
class EmbedDocs(BatchNode):
def prep(self, shared):
return shared["all_chunks"]
def exec(self, chunk):
return get_embedding(chunk)
def post(self, shared, prep_res, exec_res_list):
# Store the list of embeddings.
shared["all_embeds"] = exec_res_list
print(f"Total embeddings: {len(exec_res_list)}")
class StoreIndex(Node):
def prep(self, shared):
# We'll read all embeds from shared.
return shared["all_embeds"]
def exec(self, all_embeds):
# Create a vector index (faiss or other DB in real usage).
index = create_index(all_embeds)
return index
def post(self, shared, prep_res, index):
shared["index"] = index
# Wire them in sequence
chunk_node = ChunkDocs()
embed_node = EmbedDocs()
store_node = StoreIndex()
chunk_node >> embed_node >> store_node
OfflineFlow = Flow(start=chunk_node)
Usage example:
shared = {
"files": ["doc1.txt", "doc2.txt"], # any text files
}
OfflineFlow.run(shared)
Stage 2: Online Query & Answer
We have 3 nodes:
EmbedQuery– embeds the user’s question.RetrieveDocs– retrieves top chunk from the index.GenerateAnswer– calls the LLM with the question + chunk to produce the final answer.
class EmbedQuery(Node):
def prep(self, shared):
return shared["question"]
def exec(self, question):
return get_embedding(question)
def post(self, shared, prep_res, q_emb):
shared["q_emb"] = q_emb
class RetrieveDocs(Node):
def prep(self, shared):
# We'll need the query embedding, plus the offline index/chunks
return shared["q_emb"], shared["index"], shared["all_chunks"]
def exec(self, inputs):
q_emb, index, chunks = inputs
I, D = search_index(index, q_emb, top_k=1)
best_id = I[0][0]
relevant_chunk = chunks[best_id]
return relevant_chunk
def post(self, shared, prep_res, relevant_chunk):
shared["retrieved_chunk"] = relevant_chunk
print("Retrieved chunk:", relevant_chunk[:60], "...")
class GenerateAnswer(Node):
def prep(self, shared):
return shared["question"], shared["retrieved_chunk"]
def exec(self, inputs):
question, chunk = inputs
prompt = f"Question: {question}\nContext: {chunk}\nAnswer:"
return call_llm(prompt)
def post(self, shared, prep_res, answer):
shared["answer"] = answer
print("Answer:", answer)
embed_qnode = EmbedQuery()
retrieve_node = RetrieveDocs()
generate_node = GenerateAnswer()
embed_qnode >> retrieve_node >> generate_node
OnlineFlow = Flow(start=embed_qnode)
Usage example:
# Suppose we already ran OfflineFlow and have:
# shared["all_chunks"], shared["index"], etc.
shared["question"] = "Why do people like cats?"
OnlineFlow.run(shared)
# final answer in shared["answer"]