diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index af0286a..2a7ea3d 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -33,12 +33,26 @@ class NodeMeta(type): attrs[attr_name] = _wrap_async(old_fn) return super().__new__(mcs, name, bases, attrs) + + +async def hello(text): + print("Start") + await asyncio.sleep(1) # Simulate some async work + print(text) class BaseNode(metaclass=NodeMeta): def __init__(self): self.parameters = {} self.successors = {} + # Syntactic sugar for chaining + def add_successor(self, node, condition="default"): + # warn if we're overwriting an existing successor + if condition in self.successors: + print(f"Warning: overwriting existing successor for condition '{condition}'") + self.successors[condition] = node + return node + # By default these are already async. If a subclass overrides them # with non-async definitions, they'll get wrapped automatically. def preprocess(self, shared_storage): @@ -56,28 +70,20 @@ class BaseNode(metaclass=NodeMeta): async def _run_one(self, shared_storage): preprocess_result = await self.preprocess(shared_storage) process_result = await self._process(shared_storage, preprocess_result) - condition = await self.postprocess(shared_storage, preprocess_result, process_result) - - if not self.successors: - return None - if len(self.successors) == 1: - return next(iter(self.successors.values())) + condition = await self.postprocess(shared_storage, preprocess_result, process_result) or "default" return self.successors.get(condition) def run(self, shared_storage=None): - return asyncio.run(self.run_async(shared_storage)) + asyncio.run(self._run_async(shared_storage)) - async def run_async(self, shared_storage=None): - shared_storage = shared_storage or {} + async def run_in_jupyter(self, shared_storage=None): + await self._run_async(shared_storage) + + async def _run_async(self, shared_storage): current_node = self while current_node: current_node = await current_node._run_one(shared_storage) - # Syntactic sugar for chaining - def add_successor(self, node, condition="default"): - self.successors[condition] = node - return node - def __rshift__(self, other): return self.add_successor(other) @@ -91,7 +97,6 @@ class BaseNode(metaclass=NodeMeta): def __call__(self, condition): return _ConditionalTransition(self, condition) - class _ConditionalTransition: def __init__(self, source_node, condition): self.source_node = source_node @@ -102,7 +107,6 @@ class _ConditionalTransition: raise TypeError("Target must be a BaseNode") return self.source_node.add_successor(target_node, self.condition) - class Node(BaseNode): def __init__(self, max_retries=5, delay_s=0.1): super().__init__() @@ -110,8 +114,8 @@ class Node(BaseNode): self.delay_s = delay_s def process_after_fail(self, shared_storage, data, exc): - print(f"[FAIL_ITEM] data={data}, error={exc}") - return "fail" + raise exc + # return "fail" async def _process(self, shared_storage, data): for attempt in range(self.max_retries): @@ -121,6 +125,18 @@ class Node(BaseNode): if attempt == self.max_retries - 1: return await self.process_after_fail(shared_storage, data, e) await asyncio.sleep(self.delay_s) + +class Flow(BaseNode): + def __init__(self, start_node=None): + super().__init__() + self.start_node = start_node + + async def _process(self, shared_storage, _): + if self.start_node: + current_node = self.start_node + while current_node: + current_node = await current_node._run_one(shared_storage or {}) + return "Flow done" class BatchNode(BaseNode): def __init__(self, max_retries=5, delay_s=0.1): @@ -156,18 +172,6 @@ class BatchNode(BaseNode): results.append(r) return results -class Flow(BaseNode): - def __init__(self, start_node=None): - super().__init__() - self.start_node = start_node - - async def _process(self, shared_storage, _): - if self.start_node: - current_node = self.start_node - while current_node: - current_node = await current_node._run_one(shared_storage or {}) - return "Flow done" - class BatchFlow(BaseNode): def __init__(self, start_node=None): super().__init__() diff --git a/minillmflow/__pycache__/__init__.cpython-39.pyc b/minillmflow/__pycache__/__init__.cpython-39.pyc index 549c726..57d50b0 100644 Binary files a/minillmflow/__pycache__/__init__.cpython-39.pyc and b/minillmflow/__pycache__/__init__.cpython-39.pyc differ