86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
import streamlit as st
|
|
import base64
|
|
from flow import create_generation_flow
|
|
|
|
st.title("PocketFlow Image Generation HITL")
|
|
|
|
# Initialize session state for shared store
|
|
if 'stage' not in st.session_state:
|
|
st.session_state.stage = "initial_input"
|
|
st.session_state.task_input = ""
|
|
st.session_state.generated_image = ""
|
|
st.session_state.final_result = ""
|
|
st.session_state.error_message = ""
|
|
|
|
# Debug info
|
|
with st.expander("Session State"):
|
|
st.json({k: v for k, v in st.session_state.items() if not k.startswith("_")})
|
|
|
|
# State-based UI
|
|
if st.session_state.stage == "initial_input":
|
|
st.header("1. Generate Image")
|
|
|
|
prompt = st.text_area("Enter image prompt:", value=st.session_state.task_input, height=100)
|
|
|
|
if st.button("Generate Image"):
|
|
if prompt.strip():
|
|
st.session_state.task_input = prompt
|
|
st.session_state.error_message = ""
|
|
|
|
try:
|
|
with st.spinner("Generating image..."):
|
|
flow = create_generation_flow()
|
|
flow.run(st.session_state)
|
|
st.rerun()
|
|
except Exception as e:
|
|
st.session_state.error_message = str(e)
|
|
else:
|
|
st.error("Please enter a prompt")
|
|
|
|
elif st.session_state.stage == "user_feedback":
|
|
st.header("2. Review Generated Image")
|
|
|
|
if st.session_state.generated_image:
|
|
# Display image
|
|
image_bytes = base64.b64decode(st.session_state.generated_image)
|
|
st.image(image_bytes, caption=f"Prompt: {st.session_state.task_input}")
|
|
|
|
col1, col2 = st.columns(2)
|
|
|
|
with col1:
|
|
if st.button("Approve", use_container_width=True):
|
|
st.session_state.final_result = st.session_state.generated_image
|
|
st.session_state.stage = "final"
|
|
st.rerun()
|
|
|
|
with col2:
|
|
if st.button("Regenerate", use_container_width=True):
|
|
try:
|
|
with st.spinner("Regenerating image..."):
|
|
flow = create_generation_flow()
|
|
flow.run(st.session_state)
|
|
st.rerun()
|
|
except Exception as e:
|
|
st.session_state.error_message = str(e)
|
|
|
|
elif st.session_state.stage == "final":
|
|
st.header("3. Final Result")
|
|
st.success("Image approved!")
|
|
|
|
if st.session_state.final_result:
|
|
image_bytes = base64.b64decode(st.session_state.final_result)
|
|
st.image(image_bytes, caption=f"Final approved image: {st.session_state.task_input}")
|
|
|
|
if st.button("Start Over", use_container_width=True):
|
|
st.session_state.stage = "initial_input"
|
|
st.session_state.task_input = ""
|
|
st.session_state.generated_image = ""
|
|
st.session_state.final_result = ""
|
|
st.session_state.error_message = ""
|
|
st.rerun()
|
|
|
|
# Show errors
|
|
if st.session_state.error_message:
|
|
st.error(st.session_state.error_message)
|
|
|