first lib
This commit is contained in:
parent
0b8680bf6a
commit
7f107f60e6
|
|
@ -0,0 +1,117 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
class BaseNode:
|
||||||
|
def __init__(self):
|
||||||
|
self.set_parameters({}) # immutable during processing; could be overwritten as node can be reused
|
||||||
|
self.successors = {}
|
||||||
|
|
||||||
|
def set_parameters(self, parameters):
|
||||||
|
self.parameters = parameters.copy() if parameters else {}
|
||||||
|
|
||||||
|
def add_successor(self, node, condition="default"):
|
||||||
|
self.successors[condition] = node
|
||||||
|
return node
|
||||||
|
|
||||||
|
async def preprocess(self, shared_storage):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def process_one(self, shared_storage, item):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def process(self, shared_storage, preprocess_result):
|
||||||
|
return await self.process_one(shared_storage, preprocess_result)
|
||||||
|
|
||||||
|
async def postprocess(self, shared_storage, preprocess_result, process_result):
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
async def run_one(self, shared_storage):
|
||||||
|
preprocess_result = await self.preprocess(shared_storage)
|
||||||
|
process_result = await self.process(shared_storage, preprocess_result)
|
||||||
|
condition = await self.postprocess(shared_storage, preprocess_result, process_result)
|
||||||
|
|
||||||
|
if not self.successors:
|
||||||
|
return None
|
||||||
|
elif len(self.successors) == 1:
|
||||||
|
return next(iter(self.successors.values()))
|
||||||
|
return self.successors.get(condition)
|
||||||
|
|
||||||
|
def run(self, shared_storage=None):
|
||||||
|
return asyncio.run(self.run_async(shared_storage))
|
||||||
|
|
||||||
|
async def run_async(self, shared_storage=None):
|
||||||
|
shared_storage = shared_storage or {}
|
||||||
|
current_node = self
|
||||||
|
while current_node:
|
||||||
|
current_node = await current_node.run_one(shared_storage)
|
||||||
|
|
||||||
|
def __rshift__(self, other):
|
||||||
|
return self.add_successor(other)
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
if isinstance(other, str):
|
||||||
|
return _ConditionalTransition(self, other)
|
||||||
|
elif isinstance(other, BaseNode):
|
||||||
|
return self.add_successor(other)
|
||||||
|
raise TypeError("Unsupported operand type")
|
||||||
|
|
||||||
|
def __call__(self, condition):
|
||||||
|
return _ConditionalTransition(self, condition)
|
||||||
|
|
||||||
|
class _ConditionalTransition:
|
||||||
|
def __init__(self, source_node, condition):
|
||||||
|
self.source_node = source_node
|
||||||
|
self.condition = condition
|
||||||
|
|
||||||
|
def __gt__(self, target_node):
|
||||||
|
if not isinstance(target_node, BaseNode):
|
||||||
|
raise TypeError("Target must be a BaseNode")
|
||||||
|
return self.source_node.add_successor(target_node, self.condition)
|
||||||
|
|
||||||
|
class BaseSuperNode(BaseNode):
|
||||||
|
def __init__(self, start_node=None):
|
||||||
|
super().__init__()
|
||||||
|
self.start_node = start_node
|
||||||
|
|
||||||
|
async def process_one(self, shared_storage, item):
|
||||||
|
if self.start_node:
|
||||||
|
current_node = self.start_node
|
||||||
|
while current_node:
|
||||||
|
current_node.set_parameters(self.parameters)
|
||||||
|
current_node = await current_node.run_one(shared_storage or {})
|
||||||
|
|
||||||
|
class BatchMixin:
|
||||||
|
async def process(self, shared_storage, items):
|
||||||
|
partial_results = []
|
||||||
|
for item in items:
|
||||||
|
r = await self.process_one(shared_storage, item)
|
||||||
|
partial_results.append(r)
|
||||||
|
|
||||||
|
return self.merge(shared_storage, partial_results)
|
||||||
|
|
||||||
|
def merge(self, shared_storage, partial_results):
|
||||||
|
return partial_results
|
||||||
|
|
||||||
|
async def preprocess(self, shared_storage):
|
||||||
|
return []
|
||||||
|
|
||||||
|
class BatchBaseNode(BatchMixin, BaseNode):
|
||||||
|
async def preprocess(self, shared_storage):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def process_one(self, shared_storage, item):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class BatchSuperNode(BatchMixin, BaseSuperNode):
|
||||||
|
async def preprocess(self, shared_storage):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def process_one(self, shared_storage, param_dict):
|
||||||
|
node_parameters = self.parameters.copy()
|
||||||
|
node_parameters.update(param_dict)
|
||||||
|
|
||||||
|
if self.start_node:
|
||||||
|
current_node = self.start_node
|
||||||
|
while current_node:
|
||||||
|
current_node.set_parameters(node_parameters)
|
||||||
|
current_node = await current_node.run_one(shared_storage or {})
|
||||||
|
|
||||||
Binary file not shown.
Loading…
Reference in New Issue