From 6206eb3650985b49762f7a71382d1b4ae27547a1 Mon Sep 17 00:00:00 2001 From: zachary62 Date: Wed, 25 Dec 2024 02:14:29 +0000 Subject: [PATCH] refactor --- minillmflow/__init__.py | 278 ++++++++++++++++++++-------------------- 1 file changed, 142 insertions(+), 136 deletions(-) diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index 3ef8ada..9ca9283 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -1,55 +1,57 @@ import asyncio # --------------------------------------------------------------------- -# Base Classes +# BaseNode (no duplication of run_one) # --------------------------------------------------------------------- class BaseNode: def __init__(self): - self.set_parameters({}) + self.parameters = {} self.successors = {} - def set_parameters(self, parameters): - self.parameters = parameters.copy() if parameters else {} - def add_successor(self, node, condition="default"): self.successors[condition] = node return node - + async def preprocess(self, shared_storage): - return None - - async def process_one(self, shared_storage, item): """ - The main single-item processing method that end developers override. - Default does nothing. + Override if needed to load or prepare data. """ return None - async def robust_process_one(self, shared_storage, item): + async def process(self, shared_storage, data): """ - In BaseNode, this is just a pass-through to `process_one`. - Subclasses (like Node with retry) can override this to add extra logic. + Public method for user logic. Default does nothing. """ - return await self.process_one(shared_storage, item) + return None - async def process(self, shared_storage, preprocess_result): + async def _process(self, shared_storage, data): """ - Calls `robust_process_one` instead of `process_one` so that - any subclass overrides of robust_process_one will apply. + Internal hook that calls `process(...)`. + Subclasses override this to add extra logic (e.g. retries). """ - return await self.robust_process_one(shared_storage, preprocess_result) - + return await self.process(shared_storage, data) + async def postprocess(self, shared_storage, preprocess_result, process_result): + """ + By default, returns "default" to pick the default successor. + """ return "default" async def run_one(self, shared_storage): + """ + One cycle of the node: + 1) preprocess + 2) _process + 3) postprocess + 4) pick successor + """ preprocess_result = await self.preprocess(shared_storage) - process_result = await self.process(shared_storage, preprocess_result) + process_result = await self._process(shared_storage, preprocess_result) condition = await self.postprocess(shared_storage, preprocess_result, process_result) if not self.successors: return None - elif len(self.successors) == 1: + if len(self.successors) == 1: return next(iter(self.successors.values())) return self.successors.get(condition) @@ -62,10 +64,14 @@ class BaseNode: while current_node: current_node = await current_node.run_one(shared_storage) + # Syntactic sugar for chaining def __rshift__(self, other): return self.add_successor(other) def __gt__(self, other): + """ + For branching: node > "condition" > another_node + """ if isinstance(other, str): return _ConditionalTransition(self, other) elif isinstance(other, BaseNode): @@ -87,139 +93,139 @@ class _ConditionalTransition: return self.source_node.add_successor(target_node, self.condition) -# --------------------------------------------------------------------- -# Flow: allows you to define a "start_node" that is run in a sub-flow -# --------------------------------------------------------------------- -class Flow(BaseNode): - def __init__(self, start_node=None): - super().__init__() - self.start_node = start_node - - async def process_one(self, shared_storage, item): - # Instead of doing a single operation, we run a sub-flow - if self.start_node: - current_node = self.start_node - while current_node: - # Pass down the parameters - current_node.set_parameters(self.parameters) - current_node = await current_node.run_one(shared_storage or {}) - - -# --------------------------------------------------------------------- -# Node: adds robust retry logic on top of BaseNode -# --------------------------------------------------------------------- class Node(BaseNode): """ - Retries its single-item operation up to `max_retries` times, - waiting `delay_s` seconds between attempts. - By default: max_retries=5, delay_s=0.1 - End developers simply override `process_one` to define logic. + Single-item node with robust logic. + End devs override `process(...)`. + `_process(...)` adds the retry logic. + """ + def __init__(self, max_retries=5, delay_s=0.1): + super().__init__() + self.max_retries = max_retries + self.delay_s = delay_s + + async def fail_item(self, shared_storage, data, exc): + """ + Called if we exhaust all retries. + """ + print(f"[FAIL_ITEM] data={data}, error={exc}") + return "fail" + + async def _process(self, shared_storage, data): + """ + Wraps the user’s `process(...)` with retry logic. + """ + for attempt in range(self.max_retries): + try: + return await super()._process(shared_storage, data) + except Exception as e: + if attempt == self.max_retries - 1: + return await self.fail_item(shared_storage, data, e) + await asyncio.sleep(self.delay_s) + +# --------------------------------------------------------------------- +# BatchNode: processes multiple items +# --------------------------------------------------------------------- +class BatchNode(BaseNode): + """ + Processes a list of items in `process(...)`. + The user overrides `process_one(item)`. + `_process_one(...)` handles robust retries for each item. """ def __init__(self, max_retries=5, delay_s=0.1): super().__init__() - self.parameters.setdefault("max_retries", max_retries) - self.parameters.setdefault("delay_s", delay_s) - - async def fail_one(self, shared_storage, item, exc): - """ - Called if the final retry also fails. By default, - just returns a special string or could log an error. - End developers can override this to do something else - (e.g., store the failure in a separate list or - trigger alternative logic). - """ - # Example: log and return a special status - print(f"[FAIL_ONE] item={item}, error={exc}") - return "fail" - - async def robust_process_one(self, shared_storage, item): - max_retries = self.parameters.get("max_retries", 5) - delay_s = self.parameters.get("delay_s", 0.1) - - for attempt in range(max_retries): - try: - # Defer to the user's process_one logic - return await super().robust_process_one(shared_storage, item) - except Exception as e: - if attempt == max_retries - 1: - # Final attempt failed; call fail_one - return await self.fail_one(shared_storage, item, e) - # Otherwise, wait a bit and try again - await asyncio.sleep(delay_s) - -# --------------------------------------------------------------------- -# BatchMixin: processes a collection of items by calling robust_process_one for each -# --------------------------------------------------------------------- -class BatchMixin: - async def process(self, shared_storage, items): - """ - Processes a *collection* of items in a loop, calling robust_process_one per item. - """ - partial_results = [] - for item in items: - r = await self.robust_process_one(shared_storage, item) - partial_results.append(r) - return self.merge(shared_storage, partial_results) - - def merge(self, shared_storage, partial_results): - """ - Combines partial results into a single output. - By default, returns the list of partial results. - """ - return partial_results + self.max_retries = max_retries + self.delay_s = delay_s async def preprocess(self, shared_storage): """ - Typically, you'd return a list or collection of items to process here. - By default, returns an empty list. + Typically return a list of items to process. """ return [] - -# --------------------------------------------------------------------- -# BatchNode: combines Node (robust logic) + BatchMixin (batch logic) -# --------------------------------------------------------------------- -class BatchNode(BatchMixin, Node): - """ - A batch-processing node that: - - Inherits robust retry logic from Node - - Uses BatchMixin to process a list of items - """ - - async def preprocess(self, shared_storage): - # Gather or return the batch items. By default, no items. - return [] - async def process_one(self, shared_storage, item): """ - The per-item logic that the end developer will override. - By default, does nothing. + End developers override single-item logic here. """ return None - -# --------------------------------------------------------------------- -# BatchFlow: combines Flow (sub-flow logic) + batch processing + robust logic -# --------------------------------------------------------------------- -class BatchFlow(BatchMixin, Flow): - """ - This class runs a sub-flow (start_node) for each item in a batch. - If you also want robust retries, you can adapt or combine with `Node`. - """ - - async def preprocess(self, shared_storage): - # Return your batch items here - return [] - - async def process(self, shared_storage, items): + async def _process_one(self, shared_storage, item): """ - For each item, run the sub-flow (start_node). + Retry logic around process_one(item). + """ + for attempt in range(self.max_retries): + try: + return await self.process_one(shared_storage, item) + except Exception as e: + if attempt == self.max_retries - 1: + print(f"[FAIL_ITEM] item={item}, error={e}") + return "fail" + await asyncio.sleep(self.delay_s) + + async def _process(self, shared_storage, items): + """ + Loops over items, calling _process_one per item. """ results = [] for item in items: - # Here we re-run the sub-flow, which happens inside process_one of Flow - await self.process_one(shared_storage, item) - # Optionally collect results or do something after the sub-flow - results.append(f"Finished sub-flow for item: {item}") + r = await self._process_one(shared_storage, item) + results.append(r) + return results + + +class Flow(BaseNode): + """ + Runs a sub-flow from `start_node` once per call. + """ + def __init__(self, start_node=None): + super().__init__() + self.start_node = start_node + + async def _process(self, shared_storage, _): + if self.start_node: + current_node = self.start_node + while current_node: + current_node = await current_node.run_one(shared_storage or {}) + return "Flow done" + +class BatchFlow(BaseNode): + """ + For each param_dict in the batch, merges it into self.parameters, + then runs the sub-flow from `start_node`. + """ + + def __init__(self, start_node=None): + super().__init__() + self.start_node = start_node + + async def preprocess(self, shared_storage): + """ + Return a list of param_dict objects. + """ + return [] + + async def process_one(self, shared_storage, param_dict): + """ + Merge param_dict into the node's parameters, + then run the sub-flow. + """ + node_parameters = self.parameters.copy() + node_parameters.update(param_dict) + + if self.start_node: + current_node = self.start_node + while current_node: + # set the combined parameters + current_node.set_parameters(node_parameters) + current_node = await current_node.run_one(shared_storage or {}) + + async def _process(self, shared_storage, items): + """ + For each param_dict in items, run the sub-flow once. + """ + results = [] + for param_dict in items: + await self.process_one(shared_storage, param_dict) + results.append(f"Ran sub-flow for param_dict={param_dict}") return results \ No newline at end of file