update streamlit app
This commit is contained in:
parent
442fd269fa
commit
966ced9dd1
|
|
@ -1,34 +1,51 @@
|
|||
# PocketFlow Streamlit Human-in-the-Loop (HITL) Application
|
||||
# PocketFlow Streamlit Image Generation HITL
|
||||
|
||||
Minimal Human-in-the-Loop (HITL) web application using PocketFlow and Streamlit. Submit text, review processed output, and approve/reject.
|
||||
Human-in-the-Loop (HITL) image generation application using PocketFlow and Streamlit. Enter text prompts, generate images with OpenAI, and approve/regenerate results.
|
||||
|
||||
## Features
|
||||
|
||||
- **Streamlit UI:** Simple, interactive interface for submitting tasks and providing feedback, built entirely in Python.
|
||||
- **PocketFlow Workflow:** Manages distinct processing stages (initial processing, finalization) using synchronous PocketFlow `Flow`s.
|
||||
- **Session State Management:** Utilizes Streamlit's `st.session_state` to manage the current stage of the workflow and to act as the `shared` data store for PocketFlow.
|
||||
- **Iterative Feedback Loop:** Allows users to reject processed output and resubmit, facilitating refinement.
|
||||
- **Image Generation:** Uses OpenAI's `gpt-image-1` model to generate images from text prompts
|
||||
- **Human Review:** Interactive interface to approve or regenerate images
|
||||
- **State Machine:** Clean state-based workflow (`initial_input` → `user_feedback` → `final`)
|
||||
- **PocketFlow Integration:** Uses PocketFlow `Node` and `Flow` for image generation with built-in retries
|
||||
- **Session State Management:** Streamlit session state acts as PocketFlow's shared store
|
||||
- **In-Memory Images:** Images stored as base64 strings, no disk storage required
|
||||
|
||||
## How to Run
|
||||
|
||||
1. **Install Dependencies:**
|
||||
1. **Set OpenAI API Key:**
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-openai-api-key"
|
||||
```
|
||||
|
||||
2. **Install Dependencies:**
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. **Run the Streamlit Application:**
|
||||
3. **Run the Streamlit Application:**
|
||||
```bash
|
||||
streamlit run app.py
|
||||
```
|
||||
|
||||
3. **Access the Web UI:**
|
||||
4. **Access the Web UI:**
|
||||
Open the URL provided by Streamlit (usually `http://localhost:8501`).
|
||||
|
||||
## Usage
|
||||
|
||||
1. **Enter Prompt**: Describe the image you want to generate
|
||||
2. **Generate**: Click "Generate Image" to create the image
|
||||
3. **Review**: View the generated image and choose:
|
||||
- **Approve**: Accept the image and move to final result
|
||||
- **Regenerate**: Generate a new image with the same prompt
|
||||
4. **Final**: View approved image and optionally start over
|
||||
|
||||
## Files
|
||||
|
||||
- [`app.py`](./app.py): Main Streamlit application logic and UI.
|
||||
- [`nodes.py`](./nodes.py): PocketFlow `Node` definitions.
|
||||
- [`flows.py`](./flows.py): PocketFlow `Flow` construction.
|
||||
- [`utils/process_task.py`](./utils/process_task.py): Simulated task processing utility.
|
||||
- [`requirements.txt`](./requirements.txt): Project dependencies.
|
||||
- [`README.md`](./README.md): This file.
|
||||
- [`app.py`](./app.py): Main Streamlit application with state-based UI
|
||||
- [`nodes.py`](./nodes.py): PocketFlow `GenerateImageNode` definition
|
||||
- [`flow.py`](./flow.py): PocketFlow `Flow` for image generation
|
||||
- [`utils/generate_image.py`](./utils/generate_image.py): OpenAI image generation utility
|
||||
- [`requirements.txt`](./requirements.txt): Project dependencies
|
||||
- [`docs/design.md`](./docs/design.md): System design documentation
|
||||
- [`README.md`](./README.md): This file
|
||||
|
|
|
|||
|
|
@ -1,176 +1,85 @@
|
|||
import streamlit as st
|
||||
from flow import create_initial_processing_flow, create_finalization_flow
|
||||
import base64
|
||||
from flow import create_generation_flow
|
||||
|
||||
st.title("PocketFlow HITL with Streamlit")
|
||||
st.title("PocketFlow Image Generation HITL")
|
||||
|
||||
# Initialize session state variables if they don't exist
|
||||
# Initialize session state for shared store
|
||||
if 'stage' not in st.session_state:
|
||||
st.session_state.stage = "initial"
|
||||
st.session_state.error_message = None
|
||||
st.session_state.stage = "initial_input"
|
||||
st.session_state.task_input = ""
|
||||
# Flow-related data will be added directly as needed
|
||||
print("Initialized session state.")
|
||||
st.session_state.generated_image = ""
|
||||
st.session_state.final_result = ""
|
||||
st.session_state.error_message = ""
|
||||
|
||||
# --- Helper Function to Reset State ---
|
||||
def reset_state():
|
||||
# Keep essential Streamlit state keys if necessary, or clear selectively
|
||||
keys_to_clear = [k for k in st.session_state.keys() if k not in ['stage', 'error_message', 'task_input']]
|
||||
for key in keys_to_clear:
|
||||
del st.session_state[key]
|
||||
# Debug info
|
||||
with st.expander("Session State"):
|
||||
st.json({k: v for k, v in st.session_state.items() if not k.startswith("_")})
|
||||
|
||||
st.session_state.stage = "initial"
|
||||
st.session_state.error_message = None
|
||||
st.session_state.task_input = ""
|
||||
print("Reset session state (keeping core stage/error keys).")
|
||||
# State-based UI
|
||||
if st.session_state.stage == "initial_input":
|
||||
st.header("1. Generate Image")
|
||||
|
||||
# --- Display Area for Shared Data (now the entire session state) ---
|
||||
with st.expander("Show Session State (Shared Data)"):
|
||||
# Convert to dict for clean JSON display, excluding internal Streamlit keys if desired
|
||||
display_state = {k: v for k, v in st.session_state.items() if not k.startswith("_")}
|
||||
st.json(display_state)
|
||||
prompt = st.text_area("Enter image prompt:", value=st.session_state.task_input, height=100)
|
||||
|
||||
# --- Stage: Initial Input ---
|
||||
st.header("1. Submit Data for Processing")
|
||||
task_input_value = st.text_area(
|
||||
"Enter data to process:",
|
||||
value=st.session_state.task_input,
|
||||
height=150,
|
||||
disabled=(st.session_state.stage != "initial")
|
||||
)
|
||||
if st.button("Generate Image"):
|
||||
if prompt.strip():
|
||||
st.session_state.task_input = prompt
|
||||
st.session_state.error_message = ""
|
||||
|
||||
# Disable button if not in 'initial' stage
|
||||
submit_button_disabled = (st.session_state.stage != "initial")
|
||||
if st.button("Submit", disabled=submit_button_disabled):
|
||||
if not submit_button_disabled: # Process click only if button was not meant to be disabled
|
||||
if not task_input_value.strip():
|
||||
st.error("Please enter some data to process.")
|
||||
try:
|
||||
with st.spinner("Generating image..."):
|
||||
flow = create_generation_flow()
|
||||
flow.run(st.session_state)
|
||||
st.rerun()
|
||||
except Exception as e:
|
||||
st.session_state.error_message = str(e)
|
||||
else:
|
||||
print(f"Submit button clicked. Input: '{task_input_value[:50]}...'")
|
||||
# Store input directly in session state
|
||||
st.session_state.task_input = task_input_value
|
||||
st.session_state.error_message = None
|
||||
# Clear previous results if any
|
||||
if "processed_output" in st.session_state: del st.session_state.processed_output
|
||||
if "final_result" in st.session_state: del st.session_state.final_result
|
||||
if "input_used_by_process" in st.session_state: del st.session_state.input_used_by_process
|
||||
st.error("Please enter a prompt")
|
||||
|
||||
try:
|
||||
with st.spinner("Processing initial task..."):
|
||||
initial_flow = create_initial_processing_flow()
|
||||
# Pass the entire session state as shared data
|
||||
initial_flow.run(st.session_state)
|
||||
elif st.session_state.stage == "user_feedback":
|
||||
st.header("2. Review Generated Image")
|
||||
|
||||
# Check if processing was successful (output exists directly in session state)
|
||||
if "processed_output" in st.session_state:
|
||||
st.session_state.stage = "awaiting_review"
|
||||
print("Initial processing complete. Moving to 'awaiting_review' stage.")
|
||||
st.rerun()
|
||||
else:
|
||||
st.session_state.error_message = "Processing failed to produce an output."
|
||||
print("Error: Processing failed, no output found.")
|
||||
# Keep stage as initial to allow retry/correction
|
||||
if st.session_state.generated_image:
|
||||
# Display image
|
||||
image_bytes = base64.b64decode(st.session_state.generated_image)
|
||||
st.image(image_bytes, caption=f"Prompt: {st.session_state.task_input}")
|
||||
|
||||
except Exception as e:
|
||||
st.session_state.error_message = f"An error occurred during initial processing: {e}"
|
||||
print(f"Exception during initial processing: {e}")
|
||||
# Keep stage as initial
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
# --- Stage: Awaiting Review ---
|
||||
st.header("2. Review Processed Output")
|
||||
# Get processed output directly from session state
|
||||
processed_output = st.session_state.get("processed_output", "No output to review yet.")
|
||||
# Display placeholder if no output yet or not in review stage
|
||||
output_to_display = processed_output if st.session_state.stage == "awaiting_review" else "Output will appear here after submission."
|
||||
|
||||
st.subheader("Output to Review:")
|
||||
st.markdown(f"```\\n{str(output_to_display)}\\n```") # Display as markdown code block
|
||||
|
||||
col1, col2, _ = st.columns([1, 1, 5]) # Layout buttons
|
||||
with col1:
|
||||
# Disable button if not in 'awaiting_review' stage
|
||||
approve_button_disabled = (st.session_state.stage != "awaiting_review")
|
||||
if st.button("Approve", disabled=approve_button_disabled):
|
||||
if not approve_button_disabled: # Process click only if button was not meant to be disabled
|
||||
print("Approve button clicked.")
|
||||
st.session_state.error_message = None
|
||||
try:
|
||||
with st.spinner("Finalizing result..."):
|
||||
finalization_flow = create_finalization_flow()
|
||||
# Pass the entire session state
|
||||
finalization_flow.run(st.session_state)
|
||||
|
||||
# Check for final result directly in session state
|
||||
if "final_result" in st.session_state:
|
||||
st.session_state.stage = "completed"
|
||||
print("Approval processed. Moving to 'completed' stage.")
|
||||
st.rerun()
|
||||
else:
|
||||
st.session_state.error_message = "Finalization failed to produce a result."
|
||||
print("Error: Finalization failed, no final_result found.")
|
||||
# Stay in review stage and show error.
|
||||
|
||||
except Exception as e:
|
||||
st.session_state.error_message = f"An error occurred during finalization: {e}"
|
||||
print(f"Exception during finalization: {e}")
|
||||
# Stay in review stage
|
||||
st.rerun() # Rerun to show error message
|
||||
|
||||
with col2:
|
||||
# Disable button if not in 'awaiting_review' stage
|
||||
reject_button_disabled = (st.session_state.stage != "awaiting_review")
|
||||
if st.button("Reject", disabled=reject_button_disabled):
|
||||
if not reject_button_disabled: # Process click only if button was not meant to be disabled
|
||||
print("Reject button clicked.")
|
||||
st.session_state.error_message = None # Clear previous errors
|
||||
# Go back to initial stage to allow modification/resubmission
|
||||
st.session_state.stage = "initial"
|
||||
# Keep the rejected output visible in the input field for modification
|
||||
st.session_state.task_input = st.session_state.get("processed_output", st.session_state.task_input)
|
||||
# Clear the processed output so it doesn't linger
|
||||
if "processed_output" in st.session_state: del st.session_state.processed_output
|
||||
if "final_result" in st.session_state: del st.session_state.final_result
|
||||
st.info("Task rejected. Modify the input below and resubmit.")
|
||||
print("Task rejected. Moving back to 'initial' stage.")
|
||||
with col1:
|
||||
if st.button("Approve", use_container_width=True):
|
||||
st.session_state.final_result = st.session_state.generated_image
|
||||
st.session_state.stage = "final"
|
||||
st.rerun()
|
||||
|
||||
# --- Stage: Completed ---
|
||||
st.header("3. Task Completed")
|
||||
# Get final result directly from session state
|
||||
final_result = st.session_state.get("final_result", "Task not completed yet.")
|
||||
# Display placeholder if not completed
|
||||
result_to_display = final_result if st.session_state.stage == "completed" else "Final result will appear here upon completion."
|
||||
with col2:
|
||||
if st.button("Regenerate", use_container_width=True):
|
||||
try:
|
||||
with st.spinner("Regenerating image..."):
|
||||
flow = create_generation_flow()
|
||||
flow.run(st.session_state)
|
||||
st.rerun()
|
||||
except Exception as e:
|
||||
st.session_state.error_message = str(e)
|
||||
|
||||
st.subheader("Final Result:")
|
||||
if st.session_state.stage == "completed":
|
||||
st.success("Task approved and completed successfully!")
|
||||
st.text_area("", value=str(result_to_display), height=200, disabled=True) # Always disabled for display
|
||||
elif st.session_state.stage == "final":
|
||||
st.header("3. Final Result")
|
||||
st.success("Image approved!")
|
||||
|
||||
# Disable button if not in 'completed' or 'rejected_final' stage
|
||||
start_over_button_disabled = not (st.session_state.stage == "completed" or st.session_state.stage == "rejected_final")
|
||||
if st.button("Start Over", disabled=start_over_button_disabled):
|
||||
if not start_over_button_disabled: # Process click only if button was not meant to be disabled
|
||||
print("Start Over button clicked.")
|
||||
reset_state()
|
||||
if st.session_state.final_result:
|
||||
image_bytes = base64.b64decode(st.session_state.final_result)
|
||||
st.image(image_bytes, caption=f"Final approved image: {st.session_state.task_input}")
|
||||
|
||||
if st.button("Start Over", use_container_width=True):
|
||||
st.session_state.stage = "initial_input"
|
||||
st.session_state.task_input = ""
|
||||
st.session_state.generated_image = ""
|
||||
st.session_state.final_result = ""
|
||||
st.session_state.error_message = ""
|
||||
st.rerun()
|
||||
|
||||
# --- Stage: Rejected -- (This section appears to be for a final rejected state, let's adjust for visibility)
|
||||
# elif st.session_state.stage == "rejected_final": # Removed conditional rendering
|
||||
# We'll integrate the display of rejection into the "Task Completed" area or manage it distinctly
|
||||
# For now, this specific "rejected_final" header might be redundant if we always show "3. Task Completed" area
|
||||
# And handle the message within it.
|
||||
|
||||
if st.session_state.stage == "rejected_final":
|
||||
st.header("3. Task Rejected") # This can be shown when in this specific state.
|
||||
st.error("The processed output was rejected.")
|
||||
# Get rejected output directly from session state
|
||||
rejected_output = st.session_state.get("processed_output", "")
|
||||
if rejected_output:
|
||||
st.text_area("Rejected Output:", value=str(rejected_output), height=150, disabled=True)
|
||||
# The "Start Over" button for this state is handled by the one in "Task Completed" section due to shared disabling logic.
|
||||
|
||||
# --- Display Error Messages ---
|
||||
# Show errors
|
||||
if st.session_state.error_message:
|
||||
st.error(st.session_state.error_message)
|
||||
# --- Add a button to reset state anytime (for debugging) ---
|
||||
# st.sidebar.button("Reset State", on_click=reset_state) # Removed sidebar
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue