This is an automated email from the ASF dual-hosted git repository.
cvandermerwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new d51ea3afb88 Add access to pipeline options during pipeline
construction and translation. (#37522)
d51ea3afb88 is described below
commit d51ea3afb88b1921380b4622dba85220a63f0907
Author: claudevdm <[email protected]>
AuthorDate: Tue Feb 10 14:15:48 2026 -0500
Add access to pipeline options during pipeline construction and
translation. (#37522)
* Add pipeline construction options.
* use apply_internal
* lint.
* remove concurrent prism jobs.
* comments.
* comments.
* comments.
* comments.
* lint.
* use apply in recursive calls
---------
Co-authored-by: Claude <[email protected]>
---
.../options/pipeline_options_context.py | 65 ++++
.../options/pipeline_options_context_test.py | 335 +++++++++++++++++++++
sdks/python/apache_beam/pipeline.py | 16 +
3 files changed, 416 insertions(+)
diff --git a/sdks/python/apache_beam/options/pipeline_options_context.py
b/sdks/python/apache_beam/options/pipeline_options_context.py
new file mode 100644
index 00000000000..5f2475e334a
--- /dev/null
+++ b/sdks/python/apache_beam/options/pipeline_options_context.py
@@ -0,0 +1,65 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Context-scoped access to pipeline options during graph construction and
+translation.
+
+This module provides thread-safe and async-safe access to globally-available
+instances of PipelineOptions using contextvars, scoped by the current pipeline.
+It allows components like transforms and coders to access the pipeline's
+configuration without requiring explicit parameter passing through every
+level of the call stack.
+
+For internal use only; no backwards-compatibility guarantees.
+"""
+
+from contextlib import contextmanager
+from contextvars import ContextVar
+from typing import TYPE_CHECKING
+from typing import Optional
+
+if TYPE_CHECKING:
+ from apache_beam.options.pipeline_options import PipelineOptions
+
+# The contextvar holding the current pipeline's options.
+# Each thread and each asyncio task gets its own isolated copy.
+_pipeline_options: ContextVar[Optional['PipelineOptions']] = ContextVar(
+ 'pipeline_options', default=None)
+
+
+def get_pipeline_options() -> Optional['PipelineOptions']:
+ """Get the current pipeline's options from the context.
+
+ Returns:
+ The PipelineOptions for the currently executing pipeline operation,
+ or None if called outside of a pipeline context.
+ """
+ return _pipeline_options.get()
+
+
+@contextmanager
+def scoped_pipeline_options(options: Optional['PipelineOptions']):
+ """Context manager that sets pipeline options for the duration of a block.
+
+ Args:
+ options: The PipelineOptions to make available during this scope.
+ """
+ token = _pipeline_options.set(options)
+ try:
+ yield
+ finally:
+ _pipeline_options.reset(token)
diff --git a/sdks/python/apache_beam/options/pipeline_options_context_test.py
b/sdks/python/apache_beam/options/pipeline_options_context_test.py
new file mode 100644
index 00000000000..da9c8fce4aa
--- /dev/null
+++ b/sdks/python/apache_beam/options/pipeline_options_context_test.py
@@ -0,0 +1,335 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for pipeline_options_context module.
+
+These tests verify that the contextvar-based approach properly isolates
+pipeline options across threads and async tasks, preventing race conditions.
+"""
+
+import asyncio
+import threading
+import unittest
+
+import apache_beam as beam
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options_context import get_pipeline_options
+from apache_beam.options.pipeline_options_context import
scoped_pipeline_options
+
+
+class PipelineConstructionOptionsTest(unittest.TestCase):
+ def test_nested_scoping(self):
+ """Test that nested scopes properly restore outer options."""
+ outer_options = PipelineOptions(['--runner=DirectRunner'])
+ inner_options = PipelineOptions(['--runner=DataflowRunner'])
+
+ with scoped_pipeline_options(outer_options):
+ self.assertIs(get_pipeline_options(), outer_options)
+
+ with scoped_pipeline_options(inner_options):
+ self.assertIs(get_pipeline_options(), inner_options)
+
+ self.assertIs(get_pipeline_options(), outer_options)
+
+ self.assertIsNone(get_pipeline_options())
+
+ def test_exception_in_scope_restores_options(self):
+ """Test that options are restored even when an exception is raised."""
+ outer_options = PipelineOptions(['--runner=DirectRunner'])
+ inner_options = PipelineOptions(['--runner=DataflowRunner'])
+
+ with scoped_pipeline_options(outer_options):
+ try:
+ with scoped_pipeline_options(inner_options):
+ self.assertIs(get_pipeline_options(), inner_options)
+ raise ValueError("Test exception")
+ except ValueError:
+ pass
+
+ self.assertIs(get_pipeline_options(), outer_options)
+
+ def test_different_threads_see_their_own_isolated_options(self):
+ """Test that different threads see their own isolated options."""
+ results = {}
+ errors = []
+ barrier = threading.Barrier(2)
+
+ def thread_worker(thread_id, runner_name):
+ try:
+ options = PipelineOptions([f'--runner={runner_name}'])
+ with scoped_pipeline_options(options):
+ barrier.wait(timeout=5)
+
+ current = get_pipeline_options()
+ results[thread_id] = current.get_all_options()['runner']
+ import time
+ time.sleep(0.01)
+
+ current_after = get_pipeline_options()
+ if current_after is not current:
+ errors.append(
+ f"Thread {thread_id}: options changed during execution")
+ except Exception as e:
+ errors.append(f"Thread {thread_id}: {e}")
+
+ thread1 = threading.Thread(target=thread_worker, args=(1, 'DirectRunner'))
+ thread2 = threading.Thread(target=thread_worker, args=(2,
'DataflowRunner'))
+
+ thread1.start()
+ thread2.start()
+
+ thread1.join(timeout=5)
+ thread2.join(timeout=5)
+
+ self.assertEqual(errors, [])
+ self.assertEqual(results[1], 'DirectRunner')
+ self.assertEqual(results[2], 'DataflowRunner')
+
+ def test_asyncio_task_isolation(self):
+ """Test that different asyncio tasks see their own isolated options."""
+ async def async_worker(
+ task_id, runner_name, results, ready_event, go_event):
+ options = PipelineOptions([f'--runner={runner_name}'])
+ with scoped_pipeline_options(options):
+ ready_event.set()
+ await go_event.wait()
+ current = get_pipeline_options()
+ results[task_id] = current.get_all_options()['runner']
+ await asyncio.sleep(0.01)
+ current_after = get_pipeline_options()
+ assert current_after is current, \
+ f"Task {task_id}: options changed during execution"
+
+ async def run_test():
+ results = {}
+ ready_events = [asyncio.Event() for _ in range(2)]
+ go_event = asyncio.Event()
+
+ task1 = asyncio.create_task(
+ async_worker(1, 'DirectRunner', results, ready_events[0], go_event))
+ task2 = asyncio.create_task(
+ async_worker(2, 'DataflowRunner', results, ready_events[1],
go_event))
+
+ # Wait for both tasks to be ready
+ await asyncio.gather(*[e.wait() for e in ready_events])
+ # Signal all tasks to proceed
+ go_event.set()
+
+ await asyncio.gather(task1, task2)
+ return results
+
+ results = asyncio.run(run_test())
+ self.assertEqual(results[1], 'DirectRunner')
+ self.assertEqual(results[2], 'DataflowRunner')
+
+ def test_transform_sees_pipeline_options(self):
+ """Test that a transform can access pipeline options during expand()."""
+ class OptionsCapturingTransform(beam.PTransform):
+ """Transform that captures pipeline options during expand()."""
+ def __init__(self, expected_job_name):
+ self.expected_job_name = expected_job_name
+ self.captured_options = None
+
+ def expand(self, pcoll):
+ # This runs during pipeline construction
+ self.captured_options = get_pipeline_options()
+ return pcoll | beam.Map(lambda x: x)
+
+ options = PipelineOptions(['--job_name=test_job_123'])
+ transform = OptionsCapturingTransform('test_job_123')
+
+ with beam.Pipeline(options=options) as p:
+ _ = p | beam.Create([1, 2, 3]) | transform
+
+ # Verify the transform saw the correct options
+ self.assertIsNotNone(transform.captured_options)
+ self.assertEqual(
+ transform.captured_options.get_all_options()['job_name'],
+ 'test_job_123')
+
+ def test_coder_sees_correct_options_during_run(self):
+ """Test that coders see correct pipeline options during proto conversion.
+
+ This tests the run path where as_deterministic_coder() is called during
+ to_runner_api() proto conversion.
+ """
+ from apache_beam.coders import coders
+ from apache_beam.utils import shared
+
+ errors = []
+
+ class WeakRefDict(dict):
+ pass
+
+ class TestKey:
+ def __init__(self, value):
+ self.value = value
+
+ def __eq__(self, other):
+ return isinstance(other, TestKey) and self.value == other.value
+
+ def __hash__(self):
+ return hash(self.value)
+
+ class OptionsCapturingKeyCoder(coders.Coder):
+ """Coder that captures pipeline options in as_deterministic_coder."""
+ shared_handle = shared.Shared()
+
+ def encode(self, value):
+ return str(value.value).encode('utf-8')
+
+ def decode(self, encoded):
+ return TestKey(encoded.decode('utf-8'))
+
+ def is_deterministic(self):
+ return False
+
+ def as_deterministic_coder(self, step_label, error_message=None):
+ opts = get_pipeline_options()
+ if opts is not None:
+ results = OptionsCapturingKeyCoder.shared_handle.acquire(WeakRefDict)
+ job_name = opts.get_all_options().get('job_name')
+ results['Worker1'] = job_name
+ return self
+
+ beam.coders.registry.register_coder(TestKey, OptionsCapturingKeyCoder)
+
+ results = OptionsCapturingKeyCoder.shared_handle.acquire(WeakRefDict)
+
+ job_name = 'gbk_job'
+ options = PipelineOptions([f'--job_name={job_name}'])
+
+ with beam.Pipeline(options=options) as p:
+ _ = (
+ p
+ | beam.Create([(TestKey(1), 'a'), (TestKey(2), 'b')])
+ | beam.GroupByKey())
+
+ self.assertEqual(errors, [], f"Errors occurred: {errors}")
+ self.assertEqual(
+ results.get('Worker1'),
+ job_name,
+ f"Worker1 saw wrong options: {results}")
+ self.assertFalse(get_pipeline_options() == options)
+
+ def test_barrier_inside_default_type_hints(self):
+ """Test race condition detection with barrier inside default_type_hints.
+
+ This test reliably detects race conditions because:
+ 1. Both threads start pipeline construction simultaneously
+ 2. Inside default_type_hints (which is called during Pipeline.apply()),
+ both threads hit a barrier and wait for each other
+ 3. At this point, BOTH threads are inside scoped_pipeline_options
+ 4. When they continue, they read options - with a global var, they'd see
+ the wrong values because the last thread to set options would win
+ """
+
+ results = {}
+ errors = []
+ inner_barrier = threading.Barrier(2)
+
+ class BarrierTransform(beam.PTransform):
+ """Transform that synchronizes threads INSIDE default_type_hints."""
+ def __init__(self, worker_id, results_dict, barrier):
+ self.worker_id = worker_id
+ self.results_dict = results_dict
+ self.barrier = barrier
+
+ def expand(self, pcoll):
+ return pcoll | beam.Map(lambda x: x)
+
+ def default_type_hints(self):
+ self.barrier.wait(timeout=5)
+ opts = get_pipeline_options()
+ if opts is not None:
+ job_name = opts.get_all_options().get('job_name')
+ if self.worker_id not in self.results_dict:
+ self.results_dict[self.worker_id] = job_name
+
+ return super().default_type_hints()
+
+ def construct_pipeline(worker_id):
+ try:
+ job_name = f'barrier_job_{worker_id}'
+ options = PipelineOptions([f'--job_name={job_name}'])
+ transform = BarrierTransform(worker_id, results, inner_barrier)
+
+ with beam.Pipeline(options=options) as p:
+ _ = p | beam.Create([1, 2, 3]) | transform
+
+ except Exception as e:
+ import traceback
+ errors.append(f"Worker {worker_id}: {e}\n{traceback.format_exc()}")
+
+ thread1 = threading.Thread(
+ target=construct_pipeline, args=(1, ))
+ thread2 = threading.Thread(
+ target=construct_pipeline, args=(2, ))
+
+ thread1.start()
+ thread2.start()
+
+ thread1.join(timeout=10)
+ thread2.join(timeout=10)
+
+ self.assertEqual(errors, [], f"Errors occurred: {errors}")
+ self.assertEqual(
+ results.get(1),
+ 'barrier_job_1',
+ f"Worker 1 saw wrong options: {results}")
+ self.assertEqual(
+ results.get(2),
+ 'barrier_job_2',
+ f"Worker 2 saw wrong options: {results}")
+
+
+class PipelineSubclassApplyTest(unittest.TestCase):
+ def test_subclass_apply_called_on_recursive_paths(self):
+ """Test that Pipeline subclass overrides of apply() are respected.
+
+ _apply_internal's recursive calls must go through self.apply(), not
+ self._apply_internal(), so that subclass interceptions are not skipped.
+ """
+ apply_calls = []
+
+ class TrackingPipeline(beam.Pipeline):
+ def apply(self, transform, pvalueish=None, label=None):
+ apply_calls.append(label or transform.label)
+ return super().apply(transform, pvalueish, label)
+
+ options = PipelineOptions(['--job_name=subclass_test'])
+ with TrackingPipeline(options=options) as p:
+ # "my_label" >> transform creates a _NamedPTransform, which triggers
+ # two recursive apply() calls: one to unwrap _NamedPTransform, and
+ # one to handle the label argument.
+ _ = p | beam.Create([1, 2, 3]) | "my_label" >> beam.Map(lambda x: x)
+
+ # beam.Create goes through apply() once (no recursion).
+ # "my_label" >> Map triggers: apply(_NamedPTransform) -> apply(Map,
+ # label="my_label") -> apply(Map). That's 3 calls through apply().
+ # Total: 1 (Create) + 3 (Map) = 4 calls minimum.
+ map_calls = [c for c in apply_calls if c == 'my_label' or c == 'Map']
+ self.assertGreaterEqual(
+ len(map_calls),
+ 3,
+ f"Expected at least 3 apply() calls for the Map transform "
+ f"(NamedPTransform unwrap + label handling + final), "
+ f"got {len(map_calls)}. All calls: {apply_calls}")
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/pipeline.py
b/sdks/python/apache_beam/pipeline.py
index 6ef06abb743..a6080f2f3e7 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -81,6 +81,7 @@ from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.pipeline_options import StreamingOptions
from apache_beam.options.pipeline_options import TypeOptions
+from apache_beam.options.pipeline_options_context import
scoped_pipeline_options
from apache_beam.options.pipeline_options_validator import
PipelineOptionsValidator
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
@@ -559,6 +560,12 @@ class Pipeline(HasDisplayData):
def run(self, test_runner_api: Union[bool, str] = 'AUTO') ->
'PipelineResult':
"""Runs the pipeline. Returns whatever our runner returns after running."""
+ with scoped_pipeline_options(self._options):
+ return self._run_internal(test_runner_api)
+
+ def _run_internal(
+ self, test_runner_api: Union[bool, str] = 'AUTO') -> 'PipelineResult':
+ """Internal implementation of run(), called within scoped options."""
# All pipeline options are finalized at this point.
# Call get_all_options to print warnings on invalid options.
self.options.get_all_options(
@@ -698,6 +705,15 @@ class Pipeline(HasDisplayData):
RuntimeError: if the transform object was already applied to
this pipeline and needs to be cloned in order to apply again.
"""
+ with scoped_pipeline_options(self._options):
+ return self._apply_internal(transform, pvalueish, label)
+
+ def _apply_internal(
+ self,
+ transform: ptransform.PTransform,
+ pvalueish: Optional[pvalue.PValue] = None,
+ label: Optional[str] = None) -> pvalue.PValue:
+ """Internal implementation of apply(), called within scoped options."""
if isinstance(transform, ptransform._NamedPTransform):
return self.apply(
transform.transform, pvalueish, label or transform.label)