""" Decorator for tracing PocketFlow workflows with Langfuse. """ import functools import inspect import uuid from typing import Any, Callable, Dict, Optional, Union from .config import TracingConfig from .core import LangfuseTracer def trace_flow( config: Optional[TracingConfig] = None, flow_name: Optional[str] = None, session_id: Optional[str] = None, user_id: Optional[str] = None ): """ Decorator to add Langfuse tracing to PocketFlow flows. This decorator automatically traces: - Flow execution start/end - Each node's prep, exec, and post phases - Input and output data for each phase - Errors and exceptions Args: config: TracingConfig instance. If None, loads from environment. flow_name: Custom name for the flow. If None, uses the flow class name. session_id: Session ID for grouping related traces. user_id: User ID for the trace. Returns: Decorated flow class or function. Example: ```python from tracing import trace_flow @trace_flow() class MyFlow(Flow): def __init__(self): super().__init__(start=MyNode()) # Or with custom configuration config = TracingConfig.from_env() @trace_flow(config=config, flow_name="CustomFlow") class MyFlow(Flow): pass ``` """ def decorator(flow_class_or_func): # Handle both class and function decoration if inspect.isclass(flow_class_or_func): return _trace_flow_class(flow_class_or_func, config, flow_name, session_id, user_id) else: return _trace_flow_function(flow_class_or_func, config, flow_name, session_id, user_id) return decorator def _trace_flow_class(flow_class, config, flow_name, session_id, user_id): """Trace a Flow class by wrapping its methods.""" # Get or create config if config is None: config = TracingConfig.from_env() # Override session/user if provided if session_id: config.session_id = session_id if user_id: config.user_id = user_id # Get flow name if flow_name is None: flow_name = flow_class.__name__ # Store original methods original_init = flow_class.__init__ original_run = getattr(flow_class, 'run', None) original_run_async = getattr(flow_class, 'run_async', None) def traced_init(self, *args, **kwargs): """Initialize the flow with tracing capabilities.""" # Call original init original_init(self, *args, **kwargs) # Add tracing attributes self._tracer = LangfuseTracer(config) self._flow_name = flow_name self._trace_id = None # Patch all nodes in the flow self._patch_nodes() def traced_run(self, shared): """Traced version of the run method.""" if not hasattr(self, '_tracer'): # Fallback if not properly initialized return original_run(self, shared) if original_run else None # Start trace self._trace_id = self._tracer.start_trace(self._flow_name, shared) try: # Run the original flow result = original_run(self, shared) if original_run else None # End trace successfully self._tracer.end_trace(shared, "success") return result except Exception as e: # End trace with error self._tracer.end_trace(shared, "error") raise finally: # Ensure cleanup self._tracer.flush() async def traced_run_async(self, shared): """Traced version of the async run method.""" if not hasattr(self, '_tracer'): # Fallback if not properly initialized return await original_run_async(self, shared) if original_run_async else None # Start trace self._trace_id = self._tracer.start_trace(self._flow_name, shared) try: # Run the original flow result = await original_run_async(self, shared) if original_run_async else None # End trace successfully self._tracer.end_trace(shared, "success") return result except Exception as e: # End trace with error self._tracer.end_trace(shared, "error") raise finally: # Ensure cleanup self._tracer.flush() def patch_nodes(self): """Patch all nodes in the flow to add tracing.""" if not hasattr(self, 'start_node') or not self.start_node: return visited = set() nodes_to_patch = [self.start_node] while nodes_to_patch: node = nodes_to_patch.pop(0) if id(node) in visited: continue visited.add(id(node)) # Patch this node self._patch_node(node) # Add successors to patch list if hasattr(node, 'successors'): for successor in node.successors.values(): if successor and id(successor) not in visited: nodes_to_patch.append(successor) def patch_node(self, node): """Patch a single node to add tracing.""" if hasattr(node, '_pocketflow_traced'): return # Already patched node_id = str(uuid.uuid4()) node_name = type(node).__name__ # Store original methods original_prep = getattr(node, 'prep', None) original_exec = getattr(node, 'exec', None) original_post = getattr(node, 'post', None) original_prep_async = getattr(node, 'prep_async', None) original_exec_async = getattr(node, 'exec_async', None) original_post_async = getattr(node, 'post_async', None) # Create traced versions if original_prep: node.prep = self._create_traced_method(original_prep, node_id, node_name, 'prep') if original_exec: node.exec = self._create_traced_method(original_exec, node_id, node_name, 'exec') if original_post: node.post = self._create_traced_method(original_post, node_id, node_name, 'post') if original_prep_async: node.prep_async = self._create_traced_async_method(original_prep_async, node_id, node_name, 'prep') if original_exec_async: node.exec_async = self._create_traced_async_method(original_exec_async, node_id, node_name, 'exec') if original_post_async: node.post_async = self._create_traced_async_method(original_post_async, node_id, node_name, 'post') # Mark as traced node._pocketflow_traced = True def create_traced_method(self, original_method, node_id, node_name, phase): """Create a traced version of a synchronous method.""" @functools.wraps(original_method) def traced_method(*args, **kwargs): span_id = self._tracer.start_node_span(node_name, node_id, phase) try: result = original_method(*args, **kwargs) self._tracer.end_node_span(span_id, input_data=args, output_data=result) return result except Exception as e: self._tracer.end_node_span(span_id, input_data=args, error=e) raise return traced_method def create_traced_async_method(self, original_method, node_id, node_name, phase): """Create a traced version of an asynchronous method.""" @functools.wraps(original_method) async def traced_async_method(*args, **kwargs): span_id = self._tracer.start_node_span(node_name, node_id, phase) try: result = await original_method(*args, **kwargs) self._tracer.end_node_span(span_id, input_data=args, output_data=result) return result except Exception as e: self._tracer.end_node_span(span_id, input_data=args, error=e) raise return traced_async_method # Replace methods on the class flow_class.__init__ = traced_init flow_class._patch_nodes = patch_nodes flow_class._patch_node = patch_node flow_class._create_traced_method = create_traced_method flow_class._create_traced_async_method = create_traced_async_method if original_run: flow_class.run = traced_run if original_run_async: flow_class.run_async = traced_run_async return flow_class def _trace_flow_function(flow_func, config, flow_name, session_id, user_id): """Trace a flow function (for functional-style flows).""" # Get or create config if config is None: config = TracingConfig.from_env() # Override session/user if provided if session_id: config.session_id = session_id if user_id: config.user_id = user_id # Get flow name if flow_name is None: flow_name = flow_func.__name__ tracer = LangfuseTracer(config) @functools.wraps(flow_func) def traced_flow_func(*args, **kwargs): # Assume first argument is shared data shared = args[0] if args else {} # Start trace trace_id = tracer.start_trace(flow_name, shared) try: result = flow_func(*args, **kwargs) tracer.end_trace(shared, "success") return result except Exception as e: tracer.end_trace(shared, "error") raise finally: tracer.flush() return traced_flow_func