make run robust
This commit is contained in:
parent
f631aaa4d3
commit
881a903e2f
|
|
@ -30,11 +30,16 @@ class BaseNode:
|
||||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||||
return "default" # condition for next node
|
return "default" # condition for next node
|
||||||
|
|
||||||
def run(self, shared_storage=None):
|
def _run(self, shared_storage=None):
|
||||||
prep = self.preprocess(shared_storage)
|
prep = self.preprocess(shared_storage)
|
||||||
proc = self._process(shared_storage, prep)
|
proc = self._process(shared_storage, prep)
|
||||||
return self.postprocess(shared_storage, prep, proc)
|
return self.postprocess(shared_storage, prep, proc)
|
||||||
|
|
||||||
|
def run(self, shared_storage=None):
|
||||||
|
if self.successors:
|
||||||
|
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent Flow and use that Flow.run() instead.")
|
||||||
|
return self._run(shared_storage)
|
||||||
|
|
||||||
def __rshift__(self, other):
|
def __rshift__(self, other):
|
||||||
# chaining: node1 >> node2
|
# chaining: node1 >> node2
|
||||||
return self.add_successor(other)
|
return self.add_successor(other)
|
||||||
|
|
@ -103,18 +108,19 @@ class AsyncNode(Node):
|
||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
async def run_async(self, shared_storage=None):
|
async def run_async(self, shared_storage=None):
|
||||||
|
if self.successors:
|
||||||
|
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent AsyncFlow and use that AsyncFlow.run_async() instead.")
|
||||||
|
return await self._run_async(shared_storage)
|
||||||
|
|
||||||
|
async def _run_async(self, shared_storage=None):
|
||||||
prep = self.preprocess(shared_storage)
|
prep = self.preprocess(shared_storage)
|
||||||
proc = self._process(shared_storage, prep)
|
proc = self._process(shared_storage, prep)
|
||||||
return await self.postprocess_async(shared_storage, prep, proc)
|
return await self.postprocess_async(shared_storage, prep, proc)
|
||||||
|
|
||||||
|
def _run(self, shared_storage=None):
|
||||||
|
raise RuntimeError("AsyncNode requires run_async, and should be run in an AsyncFlow")
|
||||||
|
|
||||||
class BaseFlow(BaseNode):
|
class BaseFlow(BaseNode):
|
||||||
"""
|
|
||||||
Abstract base flow that provides the main logic of:
|
|
||||||
- Starting from self.start_node
|
|
||||||
- Looping until no more successors
|
|
||||||
Subclasses must define how they *call* each node (sync or async).
|
|
||||||
"""
|
|
||||||
def __init__(self, start_node=None):
|
def __init__(self, start_node=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.start_node = start_node
|
self.start_node = start_node
|
||||||
|
|
@ -127,44 +133,23 @@ class BaseFlow(BaseNode):
|
||||||
|
|
||||||
return next_node
|
return next_node
|
||||||
|
|
||||||
def run(self, shared_storage=None):
|
|
||||||
"""
|
|
||||||
By default, do nothing (or raise).
|
|
||||||
Subclasses (Flow, AsyncFlow) will implement.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("BaseFlow.run must be implemented by subclasses")
|
|
||||||
|
|
||||||
async def run_async(self, shared_storage=None):
|
|
||||||
"""
|
|
||||||
By default, do nothing (or raise).
|
|
||||||
Subclasses (Flow, AsyncFlow) will implement.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("BaseFlow.run_async must be implemented by subclasses")
|
|
||||||
|
|
||||||
class Flow(BaseFlow):
|
class Flow(BaseFlow):
|
||||||
"""
|
|
||||||
Synchronous flow: each node is called with .run(shared_storage).
|
|
||||||
"""
|
|
||||||
def _process_flow(self, shared_storage):
|
def _process_flow(self, shared_storage):
|
||||||
current_node = self.start_node
|
current_node = self.start_node
|
||||||
while current_node:
|
while current_node:
|
||||||
# Pass down the Flow's parameters to the current node
|
# Pass down the Flow's parameters to the current node
|
||||||
current_node.set_parameters(self.parameters)
|
current_node.set_parameters(self.parameters)
|
||||||
# Synchronous run
|
# Synchronous run
|
||||||
condition = current_node.run(shared_storage)
|
condition = current_node._run(shared_storage)
|
||||||
# Decide next node
|
# Decide next node
|
||||||
current_node = self.get_next_node(current_node, condition)
|
current_node = self.get_next_node(current_node, condition)
|
||||||
|
|
||||||
def run(self, shared_storage=None):
|
def _run(self, shared_storage=None):
|
||||||
prep_result = self.preprocess(shared_storage)
|
prep_result = self.preprocess(shared_storage)
|
||||||
self._process_flow(shared_storage)
|
self._process_flow(shared_storage)
|
||||||
return self.postprocess(shared_storage, prep_result, None)
|
return self.postprocess(shared_storage, prep_result, None)
|
||||||
|
|
||||||
class AsyncFlow(BaseFlow):
|
class AsyncFlow(BaseFlow):
|
||||||
"""
|
|
||||||
Asynchronous flow: if a node has .run_async, we await it.
|
|
||||||
Otherwise, we fallback to .run.
|
|
||||||
"""
|
|
||||||
async def _process_flow_async(self, shared_storage):
|
async def _process_flow_async(self, shared_storage):
|
||||||
current_node = self.start_node
|
current_node = self.start_node
|
||||||
while current_node:
|
while current_node:
|
||||||
|
|
@ -172,41 +157,29 @@ class AsyncFlow(BaseFlow):
|
||||||
|
|
||||||
# If node is async-capable, call run_async; otherwise run sync
|
# If node is async-capable, call run_async; otherwise run sync
|
||||||
if hasattr(current_node, "run_async") and callable(current_node.run_async):
|
if hasattr(current_node, "run_async") and callable(current_node.run_async):
|
||||||
condition = await current_node.run_async(shared_storage)
|
condition = await current_node._run_async(shared_storage)
|
||||||
else:
|
else:
|
||||||
condition = current_node.run(shared_storage)
|
condition = current_node._run(shared_storage)
|
||||||
|
|
||||||
current_node = self.get_next_node(current_node, condition)
|
current_node = self.get_next_node(current_node, condition)
|
||||||
|
|
||||||
async def run_async(self, shared_storage=None):
|
async def _run_async(self, shared_storage=None):
|
||||||
prep_result = self.preprocess(shared_storage)
|
prep_result = self.preprocess(shared_storage)
|
||||||
await self._process_flow_async(shared_storage)
|
await self._process_flow_async(shared_storage)
|
||||||
return self.postprocess(shared_storage, prep_result, None)
|
return self.postprocess(shared_storage, prep_result, None)
|
||||||
|
|
||||||
def run(self, shared_storage=None):
|
def _run(self, shared_storage=None):
|
||||||
try:
|
try:
|
||||||
return asyncio.run(self.run_async(shared_storage))
|
return asyncio.run(self._run_async(shared_storage))
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise RuntimeError("If you are running in Jupyter, please use `await run_async()` instead of `run()`.") from e
|
raise RuntimeError("If you are running in Jupyter, please use `await run_async()` instead of `run()`.") from e
|
||||||
|
|
||||||
class BaseBatchFlow(BaseFlow):
|
class BaseBatchFlow(BaseFlow):
|
||||||
"""
|
|
||||||
Abstract base for a flow that runs multiple times (a batch),
|
|
||||||
once for each set of parameters or items from preprocess().
|
|
||||||
"""
|
|
||||||
def preprocess(self, shared_storage):
|
def preprocess(self, shared_storage):
|
||||||
"""
|
|
||||||
By default, returns an iterable of parameter-dicts or items
|
|
||||||
for the flow to process in a batch.
|
|
||||||
"""
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
class BatchFlow(BaseBatchFlow, Flow):
|
class BatchFlow(BaseBatchFlow, Flow):
|
||||||
"""
|
def _run(self, shared_storage=None):
|
||||||
Synchronous batch flow: calls the flow repeatedly
|
|
||||||
for each set of parameters/items in preprocess().
|
|
||||||
"""
|
|
||||||
def run(self, shared_storage=None):
|
|
||||||
prep_result = self.preprocess(shared_storage)
|
prep_result = self.preprocess(shared_storage)
|
||||||
all_results = []
|
all_results = []
|
||||||
|
|
||||||
|
|
@ -226,11 +199,7 @@ class BatchFlow(BaseBatchFlow, Flow):
|
||||||
self.parameters = original_params
|
self.parameters = original_params
|
||||||
|
|
||||||
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
||||||
"""
|
async def _run_async(self, shared_storage=None):
|
||||||
Asynchronous batch flow: calls the flow repeatedly in an async manner
|
|
||||||
for each set of parameters/items in preprocess().
|
|
||||||
"""
|
|
||||||
async def run_async(self, shared_storage=None):
|
|
||||||
prep_result = self.preprocess(shared_storage)
|
prep_result = self.preprocess(shared_storage)
|
||||||
all_results = []
|
all_results = []
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue