This is an automated email from the ASF dual-hosted git repository. skrawcz pushed a commit to branch stefan/add-more-lifecycles in repository https://gitbox.apache.org/repos/asf/burr.git
commit 55fcd1e15d5d3ca26d3d5d18486cb2c88eb21e63 Author: Stefan Krawczyk <[email protected]> AuthorDate: Fri Nov 28 22:05:43 2025 -0800 Adds interceptor hooks This proposes a new hook to allow one to remotely push execution of a Burr Action. For example, if one wants to selectively push remote execution of an action to say Ray, then one can now build an interceptor. This also introduces companion hooks that would run pre & post on the remote to mirror what Burr has today - -I thought it would be simpler to not try to reuse those hooks, and instead make it explicit as to what is going on if you wanted to add something similar. See tests for examples, but open to feedback here. Note: we'd need to think through how this plays with parallelism functionality we have. E.g. currently interceptors aren't propagated to sub applications... --- burr/core/application.py | 162 +++++- burr/lifecycle/__init__.py | 24 + burr/lifecycle/base.py | 354 +++++++++++++ burr/lifecycle/internal.py | 130 ++++- tests/integration_tests/test_action_interceptor.py | 581 +++++++++++++++++++++ .../test_async_streaming_interceptor.py | 342 ++++++++++++ tests/lifecycle/__init__.py | 16 + tests/lifecycle/test_internal.py | 401 ++++++++++++++ 8 files changed, 1998 insertions(+), 12 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index 55f98acf..ea39014e 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -158,7 +158,13 @@ def _remap_dunder_parameters( return inputs -def _run_function(function: Function, state: State, inputs: Dict[str, Any], name: str) -> dict: +def _run_function( + function: Function, + state: State, + inputs: Dict[str, Any], + name: str, + adapter_set: Optional["LifecycleAdapterSet"] = None, +) -> dict: """Runs a function, returning the result of running the function. Note this restricts the keys in the state to only those that the function reads. @@ -166,6 +172,8 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name :param function: Function to run :param state: State at time of execution :param inputs: Inputs to the function + :param name: Name of the action (for error messages) + :param adapter_set: Optional lifecycle adapter set for checking interceptors :return: """ if function.is_async(): @@ -174,6 +182,21 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name "in non-async context. Use astep()/aiterate()/arun() " "instead...)" ) + + # Check for execution interceptors + if adapter_set: + interceptor = adapter_set.get_first_matching_hook( + "intercept_action_execution", lambda hook: hook.should_intercept(action=function) + ) + if interceptor: + worker_adapter_set = adapter_set.get_worker_adapter_set() + result = interceptor.intercept_run( + action=function, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + _validate_result(result, name) + return result + + # Normal execution path state_to_use = state.subset(*function.reads) function.validate_inputs(inputs) if "__context" in inputs or "__tracer" in inputs: @@ -185,10 +208,30 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name async def _arun_function( - function: Function, state: State, inputs: Dict[str, Any], name: str + function: Function, + state: State, + inputs: Dict[str, Any], + name: str, + adapter_set: Optional["LifecycleAdapterSet"] = None, ) -> dict: """Runs a function, returning the result of running the function. Async version of the above.""" + + # Check for execution interceptors + if adapter_set: + interceptor = adapter_set.get_first_matching_hook( + "intercept_action_execution", + lambda hook: hook.should_intercept(action=function) and hasattr(hook, "intercept_run"), + ) + if interceptor and inspect.iscoroutinefunction(interceptor.intercept_run): + worker_adapter_set = adapter_set.get_worker_adapter_set() + result = await interceptor.intercept_run( + action=function, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + _validate_result(result, name) + return result + + # Normal execution path state_to_use = state.subset(*function.reads) function.validate_inputs(inputs) result = await function.run(state_to_use, **inputs) @@ -299,7 +342,10 @@ def _format_BASE_ERROR_MESSAGE(action: Action, input_state: State, inputs: dict) def _run_single_step_action( - action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] + action: SingleStepAction, + state: State, + inputs: Optional[Dict[str, Any]], + adapter_set: Optional["LifecycleAdapterSet"] = None, ) -> Tuple[Dict[str, Any], State]: """Runs a single step action. This API is internal-facing and a bit in flux, but it corresponds to the SingleStepAction class. @@ -307,9 +353,33 @@ def _run_single_step_action( :param action: Action to run :param state: State to run with :param inputs: Inputs to pass directly to the action + :param adapter_set: Optional lifecycle adapter set for checking interceptors :return: The result of running the action, and the new state """ - # TODO -- guard all reads/writes with a subset of the state + # Check for execution interceptors + if adapter_set: + interceptor = adapter_set.get_first_matching_hook( + "intercept_action_execution", lambda hook: hook.should_intercept(action=action) + ) + if interceptor: + worker_adapter_set = adapter_set.get_worker_adapter_set() + result = interceptor.intercept_run( + action=action, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + # Check if interceptor returned state via special key (for single-step actions) + if "__INTERCEPTOR_NEW_STATE__" in result: + new_state = result.pop("__INTERCEPTOR_NEW_STATE__") + else: + # For multi-step actions or if state wasn't provided + # we need to compute it + new_state = action.update(result, state) + + _validate_result(result, action.name, action.schema) + out = result, _state_update(state, new_state) + _validate_reducer_writes(action, new_state, action.name) + return out + + # Normal execution path action.validate_inputs(inputs) result, new_state = _adjust_single_step_output( action.run_and_update(state, **inputs), action.name, action.schema @@ -334,7 +404,18 @@ def _run_single_step_streaming_action( action.validate_inputs(inputs) stream_initialize_time = system.now() first_stream_start_time = None - generator = action.stream_run_and_update(state, **inputs) + + # Check for streaming action interceptors + interceptor = lifecycle_adapters.get_first_matching_hook( + "intercept_streaming_action", lambda hook: hook.should_intercept(action=action) + ) + if interceptor: + worker_adapter_set = lifecycle_adapters.get_worker_adapter_set() + generator = interceptor.intercept_stream_run_and_update( + action=action, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + else: + generator = action.stream_run_and_update(state, **inputs) result = None state_update = None count = 0 @@ -387,7 +468,20 @@ async def _arun_single_step_streaming_action( action.validate_inputs(inputs) stream_initialize_time = system.now() first_stream_start_time = None - generator = action.stream_run_and_update(state, **inputs) + + # Check for streaming action interceptors + interceptor = lifecycle_adapters.get_first_matching_hook( + "intercept_streaming_action", + lambda hook: hook.should_intercept(action=action) + and hasattr(hook, "intercept_stream_run_and_update"), + ) + if interceptor and inspect.isasyncgenfunction(interceptor.intercept_stream_run_and_update): + worker_adapter_set = lifecycle_adapters.get_worker_adapter_set() + generator = interceptor.intercept_stream_run_and_update( + action=action, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + else: + generator = action.stream_run_and_update(state, **inputs) result = None state_update = None count = 0 @@ -523,9 +617,35 @@ async def _arun_multi_step_streaming_action( async def _arun_single_step_action( - action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] + action: SingleStepAction, + state: State, + inputs: Optional[Dict[str, Any]], + adapter_set: Optional["LifecycleAdapterSet"] = None, ) -> Tuple[dict, State]: """Runs a single step action in async. See the synchronous version for more details.""" + # Check for execution interceptors + if adapter_set: + interceptor = adapter_set.get_first_matching_hook( + "intercept_action_execution", + lambda hook: hook.should_intercept(action=action) and hasattr(hook, "intercept_run"), + ) + if interceptor and inspect.iscoroutinefunction(interceptor.intercept_run): + worker_adapter_set = adapter_set.get_worker_adapter_set() + result = await interceptor.intercept_run( + action=action, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + # Check if interceptor returned state via special key (for single-step actions) + if "__INTERCEPTOR_NEW_STATE__" in result: + new_state = result.pop("__INTERCEPTOR_NEW_STATE__") + else: + # For multi-step actions or if state wasn't provided + new_state = action.update(result, state) + + _validate_result(result, action.name, action.schema) + _validate_reducer_writes(action, new_state, action.name) + return result, _state_update(state, new_state) + + # Normal execution path state_to_use = state action.validate_inputs(inputs) result, new_state = _adjust_single_step_output( @@ -915,11 +1035,15 @@ class Application(Generic[ApplicationStateType]): try: if next_action.single_step: result, new_state = _run_single_step_action( - next_action, self._state, action_inputs + next_action, self._state, action_inputs, adapter_set=self._adapter_set ) else: result = _run_function( - next_action, self._state, action_inputs, name=next_action.name + next_action, + self._state, + action_inputs, + name=next_action.name, + adapter_set=self._adapter_set, ) new_state = _run_reducer(next_action, self._state, result, next_action.name) @@ -1051,7 +1175,19 @@ class Application(Generic[ApplicationStateType]): result = None new_state = self._state try: - if not next_action.is_async(): + # Check if there's an async interceptor for this action + has_async_interceptor = False + if self._adapter_set: + interceptor = self._adapter_set.get_first_matching_hook( + "intercept_action_execution", + lambda hook: hook.should_intercept(action=next_action) + and hasattr(hook, "intercept_run"), + ) + if interceptor and inspect.iscoroutinefunction(interceptor.intercept_run): + has_async_interceptor = True + + # Only delegate to sync version if action is sync AND no async interceptor + if not next_action.is_async() and not has_async_interceptor: # we can just delegate to the synchronous version, it will block the event loop, # but that's safer than assuming its OK to launch a thread # TODO -- add an option/configuration to launch a thread (yikes, not super safe, but for a pure function @@ -1065,7 +1201,10 @@ class Application(Generic[ApplicationStateType]): action_inputs = self._process_inputs(inputs, next_action) if next_action.single_step: result, new_state = await _arun_single_step_action( - next_action, self._state, inputs=action_inputs + next_action, + self._state, + inputs=action_inputs, + adapter_set=self._adapter_set, ) else: result = await _arun_function( @@ -1073,6 +1212,7 @@ class Application(Generic[ApplicationStateType]): self._state, inputs=action_inputs, name=next_action.name, + adapter_set=self._adapter_set, ) new_state = _run_reducer(next_action, self._state, result, next_action.name) new_state = self._update_internal_state_value(new_state, next_action) diff --git a/burr/lifecycle/__init__.py b/burr/lifecycle/__init__.py index 4ae24073..991ee984 100644 --- a/burr/lifecycle/__init__.py +++ b/burr/lifecycle/__init__.py @@ -16,18 +16,30 @@ # under the License. from burr.lifecycle.base import ( + ActionExecutionInterceptorHook, + ActionExecutionInterceptorHookAsync, LifecycleAdapter, PostApplicationCreateHook, PostApplicationExecuteCallHook, PostApplicationExecuteCallHookAsync, PostEndSpanHook, + PostEndStreamHookWorker, + PostEndStreamHookWorkerAsync, PostRunStepHook, PostRunStepHookAsync, + PostRunStepHookWorker, + PostRunStepHookWorkerAsync, PreApplicationExecuteCallHook, PreApplicationExecuteCallHookAsync, PreRunStepHook, PreRunStepHookAsync, + PreRunStepHookWorker, + PreRunStepHookWorkerAsync, PreStartSpanHook, + PreStartStreamHookWorker, + PreStartStreamHookWorkerAsync, + StreamingActionInterceptorHook, + StreamingActionInterceptorHookAsync, ) from burr.lifecycle.default import StateAndResultsFullLogger @@ -45,4 +57,16 @@ __all__ = [ "PostApplicationCreateHook", "PostEndSpanHook", "PreStartSpanHook", + "PreRunStepHookWorker", + "PreRunStepHookWorkerAsync", + "PostRunStepHookWorker", + "PostRunStepHookWorkerAsync", + "PreStartStreamHookWorker", + "PreStartStreamHookWorkerAsync", + "PostEndStreamHookWorker", + "PostEndStreamHookWorkerAsync", + "ActionExecutionInterceptorHook", + "ActionExecutionInterceptorHookAsync", + "StreamingActionInterceptorHook", + "StreamingActionInterceptorHookAsync", ] diff --git a/burr/lifecycle/base.py b/burr/lifecycle/base.py index 66d8bd7e..7474172e 100644 --- a/burr/lifecycle/base.py +++ b/burr/lifecycle/base.py @@ -492,6 +492,348 @@ class PostEndStreamHookAsync(abc.ABC): pass [email protected]_hook("pre_run_step_worker") +class PreRunStepHookWorker(abc.ABC): + """Hook that runs on the worker (e.g., Ray/Temporal) before action execution. + This hook is designed to be called by execution interceptors on remote workers, + as opposed to PreRunStepHook which always runs on the main orchestrator process.""" + + @abc.abstractmethod + def pre_run_step_worker( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + """Run before a step is executed on the worker. + + :param action: Action to be executed + :param state: State prior to step execution + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments + """ + pass + + [email protected]_hook("pre_run_step_worker") +class PreRunStepHookWorkerAsync(abc.ABC): + """Async hook that runs on the worker before action execution.""" + + @abc.abstractmethod + async def pre_run_step_worker( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + """Async run before a step is executed on the worker. + + :param action: Action to be executed + :param state: State prior to step execution + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments + """ + pass + + [email protected]_hook("post_run_step_worker") +class PostRunStepHookWorker(abc.ABC): + """Hook that runs on the worker after action execution. + This hook is designed to be called by execution interceptors on remote workers, + as opposed to PostRunStepHook which always runs on the main orchestrator process.""" + + @abc.abstractmethod + def post_run_step_worker( + self, + *, + action: "Action", + state: "State", + result: Optional[Dict[str, Any]], + exception: Exception, + **future_kwargs: Any, + ): + """Run after a step is executed on the worker. + + :param action: Action that was executed + :param state: State after step execution + :param result: Result of the action + :param exception: Exception that was raised + :param future_kwargs: Future keyword arguments + """ + pass + + [email protected]_hook("post_run_step_worker") +class PostRunStepHookWorkerAsync(abc.ABC): + """Async hook that runs on the worker after action execution.""" + + @abc.abstractmethod + async def post_run_step_worker( + self, + *, + action: "Action", + state: "State", + result: Optional[dict], + exception: Exception, + **future_kwargs: Any, + ): + """Async run after a step is executed on the worker. + + :param action: Action that was executed + :param state: State after step execution + :param result: Result of the action + :param exception: Exception that was raised + :param future_kwargs: Future keyword arguments + """ + pass + + [email protected]_hook("pre_start_stream_worker") +class PreStartStreamHookWorker(abc.ABC): + """Hook that runs on the worker after a stream is started.""" + + @abc.abstractmethod + def pre_start_stream_worker( + self, + *, + action: str, + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + pass + + [email protected]_hook("pre_start_stream_worker") +class PreStartStreamHookWorkerAsync(abc.ABC): + """Async hook that runs on the worker after a stream is started.""" + + @abc.abstractmethod + async def pre_start_stream_worker( + self, + *, + action: str, + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + pass + + [email protected]_hook("post_end_stream_worker") +class PostEndStreamHookWorker(abc.ABC): + """Hook that runs on the worker after a stream is ended.""" + + @abc.abstractmethod + def post_end_stream_worker( + self, + *, + action: str, + result: Optional[Dict[str, Any]], + exception: Exception, + **future_kwargs: Any, + ): + pass + + [email protected]_hook("post_end_stream_worker") +class PostEndStreamHookWorkerAsync(abc.ABC): + """Async hook that runs on the worker after a stream is ended.""" + + @abc.abstractmethod + async def post_end_stream_worker( + self, + *, + action: str, + result: Optional[Dict[str, Any]], + exception: Exception, + **future_kwargs: Any, + ): + pass + + [email protected]_hook("intercept_action_execution") +class ActionExecutionInterceptorHook(abc.ABC): + """Hook that can wrap/replace action execution (e.g., for Ray/Temporal). + This hook allows you to intercept the execution of an action and run it + on a different execution backend while maintaining the same interface. + + The interceptor receives a worker_adapter_set containing only worker hooks + (PreRunStepHookWorker, PostRunStepHookWorker, etc.) that can be called + on the remote execution environment.""" + + @abc.abstractmethod + def should_intercept( + self, + *, + action: "Action", + **future_kwargs: Any, + ) -> bool: + """Determine if this action should be intercepted. + + :param action: Action to potentially intercept + :param future_kwargs: Future keyword arguments + :return: True if this hook should intercept execution + """ + pass + + @abc.abstractmethod + def intercept_run( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ) -> dict: + """Replace the action.run() call with custom execution. + + Note: The state passed here is the FULL state, not subsetted. + You are responsible for subsetting it to action.reads if needed. + + :param action: Action to execute + :param state: Current state (FULL state, not subsetted) + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments (includes worker_adapter_set) + :return: Result dictionary from running the action + """ + pass + + [email protected]_hook("intercept_action_execution") +class ActionExecutionInterceptorHookAsync(abc.ABC): + """Async version of ActionExecutionInterceptorHook for intercepting async actions.""" + + @abc.abstractmethod + async def should_intercept( + self, + *, + action: "Action", + **future_kwargs: Any, + ) -> bool: + """Determine if this action should be intercepted. + + :param action: Action to potentially intercept + :param future_kwargs: Future keyword arguments + :return: True if this hook should intercept execution + """ + pass + + @abc.abstractmethod + async def intercept_run( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ) -> dict: + """Replace the action.run() call with custom execution. + + Note: The state passed here is the FULL state, not subsetted. + You are responsible for subsetting it to action.reads if needed. + + :param action: Action to execute + :param state: Current state (FULL state, not subsetted) + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments (includes worker_adapter_set) + :return: Result dictionary from running the action + """ + pass + + [email protected]_hook( + "intercept_streaming_action", intercept_method="intercept_stream_run_and_update" +) +class StreamingActionInterceptorHook(abc.ABC): + """Hook to intercept streaming action execution (e.g., for Ray/Temporal). + This hook allows you to wrap streaming actions to execute on different backends. + + The interceptor receives a worker_adapter_set containing only worker hooks + that can be called on the remote execution environment.""" + + @abc.abstractmethod + def should_intercept( + self, + *, + action: "Action", + **future_kwargs: Any, + ) -> bool: + """Determine if this streaming action should be intercepted. + + :param action: Streaming action to potentially intercept + :param future_kwargs: Future keyword arguments + :return: True if this hook should intercept execution + """ + pass + + @abc.abstractmethod + def intercept_stream_run_and_update( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + """Replace stream_run_and_update with custom execution. + Must be a generator that yields (result_dict, optional_state) tuples. + + :param action: Streaming action to execute + :param state: Current state + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments (includes worker_adapter_set) + :return: Generator yielding (dict, Optional[State]) tuples + """ + pass + + [email protected]_hook( + "intercept_streaming_action", intercept_method="intercept_stream_run_and_update" +) +class StreamingActionInterceptorHookAsync(abc.ABC): + """Async version for intercepting async streaming actions.""" + + @abc.abstractmethod + async def should_intercept( + self, + *, + action: "Action", + **future_kwargs: Any, + ) -> bool: + """Determine if this streaming action should be intercepted. + + :param action: Streaming action to potentially intercept + :param future_kwargs: Future keyword arguments + :return: True if this hook should intercept execution + """ + pass + + @abc.abstractmethod + def intercept_stream_run_and_update( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + """Replace stream_run_and_update with custom execution. + Must be an async generator that yields (result_dict, optional_state) tuples. + + :param action: Streaming action to execute + :param state: Current state + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments (includes worker_adapter_set) + :return: Async generator yielding (dict, Optional[State]) tuples + """ + pass + + # strictly for typing -- this conflicts a bit with the lifecycle decorator above, but its fine for now # This makes IDE completion/type-hinting easier LifecycleAdapter = Union[ @@ -515,4 +857,16 @@ LifecycleAdapter = Union[ PreStartStreamHookAsync, PostStreamItemHookAsync, PostEndStreamHookAsync, + PreRunStepHookWorker, + PreRunStepHookWorkerAsync, + PostRunStepHookWorker, + PostRunStepHookWorkerAsync, + PreStartStreamHookWorker, + PreStartStreamHookWorkerAsync, + PostEndStreamHookWorker, + PostEndStreamHookWorkerAsync, + ActionExecutionInterceptorHook, + ActionExecutionInterceptorHookAsync, + StreamingActionInterceptorHook, + StreamingActionInterceptorHookAsync, ] diff --git a/burr/lifecycle/internal.py b/burr/lifecycle/internal.py index 1043bd0a..168fedc8 100644 --- a/burr/lifecycle/internal.py +++ b/burr/lifecycle/internal.py @@ -28,9 +28,11 @@ if TYPE_CHECKING: SYNC_HOOK = "hooks" ASYNC_HOOK = "async_hooks" +INTERCEPTOR_TYPE = "interceptor_type" REGISTERED_SYNC_HOOKS: Set[str] = set() REGISTERED_ASYNC_HOOKS: Set[str] = set() +REGISTERED_INTERCEPTORS: Set[str] = set() class InvalidLifecycleHook(Exception): @@ -64,6 +66,36 @@ def validate_hook_fn(fn: Callable): ) +def validate_interceptor_method(fn: Callable, method_name: str): + """Validates that an interceptor method has the correct signature. + Interceptor methods must have keyword-only arguments (including **future_kwargs). + + :param fn: The function to validate + :param method_name: Name of the method being validated + :raises InvalidLifecycleHook: If the function is not a valid interceptor method + """ + if fn is None: + raise InvalidLifecycleHook(f"Interceptor method {method_name} does not exist on the class.") + sig = inspect.signature(fn) + # Check for **future_kwargs + if ( + "future_kwargs" not in sig.parameters + or sig.parameters["future_kwargs"].kind != inspect.Parameter.VAR_KEYWORD + ): + raise InvalidLifecycleHook( + f"Interceptor method {method_name} must have a `**future_kwargs` argument. " + f"Method {fn} does not." + ) + # All non-self, non-future_kwargs parameters must be keyword-only + for param in sig.parameters.values(): + if param.name not in ("future_kwargs", "self"): + if param.kind != inspect.Parameter.KEYWORD_ONLY: + raise InvalidLifecycleHook( + f"Interceptor method {method_name} can only have keyword-only arguments. " + f"Method {fn} has argument {param} that is not keyword-only." + ) + + class lifecycle: """Container class for decorators to register hooks. This is just a container so it looks clean (`@lifecycle.base_hook(...)`), but we could easily move it out. @@ -105,6 +137,41 @@ class lifecycle: return decorator + @classmethod + def interceptor_hook( + cls, + interceptor_type: str, + should_intercept_method: str = "should_intercept", + intercept_method: str = "intercept_run", + ): + """Decorator for interceptor hooks that can wrap/replace action execution. + + Interceptors have two methods: + 1. should_intercept() - determines if an action should be intercepted + 2. intercept_run() or intercept_stream_run_and_update() - replaces the execution + + :param interceptor_type: Type identifier for the interceptor (e.g., "intercept_action_execution", "intercept_streaming_action") + :param should_intercept_method: Name of the should_intercept method (default: "should_intercept") + :param intercept_method: Name of the intercept method (default: "intercept_run" or "intercept_stream_run_and_update") + """ + + def decorator(clazz): + # Validate should_intercept method + should_intercept_fn = getattr(clazz, should_intercept_method, None) + validate_interceptor_method(should_intercept_fn, should_intercept_method) + + # Validate intercept method + intercept_fn = getattr(clazz, intercept_method, None) + validate_interceptor_method(intercept_fn, intercept_method) + + # Register the interceptor type + setattr(clazz, INTERCEPTOR_TYPE, interceptor_type) + REGISTERED_INTERCEPTORS.add(interceptor_type) + + return clazz + + return decorator + class LifecycleAdapterSet: """An internal class that groups together all the lifecycle adapters. @@ -119,7 +186,15 @@ class LifecycleAdapterSet: :param adapters: Adapters to group together """ self._adapters = list(adapters) - self.sync_hooks, self.async_hooks = self._get_lifecycle_hooks() + self._sync_hooks, self._async_hooks = self._get_lifecycle_hooks() + + @property + def sync_hooks(self): + return self._sync_hooks + + @property + def async_hooks(self): + return self._async_hooks def with_new_adapters(self, *adapters: "LifecycleAdapter") -> "LifecycleAdapterSet": """Adds new adapters to the set. @@ -212,3 +287,56 @@ class LifecycleAdapterSet: :return: A list of adapters """ return self._adapters + + def get_first_matching_hook( + self, hook_name: str, predicate: Callable[["LifecycleAdapter"], bool] + ): + """Get first hook of given type that matches predicate. + + For interceptor hooks, this uses the registered interceptor types to find + matching interceptors. For standard hooks, it uses the hook registry. + + :param hook_name: Name of the hook to search for (or interceptor type) + :param predicate: Function that takes a hook and returns True if it matches + :return: The first matching hook, or None if no match found + """ + # Check if this is a registered interceptor type + if hook_name in REGISTERED_INTERCEPTORS: + # Search for adapters with this interceptor type + for adapter in self.adapters: + for cls in inspect.getmro(adapter.__class__): + interceptor_type = getattr(cls, INTERCEPTOR_TYPE, None) + if interceptor_type == hook_name: + if predicate(adapter): + return adapter + return None + + # Standard hook lookup for registered hooks + hooks = self.sync_hooks.get(hook_name, []) + self.async_hooks.get(hook_name, []) + for hook in hooks: + if predicate(hook): + return hook + return None + + def get_worker_adapter_set(self) -> "LifecycleAdapterSet": + """Create a new LifecycleAdapterSet containing only worker hooks. + Worker hooks are those with names ending in '_worker' and are designed + to be called on remote execution environments (Ray/Temporal workers). + + :return: A new LifecycleAdapterSet with only worker hooks + """ + worker_hooks = [] + for adapter in self.adapters: + # Check if this adapter is a worker hook by looking at its registered hooks + is_worker = False + for cls in inspect.getmro(adapter.__class__): + sync_hook = getattr(cls, SYNC_HOOK, None) + async_hook = getattr(cls, ASYNC_HOOK, None) + if (sync_hook and sync_hook.endswith("_worker")) or ( + async_hook and async_hook.endswith("_worker") + ): + is_worker = True + break + if is_worker: + worker_hooks.append(adapter) + return LifecycleAdapterSet(*worker_hooks) diff --git a/tests/integration_tests/test_action_interceptor.py b/tests/integration_tests/test_action_interceptor.py new file mode 100644 index 00000000..d9149efd --- /dev/null +++ b/tests/integration_tests/test_action_interceptor.py @@ -0,0 +1,581 @@ +# Tests for action execution interceptor hooks +from typing import Any, Dict, Generator, Optional, Tuple + +import pytest + +from burr.core import Action, ApplicationBuilder, State, action +from burr.core.action import streaming_action +from burr.lifecycle import ( + ActionExecutionInterceptorHook, + ActionExecutionInterceptorHookAsync, + PostRunStepHookWorker, + PreRunStepHookWorker, + StreamingActionInterceptorHook, +) + + +# Test actions +@action(reads=["x"], writes=["y"]) +def add_one(state: State) -> Tuple[dict, State]: + result = {"y": state["x"] + 1} + return result, state.update(**result) + + +@action(reads=["x"], writes=["z"], tags=["intercepted"]) +def multiply_by_two(state: State) -> Tuple[dict, State]: + result = {"z": state["x"] * 2} + return result, state.update(**result) + + +@streaming_action(reads=["prompt"], writes=["response"], tags=["streaming_intercepted"]) +def streaming_responder(state: State) -> Generator[Tuple[dict, Optional[State]], None, None]: + """Simple streaming action for testing""" + tokens = ["Hello", " ", "World", "!"] + buffer = [] + for token in tokens: + buffer.append(token) + yield {"response": token}, None + full_response = "".join(buffer) + yield {"response": full_response}, state.update(response=full_response) + + +@action(reads=["x"], writes=["w"], tags=["intercepted"]) +async def async_multiply(state: State) -> Tuple[dict, State]: + """Async action for testing""" + import asyncio + + await asyncio.sleep(0.01) # Simulate async work + result = {"w": state["x"] * 3} + return result, state.update(**result) + + +# Mock interceptor that captures execution +class MockActionInterceptor(ActionExecutionInterceptorHook): + """Test interceptor that tracks which actions were intercepted""" + + def __init__(self): + self.intercepted_actions = [] + self.worker_hooks_called = [] + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + # Intercept actions with the "intercepted" tag + return "intercepted" in action.tags + + def intercept_run( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ) -> dict: + self.intercepted_actions.append(action.name) + + # Extract worker_adapter_set if provided + worker_adapter_set = kwargs.get("worker_adapter_set") + + # Call worker pre-hooks if they exist + if worker_adapter_set: + worker_adapter_set.call_all_lifecycle_hooks_sync( + "pre_run_step_worker", + action=action, + state=state, + inputs=inputs, + ) + + # Simulate "remote" execution - check if it's a single-step action + # For single-step actions, we need to call run_and_update and handle both result and state + if hasattr(action, "single_step") and action.single_step: + # Store the new state in a special key that _run_single_step_action will extract + result, new_state = action.run_and_update(state, **inputs) + # Store state in result for extraction + result_with_state = result.copy() + result_with_state["__INTERCEPTOR_NEW_STATE__"] = new_state + result = result_with_state + else: + # For multi-step actions, call run + state_to_use = state.subset(*action.reads) + action.validate_inputs(inputs) + result = action.run(state_to_use, **inputs) + + # Call worker post-hooks if they exist + if worker_adapter_set: + worker_adapter_set.call_all_lifecycle_hooks_sync( + "post_run_step_worker", + action=action, + state=state, + result=result, + exception=None, + ) + + return result + + +class MockStreamingInterceptor(StreamingActionInterceptorHook): + """Test interceptor for streaming actions""" + + def __init__(self): + self.intercepted_actions = [] + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + return "streaming_intercepted" in action.tags + + def intercept_stream_run_and_update( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ): + self.intercepted_actions.append(action.name) + + # Extract worker_adapter_set if provided + worker_adapter_set = kwargs.get("worker_adapter_set") + + # Call worker pre-stream-hooks if they exist + if worker_adapter_set: + worker_adapter_set.call_all_lifecycle_hooks_sync( + "pre_start_stream_worker", + action=action.name, + state=state, + inputs=inputs, + ) + + # Run the streaming action normally (simulating remote execution) + generator = action.stream_run_and_update(state, **inputs) + result = None + for item in generator: + result = item + yield item + + # Call worker post-stream-hooks if they exist + if worker_adapter_set and result: + worker_adapter_set.call_all_lifecycle_hooks_sync( + "post_end_stream_worker", + action=action.name, + result=result[0] if result else None, + exception=None, + ) + + +class WorkerPreHook(PreRunStepHookWorker): + """Test worker hook that runs before action execution""" + + def __init__(self): + self.called_actions = [] + + def pre_run_step_worker( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ): + self.called_actions.append(("pre", action.name)) + + +class WorkerPostHook(PostRunStepHookWorker): + """Test worker hook that runs after action execution""" + + def __init__(self): + self.called_actions = [] + + def post_run_step_worker( + self, + *, + action: Action, + state: State, + result: Optional[Dict[str, Any]], + exception: Exception, + **kwargs, + ): + self.called_actions.append(("post", action.name)) + + +def test_interceptor_intercepts_tagged_action(): + """Test that interceptor only intercepts actions with specific tags""" + interceptor = MockActionInterceptor() + + app = ( + ApplicationBuilder() + .with_state(x=5) + .with_actions(add_one, multiply_by_two) + .with_transitions( + ("add_one", "multiply_by_two"), + ("multiply_by_two", "add_one"), + ) + .with_entrypoint("add_one") + .with_hooks(interceptor) + .build() + ) + + # Run add_one (not intercepted) + action, result, state = app.step() + assert action.name == "add_one" + assert state["y"] == 6 + assert "add_one" not in interceptor.intercepted_actions + + # Run multiply_by_two (intercepted) + action, result, state = app.step() + assert action.name == "multiply_by_two" + assert state["z"] == 10 # 5 * 2, using original x value + assert "multiply_by_two" in interceptor.intercepted_actions + + +def test_interceptor_calls_worker_hooks(): + """Test that interceptor properly calls worker hooks""" + interceptor = MockActionInterceptor() + worker_pre = WorkerPreHook() + worker_post = WorkerPostHook() + + app = ( + ApplicationBuilder() + .with_state(x=10) + .with_actions(multiply_by_two) + .with_entrypoint("multiply_by_two") + .with_hooks(interceptor, worker_pre, worker_post) + .build() + ) + + action, result, state = app.step() + assert action.name == "multiply_by_two" + assert state["z"] == 20 + + # Verify interceptor ran + assert "multiply_by_two" in interceptor.intercepted_actions + + # Verify worker hooks were called + assert ("pre", "multiply_by_two") in worker_pre.called_actions + assert ("post", "multiply_by_two") in worker_post.called_actions + + +def test_no_interceptor_normal_execution(): + """Test that actions run normally without interceptors""" + app = ( + ApplicationBuilder() + .with_state(x=3) + .with_actions(add_one, multiply_by_two) + .with_transitions( + ("add_one", "multiply_by_two"), + ) + .with_entrypoint("add_one") + .build() + ) + + # Both should run normally + action, result, state = app.step() + assert action.name == "add_one" + assert state["y"] == 4 + + action, result, state = app.step() + assert action.name == "multiply_by_two" + assert state["z"] == 6 # 3 * 2 + + +def test_streaming_action_interceptor(): + """Test interceptor for streaming actions""" + streaming_interceptor = MockStreamingInterceptor() + + app = ( + ApplicationBuilder() + .with_state(prompt="test") + .with_actions(streaming_responder) + .with_entrypoint("streaming_responder") + .with_hooks(streaming_interceptor) + .build() + ) + + # Run streaming action + action, streaming_container = app.stream_result( + halt_after=["streaming_responder"], + ) + + # Consume the stream + tokens = [] + for item in streaming_container: + tokens.append(item["response"]) + + result, final_state = streaming_container.get() + + # Verify interceptor ran + assert "streaming_responder" in streaming_interceptor.intercepted_actions + + # Verify streaming worked correctly + assert tokens == ["Hello", " ", "World", "!"] + assert final_state["response"] == "Hello World!" + + +def test_multiple_interceptors_first_wins(): + """Test that when multiple interceptors match, the first one wins""" + + class FirstInterceptor(ActionExecutionInterceptorHook): + def __init__(self): + self.called = False + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + return "intercepted" in action.tags + + def intercept_run( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ) -> dict: + self.called = True + # Return a custom result with state for single-step actions + result = {"z": 999} + if hasattr(action, "single_step") and action.single_step: + result["__INTERCEPTOR_NEW_STATE__"] = state.update(z=999) + return result + + class SecondInterceptor(ActionExecutionInterceptorHook): + def __init__(self): + self.called = False + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + return "intercepted" in action.tags + + def intercept_run( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ) -> dict: + self.called = True + result = {"z": 777} + if hasattr(action, "single_step") and action.single_step: + result["__INTERCEPTOR_NEW_STATE__"] = state.update(z=777) + return result + + first = FirstInterceptor() + second = SecondInterceptor() + + app = ( + ApplicationBuilder() + .with_state(x=5) + .with_actions(multiply_by_two) + .with_entrypoint("multiply_by_two") + .with_hooks(first, second) # first is registered first + .build() + ) + + action, result, state = app.step() + + # First interceptor should have been called + assert first.called + assert state["z"] == 999 + + # Second interceptor should NOT have been called + assert not second.called + + [email protected] +async def test_async_interceptor_with_sync_action(): + """Test that async interceptors work with sync actions""" + import asyncio + + class AsyncMockInterceptor(ActionExecutionInterceptorHookAsync): + """Async interceptor that simulates async execution (e.g., Ray with asyncio)""" + + def __init__(self): + self.intercepted_actions = [] + self.async_calls_made = 0 + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + return "intercepted" in action.tags + + async def intercept_run( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ) -> dict: + self.intercepted_actions.append(action.name) + + # Simulate async operation (e.g., waiting for Ray actor) + await asyncio.sleep(0.01) + self.async_calls_made += 1 + + # Execute action (sync action, but in async context) + if hasattr(action, "single_step") and action.single_step: + result, new_state = action.run_and_update(state, **inputs) + result_with_state = result.copy() + result_with_state["__INTERCEPTOR_NEW_STATE__"] = new_state + result = result_with_state + else: + state_to_use = state.subset(*action.reads) + result = action.run(state_to_use, **inputs) + + return result + + interceptor = AsyncMockInterceptor() + + app = ( + ApplicationBuilder() + .with_state(x=5) + .with_actions(add_one, multiply_by_two) + .with_transitions( + ("add_one", "multiply_by_two"), + ("multiply_by_two", "add_one"), + ) + .with_entrypoint("add_one") + .with_hooks(interceptor) + .build() + ) + + # Run add_one (not intercepted) - should work with astep + action, result, state = await app.astep() + assert action.name == "add_one" + assert state["y"] == 6 + assert "add_one" not in interceptor.intercepted_actions + assert interceptor.async_calls_made == 0 + + # Run multiply_by_two (intercepted) - async interceptor should be called + action, result, state = await app.astep() + assert action.name == "multiply_by_two" + assert state["z"] == 10 # 5 * 2 + assert "multiply_by_two" in interceptor.intercepted_actions + assert interceptor.async_calls_made == 1 + + +def test_interceptor_with_field_level_serde(): + """Test that interceptors properly handle non-serializable objects via field-level serde""" + + # Create a mock non-serializable object (simulating DB client) + class DummyDBClient: + def __init__(self, connection_string: str): + self.connection_string = connection_string + + def query(self, sql: str): + return f"Result from {self.connection_string}: {sql}" + + # Register field-level serde for db_client + from burr.core.state import register_field_serde + + def serialize_db_client(value: Any, **kwargs) -> dict: + """Serialize DB client to connection string""" + return { + "connection_string": value.connection_string, + "type": "db_client", + } + + def deserialize_db_client(value: dict, **kwargs) -> Any: + """Recreate DB client from connection string""" + return DummyDBClient(value["connection_string"]) + + register_field_serde("db_client", serialize_db_client, deserialize_db_client) + + # Create interceptor that uses serialize/deserialize + class SerdeAwareInterceptor(ActionExecutionInterceptorHook): + def __init__(self): + self.intercepted_actions = [] + self.serialized_states = [] + self.deserialized_states = [] + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + return "intercepted" in action.tags + + def intercept_run( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ) -> dict: + self.intercepted_actions.append(action.name) + + # Serialize state (this will use field-level serde for db_client) + state_subset = state.subset(*action.reads) if action.reads else state + state_dict = state_subset.serialize() + self.serialized_states.append(state_dict) + + # Deserialize on "worker" side + worker_state = State.deserialize(state_dict) + self.deserialized_states.append(worker_state) + + # Execute action + if hasattr(action, "single_step") and action.single_step: + result, new_state = action.run_and_update(worker_state, **inputs) + # Serialize new_state before returning + new_state_dict = new_state.serialize() + # Deserialize when reconstructing + reconstructed_state = State.deserialize(new_state_dict) + result_with_state = result.copy() + result_with_state["__INTERCEPTOR_NEW_STATE__"] = reconstructed_state + return result_with_state + else: + state_to_use = worker_state.subset(*action.reads) + result = action.run(state_to_use, **inputs) + return result + + # Create action that uses db_client + @action(reads=["x", "db_client"], writes=["y"], tags=["intercepted"]) + def query_db(state: State) -> Tuple[dict, State]: + """Action that uses db_client from state""" + db_client = state["db_client"] + query_result = db_client.query(f"SELECT * FROM table WHERE x={state['x']}") + result = {"y": query_result} + return result, state.update(**result) + + interceptor = SerdeAwareInterceptor() + db_client = DummyDBClient("postgresql://localhost/db") + + app = ( + ApplicationBuilder() + .with_state(x=5, db_client=db_client) + .with_actions(query_db) + .with_entrypoint("query_db") + .with_hooks(interceptor) + .build() + ) + + # Run action + executed_action, result, state = app.step() + + # Verify interceptor ran + assert "query_db" in interceptor.intercepted_actions + + # Verify state was serialized (db_client should be converted to dict) + serialized_state = interceptor.serialized_states[0] + assert "db_client" in serialized_state + assert isinstance(serialized_state["db_client"], dict) + assert serialized_state["db_client"]["type"] == "db_client" + assert serialized_state["db_client"]["connection_string"] == "postgresql://localhost/db" + + # Verify state was deserialized (db_client should be recreated) + deserialized_state = interceptor.deserialized_states[0] + assert "db_client" in deserialized_state + assert isinstance(deserialized_state["db_client"], DummyDBClient) + assert deserialized_state["db_client"].connection_string == "postgresql://localhost/db" + + # Verify final state has working db_client + assert "db_client" in state + assert isinstance(state["db_client"], DummyDBClient) + assert "Result from postgresql://localhost/db" in state["y"] + + [email protected] +async def test_async_interceptor_with_async_action(): + """Test that async interceptors work with async actions""" + import asyncio + + class AsyncMockInterceptor(ActionExecutionInterceptorHookAsync): + def __init__(self): + self.intercepted_actions = [] + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + return "intercepted" in action.tags + + async def intercept_run( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ) -> dict: + self.intercepted_actions.append(action.name) + + # Simulate async execution + await asyncio.sleep(0.01) + + # Execute async action + if hasattr(action, "single_step") and action.single_step: + result, new_state = await action.run_and_update(state, **inputs) + result_with_state = result.copy() + result_with_state["__INTERCEPTOR_NEW_STATE__"] = new_state + return result_with_state + else: + state_to_use = state.subset(*action.reads) + result = await action.run(state_to_use, **inputs) + return result + + interceptor = AsyncMockInterceptor() + + app = ( + ApplicationBuilder() + .with_state(x=7) + .with_actions(async_multiply) + .with_entrypoint("async_multiply") + .with_hooks(interceptor) + .build() + ) + + # Run async action with async interceptor + action, result, state = await app.astep() + assert action.name == "async_multiply" + assert state["w"] == 21 # 7 * 3 + assert "async_multiply" in interceptor.intercepted_actions + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/integration_tests/test_async_streaming_interceptor.py b/tests/integration_tests/test_async_streaming_interceptor.py new file mode 100644 index 00000000..6dd796ac --- /dev/null +++ b/tests/integration_tests/test_async_streaming_interceptor.py @@ -0,0 +1,342 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Integration tests for async streaming action interceptors.""" +import asyncio +from typing import Any, AsyncGenerator, Dict, Optional, Tuple + +import pytest + +from burr.core import Action, ApplicationBuilder, State +from burr.core.action import streaming_action +from burr.lifecycle import ( + PostEndStreamHookWorkerAsync, + PreStartStreamHookWorkerAsync, + StreamingActionInterceptorHookAsync, +) + + +@streaming_action(reads=["prompt"], writes=["response"], tags=["async_streaming_intercepted"]) +async def async_streaming_responder( + state: State, prompt: str = "" +) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + """Async streaming action that yields tokens one by one.""" + tokens = ["Hello", " ", "Async", " ", "World", "!"] + buffer = [] + for token in tokens: + # Simulate async work (e.g., API call) + await asyncio.sleep(0.001) + buffer.append(token) + yield {"response": token}, None + full_response = "".join(buffer) + yield {"response": full_response}, state.update(response=full_response) + + +@streaming_action(reads=["count"], writes=["numbers"], tags=["async_streaming_intercepted"]) +async def async_count_streamer( + state: State, count: int = 5 +) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + """Async streaming action that counts from 1 to count.""" + numbers = [] + for i in range(1, count + 1): + await asyncio.sleep(0.001) + numbers.append(i) + yield {"numbers": i}, None + yield {"numbers": numbers}, state.update(numbers=numbers) + + +class AsyncStreamingWorkerPreHook(PreStartStreamHookWorkerAsync): + """Async worker hook that runs before streaming action execution.""" + + def __init__(self): + self.called_actions = [] + self.call_count = 0 + + async def pre_start_stream_worker( + self, *, action: str, state: State, inputs: Dict[str, Any], **future_kwargs: Any + ): + self.called_actions.append(("pre_stream", action)) + self.call_count += 1 + + +class AsyncStreamingWorkerPostHook(PostEndStreamHookWorkerAsync): + """Async worker hook that runs after streaming action execution.""" + + def __init__(self): + self.called_actions = [] + self.call_count = 0 + + async def post_end_stream_worker( + self, + *, + action: str, + result: Optional[Dict[str, Any]], + exception: Exception, + **future_kwargs: Any, + ): + self.called_actions.append(("post_stream", action)) + self.call_count += 1 + + +class AsyncStreamingInterceptor(StreamingActionInterceptorHookAsync): + """Async streaming interceptor that wraps streaming action execution.""" + + def __init__(self): + self.intercepted_actions = [] + self.intercept_count = 0 + self.stream_items_processed = [] + + def should_intercept(self, *, action: Action, **future_kwargs: Any) -> bool: + """Intercept actions tagged with 'async_streaming_intercepted'.""" + return "async_streaming_intercepted" in action.tags + + async def intercept_stream_run_and_update( + self, + *, + action: Action, + state: State, + inputs: Dict[str, Any], + **future_kwargs: Any, + ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + """Intercept and wrap the streaming action execution.""" + self.intercepted_actions.append(action.name) + self.intercept_count += 1 + + # Extract worker_adapter_set if provided + worker_adapter_set = future_kwargs.get("worker_adapter_set") + + # Call worker pre-stream-hooks if they exist + if worker_adapter_set: + await worker_adapter_set.call_all_lifecycle_hooks_async( + "pre_start_stream_worker", + action=action.name, + state=state, + inputs=inputs, + ) + + # Run the streaming action normally (simulating remote execution) + # This is an async generator, so we need to iterate with async for + generator = action.stream_run_and_update(state, **inputs) + result = None + async for item in generator: + result = item + self.stream_items_processed.append(item[0]) # Store the result dict + yield item + + # Call worker post-stream-hooks if they exist + if worker_adapter_set and result: + await worker_adapter_set.call_all_lifecycle_hooks_async( + "post_end_stream_worker", + action=action.name, + result=result[0] if result else None, + exception=None, + ) + + [email protected] +async def test_async_streaming_interceptor_intercepts_action(): + """Test that async streaming interceptor intercepts tagged actions.""" + interceptor = AsyncStreamingInterceptor() + + app = ( + ApplicationBuilder() + .with_state(prompt="test") + .with_actions(async_streaming_responder) + .with_entrypoint("async_streaming_responder") + .with_hooks(interceptor) + .build() + ) + + # Run async streaming action + action, streaming_container = await app.astream_result( + halt_after=["async_streaming_responder"], + ) + + # Consume the stream + tokens = [] + async for item in streaming_container: + tokens.append(item["response"]) + + result, final_state = await streaming_container.get() + + # Verify interceptor ran + assert "async_streaming_responder" in interceptor.intercepted_actions + assert interceptor.intercept_count == 1 + + # Verify streaming worked correctly + assert tokens == ["Hello", " ", "Async", " ", "World", "!"] + assert final_state["response"] == "Hello Async World!" + assert result["response"] == "Hello Async World!" + + # Verify interceptor processed all stream items + assert len(interceptor.stream_items_processed) == 7 # 6 intermediate + 1 final + + [email protected] +async def test_async_streaming_interceptor_with_worker_hooks(): + """Test that async streaming interceptor properly calls worker hooks.""" + interceptor = AsyncStreamingInterceptor() + worker_pre = AsyncStreamingWorkerPreHook() + worker_post = AsyncStreamingWorkerPostHook() + + app = ( + ApplicationBuilder() + .with_state(prompt="test") + .with_actions(async_streaming_responder) + .with_entrypoint("async_streaming_responder") + .with_hooks(interceptor, worker_pre, worker_post) + .build() + ) + + # Run async streaming action + action, streaming_container = await app.astream_result( + halt_after=["async_streaming_responder"], + ) + + # Consume the stream + async for item in streaming_container: + pass # Consume all items + + result, final_state = await streaming_container.get() + + # Verify interceptor ran + assert "async_streaming_responder" in interceptor.intercepted_actions + + # Verify worker hooks were called + assert ("pre_stream", "async_streaming_responder") in worker_pre.called_actions + assert ("post_stream", "async_streaming_responder") in worker_post.called_actions + assert worker_pre.call_count == 1 + assert worker_post.call_count == 1 + + [email protected] +async def test_async_streaming_interceptor_only_intercepts_tagged_actions(): + """Test that interceptor only intercepts actions with the correct tag.""" + + @streaming_action(reads=["x"], writes=["y"], tags=["not_intercepted"]) + async def non_intercepted_streaming( + state: State, + ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + """Streaming action that should NOT be intercepted.""" + yield {"y": "not intercepted"}, None + yield {"y": "not intercepted"}, state.update(y="not intercepted") + + interceptor = AsyncStreamingInterceptor() + + app = ( + ApplicationBuilder() + .with_state(x=5) + .with_actions(async_streaming_responder, non_intercepted_streaming) + .with_transitions( + ("async_streaming_responder", "non_intercepted_streaming"), + ) + .with_entrypoint("async_streaming_responder") + .with_hooks(interceptor) + .build() + ) + + # Run first action (should be intercepted) + action1, streaming_container1 = await app.astream_result( + halt_after=["async_streaming_responder"], + ) + async for item in streaming_container1: + pass + await streaming_container1.get() + + # Run second action (should NOT be intercepted) + action2, streaming_container2 = await app.astream_result( + halt_after=["non_intercepted_streaming"], + ) + async for item in streaming_container2: + pass + await streaming_container2.get() + + # Verify only tagged action was intercepted + assert "async_streaming_responder" in interceptor.intercepted_actions + assert "non_intercepted_streaming" not in interceptor.intercepted_actions + assert interceptor.intercept_count == 1 + + [email protected] +async def test_async_streaming_interceptor_with_multiple_stream_items(): + """Test async streaming interceptor with an action that yields many items.""" + interceptor = AsyncStreamingInterceptor() + + app = ( + ApplicationBuilder() + .with_state(count=10) + .with_actions(async_count_streamer) + .with_entrypoint("async_count_streamer") + .with_hooks(interceptor) + .build() + ) + + # Run async streaming action + action, streaming_container = await app.astream_result( + halt_after=["async_count_streamer"], + inputs={"count": 10}, # Pass count as input + ) + + # Consume the stream + numbers = [] + async for item in streaming_container: + numbers.append(item["numbers"]) + + result, final_state = await streaming_container.get() + + # Verify interceptor ran + assert "async_count_streamer" in interceptor.intercepted_actions + + # Verify all stream items were processed + assert numbers == list(range(1, 11)) # 1 to 10 + assert final_state["numbers"] == list(range(1, 11)) + assert result["numbers"] == list(range(1, 11)) + + # Verify interceptor processed all items (10 intermediate + 1 final) + assert len(interceptor.stream_items_processed) == 11 + + [email protected] +async def test_async_streaming_interceptor_preserves_state_updates(): + """Test that async streaming interceptor preserves state updates correctly.""" + interceptor = AsyncStreamingInterceptor() + + app = ( + ApplicationBuilder() + .with_state(prompt="test", counter=0) + .with_actions(async_streaming_responder) + .with_entrypoint("async_streaming_responder") + .with_hooks(interceptor) + .build() + ) + + # Run async streaming action + action, streaming_container = await app.astream_result( + halt_after=["async_streaming_responder"], + ) + + # Consume the stream + async for item in streaming_container: + pass + + result, final_state = await streaming_container.get() + + # Verify state was updated correctly + assert "response" in final_state + assert final_state["response"] == "Hello Async World!" + assert final_state["prompt"] == "test" # Original state preserved + assert final_state["counter"] == 0 # Original state preserved diff --git a/tests/lifecycle/__init__.py b/tests/lifecycle/__init__.py new file mode 100644 index 00000000..13a83393 --- /dev/null +++ b/tests/lifecycle/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/lifecycle/test_internal.py b/tests/lifecycle/test_internal.py new file mode 100644 index 00000000..4a8c14c6 --- /dev/null +++ b/tests/lifecycle/test_internal.py @@ -0,0 +1,401 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for lifecycle internal functions and decorators.""" +import abc +import inspect +from typing import Any, Dict + +import pytest + +from burr.lifecycle.internal import ( + INTERCEPTOR_TYPE, + REGISTERED_INTERCEPTORS, + InvalidLifecycleHook, + LifecycleAdapterSet, + lifecycle, + validate_interceptor_method, +) + + +class TestValidateInterceptorMethod: + """Tests for validate_interceptor_method function.""" + + def test_valid_interceptor_method_with_future_kwargs(self): + """Test that a valid interceptor method with **future_kwargs passes validation.""" + + def valid_method(self, *, action: Any, **future_kwargs: Any) -> bool: + return True + + # Should not raise + validate_interceptor_method(valid_method, "valid_method") + + def test_valid_interceptor_method_with_multiple_keyword_args(self): + """Test that a valid interceptor method with multiple keyword-only args passes.""" + + def valid_method( + self, *, action: Any, state: Any, inputs: Dict[str, Any], **future_kwargs: Any + ) -> dict: + return {} + + # Should not raise + validate_interceptor_method(valid_method, "valid_method") + + def test_valid_async_interceptor_method(self): + """Test that async interceptor methods are validated correctly.""" + + async def valid_async_method(self, *, action: Any, **future_kwargs: Any) -> bool: + return True + + # Should not raise + validate_interceptor_method(valid_async_method, "valid_async_method") + + def test_missing_future_kwargs_raises_error(self): + """Test that missing **future_kwargs raises InvalidLifecycleHook.""" + + def invalid_method(self, *, action: Any) -> bool: + return True + + with pytest.raises(InvalidLifecycleHook) as exc_info: + validate_interceptor_method(invalid_method, "invalid_method") + + assert "must have a `**future_kwargs` argument" in str(exc_info.value) + + def test_positional_args_raises_error(self): + """Test that positional arguments (non-keyword-only) raise error.""" + + def invalid_method(self, action: Any, **future_kwargs: Any) -> bool: + return True + + with pytest.raises(InvalidLifecycleHook) as exc_info: + validate_interceptor_method(invalid_method, "invalid_method") + + assert "can only have keyword-only arguments" in str(exc_info.value) + + def test_none_method_raises_error(self): + """Test that None method raises InvalidLifecycleHook.""" + + with pytest.raises(InvalidLifecycleHook) as exc_info: + validate_interceptor_method(None, "missing_method") + + assert "does not exist on the class" in str(exc_info.value) + + def test_var_keyword_not_named_future_kwargs_raises_error(self): + """Test that **kwargs (not **future_kwargs) raises error.""" + + def invalid_method(self, *, action: Any, **kwargs: Any) -> bool: + return True + + with pytest.raises(InvalidLifecycleHook) as exc_info: + validate_interceptor_method(invalid_method, "invalid_method") + + assert "must have a `**future_kwargs` argument" in str(exc_info.value) + + +class TestInterceptorHookDecorator: + """Tests for @lifecycle.interceptor_hook decorator.""" + + def test_interceptor_hook_registers_type(self): + """Test that @lifecycle.interceptor_hook registers the interceptor type.""" + + @lifecycle.interceptor_hook("test_interceptor_type") + class TestInterceptor(abc.ABC): + @abc.abstractmethod + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + pass + + @abc.abstractmethod + def intercept_run(self, *, action: Any, state: Any, **future_kwargs: Any) -> dict: + pass + + # Check that interceptor type is registered + assert "test_interceptor_type" in REGISTERED_INTERCEPTORS + + # Check that class has interceptor_type attribute + assert hasattr(TestInterceptor, INTERCEPTOR_TYPE) + assert getattr(TestInterceptor, INTERCEPTOR_TYPE) == "test_interceptor_type" + + def test_interceptor_hook_with_custom_method_names(self): + """Test that @lifecycle.interceptor_hook works with custom method names.""" + + @lifecycle.interceptor_hook( + "custom_interceptor", should_intercept_method="should_handle", intercept_method="handle" + ) + class CustomInterceptor(abc.ABC): + @abc.abstractmethod + def should_handle(self, *, action: Any, **future_kwargs: Any) -> bool: + pass + + @abc.abstractmethod + def handle(self, *, action: Any, state: Any, **future_kwargs: Any) -> dict: + pass + + # Check that interceptor type is registered + assert "custom_interceptor" in REGISTERED_INTERCEPTORS + assert getattr(CustomInterceptor, INTERCEPTOR_TYPE) == "custom_interceptor" + + def test_interceptor_hook_validates_should_intercept_method(self): + """Test that decorator validates should_intercept method signature.""" + + with pytest.raises(InvalidLifecycleHook): + + @lifecycle.interceptor_hook("invalid_interceptor") + class InvalidInterceptor(abc.ABC): + # Missing **future_kwargs + @abc.abstractmethod + def should_intercept(self, *, action: Any) -> bool: + pass + + @abc.abstractmethod + def intercept_run(self, *, action: Any, **future_kwargs: Any) -> dict: + pass + + def test_interceptor_hook_validates_intercept_method(self): + """Test that decorator validates intercept method signature.""" + + with pytest.raises(InvalidLifecycleHook): + + @lifecycle.interceptor_hook("invalid_interceptor") + class InvalidInterceptor(abc.ABC): + @abc.abstractmethod + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + pass + + # Missing **future_kwargs + @abc.abstractmethod + def intercept_run(self, *, action: Any) -> dict: + pass + + def test_interceptor_hook_validates_missing_method(self): + """Test that decorator raises error if method doesn't exist.""" + + with pytest.raises(InvalidLifecycleHook): + + @lifecycle.interceptor_hook("missing_method_interceptor") + class MissingMethodInterceptor(abc.ABC): + @abc.abstractmethod + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + pass + + # intercept_run is missing + + def test_interceptor_hook_with_streaming_method(self): + """Test that decorator works with intercept_stream_run_and_update method.""" + + @lifecycle.interceptor_hook( + "streaming_interceptor", intercept_method="intercept_stream_run_and_update" + ) + class StreamingInterceptor(abc.ABC): + @abc.abstractmethod + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + pass + + @abc.abstractmethod + def intercept_stream_run_and_update( + self, *, action: Any, state: Any, **future_kwargs: Any + ): + pass + + assert "streaming_interceptor" in REGISTERED_INTERCEPTORS + assert getattr(StreamingInterceptor, INTERCEPTOR_TYPE) == "streaming_interceptor" + + def test_interceptor_hook_preserves_class(self): + """Test that decorator returns the class unchanged (for chaining).""" + + @lifecycle.interceptor_hook("preserved_interceptor") + class PreservedInterceptor(abc.ABC): + @abc.abstractmethod + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + pass + + @abc.abstractmethod + def intercept_run(self, *, action: Any, **future_kwargs: Any) -> dict: + pass + + # Class should still be usable + assert PreservedInterceptor.__name__ == "PreservedInterceptor" + assert inspect.isabstract(PreservedInterceptor) + + +class TestGetFirstMatchingHookWithInterceptors: + """Tests for get_first_matching_hook with registered interceptors.""" + + def test_get_first_matching_interceptor_by_type(self): + """Test that get_first_matching_hook finds interceptors by registered type.""" + + @lifecycle.interceptor_hook("test_find_interceptor") + class FindableInterceptor: + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + return True + + def intercept_run(self, *, action: Any, state: Any, **future_kwargs: Any) -> dict: + return {} + + interceptor = FindableInterceptor() + adapter_set = LifecycleAdapterSet(interceptor) + + # Should find the interceptor + found = adapter_set.get_first_matching_hook( + "test_find_interceptor", lambda hook: hook.should_intercept(action=None) + ) + + assert found is interceptor + + def test_get_first_matching_interceptor_with_predicate(self): + """Test that predicate filters interceptors correctly.""" + + @lifecycle.interceptor_hook("test_predicate_interceptor") + class MatchingInterceptor: + def __init__(self, tag: str): + self.tag = tag + + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + return getattr(action, "tag", None) == self.tag + + def intercept_run(self, *, action: Any, state: Any, **future_kwargs: Any) -> dict: + return {} + + class MockAction: + def __init__(self, tag: str): + self.tag = tag + + interceptor1 = MatchingInterceptor("tag1") + interceptor2 = MatchingInterceptor("tag2") + adapter_set = LifecycleAdapterSet(interceptor1, interceptor2) + + # Should find first matching interceptor + found = adapter_set.get_first_matching_hook( + "test_predicate_interceptor", + lambda hook: hook.should_intercept(action=MockAction("tag1")), + ) + + assert found is interceptor1 + + def test_get_first_matching_interceptor_returns_none_if_no_match(self): + """Test that get_first_matching_hook returns None if no interceptor matches.""" + + @lifecycle.interceptor_hook("test_no_match_interceptor") + class NonMatchingInterceptor: + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + return False + + def intercept_run(self, *, action: Any, state: Any, **future_kwargs: Any) -> dict: + return {} + + interceptor = NonMatchingInterceptor() + adapter_set = LifecycleAdapterSet(interceptor) + + # Should return None when predicate doesn't match + found = adapter_set.get_first_matching_hook( + "test_no_match_interceptor", lambda hook: hook.should_intercept(action=None) + ) + + assert found is None + + def test_get_first_matching_interceptor_returns_none_if_not_registered(self): + """Test that unregistered interceptor types return None.""" + + adapter_set = LifecycleAdapterSet() + + # Should return None for unregistered interceptor type + found = adapter_set.get_first_matching_hook( + "unregistered_interceptor_type", lambda hook: True + ) + + assert found is None + + def test_get_first_matching_interceptor_inheritance(self): + """Test that interceptor discovery works with inheritance.""" + + @lifecycle.interceptor_hook("test_inheritance_interceptor") + class BaseInterceptor(abc.ABC): + @abc.abstractmethod + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + pass + + @abc.abstractmethod + def intercept_run(self, *, action: Any, state: Any, **future_kwargs: Any) -> dict: + pass + + class ConcreteInterceptor(BaseInterceptor): + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + return True + + def intercept_run(self, *, action: Any, state: Any, **future_kwargs: Any) -> dict: + return {} + + interceptor = ConcreteInterceptor() + adapter_set = LifecycleAdapterSet(interceptor) + + # Should find interceptor through inheritance + found = adapter_set.get_first_matching_hook( + "test_inheritance_interceptor", lambda hook: hook.should_intercept(action=None) + ) + + assert found is interceptor + + def test_get_first_matching_interceptor_multiple_types(self): + """Test that different interceptor types can coexist.""" + + @lifecycle.interceptor_hook("type_a_interceptor") + class TypeAInterceptor: + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + return True + + def intercept_run(self, *, action: Any, state: Any, **future_kwargs: Any) -> dict: + return {"type": "A"} + + @lifecycle.interceptor_hook("type_b_interceptor") + class TypeBInterceptor: + def should_intercept(self, *, action: Any, **future_kwargs: Any) -> bool: + return True + + def intercept_run(self, *, action: Any, state: Any, **future_kwargs: Any) -> dict: + return {"type": "B"} + + interceptor_a = TypeAInterceptor() + interceptor_b = TypeBInterceptor() + adapter_set = LifecycleAdapterSet(interceptor_a, interceptor_b) + + # Should find correct interceptor by type + found_a = adapter_set.get_first_matching_hook( + "type_a_interceptor", lambda hook: hook.should_intercept(action=None) + ) + found_b = adapter_set.get_first_matching_hook( + "type_b_interceptor", lambda hook: hook.should_intercept(action=None) + ) + + assert found_a is interceptor_a + assert found_b is interceptor_b + assert found_a.intercept_run(action=None, state=None) == {"type": "A"} + assert found_b.intercept_run(action=None, state=None) == {"type": "B"} + + def test_get_first_matching_hook_falls_back_to_standard_hooks(self): + """Test that get_first_matching_hook still works for standard hooks.""" + + @lifecycle.base_hook("test_standard_hook") + class StandardHook: + def test_standard_hook(self, *, app_id: str, **future_kwargs: Any): + pass + + hook = StandardHook() + adapter_set = LifecycleAdapterSet(hook) + + # Should find standard hook + found = adapter_set.get_first_matching_hook("test_standard_hook", lambda h: True) + + assert found is hook
