This is an automated email from the ASF dual-hosted git repository. skrawcz pushed a commit to branch stefan/fix_nested_otel in repository https://gitbox.apache.org/repos/asf/burr.git
commit 8d50ef1044a719700bc7507614357aae1d7a0134 Author: Stefan Krawczyk <[email protected]> AuthorDate: Thu Oct 23 14:39:03 2025 -0700 Fixes some otel and nested burr application bugs This was caught trying to run nested burr with otel. --- burr/core/application.py | 2 +- burr/integrations/opentelemetry.py | 19 ++-- tests/core/test_application.py | 46 +++++++++ tests/integrations/test_burr_opentelemetry.py | 141 +++++++++++++++++++++++++- 4 files changed, 197 insertions(+), 11 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index 9831e1af..1bcea94e 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -232,7 +232,7 @@ def _state_update(state_to_modify: State, modified_state: State) -> State: def _validate_reducer_writes(reducer: Reducer, state: State, name: str) -> None: required_writes = reducer.writes - missing_writes = set(reducer.writes) - state.keys() + missing_writes = set(reducer.writes) - set(state.keys()) if len(missing_writes) > 0: raise ValueError( f"State is missing write keys after running: {name}. Missing keys are: {missing_writes}. " diff --git a/burr/integrations/opentelemetry.py b/burr/integrations/opentelemetry.py index 78944569..32dc4dd7 100644 --- a/burr/integrations/opentelemetry.py +++ b/burr/integrations/opentelemetry.py @@ -499,19 +499,20 @@ class BurrTrackingSpanProcessor(SpanProcessor): app_id=parent_span.app_id, ), ) - self.tracker.pre_start_span( - action=context.action_span.action, - action_sequence_id=context.action_span.action_sequence_id, - span=context.action_span, - span_dependencies=[], # TODO -- log - app_id=context.app_id, - partition_key=context.partition_key, - ) + if self.tracker is not None: + self.tracker.pre_start_span( + action=context.action_span.action, + action_sequence_id=context.action_span.action_sequence_id, + span=context.action_span, + span_dependencies=[], # TODO -- log + app_id=context.app_id, + partition_key=context.partition_key, + ) def on_end(self, span: "Span") -> None: cached_span = get_cached_span(span.get_span_context().span_id) # If this is none it means we're outside of the burr context - if cached_span is not None: + if cached_span is not None and self.tracker is not None: # TODO -- get tracker context to work self.tracker.post_end_span( action=cached_span.action_span.action, diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 5ebd3341..8bab570a 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -486,6 +486,52 @@ def test__run_reducer_deletes_state(): assert "count" not in state +def test__validate_reducer_writes_with_state_keys_returning_list(): + """Tests that _validate_reducer_writes works when state.keys() returns a list. + + This is a regression test for a bug where state.keys() could return a list + instead of a set, causing a TypeError when trying to do set subtraction. + """ + from burr.core.application import _validate_reducer_writes + + # Create a reducer with some expected writes + reducer = PassedInAction( + reads=["input"], + writes=["output", "result"], + fn=..., + update_fn=lambda result, state: state.update(**result), + inputs=[], + ) + + # Create a state that has all the required writes + state = State({"input": 1, "output": 2, "result": 3}) + + # This should not raise a TypeError even if state.keys() returns a list + # (which was the original bug) + _validate_reducer_writes(reducer, state, "test_action") + + +def test__validate_reducer_writes_raises_on_missing_keys(): + """Tests that _validate_reducer_writes raises ValueError when required keys are missing.""" + from burr.core.application import _validate_reducer_writes + + # Create a reducer with some expected writes + reducer = PassedInAction( + reads=["input"], + writes=["output", "result", "missing_key"], + fn=..., + update_fn=lambda result, state: state.update(**result), + inputs=[], + ) + + # Create a state that is missing some required writes + state = State({"input": 1, "output": 2, "result": 3}) + + # This should raise a ValueError for missing "missing_key" + with pytest.raises(ValueError, match="missing_key"): + _validate_reducer_writes(reducer, state, "test_action") + + async def test__arun_function(): """Tests that we can run an async function""" action = base_counter_action_async diff --git a/tests/integrations/test_burr_opentelemetry.py b/tests/integrations/test_burr_opentelemetry.py index e5568202..6fc7ee3a 100644 --- a/tests/integrations/test_burr_opentelemetry.py +++ b/tests/integrations/test_burr_opentelemetry.py @@ -16,12 +16,17 @@ # under the License. import json +from unittest.mock import Mock, patch import pydantic import pytest from burr.core import serde -from burr.integrations.opentelemetry import convert_to_otel_attribute +from burr.integrations.opentelemetry import ( + BurrTrackingSpanProcessor, + convert_to_otel_attribute, + tracker_context, +) class SampleModel(pydantic.BaseModel): @@ -43,3 +48,137 @@ class SampleModel(pydantic.BaseModel): ) def test_convert_to_otel_attribute(value, expected): assert convert_to_otel_attribute(value) == expected + + +def test_burr_tracking_span_processor_on_start_with_none_tracker(): + """Test that on_start handles None tracker gracefully without raising an error.""" + processor = BurrTrackingSpanProcessor() + + # Mock a span with a parent + mock_span = Mock() + mock_span.parent = Mock() + mock_span.parent.span_id = 12345 + mock_span.name = "test_span" + + # Mock the get_cached_span to return a parent span context + with patch("burr.integrations.opentelemetry.get_cached_span") as mock_get_cached: + mock_parent_context = Mock() + mock_parent_context.action_span = Mock() + mock_parent_context.action_span.spawn = Mock(return_value=Mock()) + mock_parent_context.partition_key = "test_partition" + mock_parent_context.app_id = "test_app" + mock_get_cached.return_value = mock_parent_context + + # Mock cache_span + with patch("burr.integrations.opentelemetry.cache_span"): + # Set tracker_context to None (simulating no tracker in context) + token = tracker_context.set(None) + try: + # This should not raise an error even though tracker is None + processor.on_start(mock_span, parent_context=None) + finally: + tracker_context.reset(token) + + +def test_burr_tracking_span_processor_on_end_with_none_tracker(): + """Test that on_end handles None tracker gracefully without raising an error.""" + processor = BurrTrackingSpanProcessor() + + # Mock a span + mock_span = Mock() + mock_span_context = Mock() + mock_span_context.span_id = 67890 + mock_span.get_span_context = Mock(return_value=mock_span_context) + mock_span.attributes = {} + + # Mock the get_cached_span to return a cached span + with patch("burr.integrations.opentelemetry.get_cached_span") as mock_get_cached: + mock_cached_span = Mock() + mock_cached_span.action_span = Mock() + mock_cached_span.action_span.action = "test_action" + mock_cached_span.action_span.action_sequence_id = 1 + mock_cached_span.app_id = "test_app" + mock_cached_span.partition_key = "test_partition" + mock_get_cached.return_value = mock_cached_span + + # Mock uncache_span + with patch("burr.integrations.opentelemetry.uncache_span"): + # Set tracker_context to None (simulating no tracker in context) + token = tracker_context.set(None) + try: + # This should not raise an error even though tracker is None + processor.on_end(mock_span) + finally: + tracker_context.reset(token) + + +def test_burr_tracking_span_processor_on_start_with_valid_tracker(): + """Test that on_start calls tracker methods when tracker is available.""" + processor = BurrTrackingSpanProcessor() + + # Mock a span with a parent + mock_span = Mock() + mock_span.parent = Mock() + mock_span.parent.span_id = 12345 + mock_span.name = "test_span" + + # Mock tracker + mock_tracker = Mock() + + # Mock the get_cached_span to return a parent span context + with patch("burr.integrations.opentelemetry.get_cached_span") as mock_get_cached: + mock_parent_context = Mock() + mock_parent_context.action_span = Mock() + mock_parent_context.action_span.spawn = Mock(return_value=Mock(action="test_action")) + mock_parent_context.partition_key = "test_partition" + mock_parent_context.app_id = "test_app" + mock_get_cached.return_value = mock_parent_context + + # Mock cache_span + with patch("burr.integrations.opentelemetry.cache_span"): + # Set tracker_context to a valid tracker + token = tracker_context.set(mock_tracker) + try: + processor.on_start(mock_span, parent_context=None) + + # Verify that pre_start_span was called on the tracker + assert mock_tracker.pre_start_span.called + finally: + tracker_context.reset(token) + + +def test_burr_tracking_span_processor_on_end_with_valid_tracker(): + """Test that on_end calls tracker methods when tracker is available.""" + processor = BurrTrackingSpanProcessor() + + # Mock a span + mock_span = Mock() + mock_span_context = Mock() + mock_span_context.span_id = 67890 + mock_span.get_span_context = Mock(return_value=mock_span_context) + mock_span.attributes = {} + + # Mock tracker + mock_tracker = Mock() + + # Mock the get_cached_span to return a cached span + with patch("burr.integrations.opentelemetry.get_cached_span") as mock_get_cached: + mock_cached_span = Mock() + mock_cached_span.action_span = Mock() + mock_cached_span.action_span.action = "test_action" + mock_cached_span.action_span.action_sequence_id = 1 + mock_cached_span.app_id = "test_app" + mock_cached_span.partition_key = "test_partition" + mock_get_cached.return_value = mock_cached_span + + # Mock uncache_span + with patch("burr.integrations.opentelemetry.uncache_span"): + # Set tracker_context to a valid tracker + token = tracker_context.set(mock_tracker) + try: + processor.on_end(mock_span) + + # Verify that post_end_span was called on the tracker + assert mock_tracker.post_end_span.called + finally: + tracker_context.reset(token)
