diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index f3a8146..dee6f93 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -1,5 +1,6 @@ import asyncio - +import warnings + class BaseNode: """ A base node that provides: @@ -62,10 +63,19 @@ class BaseNode: """ return _ConditionalTransition(self, condition) + def __sub__(self, condition): + """ + For chaining with - operator, e.g. node - "some_condition" >> next_node + """ + if isinstance(condition, str): + return _ConditionalTransition(self, condition) + raise TypeError("Condition must be a string") + class _ConditionalTransition: """ Helper for Node > 'condition' >> AnotherNode style + (and also Node - 'condition' >> AnotherNode now). """ def __init__(self, source_node, condition): self.source_node = source_node @@ -74,7 +84,6 @@ class _ConditionalTransition: def __rshift__(self, target_node): return self.source_node.add_successor(target_node, self.condition) -# robust running process class Node(BaseNode): def __init__(self, max_retries=1): super().__init__() @@ -82,7 +91,6 @@ class Node(BaseNode): def process_after_fail(self, shared_storage, data, exc): raise exc - # return "fail" def _process(self, shared_storage, data): for attempt in range(self.max_retries): @@ -91,28 +99,16 @@ class Node(BaseNode): except Exception as e: if attempt == self.max_retries - 1: return self.process_after_fail(shared_storage, data, e) - -class Flow(BaseNode): - def __init__(self, start_node=None): - self.start_node = start_node - - def _process(self, shared_storage, _): - current_node = self.start_node - while current_node: - condition = current_node.run(shared_storage) - current_node = current_node.successors.get(condition, None) - - def postprocess(self, shared_storage, prep_result, proc_result): - return None - - class AsyncNode(Node): """ A Node whose postprocess step is async. You can also override process() to be async if needed. """ - + def postprocess(self, shared_storage, prep_result, proc_result): + # Not used in async workflow; define postprocess_async() instead. + raise NotImplementedError("AsyncNode requires postprocess_async, and should be run in an AsyncFlow") + async def postprocess_async(self, shared_storage, prep_result, proc_result): """ Async version of postprocess. By default, returns "default". @@ -122,106 +118,155 @@ class AsyncNode(Node): return "default" async def run_async(self, shared_storage=None): - """ - Async version of run. - If your process method is also async, you'll need to adapt accordingly. - """ - # We can keep preprocess synchronous or make it async as well, - # depending on your usage. Here it's left as sync for simplicity. prep = self.preprocess(shared_storage) - - # process can remain sync if you prefer, or you can define an async process. proc = self._process(shared_storage, prep) - - # postprocess is async return await self.postprocess_async(shared_storage, prep, proc) -class AsyncFlow(Flow): +class BaseFlow(BaseNode): """ - A Flow that can handle a mixture of sync and async nodes. - If the node is an AsyncNode, calls `run_async`. - Otherwise, calls `run`. + Abstract base flow that provides the main logic of: + - Starting from self.start_node + - Looping until no more successors + Subclasses must define how they *call* each node (sync or async). """ - async def _process(self, shared_storage, _): - current_node = self.start_node - while current_node: - if hasattr(current_node, "run_async") and callable(current_node.run_async): - # If it's an async node, await its run_async - condition = await current_node.run_async(shared_storage) - else: - # Otherwise, assume it's a sync node - condition = current_node.run(shared_storage) - - current_node = current_node.successors.get(condition, None) - - async def run_async(self, shared_storage=None): - """ - Kicks off the async flow. Similar to Flow.run, - but uses our async _process method. - """ - prep = self.preprocess(shared_storage) - # Note: flows typically don't need a meaningful process step - # because the "process" is the iteration through the nodes. - await self._process(shared_storage, prep) - return self.postprocess(shared_storage, prep, None) - -class BatchNode(BaseNode): - def __init__(self, max_retries=5, delay_s=0.1): - super().__init__() - self.max_retries = max_retries - self.delay_s = delay_s - - def preprocess(self, shared_storage): - return [] - - def process_one(self, shared_storage, item): - return None - - def process_one_after_fail(self, shared_storage, item, exc): - print(f"[FAIL_ITEM] item={item}, error={exc}") - # By default, just return a "fail" marker. Could be anything you want. - return "fail" - - async def _process_one(self, shared_storage, item): - for attempt in range(self.max_retries): - try: - return await self.process_one(shared_storage, item) - except Exception as e: - if attempt == self.max_retries - 1: - # If out of retries, let a subclass handle what to do next - return await self.process_one_after_fail(shared_storage, item, e) - await asyncio.sleep(self.delay_s) - - async def _process(self, shared_storage, items): - results = [] - for item in items: - r = await self._process_one(shared_storage, item) - results.append(r) - return results - -class BatchFlow(BaseNode): def __init__(self, start_node=None): super().__init__() self.start_node = start_node + def get_next_node(self, current_node, condition): + next_node = current_node.successors.get(condition, None) + + if next_node is None and current_node.successors: + warnings.warn(f"Flow will end. Condition '{condition}' not found among possible conditions: {list(current_node.successors.keys())}") + + return next_node + + def run(self, shared_storage=None): + """ + By default, do nothing (or raise). + Subclasses (Flow, AsyncFlow) will implement. + """ + raise NotImplementedError("BaseFlow.run must be implemented by subclasses") + + async def run_async(self, shared_storage=None): + """ + By default, do nothing (or raise). + Subclasses (Flow, AsyncFlow) will implement. + """ + raise NotImplementedError("BaseFlow.run_async must be implemented by subclasses") + +class Flow(BaseFlow): + """ + Synchronous flow: each node is called with .run(shared_storage). + """ + def _process_flow(self, shared_storage): + current_node = self.start_node + while current_node: + # Pass down the Flow's parameters to the current node + current_node.set_parameters(self.parameters) + # Synchronous run + condition = current_node.run(shared_storage) + # Decide next node + current_node = self.get_next_node(current_node, condition) + + def run(self, shared_storage=None): + prep_result = self.preprocess(shared_storage) + self._process_flow(shared_storage) + return self.postprocess(shared_storage, prep_result, None) + +class AsyncFlow(BaseFlow): + """ + Asynchronous flow: if a node has .run_async, we await it. + Otherwise, we fallback to .run. + """ + async def _process_flow_async(self, shared_storage): + current_node = self.start_node + while current_node: + current_node.set_parameters(self.parameters) + + # If node is async-capable, call run_async; otherwise run sync + if hasattr(current_node, "run_async") and callable(current_node.run_async): + condition = await current_node.run_async(shared_storage) + else: + condition = current_node.run(shared_storage) + + current_node = self.get_next_node(current_node, condition) + + async def run_async(self, shared_storage=None): + prep_result = self.preprocess(shared_storage) + await self._process_flow_async(shared_storage) + return self.postprocess(shared_storage, prep_result, None) + + def run(self, shared_storage=None): + return asyncio.run(self.run_async(shared_storage)) + +class BaseBatchFlow(BaseFlow): + """ + Abstract base for a flow that runs multiple times (a batch), + once for each set of parameters or items from preprocess(). + """ def preprocess(self, shared_storage): + """ + By default, returns an iterable of parameter-dicts or items + for the flow to process in a batch. + """ return [] - async def _process_one(self, shared_storage, param_dict): - node_parameters = self.parameters.copy() - node_parameters.update(param_dict) + def post_batch_run(self, all_results): + """ + Hook for after the entire batch is done, to combine results, etc. + """ + return all_results - if self.start_node: - current_node = self.start_node - while current_node: - # set the combined parameters - current_node.set_parameters(node_parameters) - current_node = await current_node._run_one(shared_storage or {}) +class BatchFlow(BaseBatchFlow, Flow): + """ + Synchronous batch flow: calls the flow repeatedly + for each set of parameters/items in preprocess(). + """ + def run(self, shared_storage=None): + prep_result = self.preprocess(shared_storage) + all_results = [] - async def _process(self, shared_storage, items): - results = [] - for param_dict in items: - await self._process_one(shared_storage, param_dict) - results.append(f"Ran sub-flow for param_dict={param_dict}") - return results \ No newline at end of file + # For each set of parameters (or items) we got from preprocess + for param_dict in prep_result: + # Merge param_dict into the Flow's parameters + original_params = self.parameters.copy() + self.parameters.update(param_dict) + + # Run from the start node to end + self._process_flow(shared_storage) + + # Optionally collect results from shared_storage or a custom method + all_results.append(f"Finished run with parameters: {param_dict}") + + # Reset the parameters if needed + self.parameters = original_params + + # Postprocess the entire batch + result = self.post_batch_run(all_results) + return self.postprocess(shared_storage, prep_result, result) + +class BatchAsyncFlow(BaseBatchFlow, AsyncFlow): + """ + Asynchronous batch flow: calls the flow repeatedly in an async manner + for each set of parameters/items in preprocess(). + """ + async def run_async(self, shared_storage=None): + prep_result = self.preprocess(shared_storage) + all_results = [] + + for param_dict in prep_result: + original_params = self.parameters.copy() + self.parameters.update(param_dict) + + await self._process_flow_async(shared_storage) + + all_results.append(f"Finished async run with parameters: {param_dict}") + + # Reset back to original parameters if needed + self.parameters = original_params + + # Combine or process results at the end + result = self.post_batch_run(all_results) + return self.postprocess(shared_storage, prep_result, result) \ No newline at end of file diff --git a/minillmflow/__pycache__/__init__.cpython-39.pyc b/minillmflow/__pycache__/__init__.cpython-39.pyc index 912c60f..12b1cf9 100644 Binary files a/minillmflow/__pycache__/__init__.cpython-39.pyc and b/minillmflow/__pycache__/__init__.cpython-39.pyc differ