161 lines
4.5 KiB
Markdown
161 lines
4.5 KiB
Markdown
---
|
||
layout: default
|
||
title: "RAG"
|
||
parent: "Design Pattern"
|
||
nav_order: 3
|
||
---
|
||
|
||
# RAG (Retrieval Augmented Generation)
|
||
|
||
For certain LLM tasks like answering questions, providing relevant context is essential. One common architecture is a **two-stage** RAG pipeline:
|
||
|
||
<div align="center">
|
||
<img src="https://github.com/the-pocket/PocketFlow/raw/main/assets/rag.png?raw=true" width="400"/>
|
||
</div>
|
||
|
||
1. **Offline stage**: Preprocess and index documents ("building the index").
|
||
2. **Online stage**: Given a question, generate answers by retrieving the most relevant context.
|
||
|
||
---
|
||
## Stage 1: Offline Indexing
|
||
|
||
We create three Nodes:
|
||
1. `ChunkDocs` – [chunks](../utility_function/chunking.md) raw text.
|
||
2. `EmbedDocs` – [embeds](../utility_function/embedding.md) each chunk.
|
||
3. `StoreIndex` – stores embeddings into a [vector database](../utility_function/vector.md).
|
||
|
||
```python
|
||
class ChunkDocs(BatchNode):
|
||
def prep(self, shared):
|
||
# A list of file paths in shared["files"]. We process each file.
|
||
return shared["files"]
|
||
|
||
def exec(self, filepath):
|
||
# read file content. In real usage, do error handling.
|
||
with open(filepath, "r", encoding="utf-8") as f:
|
||
text = f.read()
|
||
# chunk by 100 chars each
|
||
chunks = []
|
||
size = 100
|
||
for i in range(0, len(text), size):
|
||
chunks.append(text[i : i + size])
|
||
return chunks
|
||
|
||
def post(self, shared, prep_res, exec_res_list):
|
||
# exec_res_list is a list of chunk-lists, one per file.
|
||
# flatten them all into a single list of chunks.
|
||
all_chunks = []
|
||
for chunk_list in exec_res_list:
|
||
all_chunks.extend(chunk_list)
|
||
shared["all_chunks"] = all_chunks
|
||
|
||
class EmbedDocs(BatchNode):
|
||
def prep(self, shared):
|
||
return shared["all_chunks"]
|
||
|
||
def exec(self, chunk):
|
||
return get_embedding(chunk)
|
||
|
||
def post(self, shared, prep_res, exec_res_list):
|
||
# Store the list of embeddings.
|
||
shared["all_embeds"] = exec_res_list
|
||
print(f"Total embeddings: {len(exec_res_list)}")
|
||
|
||
class StoreIndex(Node):
|
||
def prep(self, shared):
|
||
# We'll read all embeds from shared.
|
||
return shared["all_embeds"]
|
||
|
||
def exec(self, all_embeds):
|
||
# Create a vector index (faiss or other DB in real usage).
|
||
index = create_index(all_embeds)
|
||
return index
|
||
|
||
def post(self, shared, prep_res, index):
|
||
shared["index"] = index
|
||
|
||
# Wire them in sequence
|
||
chunk_node = ChunkDocs()
|
||
embed_node = EmbedDocs()
|
||
store_node = StoreIndex()
|
||
|
||
chunk_node >> embed_node >> store_node
|
||
|
||
OfflineFlow = Flow(start=chunk_node)
|
||
```
|
||
|
||
Usage example:
|
||
|
||
```python
|
||
shared = {
|
||
"files": ["doc1.txt", "doc2.txt"], # any text files
|
||
}
|
||
OfflineFlow.run(shared)
|
||
```
|
||
|
||
---
|
||
## Stage 2: Online Query & Answer
|
||
|
||
We have 3 nodes:
|
||
1. `EmbedQuery` – embeds the user’s question.
|
||
2. `RetrieveDocs` – retrieves top chunk from the index.
|
||
3. `GenerateAnswer` – calls the LLM with the question + chunk to produce the final answer.
|
||
|
||
```python
|
||
class EmbedQuery(Node):
|
||
def prep(self, shared):
|
||
return shared["question"]
|
||
|
||
def exec(self, question):
|
||
return get_embedding(question)
|
||
|
||
def post(self, shared, prep_res, q_emb):
|
||
shared["q_emb"] = q_emb
|
||
|
||
class RetrieveDocs(Node):
|
||
def prep(self, shared):
|
||
# We'll need the query embedding, plus the offline index/chunks
|
||
return shared["q_emb"], shared["index"], shared["all_chunks"]
|
||
|
||
def exec(self, inputs):
|
||
q_emb, index, chunks = inputs
|
||
I, D = search_index(index, q_emb, top_k=1)
|
||
best_id = I[0][0]
|
||
relevant_chunk = chunks[best_id]
|
||
return relevant_chunk
|
||
|
||
def post(self, shared, prep_res, relevant_chunk):
|
||
shared["retrieved_chunk"] = relevant_chunk
|
||
print("Retrieved chunk:", relevant_chunk[:60], "...")
|
||
|
||
class GenerateAnswer(Node):
|
||
def prep(self, shared):
|
||
return shared["question"], shared["retrieved_chunk"]
|
||
|
||
def exec(self, inputs):
|
||
question, chunk = inputs
|
||
prompt = f"Question: {question}\nContext: {chunk}\nAnswer:"
|
||
return call_llm(prompt)
|
||
|
||
def post(self, shared, prep_res, answer):
|
||
shared["answer"] = answer
|
||
print("Answer:", answer)
|
||
|
||
embed_qnode = EmbedQuery()
|
||
retrieve_node = RetrieveDocs()
|
||
generate_node = GenerateAnswer()
|
||
|
||
embed_qnode >> retrieve_node >> generate_node
|
||
OnlineFlow = Flow(start=embed_qnode)
|
||
```
|
||
|
||
Usage example:
|
||
|
||
```python
|
||
# Suppose we already ran OfflineFlow and have:
|
||
# shared["all_chunks"], shared["index"], etc.
|
||
shared["question"] = "Why do people like cats?"
|
||
|
||
OnlineFlow.run(shared)
|
||
# final answer in shared["answer"]
|
||
``` |