diff --git a/docs/viz.md b/docs/viz.md new file mode 100644 index 0000000..d3235bf --- /dev/null +++ b/docs/viz.md @@ -0,0 +1,129 @@ +--- +layout: default +title: "Visualization" +parent: "Details" +nav_order: 3 +--- + +# Visualization + +Visualizing the flow and node structure can greatly help understanding. While we do **not** include built-in visualization tools, we provide an example of how to create one using **Mermaid**. + +### Example: Visualization of Node with Mermaid + +```python +# Generate Mermaid diagram code for a Flow or Node structure +def build_mermaid(start): + node_ids, visited = {}, set() + lines = ["graph LR"] + counter = [1] # Counter for unique node IDs + + def get_id(node): + if node not in node_ids: + node_ids[node] = f"N{counter[0]}" + counter[0] += 1 + return node_ids[node] + + def connect(src, tgt): + lines.append(f" {src} --> {tgt}") + + def walk(node, parent_id=None): + if node in visited: + if parent_id: + connect(parent_id, get_id(node)) + return + visited.add(node) + + if isinstance(node, Flow): + # Handle Flow nodes + if parent_id and node.start: + connect(parent_id, get_id(node.start)) + + lines.append(f"\n subgraph sub_flow_{get_id(node)}[{type(node).__name__}]") + + if node.start: + walk(node.start) + + for succ_node in node.successors.values(): + if node.start: + walk(succ_node, get_id(node.start)) + else: + if parent_id: + connect(parent_id, get_id(succ_node)) + walk(succ_node) + + lines.append(" end\n") + + else: + # Handle simple Nodes + curr_id = get_id(node) + if isinstance(node, BatchNode): + label = f"{curr_id}@{{shape: procs, label: \"{type(node).__name__}\"}}" + else: + label = f"{curr_id}@{{label: \"{type(node).__name__}\"}}" + + if parent_id: + lines.append(f" {label}") + connect(parent_id, curr_id) + else: + lines.append(f" {label}") + + for succ_node in node.successors.values(): + walk(succ_node, curr_id) + + walk(start) + return "\n".join(lines) +``` + +### Usage Example + +Here, we define some example Nodes and Flows to generate a Mermaid diagram: + +```python +class DataPrepBatchNode(BatchNode): + pass + +class ValidateDataNode(Node): + pass + +class FeatureExtractionNode(Node): + pass + +class TrainModelNode(Node): + pass + +class EvaluateModelNode(Node): + pass + +class ModelFlow(Flow): + pass + +feature_node = FeatureExtractionNode() +train_node = TrainModelNode() +evaluate_node = EvaluateModelNode() +feature_node >> train_node >> evaluate_node +model_flow = ModelFlow(start=feature_node) + +data_prep_node = DataPrepBatchNode() +validate_node = ValidateDataNode() +data_prep_node >> validate_node >> model_flow +build_mermaid(start=data_prep_node) +``` + +The above code produces a Mermaid diagram (e.g., use the [Mermaid Live Editor](https://mermaid.live/) to render it): + +```mermaid +graph LR + N1@{shape: procs, label: "DataPrepBatchNode"} + N2@{label: "ValidateDataNode"} + N1 --> N2 + N2 --> N3 + + subgraph sub_flow_N4[ModelFlow] + N3@{label: "FeatureExtractionNode"} + N5@{label: "TrainModelNode"} + N3 --> N5 + N6@{label: "EvaluateModelNode"} + N5 --> N6 + end +``` diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index 00506ae..32ed3dc 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -24,7 +24,7 @@ class _ConditionalTransition: def __rshift__(self,tgt): return self.src.add_successor(tgt,self.action) class Node(BaseNode): - def __init__(self,max_retries=1,wait=0): super().__init__();self.max_retries,self.wait =max_retries,wait + def __init__(self,max_retries=1,wait=0): super().__init__();self.max_retries,self.wait=max_retries,wait def exec_fallback(self,prep_res,exc): raise exc def _exec(self,prep_res): for i in range(self.max_retries): @@ -75,10 +75,10 @@ class AsyncNode(Node): return await self._run_async(shared) async def _run_async(self,shared): p=await self.prep_async(shared);e=await self._exec(p);return await self.post_async(shared,p,e) -class AsyncBatchNode(AsyncNode): +class AsyncBatchNode(AsyncNode,BatchNode): async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items] -class AsyncParallelBatchNode(AsyncNode): +class AsyncParallelBatchNode(AsyncNode,BatchNode): async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items)) class AsyncFlow(Flow,AsyncNode): @@ -87,13 +87,13 @@ class AsyncFlow(Flow,AsyncNode): while curr:curr.set_params(p);c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared);curr=copy.copy(self.get_next_node(curr,c)) async def _run_async(self,shared): p=await self.prep_async(shared);await self._orch_async(shared);return await self.post_async(shared,p,None) -class AsyncBatchFlow(AsyncFlow): +class AsyncBatchFlow(AsyncFlow,BatchFlow): async def _run_async(self,shared): pr=await self.prep_async(shared) or [] for bp in pr: await self._orch_async(shared,{**self.params,**bp}) return await self.post_async(shared,pr,None) -class AsyncParallelBatchFlow(AsyncFlow): +class AsyncParallelBatchFlow(AsyncFlow,BatchFlow): async def _run_async(self,shared): pr=await self.prep_async(shared) or [] await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr))