From f802251246a19ad8a484ef6f009840ae87198d63 Mon Sep 17 00:00:00 2001 From: zachary62 Date: Wed, 25 Dec 2024 23:02:24 +0000 Subject: [PATCH] batch node --- minillmflow/__init__.py | 94 +++++++++++++++-------------------------- 1 file changed, 34 insertions(+), 60 deletions(-) diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index dee6f93..6c36be7 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -2,38 +2,33 @@ import asyncio import warnings class BaseNode: - """ - A base node that provides: - - preprocess() - - process() - - postprocess() - - run() -- just runs itself (no chaining) - """ + # preprocess(): this is for compute intensive preparation tasks, before the LLM call + # process(): this is for the LLM call, and should be idempotent for retries + # postprocess(): this is to summarize the result and retrun the condition for the successor node def __init__(self): - self.parameters = {} - self.successors = {} + self.parameters, self.successors = {}, {} - def set_parameters(self, params): - self.parameters.update(params) + def set_parameters(self, params): # make sure params is immutable + self.parameters = params # must be immutable during pre/post/process def add_successor(self, node, condition="default"): if condition in self.successors: - print(f"Warning: overwriting existing successor for condition '{condition}'") - self.successors[condition] = node + warnings.warn(f"Overwriting existing successor for condition '{condition}'") + self.successors[condition] = node # maps condition -> successor node return node def preprocess(self, shared_storage): - return None + return None # will be passed to process() and postprocess() def process(self, shared_storage, prep_result): - return None + return None # will be passed to postprocess() def _process(self, shared_storage, prep_result): # Could have retry logic or other wrap logic return self.process(shared_storage, prep_result) def postprocess(self, shared_storage, prep_result, proc_result): - return "default" + return "default" # condition for next node def run(self, shared_storage=None): prep = self.preprocess(shared_storage) @@ -41,42 +36,16 @@ class BaseNode: return self.postprocess(shared_storage, prep, proc) def __rshift__(self, other): - """ - For chaining with >> operator, e.g. node1 >> node2 - """ + # chaining: node1 >> node2 return self.add_successor(other) - def __gt__(self, other): - """ - For chaining with > operator, e.g. node1 > "some_condition" - then >> node2 - """ - if isinstance(other, str): - return _ConditionalTransition(self, other) - elif isinstance(other, BaseNode): - return self.add_successor(other) - raise TypeError("Unsupported operand type") - - def __call__(self, condition): - """ - For node("condition") >> next_node syntax - """ - return _ConditionalTransition(self, condition) - def __sub__(self, condition): - """ - For chaining with - operator, e.g. node - "some_condition" >> next_node - """ + # condition-based chaining: node - "some_condition" >> next_node if isinstance(condition, str): return _ConditionalTransition(self, condition) raise TypeError("Condition must be a string") - class _ConditionalTransition: - """ - Helper for Node > 'condition' >> AnotherNode style - (and also Node - 'condition' >> AnotherNode now). - """ def __init__(self, source_node, condition): self.source_node = source_node self.condition = condition @@ -100,6 +69,22 @@ class Node(BaseNode): if attempt == self.max_retries - 1: return self.process_after_fail(shared_storage, data, e) +class BatchNode(Node): + def preprocess(self, shared_storage): + # return an iterable of items, one for each run + return [] + + def process(self, shared_storage, item): # process() is called for each item + return None + + def _process(self, shared_storage, items): + results = [] + for item in items: + # Here, 'item' is passed in place of 'prep_result' from the BaseNode's perspective. + r = super()._process(shared_storage, item) + results.append(r) + return results + class AsyncNode(Node): """ A Node whose postprocess step is async. @@ -199,7 +184,10 @@ class AsyncFlow(BaseFlow): return self.postprocess(shared_storage, prep_result, None) def run(self, shared_storage=None): - return asyncio.run(self.run_async(shared_storage)) + try: + return asyncio.run(self.run_async(shared_storage)) + except RuntimeError as e: + raise RuntimeError("If you are running in Jupyter, please use `await run_async()` instead of `run()`.") from e class BaseBatchFlow(BaseFlow): """ @@ -213,12 +201,6 @@ class BaseBatchFlow(BaseFlow): """ return [] - def post_batch_run(self, all_results): - """ - Hook for after the entire batch is done, to combine results, etc. - """ - return all_results - class BatchFlow(BaseBatchFlow, Flow): """ Synchronous batch flow: calls the flow repeatedly @@ -243,10 +225,6 @@ class BatchFlow(BaseBatchFlow, Flow): # Reset the parameters if needed self.parameters = original_params - # Postprocess the entire batch - result = self.post_batch_run(all_results) - return self.postprocess(shared_storage, prep_result, result) - class BatchAsyncFlow(BaseBatchFlow, AsyncFlow): """ Asynchronous batch flow: calls the flow repeatedly in an async manner @@ -265,8 +243,4 @@ class BatchAsyncFlow(BaseBatchFlow, AsyncFlow): all_results.append(f"Finished async run with parameters: {param_dict}") # Reset back to original parameters if needed - self.parameters = original_params - - # Combine or process results at the end - result = self.post_batch_run(all_results) - return self.postprocess(shared_storage, prep_result, result) \ No newline at end of file + self.parameters = original_params \ No newline at end of file