adapt for jupyter notebook

This commit is contained in:
zachary62 2024-12-25 06:02:59 +00:00
parent c8cae054bc
commit 4f60d32e68
2 changed files with 34 additions and 30 deletions

View File

@ -33,12 +33,26 @@ class NodeMeta(type):
attrs[attr_name] = _wrap_async(old_fn) attrs[attr_name] = _wrap_async(old_fn)
return super().__new__(mcs, name, bases, attrs) return super().__new__(mcs, name, bases, attrs)
async def hello(text):
print("Start")
await asyncio.sleep(1) # Simulate some async work
print(text)
class BaseNode(metaclass=NodeMeta): class BaseNode(metaclass=NodeMeta):
def __init__(self): def __init__(self):
self.parameters = {} self.parameters = {}
self.successors = {} self.successors = {}
# Syntactic sugar for chaining
def add_successor(self, node, condition="default"):
# warn if we're overwriting an existing successor
if condition in self.successors:
print(f"Warning: overwriting existing successor for condition '{condition}'")
self.successors[condition] = node
return node
# By default these are already async. If a subclass overrides them # By default these are already async. If a subclass overrides them
# with non-async definitions, they'll get wrapped automatically. # with non-async definitions, they'll get wrapped automatically.
def preprocess(self, shared_storage): def preprocess(self, shared_storage):
@ -56,28 +70,20 @@ class BaseNode(metaclass=NodeMeta):
async def _run_one(self, shared_storage): async def _run_one(self, shared_storage):
preprocess_result = await self.preprocess(shared_storage) preprocess_result = await self.preprocess(shared_storage)
process_result = await self._process(shared_storage, preprocess_result) process_result = await self._process(shared_storage, preprocess_result)
condition = await self.postprocess(shared_storage, preprocess_result, process_result) condition = await self.postprocess(shared_storage, preprocess_result, process_result) or "default"
if not self.successors:
return None
if len(self.successors) == 1:
return next(iter(self.successors.values()))
return self.successors.get(condition) return self.successors.get(condition)
def run(self, shared_storage=None): def run(self, shared_storage=None):
return asyncio.run(self.run_async(shared_storage)) asyncio.run(self._run_async(shared_storage))
async def run_async(self, shared_storage=None): async def run_in_jupyter(self, shared_storage=None):
shared_storage = shared_storage or {} await self._run_async(shared_storage)
async def _run_async(self, shared_storage):
current_node = self current_node = self
while current_node: while current_node:
current_node = await current_node._run_one(shared_storage) current_node = await current_node._run_one(shared_storage)
# Syntactic sugar for chaining
def add_successor(self, node, condition="default"):
self.successors[condition] = node
return node
def __rshift__(self, other): def __rshift__(self, other):
return self.add_successor(other) return self.add_successor(other)
@ -91,7 +97,6 @@ class BaseNode(metaclass=NodeMeta):
def __call__(self, condition): def __call__(self, condition):
return _ConditionalTransition(self, condition) return _ConditionalTransition(self, condition)
class _ConditionalTransition: class _ConditionalTransition:
def __init__(self, source_node, condition): def __init__(self, source_node, condition):
self.source_node = source_node self.source_node = source_node
@ -102,7 +107,6 @@ class _ConditionalTransition:
raise TypeError("Target must be a BaseNode") raise TypeError("Target must be a BaseNode")
return self.source_node.add_successor(target_node, self.condition) return self.source_node.add_successor(target_node, self.condition)
class Node(BaseNode): class Node(BaseNode):
def __init__(self, max_retries=5, delay_s=0.1): def __init__(self, max_retries=5, delay_s=0.1):
super().__init__() super().__init__()
@ -110,8 +114,8 @@ class Node(BaseNode):
self.delay_s = delay_s self.delay_s = delay_s
def process_after_fail(self, shared_storage, data, exc): def process_after_fail(self, shared_storage, data, exc):
print(f"[FAIL_ITEM] data={data}, error={exc}") raise exc
return "fail" # return "fail"
async def _process(self, shared_storage, data): async def _process(self, shared_storage, data):
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
@ -121,6 +125,18 @@ class Node(BaseNode):
if attempt == self.max_retries - 1: if attempt == self.max_retries - 1:
return await self.process_after_fail(shared_storage, data, e) return await self.process_after_fail(shared_storage, data, e)
await asyncio.sleep(self.delay_s) await asyncio.sleep(self.delay_s)
class Flow(BaseNode):
def __init__(self, start_node=None):
super().__init__()
self.start_node = start_node
async def _process(self, shared_storage, _):
if self.start_node:
current_node = self.start_node
while current_node:
current_node = await current_node._run_one(shared_storage or {})
return "Flow done"
class BatchNode(BaseNode): class BatchNode(BaseNode):
def __init__(self, max_retries=5, delay_s=0.1): def __init__(self, max_retries=5, delay_s=0.1):
@ -156,18 +172,6 @@ class BatchNode(BaseNode):
results.append(r) results.append(r)
return results return results
class Flow(BaseNode):
def __init__(self, start_node=None):
super().__init__()
self.start_node = start_node
async def _process(self, shared_storage, _):
if self.start_node:
current_node = self.start_node
while current_node:
current_node = await current_node._run_one(shared_storage or {})
return "Flow done"
class BatchFlow(BaseNode): class BatchFlow(BaseNode):
def __init__(self, start_node=None): def __init__(self, start_node=None):
super().__init__() super().__init__()