From c415b553d0d6f9682e2c600e6d4dd36c56cff6c2 Mon Sep 17 00:00:00 2001 From: rahoogan <16303974+rahoogan@users.noreply.github.com> Date: Fri, 16 May 2025 22:25:45 +1000 Subject: [PATCH] pocketflow-visualization: adds action names to links --- .../pocketflow-visualization/visualize.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/cookbook/pocketflow-visualization/visualize.py b/cookbook/pocketflow-visualization/visualize.py index 937a36b..6caf215 100644 --- a/cookbook/pocketflow-visualization/visualize.py +++ b/cookbook/pocketflow-visualization/visualize.py @@ -28,28 +28,31 @@ def build_mermaid(start): ids[n] if n in ids else (ids.setdefault(n, f"N{ctr}"), (ctr := ctr + 1))[0] ) - def link(a, b): - lines.append(f" {a} --> {b}") + def link(a, b, action=None): + if action: + lines.append(f" {a} -->|{action}| {b}") + else: + lines.append(f" {a} --> {b}") - def walk(node, parent=None): + def walk(node, parent=None, action=None): if node in visited: - return parent and link(parent, get_id(node)) + return parent and link(parent, get_id(node), action) visited.add(node) if isinstance(node, Flow): - node.start_node and parent and link(parent, get_id(node.start_node)) + node.start_node and parent and link(parent, get_id(node.start_node), action) lines.append( f"\n subgraph sub_flow_{get_id(node)}[{type(node).__name__}]" ) node.start_node and walk(node.start_node) - for nxt in node.successors.values(): - node.start_node and walk(nxt, get_id(node.start_node)) or ( - parent and link(parent, get_id(nxt)) - ) or walk(nxt) + for act, nxt in node.successors.items(): + node.start_node and walk(nxt, get_id(node.start_node), act) or ( + parent and link(parent, get_id(nxt), action) + ) or walk(nxt, None, act) lines.append(" end\n") else: lines.append(f" {(nid := get_id(node))}['{type(node).__name__}']") - parent and link(parent, nid) - [walk(nxt, nid) for nxt in node.successors.values()] + parent and link(parent, nid, action) + [walk(nxt, nid, act) for act, nxt in node.successors.items()] walk(start) return "\n".join(lines)