diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index 6c36be7..c37540c 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -29,11 +29,16 @@ class BaseNode: def postprocess(self, shared_storage, prep_result, proc_result): return "default" # condition for next node - - def run(self, shared_storage=None): + + def _run(self, shared_storage=None): prep = self.preprocess(shared_storage) proc = self._process(shared_storage, prep) return self.postprocess(shared_storage, prep, proc) + + def run(self, shared_storage=None): + if self.successors: + warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent Flow and use that Flow.run() instead.") + return self._run(shared_storage) def __rshift__(self, other): # chaining: node1 >> node2 @@ -101,20 +106,21 @@ class AsyncNode(Node): """ await asyncio.sleep(0) # trivial async pause (no-op) return "default" - + async def run_async(self, shared_storage=None): + if self.successors: + warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent AsyncFlow and use that AsyncFlow.run_async() instead.") + return await self._run_async(shared_storage) + + async def _run_async(self, shared_storage=None): prep = self.preprocess(shared_storage) proc = self._process(shared_storage, prep) return await self.postprocess_async(shared_storage, prep, proc) - + + def _run(self, shared_storage=None): + raise RuntimeError("AsyncNode requires run_async, and should be run in an AsyncFlow") class BaseFlow(BaseNode): - """ - Abstract base flow that provides the main logic of: - - Starting from self.start_node - - Looping until no more successors - Subclasses must define how they *call* each node (sync or async). - """ def __init__(self, start_node=None): super().__init__() self.start_node = start_node @@ -126,45 +132,24 @@ class BaseFlow(BaseNode): warnings.warn(f"Flow will end. Condition '{condition}' not found among possible conditions: {list(current_node.successors.keys())}") return next_node - - def run(self, shared_storage=None): - """ - By default, do nothing (or raise). - Subclasses (Flow, AsyncFlow) will implement. - """ - raise NotImplementedError("BaseFlow.run must be implemented by subclasses") - - async def run_async(self, shared_storage=None): - """ - By default, do nothing (or raise). - Subclasses (Flow, AsyncFlow) will implement. - """ - raise NotImplementedError("BaseFlow.run_async must be implemented by subclasses") - + class Flow(BaseFlow): - """ - Synchronous flow: each node is called with .run(shared_storage). - """ def _process_flow(self, shared_storage): current_node = self.start_node while current_node: # Pass down the Flow's parameters to the current node current_node.set_parameters(self.parameters) # Synchronous run - condition = current_node.run(shared_storage) + condition = current_node._run(shared_storage) # Decide next node current_node = self.get_next_node(current_node, condition) - def run(self, shared_storage=None): + def _run(self, shared_storage=None): prep_result = self.preprocess(shared_storage) self._process_flow(shared_storage) return self.postprocess(shared_storage, prep_result, None) class AsyncFlow(BaseFlow): - """ - Asynchronous flow: if a node has .run_async, we await it. - Otherwise, we fallback to .run. - """ async def _process_flow_async(self, shared_storage): current_node = self.start_node while current_node: @@ -172,41 +157,29 @@ class AsyncFlow(BaseFlow): # If node is async-capable, call run_async; otherwise run sync if hasattr(current_node, "run_async") and callable(current_node.run_async): - condition = await current_node.run_async(shared_storage) + condition = await current_node._run_async(shared_storage) else: - condition = current_node.run(shared_storage) + condition = current_node._run(shared_storage) current_node = self.get_next_node(current_node, condition) - async def run_async(self, shared_storage=None): + async def _run_async(self, shared_storage=None): prep_result = self.preprocess(shared_storage) await self._process_flow_async(shared_storage) return self.postprocess(shared_storage, prep_result, None) - def run(self, shared_storage=None): + def _run(self, shared_storage=None): try: - return asyncio.run(self.run_async(shared_storage)) + return asyncio.run(self._run_async(shared_storage)) except RuntimeError as e: raise RuntimeError("If you are running in Jupyter, please use `await run_async()` instead of `run()`.") from e class BaseBatchFlow(BaseFlow): - """ - Abstract base for a flow that runs multiple times (a batch), - once for each set of parameters or items from preprocess(). - """ def preprocess(self, shared_storage): - """ - By default, returns an iterable of parameter-dicts or items - for the flow to process in a batch. - """ return [] class BatchFlow(BaseBatchFlow, Flow): - """ - Synchronous batch flow: calls the flow repeatedly - for each set of parameters/items in preprocess(). - """ - def run(self, shared_storage=None): + def _run(self, shared_storage=None): prep_result = self.preprocess(shared_storage) all_results = [] @@ -226,11 +199,7 @@ class BatchFlow(BaseBatchFlow, Flow): self.parameters = original_params class BatchAsyncFlow(BaseBatchFlow, AsyncFlow): - """ - Asynchronous batch flow: calls the flow repeatedly in an async manner - for each set of parameters/items in preprocess(). - """ - async def run_async(self, shared_storage=None): + async def _run_async(self, shared_storage=None): prep_result = self.preprocess(shared_storage) all_results = []