update rag tutorial
This commit is contained in:
parent
337dcc0d66
commit
48630239b7
|
|
@ -1,98 +1,66 @@
|
|||
# Retrieval Augmented Generation (RAG)
|
||||
|
||||
This project demonstrates a simplified RAG system that retrieves relevant documents based on user queries.
|
||||
This project demonstrates a simplified RAG system that retrieves relevant documents based on user queries and generates answers using an LLM.
|
||||
|
||||
## Features
|
||||
|
||||
- Document chunking for better retrieval granularity
|
||||
- Simple vector-based document retrieval
|
||||
- Two-stage pipeline (offline indexing, online querying)
|
||||
- FAISS-powered similarity search
|
||||
- Document chunking for processing long texts
|
||||
- FAISS-powered vector-based document retrieval
|
||||
- LLM-powered answer generation
|
||||
|
||||
## Getting Started
|
||||
## How to Run
|
||||
|
||||
1. Install the required dependencies:
|
||||
1. Set your API key:
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
Or update it directly in `utils.py`
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Run the application with a sample query:
|
||||
|
||||
```bash
|
||||
python main.py --"Large Language Model"
|
||||
```
|
||||
|
||||
3. Or run without arguments to use the default query:
|
||||
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
## API Key
|
||||
|
||||
By default, demo uses dummy embedding based on character frequencies. To use real OpenAI embedding:
|
||||
|
||||
1. Edit nodes.py to replace the dummy `get_embedding` with `get_openai_embedding`:
|
||||
```python
|
||||
# Change this line:
|
||||
query_embedding = get_embedding(query)
|
||||
# To this:
|
||||
query_embedding = get_openai_embedding(query)
|
||||
|
||||
# And also change this line:
|
||||
return get_embedding(text)
|
||||
# To this:
|
||||
return get_openai_embedding(text)
|
||||
```
|
||||
|
||||
2. Make sure your OpenAI API key is set:
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
2. Install and run:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
python main.py
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
The magic happens through a two-stage pipeline implemented with PocketFlow:
|
||||
The magic happens through a two-phase pipeline implemented with PocketFlow:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph OfflineFlow[Offline Document Indexing]
|
||||
ChunkDocs[ChunkDocumentsNode] --> EmbedDocs[EmbedDocumentsNode]
|
||||
EmbedDocs[EmbedDocumentsNode] --> CreateIndex[CreateIndexNode]
|
||||
ChunkDocs[ChunkDocumentsNode] --> EmbedDocs[EmbedDocumentsNode] --> CreateIndex[CreateIndexNode]
|
||||
end
|
||||
|
||||
subgraph OnlineFlow[Online Query Processing]
|
||||
EmbedQuery[EmbedQueryNode] --> RetrieveDoc[RetrieveDocumentNode]
|
||||
subgraph OnlineFlow[Online Processing]
|
||||
EmbedQuery[EmbedQueryNode] --> RetrieveDoc[RetrieveDocumentNode] --> GenerateAnswer[GenerateAnswerNode]
|
||||
end
|
||||
```
|
||||
|
||||
Here's what each part does:
|
||||
1. **ChunkDocumentsNode**: Splits documents into smaller chunks for more granular retrieval
|
||||
1. **ChunkDocumentsNode**: Breaks documents into smaller chunks for better retrieval
|
||||
2. **EmbedDocumentsNode**: Converts document chunks into vector representations
|
||||
3. **CreateIndexNode**: Creates a searchable FAISS index from embeddings
|
||||
4. **EmbedQueryNode**: Converts user query into the same vector space
|
||||
5. **RetrieveDocumentNode**: Finds the most similar document chunk using vector search
|
||||
5. **RetrieveDocumentNode**: Finds the most similar document using vector search
|
||||
6. **GenerateAnswerNode**: Uses an LLM to generate an answer based on the retrieved content
|
||||
|
||||
## Example Output
|
||||
|
||||
```
|
||||
==================================================
|
||||
PocketFlow RAG Document Retrieval
|
||||
==================================================
|
||||
✅ Created 5 chunks from 5 documents
|
||||
✅ Created 5 document embeddings
|
||||
🔍 Creating search index...
|
||||
✅ Index created with 5 vectors
|
||||
🔍 Embedding query: Large Language Model
|
||||
🔍 Embedding query: How to install PocketFlow?
|
||||
🔎 Searching for relevant documents...
|
||||
📄 Retrieved document (index: 3, distance: 0.3296)
|
||||
📄 Most relevant text: "PocketFlow is a 100-line Large Language Model Framework."
|
||||
📄 Retrieved document (index: 0, distance: 0.3427)
|
||||
📄 Most relevant text: "Pocket Flow is a 100-line minimalist LLM framework
|
||||
Lightweight: Just 100 lines. Zero bloat, zero dependencies, zero vendor lock-in.
|
||||
Expressive: Everything you love—(Multi-)Agents, Workflow, RAG, and more.
|
||||
Agentic Coding: Let AI Agents (e.g., Cursor AI) build Agents—10x productivity boost!
|
||||
To install, pip install pocketflow or just copy the source code (only 100 lines)."
|
||||
|
||||
🤖 Generated Answer:
|
||||
To install PocketFlow, use the command `pip install pocketflow` or simply copy its 100 lines of source code.
|
||||
```
|
||||
|
||||
## Files
|
||||
|
||||
- [`main.py`](./main.py): Main entry point for running the RAG demonstration
|
||||
- [`flow.py`](./flow.py): Configures the flows that connect the nodes
|
||||
- [`nodes.py`](./nodes.py): Defines the nodes for document processing and retrieval
|
||||
- [`utils.py`](./utils.py): Utility functions including chunking and embedding functions
|
||||
|
|
@ -1,20 +1,27 @@
|
|||
from pocketflow import Flow
|
||||
from nodes import EmbedDocumentsNode, CreateIndexNode, EmbedQueryNode, RetrieveDocumentNode, ChunkDocumentsNode
|
||||
from nodes import EmbedDocumentsNode, CreateIndexNode, EmbedQueryNode, RetrieveDocumentNode, ChunkDocumentsNode, GenerateAnswerNode
|
||||
|
||||
def get_offline_flow():
|
||||
# Create offline flow for document indexing
|
||||
chunk_docs_node = ChunkDocumentsNode()
|
||||
embed_docs_node = EmbedDocumentsNode()
|
||||
create_index_node = CreateIndexNode()
|
||||
|
||||
# Connect the nodes
|
||||
chunk_docs_node >> embed_docs_node >> create_index_node
|
||||
|
||||
offline_flow = Flow(start=chunk_docs_node)
|
||||
return offline_flow
|
||||
|
||||
def get_online_flow():
|
||||
# Create online flow for document retrieval
|
||||
# Create online flow for document retrieval and answer generation
|
||||
embed_query_node = EmbedQueryNode()
|
||||
retrieve_doc_node = RetrieveDocumentNode()
|
||||
embed_query_node >> retrieve_doc_node
|
||||
generate_answer_node = GenerateAnswerNode()
|
||||
|
||||
# Connect the nodes
|
||||
embed_query_node >> retrieve_doc_node >> generate_answer_node
|
||||
|
||||
online_flow = Flow(start=embed_query_node)
|
||||
return online_flow
|
||||
|
||||
|
|
|
|||
|
|
@ -9,23 +9,53 @@ def run_rag_demo():
|
|||
1. Indexes a set of sample documents (offline flow)
|
||||
2. Takes a query from the command line
|
||||
3. Retrieves the most relevant document (online flow)
|
||||
4. Generates an answer using an LLM
|
||||
"""
|
||||
|
||||
# Sample texts - corpus of documents to search
|
||||
# Sample texts - specialized/fictional content that benefits from RAG
|
||||
texts = [
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"Python is a popular programming language for data science.",
|
||||
"PocketFlow is a 100-line Large Language Model Framework.",
|
||||
"The weather is sunny and warm today.",
|
||||
# PocketFlow framework
|
||||
"""Pocket Flow is a 100-line minimalist LLM framework
|
||||
Lightweight: Just 100 lines. Zero bloat, zero dependencies, zero vendor lock-in.
|
||||
Expressive: Everything you love—(Multi-)Agents, Workflow, RAG, and more.
|
||||
Agentic Coding: Let AI Agents (e.g., Cursor AI) build Agents—10x productivity boost!
|
||||
To install, pip install pocketflow or just copy the source code (only 100 lines).""",
|
||||
|
||||
# Fictional medical device
|
||||
"""NeurAlign M7 is a revolutionary non-invasive neural alignment device.
|
||||
Targeted magnetic resonance technology increases neuroplasticity in specific brain regions.
|
||||
Clinical trials showed 72% improvement in PTSD treatment outcomes.
|
||||
Developed by Cortex Medical in 2024 as an adjunct to standard cognitive therapy.
|
||||
Portable design allows for in-home use with remote practitioner monitoring.""",
|
||||
|
||||
# Made-up historical event
|
||||
"""The Velvet Revolution of Caldonia (1967-1968) ended Generalissimo Verak's 40-year rule.
|
||||
Led by poet Eliza Markovian through underground literary societies.
|
||||
Culminated in the Great Silence Protest with 300,000 silent protesters.
|
||||
First democratic elections held in March 1968 with 94% voter turnout.
|
||||
Became a model for non-violent political transitions in neighboring regions.""",
|
||||
|
||||
# Fictional technology
|
||||
"""Q-Mesh is QuantumLeap Technologies' instantaneous data synchronization protocol.
|
||||
Utilizes directed acyclic graph consensus for 500,000 transactions per second.
|
||||
Consumes 95% less energy than traditional blockchain systems.
|
||||
Adopted by three central banks for secure financial data transfer.
|
||||
Released in February 2024 after five years of development in stealth mode.""",
|
||||
|
||||
# Made-up scientific research
|
||||
"""Harlow Institute's Mycelium Strain HI-271 removes 99.7% of PFAS from contaminated soil.
|
||||
Engineered fungi create symbiotic relationships with native soil bacteria.
|
||||
Breaks down "forever chemicals" into non-toxic compounds within 60 days.
|
||||
Field tests successfully remediated previously permanently contaminated industrial sites.
|
||||
Deployment costs 80% less than traditional chemical extraction methods."""
|
||||
]
|
||||
|
||||
print("=" * 50)
|
||||
print("PocketFlow RAG Document Retrieval")
|
||||
print("=" * 50)
|
||||
|
||||
# Default query
|
||||
default_query = "Large Language Model"
|
||||
# Default query about the fictional technology
|
||||
default_query = "How to install PocketFlow?"
|
||||
|
||||
# Get query from command line if provided with --
|
||||
query = default_query
|
||||
|
|
@ -41,13 +71,14 @@ def run_rag_demo():
|
|||
"index": None,
|
||||
"query": query,
|
||||
"query_embedding": None,
|
||||
"retrieved_document": None
|
||||
"retrieved_document": None,
|
||||
"generated_answer": None
|
||||
}
|
||||
|
||||
# Initialize and run the offline flow (document indexing)
|
||||
offline_flow.run(shared)
|
||||
|
||||
# Run the online flow to retrieve the most relevant document
|
||||
# Run the online flow to retrieve the most relevant document and generate an answer
|
||||
online_flow.run(shared)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from pocketflow import Node, Flow, BatchNode
|
||||
import numpy as np
|
||||
import faiss
|
||||
from utils import get_embedding, get_openai_embedding, fixed_size_chunk
|
||||
from utils import call_llm, get_embedding, get_simple_embedding, fixed_size_chunk
|
||||
|
||||
# Nodes for the offline flow
|
||||
class ChunkDocumentsNode(BatchNode):
|
||||
|
|
@ -115,3 +115,29 @@ class RetrieveDocumentNode(Node):
|
|||
print(f"📄 Retrieved document (index: {exec_res['index']}, distance: {exec_res['distance']:.4f})")
|
||||
print(f"📄 Most relevant text: \"{exec_res['text']}\"")
|
||||
return "default"
|
||||
|
||||
class GenerateAnswerNode(Node):
|
||||
def prep(self, shared):
|
||||
"""Get query, retrieved document, and any other context needed"""
|
||||
return shared["query"], shared["retrieved_document"]
|
||||
|
||||
def exec(self, inputs):
|
||||
"""Generate an answer using the LLM"""
|
||||
query, retrieved_doc = inputs
|
||||
|
||||
prompt = f"""
|
||||
Briefly answer the following question based on the context provided:
|
||||
Question: {query}
|
||||
Context: {retrieved_doc['text']}
|
||||
Answer:
|
||||
"""
|
||||
|
||||
answer = call_llm(prompt)
|
||||
return answer
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
"""Store generated answer in shared store"""
|
||||
shared["generated_answer"] = exec_res
|
||||
print("\n🤖 Generated Answer:")
|
||||
print(exec_res)
|
||||
return "default"
|
||||
|
|
@ -2,7 +2,15 @@ import os
|
|||
import numpy as np
|
||||
from openai import OpenAI
|
||||
|
||||
def get_embedding(text):
|
||||
def call_llm(prompt):
|
||||
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
|
||||
r = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
return r.choices[0].message.content
|
||||
|
||||
def get_simple_embedding(text):
|
||||
"""
|
||||
A simple embedding function that converts text to vector.
|
||||
|
||||
|
|
@ -27,8 +35,8 @@ def get_embedding(text):
|
|||
|
||||
return embedding
|
||||
|
||||
def get_openai_embedding(text):
|
||||
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "YOUR_API_KEY"))
|
||||
def get_embedding(text):
|
||||
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
|
||||
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-ada-002",
|
||||
|
|
@ -48,37 +56,45 @@ def fixed_size_chunk(text, chunk_size=2000):
|
|||
return chunks
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the embedding function
|
||||
print("=== Testing call_llm ===")
|
||||
prompt = "In a few words, what is the meaning of life?"
|
||||
print(f"Prompt: {prompt}")
|
||||
response = call_llm(prompt)
|
||||
print(f"Response: {response}")
|
||||
|
||||
print("=== Testing embedding function ===")
|
||||
|
||||
text1 = "The quick brown fox jumps over the lazy dog."
|
||||
text2 = "Python is a popular programming language for data science."
|
||||
|
||||
emb1 = get_embedding(text1)
|
||||
emb2 = get_embedding(text2)
|
||||
# Test the simple embedding function
|
||||
# emb1 = get_embedding(text1)
|
||||
# emb2 = get_embedding(text2)
|
||||
|
||||
print(f"Embedding 1 shape: {emb1.shape}")
|
||||
print(f"Embedding 2 shape: {emb2.shape}")
|
||||
# print(f"Embedding 1 shape: {emb1.shape}")
|
||||
# print(f"Embedding 2 shape: {emb2.shape}")
|
||||
|
||||
# Calculate similarity (dot product)
|
||||
similarity = np.dot(emb1, emb2)
|
||||
print(f"Similarity between texts: {similarity:.4f}")
|
||||
# # Calculate similarity (dot product)
|
||||
# similarity = np.dot(emb1, emb2)
|
||||
# print(f"Similarity between texts: {similarity:.4f}")
|
||||
|
||||
# Compare with a different text
|
||||
text3 = "Machine learning is a subset of artificial intelligence."
|
||||
emb3 = get_embedding(text3)
|
||||
similarity13 = np.dot(emb1, emb3)
|
||||
similarity23 = np.dot(emb2, emb3)
|
||||
# # Compare with a different text
|
||||
# text3 = "Machine learning is a subset of artificial intelligence."
|
||||
# emb3 = get_embedding(text3)
|
||||
# similarity13 = np.dot(emb1, emb3)
|
||||
# similarity23 = np.dot(emb2, emb3)
|
||||
|
||||
print(f"Similarity between text1 and text3: {similarity13:.4f}")
|
||||
print(f"Similarity between text2 and text3: {similarity23:.4f}")
|
||||
# print(f"Similarity between text1 and text3: {similarity13:.4f}")
|
||||
# print(f"Similarity between text2 and text3: {similarity23:.4f}")
|
||||
|
||||
# These simple comparisons should show higher similarity
|
||||
# between related concepts (text2 and text3) than between
|
||||
# unrelated texts (text1 and text3)
|
||||
# # These simple comparisons should show higher similarity
|
||||
# # between related concepts (text2 and text3) than between
|
||||
# # unrelated texts (text1 and text3)
|
||||
|
||||
# Uncomment to test OpenAI embeddings (requires API key)
|
||||
# Test OpenAI embeddings (requires API key)
|
||||
print("\nTesting OpenAI embeddings (requires API key):")
|
||||
oai_emb1 = get_openai_embedding(text1)
|
||||
oai_emb2 = get_openai_embedding(text2)
|
||||
oai_emb1 = get_embedding(text1)
|
||||
oai_emb2 = get_embedding(text2)
|
||||
print(f"OpenAI Embedding 1 shape: {oai_emb1.shape}")
|
||||
oai_similarity = np.dot(oai_emb1, oai_emb2)
|
||||
print(f"OpenAI similarity between texts: {oai_similarity:.4f}")
|
||||
Loading…
Reference in New Issue