update rag tutorial
This commit is contained in:
parent
337dcc0d66
commit
48630239b7
|
|
@ -1,98 +1,66 @@
|
||||||
# Retrieval Augmented Generation (RAG)
|
# Retrieval Augmented Generation (RAG)
|
||||||
|
|
||||||
This project demonstrates a simplified RAG system that retrieves relevant documents based on user queries.
|
This project demonstrates a simplified RAG system that retrieves relevant documents based on user queries and generates answers using an LLM.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- Document chunking for better retrieval granularity
|
- Document chunking for processing long texts
|
||||||
- Simple vector-based document retrieval
|
- FAISS-powered vector-based document retrieval
|
||||||
- Two-stage pipeline (offline indexing, online querying)
|
- LLM-powered answer generation
|
||||||
- FAISS-powered similarity search
|
|
||||||
|
|
||||||
## Getting Started
|
## How to Run
|
||||||
|
|
||||||
1. Install the required dependencies:
|
1. Set your API key:
|
||||||
|
```bash
|
||||||
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
```
|
||||||
|
Or update it directly in `utils.py`
|
||||||
|
|
||||||
```bash
|
2. Install and run:
|
||||||
pip install -r requirements.txt
|
```bash
|
||||||
```
|
pip install -r requirements.txt
|
||||||
|
python main.py
|
||||||
2. Run the application with a sample query:
|
```
|
||||||
|
|
||||||
```bash
|
|
||||||
python main.py --"Large Language Model"
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Or run without arguments to use the default query:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python main.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## API Key
|
|
||||||
|
|
||||||
By default, demo uses dummy embedding based on character frequencies. To use real OpenAI embedding:
|
|
||||||
|
|
||||||
1. Edit nodes.py to replace the dummy `get_embedding` with `get_openai_embedding`:
|
|
||||||
```python
|
|
||||||
# Change this line:
|
|
||||||
query_embedding = get_embedding(query)
|
|
||||||
# To this:
|
|
||||||
query_embedding = get_openai_embedding(query)
|
|
||||||
|
|
||||||
# And also change this line:
|
|
||||||
return get_embedding(text)
|
|
||||||
# To this:
|
|
||||||
return get_openai_embedding(text)
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Make sure your OpenAI API key is set:
|
|
||||||
```bash
|
|
||||||
export OPENAI_API_KEY="your-api-key-here"
|
|
||||||
```
|
|
||||||
|
|
||||||
## How It Works
|
## How It Works
|
||||||
|
|
||||||
The magic happens through a two-stage pipeline implemented with PocketFlow:
|
The magic happens through a two-phase pipeline implemented with PocketFlow:
|
||||||
|
|
||||||
```mermaid
|
```mermaid
|
||||||
graph TD
|
graph TD
|
||||||
subgraph OfflineFlow[Offline Document Indexing]
|
subgraph OfflineFlow[Offline Document Indexing]
|
||||||
ChunkDocs[ChunkDocumentsNode] --> EmbedDocs[EmbedDocumentsNode]
|
ChunkDocs[ChunkDocumentsNode] --> EmbedDocs[EmbedDocumentsNode] --> CreateIndex[CreateIndexNode]
|
||||||
EmbedDocs[EmbedDocumentsNode] --> CreateIndex[CreateIndexNode]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph OnlineFlow[Online Query Processing]
|
subgraph OnlineFlow[Online Processing]
|
||||||
EmbedQuery[EmbedQueryNode] --> RetrieveDoc[RetrieveDocumentNode]
|
EmbedQuery[EmbedQueryNode] --> RetrieveDoc[RetrieveDocumentNode] --> GenerateAnswer[GenerateAnswerNode]
|
||||||
end
|
end
|
||||||
```
|
```
|
||||||
|
|
||||||
Here's what each part does:
|
Here's what each part does:
|
||||||
1. **ChunkDocumentsNode**: Splits documents into smaller chunks for more granular retrieval
|
1. **ChunkDocumentsNode**: Breaks documents into smaller chunks for better retrieval
|
||||||
2. **EmbedDocumentsNode**: Converts document chunks into vector representations
|
2. **EmbedDocumentsNode**: Converts document chunks into vector representations
|
||||||
3. **CreateIndexNode**: Creates a searchable FAISS index from embeddings
|
3. **CreateIndexNode**: Creates a searchable FAISS index from embeddings
|
||||||
4. **EmbedQueryNode**: Converts user query into the same vector space
|
4. **EmbedQueryNode**: Converts user query into the same vector space
|
||||||
5. **RetrieveDocumentNode**: Finds the most similar document chunk using vector search
|
5. **RetrieveDocumentNode**: Finds the most similar document using vector search
|
||||||
|
6. **GenerateAnswerNode**: Uses an LLM to generate an answer based on the retrieved content
|
||||||
|
|
||||||
## Example Output
|
## Example Output
|
||||||
|
|
||||||
```
|
```
|
||||||
==================================================
|
|
||||||
PocketFlow RAG Document Retrieval
|
|
||||||
==================================================
|
|
||||||
✅ Created 5 chunks from 5 documents
|
✅ Created 5 chunks from 5 documents
|
||||||
✅ Created 5 document embeddings
|
✅ Created 5 document embeddings
|
||||||
🔍 Creating search index...
|
🔍 Creating search index...
|
||||||
✅ Index created with 5 vectors
|
✅ Index created with 5 vectors
|
||||||
🔍 Embedding query: Large Language Model
|
🔍 Embedding query: How to install PocketFlow?
|
||||||
🔎 Searching for relevant documents...
|
🔎 Searching for relevant documents...
|
||||||
📄 Retrieved document (index: 3, distance: 0.3296)
|
📄 Retrieved document (index: 0, distance: 0.3427)
|
||||||
📄 Most relevant text: "PocketFlow is a 100-line Large Language Model Framework."
|
📄 Most relevant text: "Pocket Flow is a 100-line minimalist LLM framework
|
||||||
|
Lightweight: Just 100 lines. Zero bloat, zero dependencies, zero vendor lock-in.
|
||||||
|
Expressive: Everything you love—(Multi-)Agents, Workflow, RAG, and more.
|
||||||
|
Agentic Coding: Let AI Agents (e.g., Cursor AI) build Agents—10x productivity boost!
|
||||||
|
To install, pip install pocketflow or just copy the source code (only 100 lines)."
|
||||||
|
|
||||||
|
🤖 Generated Answer:
|
||||||
|
To install PocketFlow, use the command `pip install pocketflow` or simply copy its 100 lines of source code.
|
||||||
```
|
```
|
||||||
|
|
||||||
## Files
|
|
||||||
|
|
||||||
- [`main.py`](./main.py): Main entry point for running the RAG demonstration
|
|
||||||
- [`flow.py`](./flow.py): Configures the flows that connect the nodes
|
|
||||||
- [`nodes.py`](./nodes.py): Defines the nodes for document processing and retrieval
|
|
||||||
- [`utils.py`](./utils.py): Utility functions including chunking and embedding functions
|
|
||||||
|
|
@ -1,20 +1,27 @@
|
||||||
from pocketflow import Flow
|
from pocketflow import Flow
|
||||||
from nodes import EmbedDocumentsNode, CreateIndexNode, EmbedQueryNode, RetrieveDocumentNode, ChunkDocumentsNode
|
from nodes import EmbedDocumentsNode, CreateIndexNode, EmbedQueryNode, RetrieveDocumentNode, ChunkDocumentsNode, GenerateAnswerNode
|
||||||
|
|
||||||
def get_offline_flow():
|
def get_offline_flow():
|
||||||
# Create offline flow for document indexing
|
# Create offline flow for document indexing
|
||||||
chunk_docs_node = ChunkDocumentsNode()
|
chunk_docs_node = ChunkDocumentsNode()
|
||||||
embed_docs_node = EmbedDocumentsNode()
|
embed_docs_node = EmbedDocumentsNode()
|
||||||
create_index_node = CreateIndexNode()
|
create_index_node = CreateIndexNode()
|
||||||
|
|
||||||
|
# Connect the nodes
|
||||||
chunk_docs_node >> embed_docs_node >> create_index_node
|
chunk_docs_node >> embed_docs_node >> create_index_node
|
||||||
|
|
||||||
offline_flow = Flow(start=chunk_docs_node)
|
offline_flow = Flow(start=chunk_docs_node)
|
||||||
return offline_flow
|
return offline_flow
|
||||||
|
|
||||||
def get_online_flow():
|
def get_online_flow():
|
||||||
# Create online flow for document retrieval
|
# Create online flow for document retrieval and answer generation
|
||||||
embed_query_node = EmbedQueryNode()
|
embed_query_node = EmbedQueryNode()
|
||||||
retrieve_doc_node = RetrieveDocumentNode()
|
retrieve_doc_node = RetrieveDocumentNode()
|
||||||
embed_query_node >> retrieve_doc_node
|
generate_answer_node = GenerateAnswerNode()
|
||||||
|
|
||||||
|
# Connect the nodes
|
||||||
|
embed_query_node >> retrieve_doc_node >> generate_answer_node
|
||||||
|
|
||||||
online_flow = Flow(start=embed_query_node)
|
online_flow = Flow(start=embed_query_node)
|
||||||
return online_flow
|
return online_flow
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,23 +9,53 @@ def run_rag_demo():
|
||||||
1. Indexes a set of sample documents (offline flow)
|
1. Indexes a set of sample documents (offline flow)
|
||||||
2. Takes a query from the command line
|
2. Takes a query from the command line
|
||||||
3. Retrieves the most relevant document (online flow)
|
3. Retrieves the most relevant document (online flow)
|
||||||
|
4. Generates an answer using an LLM
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Sample texts - corpus of documents to search
|
# Sample texts - specialized/fictional content that benefits from RAG
|
||||||
texts = [
|
texts = [
|
||||||
"The quick brown fox jumps over the lazy dog.",
|
# PocketFlow framework
|
||||||
"Machine learning is a subset of artificial intelligence.",
|
"""Pocket Flow is a 100-line minimalist LLM framework
|
||||||
"Python is a popular programming language for data science.",
|
Lightweight: Just 100 lines. Zero bloat, zero dependencies, zero vendor lock-in.
|
||||||
"PocketFlow is a 100-line Large Language Model Framework.",
|
Expressive: Everything you love—(Multi-)Agents, Workflow, RAG, and more.
|
||||||
"The weather is sunny and warm today.",
|
Agentic Coding: Let AI Agents (e.g., Cursor AI) build Agents—10x productivity boost!
|
||||||
|
To install, pip install pocketflow or just copy the source code (only 100 lines).""",
|
||||||
|
|
||||||
|
# Fictional medical device
|
||||||
|
"""NeurAlign M7 is a revolutionary non-invasive neural alignment device.
|
||||||
|
Targeted magnetic resonance technology increases neuroplasticity in specific brain regions.
|
||||||
|
Clinical trials showed 72% improvement in PTSD treatment outcomes.
|
||||||
|
Developed by Cortex Medical in 2024 as an adjunct to standard cognitive therapy.
|
||||||
|
Portable design allows for in-home use with remote practitioner monitoring.""",
|
||||||
|
|
||||||
|
# Made-up historical event
|
||||||
|
"""The Velvet Revolution of Caldonia (1967-1968) ended Generalissimo Verak's 40-year rule.
|
||||||
|
Led by poet Eliza Markovian through underground literary societies.
|
||||||
|
Culminated in the Great Silence Protest with 300,000 silent protesters.
|
||||||
|
First democratic elections held in March 1968 with 94% voter turnout.
|
||||||
|
Became a model for non-violent political transitions in neighboring regions.""",
|
||||||
|
|
||||||
|
# Fictional technology
|
||||||
|
"""Q-Mesh is QuantumLeap Technologies' instantaneous data synchronization protocol.
|
||||||
|
Utilizes directed acyclic graph consensus for 500,000 transactions per second.
|
||||||
|
Consumes 95% less energy than traditional blockchain systems.
|
||||||
|
Adopted by three central banks for secure financial data transfer.
|
||||||
|
Released in February 2024 after five years of development in stealth mode.""",
|
||||||
|
|
||||||
|
# Made-up scientific research
|
||||||
|
"""Harlow Institute's Mycelium Strain HI-271 removes 99.7% of PFAS from contaminated soil.
|
||||||
|
Engineered fungi create symbiotic relationships with native soil bacteria.
|
||||||
|
Breaks down "forever chemicals" into non-toxic compounds within 60 days.
|
||||||
|
Field tests successfully remediated previously permanently contaminated industrial sites.
|
||||||
|
Deployment costs 80% less than traditional chemical extraction methods."""
|
||||||
]
|
]
|
||||||
|
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
print("PocketFlow RAG Document Retrieval")
|
print("PocketFlow RAG Document Retrieval")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
# Default query
|
# Default query about the fictional technology
|
||||||
default_query = "Large Language Model"
|
default_query = "How to install PocketFlow?"
|
||||||
|
|
||||||
# Get query from command line if provided with --
|
# Get query from command line if provided with --
|
||||||
query = default_query
|
query = default_query
|
||||||
|
|
@ -41,13 +71,14 @@ def run_rag_demo():
|
||||||
"index": None,
|
"index": None,
|
||||||
"query": query,
|
"query": query,
|
||||||
"query_embedding": None,
|
"query_embedding": None,
|
||||||
"retrieved_document": None
|
"retrieved_document": None,
|
||||||
|
"generated_answer": None
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize and run the offline flow (document indexing)
|
# Initialize and run the offline flow (document indexing)
|
||||||
offline_flow.run(shared)
|
offline_flow.run(shared)
|
||||||
|
|
||||||
# Run the online flow to retrieve the most relevant document
|
# Run the online flow to retrieve the most relevant document and generate an answer
|
||||||
online_flow.run(shared)
|
online_flow.run(shared)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from pocketflow import Node, Flow, BatchNode
|
from pocketflow import Node, Flow, BatchNode
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import faiss
|
import faiss
|
||||||
from utils import get_embedding, get_openai_embedding, fixed_size_chunk
|
from utils import call_llm, get_embedding, get_simple_embedding, fixed_size_chunk
|
||||||
|
|
||||||
# Nodes for the offline flow
|
# Nodes for the offline flow
|
||||||
class ChunkDocumentsNode(BatchNode):
|
class ChunkDocumentsNode(BatchNode):
|
||||||
|
|
@ -115,3 +115,29 @@ class RetrieveDocumentNode(Node):
|
||||||
print(f"📄 Retrieved document (index: {exec_res['index']}, distance: {exec_res['distance']:.4f})")
|
print(f"📄 Retrieved document (index: {exec_res['index']}, distance: {exec_res['distance']:.4f})")
|
||||||
print(f"📄 Most relevant text: \"{exec_res['text']}\"")
|
print(f"📄 Most relevant text: \"{exec_res['text']}\"")
|
||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
|
class GenerateAnswerNode(Node):
|
||||||
|
def prep(self, shared):
|
||||||
|
"""Get query, retrieved document, and any other context needed"""
|
||||||
|
return shared["query"], shared["retrieved_document"]
|
||||||
|
|
||||||
|
def exec(self, inputs):
|
||||||
|
"""Generate an answer using the LLM"""
|
||||||
|
query, retrieved_doc = inputs
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
Briefly answer the following question based on the context provided:
|
||||||
|
Question: {query}
|
||||||
|
Context: {retrieved_doc['text']}
|
||||||
|
Answer:
|
||||||
|
"""
|
||||||
|
|
||||||
|
answer = call_llm(prompt)
|
||||||
|
return answer
|
||||||
|
|
||||||
|
def post(self, shared, prep_res, exec_res):
|
||||||
|
"""Store generated answer in shared store"""
|
||||||
|
shared["generated_answer"] = exec_res
|
||||||
|
print("\n🤖 Generated Answer:")
|
||||||
|
print(exec_res)
|
||||||
|
return "default"
|
||||||
|
|
@ -2,7 +2,15 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
def get_embedding(text):
|
def call_llm(prompt):
|
||||||
|
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
|
||||||
|
r = client.chat.completions.create(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=[{"role": "user", "content": prompt}]
|
||||||
|
)
|
||||||
|
return r.choices[0].message.content
|
||||||
|
|
||||||
|
def get_simple_embedding(text):
|
||||||
"""
|
"""
|
||||||
A simple embedding function that converts text to vector.
|
A simple embedding function that converts text to vector.
|
||||||
|
|
||||||
|
|
@ -27,8 +35,8 @@ def get_embedding(text):
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def get_openai_embedding(text):
|
def get_embedding(text):
|
||||||
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "YOUR_API_KEY"))
|
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
|
||||||
|
|
||||||
response = client.embeddings.create(
|
response = client.embeddings.create(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
|
|
@ -48,37 +56,45 @@ def fixed_size_chunk(text, chunk_size=2000):
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Test the embedding function
|
print("=== Testing call_llm ===")
|
||||||
|
prompt = "In a few words, what is the meaning of life?"
|
||||||
|
print(f"Prompt: {prompt}")
|
||||||
|
response = call_llm(prompt)
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
print("=== Testing embedding function ===")
|
||||||
|
|
||||||
text1 = "The quick brown fox jumps over the lazy dog."
|
text1 = "The quick brown fox jumps over the lazy dog."
|
||||||
text2 = "Python is a popular programming language for data science."
|
text2 = "Python is a popular programming language for data science."
|
||||||
|
|
||||||
emb1 = get_embedding(text1)
|
# Test the simple embedding function
|
||||||
emb2 = get_embedding(text2)
|
# emb1 = get_embedding(text1)
|
||||||
|
# emb2 = get_embedding(text2)
|
||||||
|
|
||||||
print(f"Embedding 1 shape: {emb1.shape}")
|
# print(f"Embedding 1 shape: {emb1.shape}")
|
||||||
print(f"Embedding 2 shape: {emb2.shape}")
|
# print(f"Embedding 2 shape: {emb2.shape}")
|
||||||
|
|
||||||
# Calculate similarity (dot product)
|
# # Calculate similarity (dot product)
|
||||||
similarity = np.dot(emb1, emb2)
|
# similarity = np.dot(emb1, emb2)
|
||||||
print(f"Similarity between texts: {similarity:.4f}")
|
# print(f"Similarity between texts: {similarity:.4f}")
|
||||||
|
|
||||||
# Compare with a different text
|
# # Compare with a different text
|
||||||
text3 = "Machine learning is a subset of artificial intelligence."
|
# text3 = "Machine learning is a subset of artificial intelligence."
|
||||||
emb3 = get_embedding(text3)
|
# emb3 = get_embedding(text3)
|
||||||
similarity13 = np.dot(emb1, emb3)
|
# similarity13 = np.dot(emb1, emb3)
|
||||||
similarity23 = np.dot(emb2, emb3)
|
# similarity23 = np.dot(emb2, emb3)
|
||||||
|
|
||||||
print(f"Similarity between text1 and text3: {similarity13:.4f}")
|
# print(f"Similarity between text1 and text3: {similarity13:.4f}")
|
||||||
print(f"Similarity between text2 and text3: {similarity23:.4f}")
|
# print(f"Similarity between text2 and text3: {similarity23:.4f}")
|
||||||
|
|
||||||
# These simple comparisons should show higher similarity
|
# # These simple comparisons should show higher similarity
|
||||||
# between related concepts (text2 and text3) than between
|
# # between related concepts (text2 and text3) than between
|
||||||
# unrelated texts (text1 and text3)
|
# # unrelated texts (text1 and text3)
|
||||||
|
|
||||||
# Uncomment to test OpenAI embeddings (requires API key)
|
# Test OpenAI embeddings (requires API key)
|
||||||
print("\nTesting OpenAI embeddings (requires API key):")
|
print("\nTesting OpenAI embeddings (requires API key):")
|
||||||
oai_emb1 = get_openai_embedding(text1)
|
oai_emb1 = get_embedding(text1)
|
||||||
oai_emb2 = get_openai_embedding(text2)
|
oai_emb2 = get_embedding(text2)
|
||||||
print(f"OpenAI Embedding 1 shape: {oai_emb1.shape}")
|
print(f"OpenAI Embedding 1 shape: {oai_emb1.shape}")
|
||||||
oai_similarity = np.dot(oai_emb1, oai_emb2)
|
oai_similarity = np.dot(oai_emb1, oai_emb2)
|
||||||
print(f"OpenAI similarity between texts: {oai_similarity:.4f}")
|
print(f"OpenAI similarity between texts: {oai_similarity:.4f}")
|
||||||
Loading…
Reference in New Issue