pocketflow/docs/design_pattern/rag.md

170 lines
5.1 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

---
layout: default
title: "RAG"
parent: "Design Pattern"
nav_order: 4
---
# RAG (Retrieval Augmented Generation)
For certain LLM tasks like answering questions, providing context is essential.
Most common way to retrive text-based context is through embedding:
1. Given texts, you first [chunk](../utility_function/chunking.md) them.
2. Next, you [embed](../utility_function/embedding.md) each chunk.
3. Then you store the chunks in [vector databases](../utility_function/vector.md).
4. Finally, given a query, you embed the query and find the closest chunk in the vector databases.
<div align="center">
<img src="https://github.com/the-pocket/PocketFlow/raw/main/assets/rag.png?raw=true" width="250"/>
</div>
# RAG (Retrieval Augmented Generation)
For certain LLM tasks like answering questions, providing relevant context is essential. One common architecture is a **two-stage** RAG pipeline:
1. **Offline stage**: Preprocess and index documents ("building the index").
2. **Online stage**: Given a question, generate answers by retrieving the most relevant context from the index.
---
## Stage 1: Offline Indexing
We create three Nodes:
1. `ChunkDocs` [chunks](../utility_function/chunking.md) raw text.
2. `EmbedDocs` [embeds](../utility_function/embedding.md) each chunk.
3. `StoreIndex` stores embeddings into a [vector database](../utility_function/vector.md).
```python
class ChunkDocs(BatchNode):
def prep(self, shared):
# A list of file paths in shared["files"]. We process each file.
return shared["files"]
def exec(self, filepath):
# read file content. In real usage, do error handling.
with open(filepath, "r", encoding="utf-8") as f:
text = f.read()
# chunk by 100 chars each
chunks = []
size = 100
for i in range(0, len(text), size):
chunks.append(text[i : i + size])
return chunks
def post(self, shared, prep_res, exec_res_list):
# exec_res_list is a list of chunk-lists, one per file.
# flatten them all into a single list of chunks.
all_chunks = []
for chunk_list in exec_res_list:
all_chunks.extend(chunk_list)
shared["all_chunks"] = all_chunks
class EmbedDocs(BatchNode):
def prep(self, shared):
return shared["all_chunks"]
def exec(self, chunk):
return get_embedding(chunk)
def post(self, shared, prep_res, exec_res_list):
# Store the list of embeddings.
shared["all_embeds"] = exec_res_list
print(f"Total embeddings: {len(exec_res_list)}")
class StoreIndex(Node):
def prep(self, shared):
# We'll read all embeds from shared.
return shared["all_embeds"]
def exec(self, all_embeds):
# Create a vector index (faiss or other DB in real usage).
index = create_index(all_embeds)
return index
def post(self, shared, prep_res, index):
shared["index"] = index
# Wire them in sequence
chunk_node = ChunkDocs()
embed_node = EmbedDocs()
store_node = StoreIndex()
chunk_node >> embed_node >> store_node
OfflineFlow = Flow(start=chunk_node)
```
Usage example:
```python
shared = {
"files": ["doc1.txt", "doc2.txt"], # any text files
}
OfflineFlow.run(shared)
```
---
## Stage 2: Online Query & Answer
We have 3 nodes:
1. `EmbedQuery` embeds the users question.
2. `RetrieveDocs` retrieves top chunk from the index.
3. `GenerateAnswer` calls the LLM with the question + chunk to produce the final answer.
```python
class EmbedQuery(Node):
def prep(self, shared):
return shared["question"]
def exec(self, question):
return get_embedding(question)
def post(self, shared, prep_res, q_emb):
shared["q_emb"] = q_emb
class RetrieveDocs(Node):
def prep(self, shared):
# We'll need the query embedding, plus the offline index/chunks
return shared["q_emb"], shared["index"], shared["all_chunks"]
def exec(self, inputs):
q_emb, index, chunks = inputs
I, D = search_index(index, q_emb, top_k=1)
best_id = I[0][0]
relevant_chunk = chunks[best_id]
return relevant_chunk
def post(self, shared, prep_res, relevant_chunk):
shared["retrieved_chunk"] = relevant_chunk
print("Retrieved chunk:", relevant_chunk[:60], "...")
class GenerateAnswer(Node):
def prep(self, shared):
return shared["question"], shared["retrieved_chunk"]
def exec(self, inputs):
question, chunk = inputs
prompt = f"Question: {question}\nContext: {chunk}\nAnswer:"
return call_llm(prompt)
def post(self, shared, prep_res, answer):
shared["answer"] = answer
print("Answer:", answer)
embed_qnode = EmbedQuery()
retrieve_node = RetrieveDocs()
generate_node = GenerateAnswer()
embed_qnode >> retrieve_node >> generate_node
OnlineFlow = Flow(start=embed_qnode)
```
Usage example:
```python
# Suppose we already ran OfflineFlow and have:
# shared["all_chunks"], shared["index"], etc.
shared["question"] = "Why do people like cats?"
OnlineFlow.run(shared)
# final answer in shared["answer"]
```