266 lines
8.7 KiB
Python
266 lines
8.7 KiB
Python
class BaseNode:
|
|
"""
|
|
A base node that provides:
|
|
- preprocess()
|
|
- process()
|
|
- postprocess()
|
|
- run() -- just runs itself (no chaining)
|
|
"""
|
|
def __init__(self):
|
|
self.parameters = {}
|
|
self.successors = {}
|
|
|
|
def set_parameters(self, params):
|
|
self.parameters.update(params)
|
|
|
|
def add_successor(self, node, condition="default"):
|
|
if condition in self.successors:
|
|
print(f"Warning: overwriting existing successor for condition '{condition}'")
|
|
self.successors[condition] = node
|
|
return node
|
|
|
|
def preprocess(self, shared_storage):
|
|
return None
|
|
|
|
def process(self, shared_storage, prep_result):
|
|
return None
|
|
|
|
def _process(self, shared_storage, prep_result):
|
|
# Could have retry logic or other wrap logic
|
|
return self.process(shared_storage, prep_result)
|
|
|
|
def postprocess(self, shared_storage, prep_result, proc_result):
|
|
return "default"
|
|
|
|
def run(self, shared_storage=None):
|
|
if not shared_storage:
|
|
shared_storage = {}
|
|
|
|
prep = self.preprocess(shared_storage)
|
|
proc = self._process(shared_storage, prep)
|
|
return self.postprocess(shared_storage, prep, proc)
|
|
|
|
def __rshift__(self, other):
|
|
"""
|
|
For chaining with >> operator, e.g. node1 >> node2
|
|
"""
|
|
return self.add_successor(other)
|
|
|
|
def __gt__(self, other):
|
|
"""
|
|
For chaining with > operator, e.g. node1 > "some_condition"
|
|
then >> node2
|
|
"""
|
|
if isinstance(other, str):
|
|
return _ConditionalTransition(self, other)
|
|
elif isinstance(other, BaseNode):
|
|
return self.add_successor(other)
|
|
raise TypeError("Unsupported operand type")
|
|
|
|
def __call__(self, condition):
|
|
"""
|
|
For node("condition") >> next_node syntax
|
|
"""
|
|
return _ConditionalTransition(self, condition)
|
|
|
|
|
|
class _ConditionalTransition:
|
|
"""
|
|
Helper for Node > 'condition' >> AnotherNode style
|
|
"""
|
|
def __init__(self, source_node, condition):
|
|
self.source_node = source_node
|
|
self.condition = condition
|
|
|
|
def __rshift__(self, target_node):
|
|
return self.source_node.add_successor(target_node, self.condition)
|
|
|
|
# robust running process
|
|
class Node(BaseNode):
|
|
def __init__(self, max_retries=1):
|
|
super().__init__()
|
|
self.max_retries = max_retries
|
|
|
|
def process_after_fail(self, shared_storage, data, exc):
|
|
raise exc
|
|
# return "fail"
|
|
|
|
def _process(self, shared_storage, data):
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
return super()._process(shared_storage, data)
|
|
except Exception as e:
|
|
if attempt == self.max_retries - 1:
|
|
return self.process_after_fail(shared_storage, data, e)
|
|
|
|
class InteractiveNode(BaseNode):
|
|
"""
|
|
Interactive node. Instead of returning a condition,
|
|
we 'signal' the condition via a callback provided by the Flow.
|
|
"""
|
|
|
|
def postprocess(self, shared_storage, prep_result, proc_result, next_node_callback):
|
|
"""
|
|
We do NOT return anything. We call 'next_node_callback("some_condition")'
|
|
to tell the Flow which successor to pick.
|
|
"""
|
|
# e.g. here we pick "default", but in real usage you'd do logic or rely on user input
|
|
next_node_callback("default")
|
|
|
|
def run(self, shared_storage=None):
|
|
"""
|
|
Run just THIS node (no chain).
|
|
"""
|
|
if not shared_storage:
|
|
shared_storage = {}
|
|
|
|
# 1) Preprocess
|
|
prep = self.preprocess(shared_storage)
|
|
|
|
# 2) Process
|
|
proc = self._process(shared_storage, prep)
|
|
|
|
# 3) Postprocess with a dummy callback
|
|
def dummy_callback(condition="default"):
|
|
print("[Dummy callback] To run the flow, pass this node into a Flow instance.")
|
|
|
|
self.postprocess(shared_storage, prep, proc, dummy_callback)
|
|
|
|
def is_interactive(self):
|
|
return True
|
|
|
|
class Flow:
|
|
"""
|
|
A Flow that runs through a chain of nodes, from a start node onward.
|
|
Each iteration:
|
|
- preprocess
|
|
- process
|
|
- postprocess
|
|
The postprocess is given a callback to choose the next node.
|
|
We'll 'yield' the current node each time, so the caller can see progress.
|
|
"""
|
|
def __init__(self, start_node=None):
|
|
self.start_node = start_node
|
|
|
|
def run(self, shared_storage=None):
|
|
if shared_storage is None:
|
|
shared_storage = {}
|
|
|
|
current_node = self.start_node
|
|
print("hihihi")
|
|
|
|
while current_node:
|
|
# 1) Preprocess
|
|
prep_result = current_node.preprocess(shared_storage)
|
|
print("prep")
|
|
# 2) Process
|
|
proc_result = current_node._process(shared_storage, prep_result)
|
|
|
|
# Prepare next_node variable
|
|
next_node = [None]
|
|
|
|
# We'll define a callback only if this is an interactive node.
|
|
# The callback sets next_node[0] based on condition.
|
|
def next_node_callback(condition="default"):
|
|
nxt = current_node.successors.get(condition)
|
|
next_node[0] = nxt
|
|
|
|
# 3) Check if it's an interactive node
|
|
is_interactive = (
|
|
hasattr(current_node, 'is_interactive')
|
|
and current_node.is_interactive()
|
|
)
|
|
|
|
if is_interactive:
|
|
print("ineractive")
|
|
#
|
|
# ---- INTERACTIVE CASE ----
|
|
#
|
|
# a) yield so that external code can do UI, etc.
|
|
# yield current_node, prep_result, proc_result, next_node_callback
|
|
|
|
# # b) Now we do postprocess WITH the callback:
|
|
# current_node.postprocess(
|
|
# shared_storage,
|
|
# prep_result,
|
|
# proc_result,
|
|
# next_node_callback
|
|
# )
|
|
# # once postprocess is done, next_node[0] should be set
|
|
|
|
else:
|
|
#
|
|
# ---- NON-INTERACTIVE CASE ----
|
|
#
|
|
# We just call postprocess WITHOUT callback,
|
|
# and let it return the condition string:
|
|
condition = current_node.postprocess(
|
|
shared_storage,
|
|
prep_result,
|
|
proc_result
|
|
)
|
|
# Then we figure out the next node:
|
|
next_node[0] = current_node.successors.get(condition, None)
|
|
|
|
# 5) Move on to the next node
|
|
current_node = next_node[0]
|
|
|
|
class BatchNode(BaseNode):
|
|
def __init__(self, max_retries=5, delay_s=0.1):
|
|
super().__init__()
|
|
self.max_retries = max_retries
|
|
self.delay_s = delay_s
|
|
|
|
def preprocess(self, shared_storage):
|
|
return []
|
|
|
|
def process_one(self, shared_storage, item):
|
|
return None
|
|
|
|
def process_one_after_fail(self, shared_storage, item, exc):
|
|
print(f"[FAIL_ITEM] item={item}, error={exc}")
|
|
# By default, just return a "fail" marker. Could be anything you want.
|
|
return "fail"
|
|
|
|
async def _process_one(self, shared_storage, item):
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
return await self.process_one(shared_storage, item)
|
|
except Exception as e:
|
|
if attempt == self.max_retries - 1:
|
|
# If out of retries, let a subclass handle what to do next
|
|
return await self.process_one_after_fail(shared_storage, item, e)
|
|
await asyncio.sleep(self.delay_s)
|
|
|
|
async def _process(self, shared_storage, items):
|
|
results = []
|
|
for item in items:
|
|
r = await self._process_one(shared_storage, item)
|
|
results.append(r)
|
|
return results
|
|
|
|
class BatchFlow(BaseNode):
|
|
def __init__(self, start_node=None):
|
|
super().__init__()
|
|
self.start_node = start_node
|
|
|
|
def preprocess(self, shared_storage):
|
|
return []
|
|
|
|
async def _process_one(self, shared_storage, param_dict):
|
|
node_parameters = self.parameters.copy()
|
|
node_parameters.update(param_dict)
|
|
|
|
if self.start_node:
|
|
current_node = self.start_node
|
|
while current_node:
|
|
# set the combined parameters
|
|
current_node.set_parameters(node_parameters)
|
|
current_node = await current_node._run_one(shared_storage or {})
|
|
|
|
async def _process(self, shared_storage, items):
|
|
results = []
|
|
for param_dict in items:
|
|
await self._process_one(shared_storage, param_dict)
|
|
results.append(f"Ran sub-flow for param_dict={param_dict}")
|
|
return results |