adapt for jupyter notebook
This commit is contained in:
parent
c8cae054bc
commit
4f60d32e68
|
|
@ -33,12 +33,26 @@ class NodeMeta(type):
|
||||||
attrs[attr_name] = _wrap_async(old_fn)
|
attrs[attr_name] = _wrap_async(old_fn)
|
||||||
|
|
||||||
return super().__new__(mcs, name, bases, attrs)
|
return super().__new__(mcs, name, bases, attrs)
|
||||||
|
|
||||||
|
|
||||||
|
async def hello(text):
|
||||||
|
print("Start")
|
||||||
|
await asyncio.sleep(1) # Simulate some async work
|
||||||
|
print(text)
|
||||||
|
|
||||||
class BaseNode(metaclass=NodeMeta):
|
class BaseNode(metaclass=NodeMeta):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.parameters = {}
|
self.parameters = {}
|
||||||
self.successors = {}
|
self.successors = {}
|
||||||
|
|
||||||
|
# Syntactic sugar for chaining
|
||||||
|
def add_successor(self, node, condition="default"):
|
||||||
|
# warn if we're overwriting an existing successor
|
||||||
|
if condition in self.successors:
|
||||||
|
print(f"Warning: overwriting existing successor for condition '{condition}'")
|
||||||
|
self.successors[condition] = node
|
||||||
|
return node
|
||||||
|
|
||||||
# By default these are already async. If a subclass overrides them
|
# By default these are already async. If a subclass overrides them
|
||||||
# with non-async definitions, they'll get wrapped automatically.
|
# with non-async definitions, they'll get wrapped automatically.
|
||||||
def preprocess(self, shared_storage):
|
def preprocess(self, shared_storage):
|
||||||
|
|
@ -56,28 +70,20 @@ class BaseNode(metaclass=NodeMeta):
|
||||||
async def _run_one(self, shared_storage):
|
async def _run_one(self, shared_storage):
|
||||||
preprocess_result = await self.preprocess(shared_storage)
|
preprocess_result = await self.preprocess(shared_storage)
|
||||||
process_result = await self._process(shared_storage, preprocess_result)
|
process_result = await self._process(shared_storage, preprocess_result)
|
||||||
condition = await self.postprocess(shared_storage, preprocess_result, process_result)
|
condition = await self.postprocess(shared_storage, preprocess_result, process_result) or "default"
|
||||||
|
|
||||||
if not self.successors:
|
|
||||||
return None
|
|
||||||
if len(self.successors) == 1:
|
|
||||||
return next(iter(self.successors.values()))
|
|
||||||
return self.successors.get(condition)
|
return self.successors.get(condition)
|
||||||
|
|
||||||
def run(self, shared_storage=None):
|
def run(self, shared_storage=None):
|
||||||
return asyncio.run(self.run_async(shared_storage))
|
asyncio.run(self._run_async(shared_storage))
|
||||||
|
|
||||||
async def run_async(self, shared_storage=None):
|
async def run_in_jupyter(self, shared_storage=None):
|
||||||
shared_storage = shared_storage or {}
|
await self._run_async(shared_storage)
|
||||||
|
|
||||||
|
async def _run_async(self, shared_storage):
|
||||||
current_node = self
|
current_node = self
|
||||||
while current_node:
|
while current_node:
|
||||||
current_node = await current_node._run_one(shared_storage)
|
current_node = await current_node._run_one(shared_storage)
|
||||||
|
|
||||||
# Syntactic sugar for chaining
|
|
||||||
def add_successor(self, node, condition="default"):
|
|
||||||
self.successors[condition] = node
|
|
||||||
return node
|
|
||||||
|
|
||||||
def __rshift__(self, other):
|
def __rshift__(self, other):
|
||||||
return self.add_successor(other)
|
return self.add_successor(other)
|
||||||
|
|
||||||
|
|
@ -91,7 +97,6 @@ class BaseNode(metaclass=NodeMeta):
|
||||||
def __call__(self, condition):
|
def __call__(self, condition):
|
||||||
return _ConditionalTransition(self, condition)
|
return _ConditionalTransition(self, condition)
|
||||||
|
|
||||||
|
|
||||||
class _ConditionalTransition:
|
class _ConditionalTransition:
|
||||||
def __init__(self, source_node, condition):
|
def __init__(self, source_node, condition):
|
||||||
self.source_node = source_node
|
self.source_node = source_node
|
||||||
|
|
@ -102,7 +107,6 @@ class _ConditionalTransition:
|
||||||
raise TypeError("Target must be a BaseNode")
|
raise TypeError("Target must be a BaseNode")
|
||||||
return self.source_node.add_successor(target_node, self.condition)
|
return self.source_node.add_successor(target_node, self.condition)
|
||||||
|
|
||||||
|
|
||||||
class Node(BaseNode):
|
class Node(BaseNode):
|
||||||
def __init__(self, max_retries=5, delay_s=0.1):
|
def __init__(self, max_retries=5, delay_s=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -110,8 +114,8 @@ class Node(BaseNode):
|
||||||
self.delay_s = delay_s
|
self.delay_s = delay_s
|
||||||
|
|
||||||
def process_after_fail(self, shared_storage, data, exc):
|
def process_after_fail(self, shared_storage, data, exc):
|
||||||
print(f"[FAIL_ITEM] data={data}, error={exc}")
|
raise exc
|
||||||
return "fail"
|
# return "fail"
|
||||||
|
|
||||||
async def _process(self, shared_storage, data):
|
async def _process(self, shared_storage, data):
|
||||||
for attempt in range(self.max_retries):
|
for attempt in range(self.max_retries):
|
||||||
|
|
@ -121,6 +125,18 @@ class Node(BaseNode):
|
||||||
if attempt == self.max_retries - 1:
|
if attempt == self.max_retries - 1:
|
||||||
return await self.process_after_fail(shared_storage, data, e)
|
return await self.process_after_fail(shared_storage, data, e)
|
||||||
await asyncio.sleep(self.delay_s)
|
await asyncio.sleep(self.delay_s)
|
||||||
|
|
||||||
|
class Flow(BaseNode):
|
||||||
|
def __init__(self, start_node=None):
|
||||||
|
super().__init__()
|
||||||
|
self.start_node = start_node
|
||||||
|
|
||||||
|
async def _process(self, shared_storage, _):
|
||||||
|
if self.start_node:
|
||||||
|
current_node = self.start_node
|
||||||
|
while current_node:
|
||||||
|
current_node = await current_node._run_one(shared_storage or {})
|
||||||
|
return "Flow done"
|
||||||
|
|
||||||
class BatchNode(BaseNode):
|
class BatchNode(BaseNode):
|
||||||
def __init__(self, max_retries=5, delay_s=0.1):
|
def __init__(self, max_retries=5, delay_s=0.1):
|
||||||
|
|
@ -156,18 +172,6 @@ class BatchNode(BaseNode):
|
||||||
results.append(r)
|
results.append(r)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
class Flow(BaseNode):
|
|
||||||
def __init__(self, start_node=None):
|
|
||||||
super().__init__()
|
|
||||||
self.start_node = start_node
|
|
||||||
|
|
||||||
async def _process(self, shared_storage, _):
|
|
||||||
if self.start_node:
|
|
||||||
current_node = self.start_node
|
|
||||||
while current_node:
|
|
||||||
current_node = await current_node._run_one(shared_storage or {})
|
|
||||||
return "Flow done"
|
|
||||||
|
|
||||||
class BatchFlow(BaseNode):
|
class BatchFlow(BaseNode):
|
||||||
def __init__(self, start_node=None):
|
def __init__(self, start_node=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
||||||
Binary file not shown.
Loading…
Reference in New Issue