This is an automated email from the ASF dual-hosted git repository.
jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 388e5e59ff6 Supports Asynchronous Runs in Interactive Beam (#36853)
388e5e59ff6 is described below
commit 388e5e59ff6955dfc9bd3be264f32a8d079eb04d
Author: Ian Liao <[email protected]>
AuthorDate: Tue Nov 25 07:05:45 2025 -0800
Supports Asynchronous Runs in Interactive Beam (#36853)
* Supports Asynchronous Runs in Interactive Beam
* use PEP-585 generics
* Skip some tests for non-interactve_env and fix errors in unit tests
---
.../runners/interactive/interactive_beam.py | 99 +++-
.../runners/interactive/interactive_beam_test.py | 391 ++++++++++++++++
.../runners/interactive/interactive_environment.py | 19 +
.../interactive/interactive_environment_test.py | 41 ++
.../runners/interactive/recording_manager.py | 478 +++++++++++++++++++-
.../runners/interactive/recording_manager_test.py | 500 +++++++++++++++++++++
.../apache_beam/runners/interactive/utils_test.py | 12 +
7 files changed, 1521 insertions(+), 19 deletions(-)
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py
b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
index 76c4ea0aa66..7b773fda5db 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
@@ -35,11 +35,9 @@ this module in your notebook or application code.
# pytype: skip-file
import logging
+from collections.abc import Iterable
from datetime import timedelta
from typing import Any
-from typing import Dict
-from typing import Iterable
-from typing import List
from typing import Optional
from typing import Union
@@ -57,6 +55,7 @@ from apache_beam.runners.interactive.display import
pipeline_graph
from apache_beam.runners.interactive.display.pcoll_visualization import
visualize
from apache_beam.runners.interactive.display.pcoll_visualization import
visualize_computed_pcoll
from apache_beam.runners.interactive.options import interactive_options
+from apache_beam.runners.interactive.recording_manager import
AsyncComputationResult
from apache_beam.runners.interactive.utils import deferred_df_to_pcollection
from apache_beam.runners.interactive.utils import elements_to_df
from apache_beam.runners.interactive.utils import find_pcoll_name
@@ -275,7 +274,7 @@ class Recordings():
"""
def describe(
self,
- pipeline: Optional[beam.Pipeline] = None) -> Dict[str, Any]: # noqa:
F821
+ pipeline: Optional[beam.Pipeline] = None) -> dict[str, Any]: # noqa:
F821
"""Returns a description of all the recordings for the given pipeline.
If no pipeline is given then this returns a dictionary of descriptions for
@@ -417,10 +416,10 @@ class Clusters:
# DATAPROC_IMAGE_VERSION = '2.0.XX-debian10'
def __init__(self) -> None:
- self.dataproc_cluster_managers: Dict[ClusterMetadata,
+ self.dataproc_cluster_managers: dict[ClusterMetadata,
DataprocClusterManager] = {}
- self.master_urls: Dict[str, ClusterMetadata] = {}
- self.pipelines: Dict[beam.Pipeline, DataprocClusterManager] = {}
+ self.master_urls: dict[str, ClusterMetadata] = {}
+ self.pipelines: dict[beam.Pipeline, DataprocClusterManager] = {}
self.default_cluster_metadata: Optional[ClusterMetadata] = None
def create(
@@ -511,7 +510,7 @@ class Clusters:
def describe(
self,
cluster_identifier: Optional[ClusterIdentifier] = None
- ) -> Union[ClusterMetadata, List[ClusterMetadata]]:
+ ) -> Union[ClusterMetadata, list[ClusterMetadata]]:
"""Describes the ClusterMetadata by a ClusterIdentifier.
If no cluster_identifier is given or if the cluster_identifier is unknown,
@@ -679,7 +678,7 @@ def watch(watchable):
@progress_indicated
def show(
- *pcolls: Union[Dict[Any, PCollection], Iterable[PCollection], PCollection],
+ *pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection],
include_window_info: bool = False,
visualize_data: bool = False,
n: Union[int, str] = 'inf',
@@ -1012,6 +1011,88 @@ def collect(
return result_tuple
+@progress_indicated
+def compute(
+ *pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection],
+ wait_for_inputs: bool = True,
+ blocking: bool = False,
+ runner=None,
+ options=None,
+ force_compute=False,
+) -> Optional[AsyncComputationResult]:
+ """Computes the given PCollections, potentially asynchronously.
+
+ Args:
+ *pcolls: PCollections to compute. Can be a single PCollection, an iterable
+ of PCollections, or a dictionary with PCollections as values.
+ wait_for_inputs: Whether to wait until the asynchronous dependencies are
+ computed. Setting this to False allows to immediately schedule the
+ computation, but also potentially results in running the same pipeline
+ stages multiple times.
+ blocking: If False, the computation will run in non-blocking fashion. In
+ Colab/IPython environment this mode will also provide the controls for
the
+ running pipeline. If True, the computation will block until the pipeline
+ is done.
+ runner: (optional) the runner with which to compute the results.
+ options: (optional) any additional pipeline options to use to compute the
+ results.
+ force_compute: (optional) if True, forces recomputation rather than using
+ cached PCollections.
+
+ Returns:
+ An AsyncComputationResult object if blocking is False, otherwise None.
+ """
+ flatten_pcolls = []
+ for pcoll_container in pcolls:
+ if isinstance(pcoll_container, dict):
+ flatten_pcolls.extend(pcoll_container.values())
+ elif isinstance(pcoll_container, (beam.pvalue.PCollection, DeferredBase)):
+ flatten_pcolls.append(pcoll_container)
+ else:
+ try:
+ flatten_pcolls.extend(iter(pcoll_container))
+ except TypeError:
+ raise ValueError(
+ f'The given pcoll {pcoll_container} is not a dict, an iterable or '
+ 'a PCollection.')
+
+ pcolls_set = set()
+ for pcoll in flatten_pcolls:
+ if isinstance(pcoll, DeferredBase):
+ pcoll, _ = deferred_df_to_pcollection(pcoll)
+ watch({f'anonymous_pcollection_{id(pcoll)}': pcoll})
+ assert isinstance(
+ pcoll, beam.pvalue.PCollection
+ ), f'{pcoll} is not an apache_beam.pvalue.PCollection.'
+ pcolls_set.add(pcoll)
+
+ if not pcolls_set:
+ _LOGGER.info('No PCollections to compute.')
+ return None
+
+ pcoll_pipeline = next(iter(pcolls_set)).pipeline
+ user_pipeline = ie.current_env().user_pipeline(pcoll_pipeline)
+ if not user_pipeline:
+ watch({f'anonymous_pipeline_{id(pcoll_pipeline)}': pcoll_pipeline})
+ user_pipeline = pcoll_pipeline
+
+ for pcoll in pcolls_set:
+ if pcoll.pipeline is not user_pipeline:
+ raise ValueError('All PCollections must belong to the same pipeline.')
+
+ recording_manager = ie.current_env().get_recording_manager(
+ user_pipeline, create_if_absent=True)
+
+ return recording_manager.compute_async(
+ pcolls_set,
+ wait_for_inputs=wait_for_inputs,
+ blocking=blocking,
+ runner=runner,
+ options=options,
+ force_compute=force_compute,
+ )
+
+
@progress_indicated
def show_graph(pipeline):
"""Shows the current pipeline shape of a given Beam pipeline as a DAG.
diff --git
a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py
b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py
index 37cd63842b1..21163fc121c 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py
@@ -23,11 +23,16 @@ import importlib
import sys
import time
import unittest
+from concurrent.futures import TimeoutError
from typing import NamedTuple
+from unittest.mock import ANY
+from unittest.mock import MagicMock
+from unittest.mock import call
from unittest.mock import patch
import apache_beam as beam
from apache_beam import dataframe as frames
+from apache_beam.dataframe.frame_base import DeferredBase
from apache_beam.options.pipeline_options import FlinkRunnerOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.runners.interactive import interactive_beam as ib
@@ -36,6 +41,7 @@ from apache_beam.runners.interactive import
interactive_runner as ir
from apache_beam.runners.interactive.dataproc.dataproc_cluster_manager import
DataprocClusterManager
from apache_beam.runners.interactive.dataproc.types import ClusterMetadata
from apache_beam.runners.interactive.options.capture_limiters import Limiter
+from apache_beam.runners.interactive.recording_manager import
AsyncComputationResult
from apache_beam.runners.interactive.testing.mock_env import isolated_env
from apache_beam.runners.runner import PipelineState
from apache_beam.testing.test_stream import TestStream
@@ -65,6 +71,9 @@ def _get_watched_pcollections_with_variable_names():
return watched_pcollections
[email protected](
+ not ie.current_env().is_interactive_ready,
+ '[interactive] dependency is not installed.')
@isolated_env
class InteractiveBeamTest(unittest.TestCase):
def setUp(self):
@@ -671,5 +680,387 @@ class InteractiveBeamClustersTest(unittest.TestCase):
self.assertEqual(meta.num_workers, 2)
[email protected](
+ not ie.current_env().is_interactive_ready,
+ '[interactive] dependency is not installed.')
+@isolated_env
+class InteractiveBeamComputeTest(unittest.TestCase):
+ def setUp(self):
+ self.env = ie.current_env()
+ self.env._is_in_ipython = False # Default to non-IPython
+
+ def test_compute_blocking(self):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ data = list(range(10))
+ pcoll = p | 'Create' >> beam.Create(data)
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ result = ib.compute(pcoll, blocking=True)
+ self.assertIsNone(result) # Blocking returns None
+ self.assertTrue(pcoll in self.env.computed_pcollections)
+ collected = ib.collect(pcoll, raw_records=True)
+ self.assertEqual(collected, data)
+
+ def test_compute_non_blocking(self):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ data = list(range(5))
+ pcoll = p | 'Create' >> beam.Create(data)
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ async_result = ib.compute(pcoll, blocking=False)
+ self.assertIsInstance(async_result, AsyncComputationResult)
+
+ pipeline_result = async_result.result(timeout=60)
+ self.assertTrue(async_result.done())
+ self.assertIsNone(async_result.exception())
+ self.assertEqual(pipeline_result.state, PipelineState.DONE)
+ self.assertTrue(pcoll in self.env.computed_pcollections)
+ collected = ib.collect(pcoll, raw_records=True)
+ self.assertEqual(collected, data)
+
+ def test_compute_with_list_input(self):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
+ pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6])
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ ib.compute([pcoll1, pcoll2], blocking=True)
+ self.assertTrue(pcoll1 in self.env.computed_pcollections)
+ self.assertTrue(pcoll2 in self.env.computed_pcollections)
+
+ def test_compute_with_dict_input(self):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
+ pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6])
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ ib.compute({'a': pcoll1, 'b': pcoll2}, blocking=True)
+ self.assertTrue(pcoll1 in self.env.computed_pcollections)
+ self.assertTrue(pcoll2 in self.env.computed_pcollections)
+
+ def test_compute_empty_input(self):
+ result = ib.compute([], blocking=True)
+ self.assertIsNone(result)
+ result_async = ib.compute([], blocking=False)
+ self.assertIsNone(result_async)
+
+ def test_compute_force_recompute(self):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ pcoll = p | 'Create' >> beam.Create([1, 2, 3])
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ ib.compute(pcoll, blocking=True)
+ self.assertTrue(pcoll in self.env.computed_pcollections)
+
+ # Mock evict_computed_pcollections to check if it's called
+ with patch.object(self.env, 'evict_computed_pcollections') as mock_evict:
+ ib.compute(pcoll, blocking=True, force_compute=True)
+ mock_evict.assert_called_once_with(p)
+ self.assertTrue(pcoll in self.env.computed_pcollections)
+
+ def test_compute_non_blocking_exception(self):
+ p = beam.Pipeline(ir.InteractiveRunner())
+
+ def raise_error(elem):
+ raise ValueError('Test Error')
+
+ pcoll = p | 'Create' >> beam.Create([1]) | 'Error' >> beam.Map(raise_error)
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ async_result = ib.compute(pcoll, blocking=False)
+ self.assertIsInstance(async_result, AsyncComputationResult)
+
+ with self.assertRaises(ValueError):
+ async_result.result(timeout=60)
+
+ self.assertTrue(async_result.done())
+ self.assertIsInstance(async_result.exception(), ValueError)
+ self.assertFalse(pcoll in self.env.computed_pcollections)
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True)
+ @patch('apache_beam.runners.interactive.recording_manager.display')
+ @patch('apache_beam.runners.interactive.recording_manager.clear_output')
+ @patch('apache_beam.runners.interactive.recording_manager.HTML')
+ @patch('ipywidgets.Button')
+ @patch('ipywidgets.FloatProgress')
+ @patch('ipywidgets.Output')
+ @patch('ipywidgets.HBox')
+ @patch('ipywidgets.VBox')
+ def test_compute_non_blocking_ipython_widgets(
+ self,
+ mock_vbox,
+ mock_hbox,
+ mock_output,
+ mock_progress,
+ mock_button,
+ mock_html,
+ mock_clear_output,
+ mock_display,
+ ):
+ self.env._is_in_ipython = True
+ p = beam.Pipeline(ir.InteractiveRunner())
+ pcoll = p | 'Create' >> beam.Create(range(3))
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ mock_controls = mock_vbox.return_value
+ mock_html_instance = mock_html.return_value
+
+ async_result = ib.compute(pcoll, blocking=False)
+ self.assertIsNotNone(async_result)
+ mock_button.assert_called_once_with(description='Cancel')
+ mock_progress.assert_called_once()
+ mock_output.assert_called_once()
+ mock_hbox.assert_called_once()
+ mock_vbox.assert_called_once()
+ mock_html.assert_called_once_with('<p>Initializing...</p>')
+
+ self.assertEqual(mock_display.call_count, 2)
+ mock_display.assert_has_calls([
+ call(mock_controls, display_id=async_result._display_id),
+ call(mock_html_instance)
+ ])
+
+ mock_clear_output.assert_called_once()
+ async_result.result(timeout=60) # Let it finish
+
+ def test_compute_dependency_wait_true(self):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
+ pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2)
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ rm = self.env.get_recording_manager(p)
+
+ # Start pcoll1 computation
+ async_res1 = ib.compute(pcoll1, blocking=False)
+ self.assertTrue(self.env.is_pcollection_computing(pcoll1))
+
+ # Spy on _wait_for_dependencies
+ with patch.object(rm,
+ '_wait_for_dependencies',
+ wraps=rm._wait_for_dependencies) as spy_wait:
+ async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=True)
+
+ # Check that wait_for_dependencies was called for pcoll2
+ spy_wait.assert_called_with({pcoll2}, async_res2)
+
+ # Let pcoll1 finish
+ async_res1.result(timeout=60)
+ self.assertTrue(pcoll1 in self.env.computed_pcollections)
+ self.assertFalse(self.env.is_pcollection_computing(pcoll1))
+
+ # pcoll2 should now run and complete
+ async_res2.result(timeout=60)
+ self.assertTrue(pcoll2 in self.env.computed_pcollections)
+
+ @patch.object(ie.InteractiveEnvironment, 'is_pcollection_computing')
+ def test_compute_dependency_wait_false(self, mock_is_computing):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
+ pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2)
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ rm = self.env.get_recording_manager(p)
+
+ # Pretend pcoll1 is computing
+ mock_is_computing.side_effect = lambda pcoll: pcoll is pcoll1
+
+ with patch.object(rm,
+ '_execute_pipeline_fragment',
+ wraps=rm._execute_pipeline_fragment) as spy_execute:
+ async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=False)
+ async_res2.result(timeout=60)
+
+ # Assert that execute was called for pcoll2 without waiting
+ spy_execute.assert_called_with({pcoll2}, async_res2, ANY, ANY)
+ self.assertTrue(pcoll2 in self.env.computed_pcollections)
+
+ def test_async_computation_result_cancel(self):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ # A stream that never finishes to test cancellation
+ pcoll = p | beam.Create([1]) | beam.Map(lambda x: time.sleep(100))
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ async_result = ib.compute(pcoll, blocking=False)
+ self.assertIsInstance(async_result, AsyncComputationResult)
+
+ # Give it a moment to start
+ time.sleep(0.1)
+
+ # Mock the pipeline result's cancel method
+ mock_pipeline_result = MagicMock()
+ mock_pipeline_result.state = PipelineState.RUNNING
+ async_result.set_pipeline_result(mock_pipeline_result)
+
+ self.assertTrue(async_result.cancel())
+ mock_pipeline_result.cancel.assert_called_once()
+
+ # The future should be cancelled eventually by the runner
+ # This part is hard to test without deeper runner integration
+ with self.assertRaises(TimeoutError):
+ async_result.result(timeout=1) # It should not complete successfully
+
+ @patch(
+ 'apache_beam.runners.interactive.recording_manager.RecordingManager.'
+ '_execute_pipeline_fragment')
+ def test_compute_multiple_async(self, mock_execute_fragment):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
+ pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6])
+ pcoll3 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2)
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ mock_pipeline_result = MagicMock()
+ mock_pipeline_result.state = PipelineState.DONE
+ mock_execute_fragment.return_value = mock_pipeline_result
+
+ res1 = ib.compute(pcoll1, blocking=False)
+ res2 = ib.compute(pcoll2, blocking=False)
+ res3 = ib.compute(pcoll3, blocking=False) # Depends on pcoll1
+
+ self.assertIsNotNone(res1)
+ self.assertIsNotNone(res2)
+ self.assertIsNotNone(res3)
+
+ res1.result(timeout=60)
+ res2.result(timeout=60)
+ res3.result(timeout=60)
+
+ time.sleep(0.1)
+
+ self.assertTrue(
+ pcoll1 in self.env.computed_pcollections, "pcoll1 not marked computed")
+ self.assertTrue(
+ pcoll2 in self.env.computed_pcollections, "pcoll2 not marked computed")
+ self.assertTrue(
+ pcoll3 in self.env.computed_pcollections, "pcoll3 not marked computed")
+
+ self.assertEqual(mock_execute_fragment.call_count, 3)
+
+ @patch(
+ 'apache_beam.runners.interactive.interactive_beam.'
+ 'deferred_df_to_pcollection')
+ def test_compute_input_flattening(self, mock_deferred_to_pcoll):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ pcoll1 = p | 'C1' >> beam.Create([1])
+ pcoll2 = p | 'C2' >> beam.Create([2])
+ pcoll3 = p | 'C3' >> beam.Create([3])
+ pcoll4 = p | 'C4' >> beam.Create([4])
+
+ class MockDeferred(DeferredBase):
+ def __init__(self, pcoll):
+ mock_expr = MagicMock()
+ super().__init__(mock_expr)
+ self._pcoll = pcoll
+
+ def _get_underlying_pcollection(self):
+ return self._pcoll
+
+ deferred_pcoll = MockDeferred(pcoll4)
+
+ mock_deferred_to_pcoll.return_value = (pcoll4, p)
+
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ with patch.object(self.env, 'get_recording_manager') as mock_get_rm:
+ mock_rm = MagicMock()
+ mock_get_rm.return_value = mock_rm
+ ib.compute(pcoll1, [pcoll2], {'a': pcoll3}, deferred_pcoll)
+
+ expected_pcolls = {pcoll1, pcoll2, pcoll3, pcoll4}
+ mock_rm.compute_async.assert_called_once_with(
+ expected_pcolls,
+ wait_for_inputs=True,
+ blocking=False,
+ runner=None,
+ options=None,
+ force_compute=False)
+
+ def test_compute_invalid_input_type(self):
+ with self.assertRaisesRegex(ValueError,
+ "not a dict, an iterable or a PCollection"):
+ ib.compute(123)
+
+ def test_compute_mixed_pipelines(self):
+ p1 = beam.Pipeline(ir.InteractiveRunner())
+ pcoll1 = p1 | 'C1' >> beam.Create([1])
+ p2 = beam.Pipeline(ir.InteractiveRunner())
+ pcoll2 = p2 | 'C2' >> beam.Create([2])
+ ib.watch(locals())
+ self.env.track_user_pipelines()
+
+ with self.assertRaisesRegex(
+ ValueError, "All PCollections must belong to the same pipeline"):
+ ib.compute(pcoll1, pcoll2)
+
+ @patch(
+ 'apache_beam.runners.interactive.interactive_beam.'
+ 'deferred_df_to_pcollection')
+ @patch.object(ib, 'watch')
+ def test_compute_with_deferred_base(self, mock_watch,
mock_deferred_to_pcoll):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ pcoll = p | 'C1' >> beam.Create([1])
+
+ class MockDeferred(DeferredBase):
+ def __init__(self, pcoll):
+ # Provide a dummy expression to satisfy DeferredBase.__init__
+ mock_expr = MagicMock()
+ super().__init__(mock_expr)
+ self._pcoll = pcoll
+
+ def _get_underlying_pcollection(self):
+ return self._pcoll
+
+ deferred = MockDeferred(pcoll)
+
+ mock_deferred_to_pcoll.return_value = (pcoll, p)
+
+ with patch.object(self.env, 'get_recording_manager') as mock_get_rm:
+ mock_rm = MagicMock()
+ mock_get_rm.return_value = mock_rm
+ ib.compute(deferred)
+
+ mock_deferred_to_pcoll.assert_called_once_with(deferred)
+ self.assertEqual(mock_watch.call_count, 2)
+ mock_watch.assert_has_calls([
+ call({f'anonymous_pcollection_{id(pcoll)}': pcoll}),
+ call({f'anonymous_pipeline_{id(p)}': p})
+ ],
+ any_order=False)
+ mock_rm.compute_async.assert_called_once_with({pcoll},
+ wait_for_inputs=True,
+ blocking=False,
+ runner=None,
+ options=None,
+ force_compute=False)
+
+ def test_compute_new_pipeline(self):
+ p = beam.Pipeline(ir.InteractiveRunner())
+ pcoll = p | 'Create' >> beam.Create([1])
+ # NOT calling ib.watch() or track_user_pipelines()
+
+ with patch.object(self.env, 'get_recording_manager') as mock_get_rm, \
+ patch.object(ib, 'watch') as mock_watch:
+ mock_rm = MagicMock()
+ mock_get_rm.return_value = mock_rm
+ ib.compute(pcoll)
+
+ mock_watch.assert_called_with({f'anonymous_pipeline_{id(p)}': p})
+ mock_get_rm.assert_called_once_with(p, create_if_absent=True)
+ mock_rm.compute_async.assert_called_once()
+
+
if __name__ == '__main__':
unittest.main()
diff --git
a/sdks/python/apache_beam/runners/interactive/interactive_environment.py
b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
index e9ff86c6276..2a8fc23088a 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
@@ -175,6 +175,9 @@ class InteractiveEnvironment(object):
# Tracks the computation completeness of PCollections. PCollections tracked
# here don't need to be re-computed when data introspection is needed.
self._computed_pcolls = set()
+
+ self._computing_pcolls = set()
+
# Always watch __main__ module.
self.watch('__main__')
# Check if [interactive] dependencies are installed.
@@ -720,3 +723,19 @@ class InteractiveEnvironment(object):
bucket_name = cache_dir_path.parts[1]
assert_bucket_exists(bucket_name)
return 'gs://{}/{}'.format('/'.join(cache_dir_path.parts[1:]),
id(pipeline))
+
+ @property
+ def computing_pcollections(self):
+ return self._computing_pcolls
+
+ def mark_pcollection_computing(self, pcolls):
+ """Marks the given pcolls as currently being computed."""
+ self._computing_pcolls.update(pcolls)
+
+ def unmark_pcollection_computing(self, pcolls):
+ """Removes the given pcolls from the computing set."""
+ self._computing_pcolls.difference_update(pcolls)
+
+ def is_pcollection_computing(self, pcoll):
+ """Checks if the given pcollection is currently being computed."""
+ return pcoll in self._computing_pcolls
diff --git
a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
index 4d5f3f36ce6..eb3b4b51482 100644
---
a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
+++
b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
@@ -34,6 +34,9 @@ from apache_beam.runners.interactive.testing.mock_env import
isolated_env
_module_name = 'apache_beam.runners.interactive.interactive_environment_test'
[email protected](
+ not ie.current_env().is_interactive_ready,
+ '[interactive] dependency is not installed.')
@isolated_env
class InteractiveEnvironmentTest(unittest.TestCase):
def setUp(self):
@@ -341,6 +344,44 @@ class InteractiveEnvironmentTest(unittest.TestCase):
with self.assertRaises(ValueError):
env._get_gcs_cache_dir(p, cache_root)
+ def test_pcollection_computing_state(self):
+ env = ie.InteractiveEnvironment()
+ p = beam.Pipeline()
+ pcoll1 = p | 'Create1' >> beam.Create([1])
+ pcoll2 = p | 'Create2' >> beam.Create([2])
+
+ self.assertFalse(env.is_pcollection_computing(pcoll1))
+ self.assertFalse(env.is_pcollection_computing(pcoll2))
+ self.assertEqual(env.computing_pcollections, set())
+
+ env.mark_pcollection_computing({pcoll1})
+ self.assertTrue(env.is_pcollection_computing(pcoll1))
+ self.assertFalse(env.is_pcollection_computing(pcoll2))
+ self.assertEqual(env.computing_pcollections, {pcoll1})
+
+ env.mark_pcollection_computing({pcoll2})
+ self.assertTrue(env.is_pcollection_computing(pcoll1))
+ self.assertTrue(env.is_pcollection_computing(pcoll2))
+ self.assertEqual(env.computing_pcollections, {pcoll1, pcoll2})
+
+ env.unmark_pcollection_computing({pcoll1})
+ self.assertFalse(env.is_pcollection_computing(pcoll1))
+ self.assertTrue(env.is_pcollection_computing(pcoll2))
+ self.assertEqual(env.computing_pcollections, {pcoll2})
+
+ env.unmark_pcollection_computing({pcoll2})
+ self.assertFalse(env.is_pcollection_computing(pcoll1))
+ self.assertFalse(env.is_pcollection_computing(pcoll2))
+ self.assertEqual(env.computing_pcollections, set())
+
+ def test_mark_unmark_empty(self):
+ env = ie.InteractiveEnvironment()
+ # Ensure no errors with empty sets
+ env.mark_pcollection_computing(set())
+ self.assertEqual(env.computing_pcollections, set())
+ env.unmark_pcollection_computing(set())
+ self.assertEqual(env.computing_pcollections, set())
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py
b/sdks/python/apache_beam/runners/interactive/recording_manager.py
index f72ec2fe8e1..c19b60b64fd 100644
--- a/sdks/python/apache_beam/runners/interactive/recording_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py
@@ -15,13 +15,17 @@
# limitations under the License.
#
+import collections
import logging
+import os
import threading
import time
+import uuid
import warnings
+from concurrent.futures import Future
+from concurrent.futures import ThreadPoolExecutor
from typing import Any
-from typing import Dict
-from typing import List
+from typing import Optional
from typing import Union
import pandas as pd
@@ -37,11 +41,176 @@ from apache_beam.runners.interactive import
interactive_runner as ir
from apache_beam.runners.interactive import pipeline_fragment as pf
from apache_beam.runners.interactive import utils
from apache_beam.runners.interactive.caching.cacheable import CacheKey
+from apache_beam.runners.interactive.display.pipeline_graph import
PipelineGraph
from apache_beam.runners.interactive.options import capture_control
from apache_beam.runners.runner import PipelineState
_LOGGER = logging.getLogger(__name__)
+try:
+ import ipywidgets as widgets
+ from IPython.display import HTML
+ from IPython.display import clear_output
+ from IPython.display import display
+
+ IS_IPYTHON = True
+except ImportError:
+ IS_IPYTHON = False
+
+
+class AsyncComputationResult:
+ """Represents the result of an asynchronous computation."""
+ def __init__(
+ self,
+ future: Future,
+ pcolls: set[beam.pvalue.PCollection],
+ user_pipeline: beam.Pipeline,
+ recording_manager: 'RecordingManager',
+ ):
+ self._future = future
+ self._pcolls = pcolls
+ self._user_pipeline = user_pipeline
+ self._env = ie.current_env()
+ self._recording_manager = recording_manager
+ self._pipeline_result: Optional[beam.runners.runner.PipelineResult] = None
+ self._display_id = str(uuid.uuid4())
+ self._output_widget = widgets.Output() if IS_IPYTHON else None
+ self._cancel_button = (
+ widgets.Button(description='Cancel') if IS_IPYTHON else None)
+ self._progress_bar = (
+ widgets.FloatProgress(
+ value=0.0,
+ min=0.0,
+ max=1.0,
+ description='Running:',
+ bar_style='info',
+ ) if IS_IPYTHON else None)
+ self._cancel_requested = False
+
+ if IS_IPYTHON:
+ self._cancel_button.on_click(self._cancel_clicked)
+ controls = widgets.VBox([
+ widgets.HBox([self._cancel_button, self._progress_bar]),
+ self._output_widget,
+ ])
+ display(controls, display_id=self._display_id)
+ self.update_display('Initializing...')
+
+ self._future.add_done_callback(self._on_done)
+
+ def _cancel_clicked(self, b):
+ self._cancel_requested = True
+ self._cancel_button.disabled = True
+ self.update_display('Cancel requested...')
+ self.cancel()
+
+ def update_display(self, msg: str, progress: Optional[float] = None):
+ if not IS_IPYTHON:
+ print(f'AsyncCompute: {msg}')
+ return
+
+ with self._output_widget:
+ clear_output(wait=True)
+ display(HTML(f'<p>{msg}</p>'))
+
+ if progress is not None:
+ self._progress_bar.value = progress
+
+ if self.done():
+ self._cancel_button.disabled = True
+ if self.exception():
+ self._progress_bar.bar_style = 'danger'
+ self._progress_bar.description = 'Failed'
+ elif self._future.cancelled():
+ self._progress_bar.bar_style = 'warning'
+ self._progress_bar.description = 'Cancelled'
+ else:
+ self._progress_bar.bar_style = 'success'
+ self._progress_bar.description = 'Done'
+ elif self._cancel_requested:
+ self._cancel_button.disabled = True
+ self._progress_bar.description = 'Cancelling...'
+ else:
+ self._cancel_button.disabled = False
+
+ def set_pipeline_result(
+ self, pipeline_result: beam.runners.runner.PipelineResult):
+ self._pipeline_result = pipeline_result
+ if self._cancel_requested:
+ self.cancel()
+
+ def result(self, timeout=None):
+ return self._future.result(timeout=timeout)
+
+ def done(self):
+ return self._future.done()
+
+ def exception(self, timeout=None):
+ try:
+ return self._future.exception(timeout=timeout)
+ except TimeoutError:
+ return None
+
+ def _on_done(self, future: Future):
+ self._env.unmark_pcollection_computing(self._pcolls)
+ self._recording_manager._async_computations.pop(self._display_id, None)
+
+ if future.cancelled():
+ self.update_display('Computation Cancelled.', 1.0)
+ return
+
+ exc = future.exception()
+ if exc:
+ self.update_display(f'Error: {exc}', 1.0)
+ _LOGGER.error('Asynchronous computation failed: %s', exc, exc_info=exc)
+ else:
+ self.update_display('Computation Finished Successfully.', 1.0)
+ res = future.result()
+ if res and res.state == PipelineState.DONE:
+ self._env.mark_pcollection_computed(self._pcolls)
+ else:
+ _LOGGER.warning(
+ 'Async computation finished but state is not DONE: %s',
+ res.state if res else 'Unknown')
+
+ def cancel(self):
+ if self._future.done():
+ self.update_display('Cannot cancel: Computation already finished.')
+ return False
+
+ self._cancel_requested = True
+ self._cancel_button.disabled = True
+ self.update_display('Attempting to cancel...')
+
+ if self._pipeline_result:
+ try:
+ # Check pipeline state before cancelling
+ current_state = self._pipeline_result.state
+ if PipelineState.is_terminal(current_state):
+ self.update_display(
+ 'Cannot cancel: Pipeline already in terminal state'
+ f' {current_state}.')
+ return False
+
+ self._pipeline_result.cancel()
+ self.update_display('Cancel signal sent to pipeline.')
+ # The future will be cancelled by the runner if successful
+ return True
+ except Exception as e:
+ self.update_display('Error sending cancel signal: %s', e)
+ _LOGGER.warning('Error during pipeline cancel(): %s', e, exc_info=e)
+ # Still try to cancel the future as a fallback
+ return self._future.cancel()
+ else:
+ self.update_display('Pipeline not yet fully started, cancelling future.')
+ return self._future.cancel()
+
+ def __repr__(self):
+ return (
+ f'<AsyncComputationResult({self._display_id}) for'
+ f' {len(self._pcolls)} PCollections, status:'
+ f" {'done' if self.done() else 'running'}>")
+
class ElementStream:
"""A stream of elements from a given PCollection."""
@@ -151,7 +320,7 @@ class Recording:
def __init__(
self,
user_pipeline: beam.Pipeline,
- pcolls: List[beam.pvalue.PCollection], # noqa: F821
+ pcolls: list[beam.pvalue.PCollection], # noqa: F821
result: 'beam.runner.PipelineResult',
max_n: int,
max_duration_secs: float,
@@ -244,7 +413,7 @@ class Recording:
self._mark_computed.join()
return self._result.state
- def describe(self) -> Dict[str, int]:
+ def describe(self) -> dict[str, int]:
"""Returns a dictionary describing the cache and recording."""
cache_manager = ie.current_env().get_cache_manager(self._user_pipeline)
@@ -259,15 +428,97 @@ class RecordingManager:
self,
user_pipeline: beam.Pipeline,
pipeline_var: str = None,
- test_limiters: List['Limiter'] = None) -> None: # noqa: F821
+ test_limiters: list['Limiter'] = None) -> None: # noqa: F821
self.user_pipeline: beam.Pipeline = user_pipeline
self.pipeline_var: str = pipeline_var if pipeline_var else ''
self._recordings: set[Recording] = set()
self._start_time_sec: float = 0
self._test_limiters = test_limiters if test_limiters else []
+ self._executor = ThreadPoolExecutor(max_workers=os.cpu_count())
+ self._env = ie.current_env()
+ self._async_computations: dict[str, AsyncComputationResult] = {}
+ self._pipeline_graph = None
+
+ def _execute_pipeline_fragment(
+ self,
+ pcolls_to_compute: set[beam.pvalue.PCollection],
+ async_result: Optional['AsyncComputationResult'] = None,
+ runner: runner.PipelineRunner = None,
+ options: pipeline_options.PipelineOptions = None,
+ ) -> beam.runners.runner.PipelineResult:
+ """Synchronously executes a pipeline fragment for the given
PCollections."""
+ merged_options = pipeline_options.PipelineOptions(**{
+ **self.user_pipeline.options.get_all_options(
+ drop_default=True, retain_unknown_options=True
+ ),
+ **(
+ options.get_all_options(
+ drop_default=True, retain_unknown_options=True
+ )
+ if options
+ else {}
+ ),
+ })
+
+ fragment = pf.PipelineFragment(
+ list(pcolls_to_compute), merged_options, runner=runner)
+
+ if async_result:
+ async_result.update_display('Building pipeline fragment...', 0.1)
+
+ pipeline_to_run = fragment.deduce_fragment()
+ if async_result:
+ async_result.update_display('"Pipeline running, awaiting finish..."',
0.2)
+
+ pipeline_result = pipeline_to_run.run()
+ if async_result:
+ async_result.set_pipeline_result(pipeline_result)
+
+ pipeline_result.wait_until_finish()
+ return pipeline_result
+
+ def _run_async_computation(
+ self,
+ pcolls_to_compute: set[beam.pvalue.PCollection],
+ async_result: 'AsyncComputationResult',
+ wait_for_inputs: bool,
+ runner: runner.PipelineRunner = None,
+ options: pipeline_options.PipelineOptions = None,
+ ):
+ """The function to be run in the thread pool for async computation."""
+ try:
+ if wait_for_inputs:
+ if not self._wait_for_dependencies(pcolls_to_compute, async_result):
+ raise RuntimeError('Dependency computation failed or was cancelled.')
+
+ _LOGGER.info(
+ 'Starting asynchronous computation for %d PCollections.',
+ len(pcolls_to_compute))
+
+ pipeline_result = self._execute_pipeline_fragment(
+ pcolls_to_compute, async_result, runner, options)
+
+ # if pipeline_result.state == PipelineState.DONE:
+ # self._env.mark_pcollection_computed(pcolls_to_compute)
+ # _LOGGER.info(
+ # 'Asynchronous computation finished successfully for'
+ # f' {len(pcolls_to_compute)} PCollections.'
+ # )
+ # else:
+ # _LOGGER.error(
+ # 'Asynchronous computation failed for'
+ # f' {len(pcolls_to_compute)} PCollections. State:'
+ # f' {pipeline_result.state}'
+ # )
+ return pipeline_result
+ except Exception as e:
+ _LOGGER.exception('Exception during asynchronous computation: %s', e)
+ raise
+ # finally:
+ # self._env.unmark_pcollection_computing(pcolls_to_compute)
- def _watch(self, pcolls: List[beam.pvalue.PCollection]) -> None:
+ def _watch(self, pcolls: list[beam.pvalue.PCollection]) -> None:
"""Watch any pcollections not being watched.
This allows for the underlying caching layer to identify the PCollection as
@@ -337,7 +588,7 @@ class RecordingManager:
# evict the BCJ after they complete.
ie.current_env().evict_background_caching_job(self.user_pipeline)
- def describe(self) -> Dict[str, int]:
+ def describe(self) -> dict[str, int]:
"""Returns a dictionary describing the cache and recording."""
cache_manager = ie.current_env().get_cache_manager(self.user_pipeline)
@@ -386,9 +637,213 @@ class RecordingManager:
return True
return False
+ def compute_async(
+ self,
+ pcolls: set[beam.pvalue.PCollection],
+ wait_for_inputs: bool = True,
+ blocking: bool = False,
+ runner: runner.PipelineRunner = None,
+ options: pipeline_options.PipelineOptions = None,
+ force_compute: bool = False,
+ ) -> Optional[AsyncComputationResult]:
+ """Computes the given PCollections, potentially asynchronously."""
+
+ if force_compute:
+ self._env.evict_computed_pcollections(self.user_pipeline)
+
+ computed_pcolls = {
+ pcoll
+ for pcoll in pcolls if pcoll in self._env.computed_pcollections
+ }
+ computing_pcolls = {
+ pcoll
+ for pcoll in pcolls if self._env.is_pcollection_computing(pcoll)
+ }
+ pcolls_to_compute = pcolls - computed_pcolls - computing_pcolls
+
+ if not pcolls_to_compute:
+ _LOGGER.info(
+ 'All requested PCollections are already computed or are being'
+ ' computed.')
+ return None
+
+ self._watch(list(pcolls_to_compute))
+ self.record_pipeline()
+
+ if blocking:
+ self._env.mark_pcollection_computing(pcolls_to_compute)
+ try:
+ if wait_for_inputs:
+ if not self._wait_for_dependencies(pcolls_to_compute):
+ raise RuntimeError(
+ 'Dependency computation failed or was cancelled.')
+ pipeline_result = self._execute_pipeline_fragment(
+ pcolls_to_compute, None, runner, options)
+ if pipeline_result.state == PipelineState.DONE:
+ self._env.mark_pcollection_computed(pcolls_to_compute)
+ else:
+ _LOGGER.error(
+ 'Blocking computation failed. State: %s', pipeline_result.state)
+ raise RuntimeError(
+ 'Blocking computation failed. State: %s', pipeline_result.state)
+ finally:
+ self._env.unmark_pcollection_computing(pcolls_to_compute)
+ return None
+
+ else: # Asynchronous
+ future = Future()
+ async_result = AsyncComputationResult(
+ future, pcolls_to_compute, self.user_pipeline, self)
+ self._async_computations[async_result._display_id] = async_result
+ self._env.mark_pcollection_computing(pcolls_to_compute)
+
+ def task():
+ try:
+ result = self._run_async_computation(
+ pcolls_to_compute, async_result, wait_for_inputs, runner,
options)
+ future.set_result(result)
+ except Exception as e:
+ if not future.cancelled():
+ future.set_exception(e)
+
+ self._executor.submit(task)
+ return async_result
+
+ def _get_pipeline_graph(self):
+ """Lazily initializes and returns the PipelineGraph."""
+ if self._pipeline_graph is None:
+ try:
+ # Try to create the graph.
+ self._pipeline_graph = PipelineGraph(self.user_pipeline)
+ except (ImportError, NameError, AttributeError):
+ # If pydot is missing, PipelineGraph() might crash.
+ _LOGGER.warning(
+ "Could not create PipelineGraph (pydot missing?). " \
+ "Async features disabled."
+ )
+ self._pipeline_graph = None
+ return self._pipeline_graph
+
+ def _get_pcoll_id_map(self):
+ """Creates a map from PCollection object to its ID in the proto."""
+ pcoll_to_id = {}
+ graph = self._get_pipeline_graph()
+ if graph and graph._pipeline_instrument:
+ pcoll_to_id = graph._pipeline_instrument._pcoll_to_pcoll_id
+ return {v: k for k, v in pcoll_to_id.items()}
+
+ def _get_all_dependencies(
+ self,
+ pcolls: set[beam.pvalue.PCollection]) -> set[beam.pvalue.PCollection]:
+ """Gets all upstream PCollection dependencies
+ for the given set of PCollections."""
+ graph = self._get_pipeline_graph()
+ if not graph:
+ return set()
+
+ analyzer = graph._pipeline_instrument
+ if not analyzer:
+ return set()
+
+ pcoll_to_id = analyzer._pcoll_to_pcoll_id
+
+ target_pcoll_ids = {
+ pcoll_to_id.get(str(pcoll))
+ for pcoll in pcolls if str(pcoll) in pcoll_to_id
+ }
+
+ if not target_pcoll_ids:
+ return set()
+
+ # Build a map from PCollection ID to the actual PCollection object
+ id_to_pcoll_obj = {}
+ for _, inspectable in self._env.inspector.inspectables.items():
+ value = inspectable['value']
+ if isinstance(value, beam.pvalue.PCollection):
+ pcoll_id = pcoll_to_id.get(str(value))
+ if pcoll_id:
+ id_to_pcoll_obj[pcoll_id] = value
+
+ dependencies = set()
+ queue = collections.deque(target_pcoll_ids)
+ visited_pcoll_ids = set(target_pcoll_ids)
+
+ producers = graph._producers
+ transforms = graph._pipeline_proto.components.transforms
+
+ while queue:
+ pcoll_id = queue.popleft()
+ if pcoll_id not in producers:
+ continue
+
+ producer_id = producers[pcoll_id]
+ transform_proto = transforms.get(producer_id)
+ if not transform_proto:
+ continue
+
+ for input_pcoll_id in transform_proto.inputs.values():
+ if input_pcoll_id not in visited_pcoll_ids:
+ visited_pcoll_ids.add(input_pcoll_id)
+ queue.append(input_pcoll_id)
+
+ dep_obj = id_to_pcoll_obj.get(input_pcoll_id)
+ if dep_obj and dep_obj not in pcolls:
+ dependencies.add(dep_obj)
+
+ return dependencies
+
+ def _wait_for_dependencies(
+ self,
+ pcolls: set[beam.pvalue.PCollection],
+ async_result: Optional[AsyncComputationResult] = None,
+ ) -> bool:
+ """Waits for any dependencies of the given
+ PCollections that are currently being computed."""
+ dependencies = self._get_all_dependencies(pcolls)
+ computing_deps: dict[beam.pvalue.PCollection, AsyncComputationResult] = {}
+
+ for dep in dependencies:
+ if self._env.is_pcollection_computing(dep):
+ for comp in self._async_computations.values():
+ if dep in comp._pcolls:
+ computing_deps[dep] = comp
+ break
+
+ if not computing_deps:
+ return True
+
+ if async_result:
+ async_result.update_display(
+ 'Waiting for %d dependencies to finish...', len(computing_deps))
+ _LOGGER.info(
+ 'Waiting for %d dependencies: %s',
+ len(computing_deps),
+ computing_deps.keys())
+
+ futures_to_wait = list(
+ set(comp._future for comp in computing_deps.values()))
+
+ try:
+ for i, future in enumerate(futures_to_wait):
+ if async_result:
+ async_result.update_display(
+ f'Waiting for dependency {i + 1}/{len(futures_to_wait)}...',
+ progress=0.05 + 0.05 * (i / len(futures_to_wait)),
+ )
+ future.result()
+ if async_result:
+ async_result.update_display('Dependencies finished.', progress=0.1)
+ _LOGGER.info('Dependencies finished successfully.')
+ return True
+ except Exception as e:
+ if async_result:
+ async_result.update_display(f'Dependency failed: {e}')
+ _LOGGER.error('Dependency computation failed: %s', e, exc_info=e)
+ return False
+
def record(
self,
- pcolls: List[beam.pvalue.PCollection],
+ pcolls: list[beam.pvalue.PCollection],
*,
max_n: int,
max_duration: Union[int, str],
@@ -431,8 +886,11 @@ class RecordingManager:
# Start a pipeline fragment to start computing the PCollections.
uncomputed_pcolls = set(pcolls).difference(computed_pcolls)
if uncomputed_pcolls:
- # Clear the cache of the given uncomputed PCollections because they are
- # incomplete.
+ if not self._wait_for_dependencies(uncomputed_pcolls):
+ raise RuntimeError(
+ 'Cannot record because a dependency failed to compute'
+ ' asynchronously.')
+
self._clear()
merged_options = pipeline_options.PipelineOptions(
diff --git
a/sdks/python/apache_beam/runners/interactive/recording_manager_test.py
b/sdks/python/apache_beam/runners/interactive/recording_manager_test.py
index 698a464ae73..d2038719f67 100644
--- a/sdks/python/apache_beam/runners/interactive/recording_manager_test.py
+++ b/sdks/python/apache_beam/runners/interactive/recording_manager_test.py
@@ -17,7 +17,9 @@
import time
import unittest
+from concurrent.futures import Future
from unittest.mock import MagicMock
+from unittest.mock import call
from unittest.mock import patch
import apache_beam as beam
@@ -30,6 +32,8 @@ from apache_beam.runners.interactive import
interactive_environment as ie
from apache_beam.runners.interactive.caching.cacheable import CacheKey
from apache_beam.runners.interactive.interactive_runner import
InteractiveRunner
from apache_beam.runners.interactive.options.capture_limiters import Limiter
+from apache_beam.runners.interactive.recording_manager import _LOGGER
+from apache_beam.runners.interactive.recording_manager import
AsyncComputationResult
from apache_beam.runners.interactive.recording_manager import ElementStream
from apache_beam.runners.interactive.recording_manager import Recording
from apache_beam.runners.interactive.recording_manager import RecordingManager
@@ -43,6 +47,386 @@ from apache_beam.utils.timestamp import MIN_TIMESTAMP
from apache_beam.utils.windowed_value import WindowedValue
[email protected](
+ not ie.current_env().is_interactive_ready,
+ '[interactive] dependency is not installed.')
+class AsyncComputationResultTest(unittest.TestCase):
+ def setUp(self):
+ self.mock_future = MagicMock(spec=Future)
+ self.pcolls = {MagicMock(spec=beam.pvalue.PCollection)}
+ self.user_pipeline = MagicMock(spec=beam.Pipeline)
+ self.recording_manager = MagicMock(spec=RecordingManager)
+ self.recording_manager._async_computations = {}
+ self.env = ie.InteractiveEnvironment()
+ patch.object(ie, 'current_env', return_value=self.env).start()
+
+ self.mock_button = patch('ipywidgets.Button', autospec=True).start()
+ self.mock_float_progress = patch(
+ 'ipywidgets.FloatProgress', autospec=True).start()
+ self.mock_output = patch('ipywidgets.Output', autospec=True).start()
+ self.mock_hbox = patch('ipywidgets.HBox', autospec=True).start()
+ self.mock_vbox = patch('ipywidgets.VBox', autospec=True).start()
+ self.mock_display = patch(
+ 'apache_beam.runners.interactive.recording_manager.display',
+ autospec=True).start()
+ self.mock_clear_output = patch(
+ 'apache_beam.runners.interactive.recording_manager.clear_output',
+ autospec=True).start()
+ self.mock_html = patch(
+ 'apache_beam.runners.interactive.recording_manager.HTML',
+ autospec=True).start()
+
+ self.addCleanup(patch.stopall)
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False)
+ def test_async_result_init_non_ipython(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ self.assertIsNotNone(async_res)
+ self.mock_future.add_done_callback.assert_called_once()
+ self.assertIsNone(async_res._cancel_button)
+
+ def test_on_done_success(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ mock_pipeline_result = MagicMock()
+ mock_pipeline_result.state = PipelineState.DONE
+ self.mock_future.result.return_value = mock_pipeline_result
+ self.mock_future.exception.return_value = None
+ self.mock_future.cancelled.return_value = False
+ async_res._display_id = 'test_id'
+ self.recording_manager._async_computations['test_id'] = async_res
+
+ with patch.object(
+ self.env, 'unmark_pcollection_computing'
+ ) as mock_unmark, patch.object(
+ self.env, 'mark_pcollection_computed'
+ ) as mock_mark_computed, patch.object(
+ async_res, 'update_display'
+ ) as mock_update:
+ async_res._on_done(self.mock_future)
+ mock_unmark.assert_called_once_with(self.pcolls)
+ mock_mark_computed.assert_called_once_with(self.pcolls)
+ self.assertNotIn('test_id', self.recording_manager._async_computations)
+ mock_update.assert_called_with('Computation Finished Successfully.', 1.0)
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False)
+ def test_on_done_failure(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ test_exception = ValueError('Test')
+ self.mock_future.exception.return_value = test_exception
+ self.mock_future.cancelled.return_value = False
+
+ with patch.object(
+ self.env, 'unmark_pcollection_computing'
+ ) as mock_unmark, patch.object(
+ self.env, 'mark_pcollection_computed'
+ ) as mock_mark_computed:
+ async_res._on_done(self.mock_future)
+ mock_unmark.assert_called_once_with(self.pcolls)
+ mock_mark_computed.assert_not_called()
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False)
+ def test_on_done_cancelled(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ self.mock_future.cancelled.return_value = True
+
+ with patch.object(self.env, 'unmark_pcollection_computing') as mock_unmark:
+ async_res._on_done(self.mock_future)
+ mock_unmark.assert_called_once_with(self.pcolls)
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True)
+ def test_cancel(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ mock_pipeline_result = MagicMock()
+ mock_pipeline_result.state = PipelineState.RUNNING
+ async_res.set_pipeline_result(mock_pipeline_result)
+ self.mock_future.done.return_value = False
+
+ self.assertTrue(async_res.cancel())
+ mock_pipeline_result.cancel.assert_called_once()
+ self.assertTrue(async_res._cancel_requested)
+ self.assertTrue(async_res._cancel_button.disabled)
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False)
+ def test_cancel_already_done(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ self.mock_future.done.return_value = True
+ self.assertFalse(async_res.cancel())
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True)
+ @patch('apache_beam.runners.interactive.recording_manager.display')
+ @patch('ipywidgets.Button')
+ @patch('ipywidgets.FloatProgress')
+ @patch('ipywidgets.Output')
+ @patch('ipywidgets.HBox')
+ @patch('ipywidgets.VBox')
+ def test_async_result_init_ipython(
+ self,
+ mock_vbox,
+ mock_hbox,
+ mock_output,
+ mock_progress,
+ mock_button,
+ mock_display,
+ ):
+ mock_btn_instance = mock_button.return_value
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ self.assertIsNotNone(async_res)
+ mock_button.assert_called_once_with(description='Cancel')
+ mock_progress.assert_called_once()
+ mock_output.assert_called_once()
+ mock_hbox.assert_called_once()
+ mock_vbox.assert_called_once()
+ mock_display.assert_called()
+ mock_btn_instance.on_click.assert_called_once_with(
+ async_res._cancel_clicked)
+ self.mock_future.add_done_callback.assert_called_once()
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True)
+ @patch(
+ 'apache_beam.runners.interactive.recording_manager.display', MagicMock())
+ @patch('ipywidgets.Button', MagicMock())
+ @patch('ipywidgets.FloatProgress', MagicMock())
+ @patch('ipywidgets.Output', MagicMock())
+ def test_cancel_clicked(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ with patch.object(async_res, 'cancel') as mock_cancel, patch.object(
+ async_res, 'update_display'
+ ) as mock_update:
+ async_res._cancel_clicked(None)
+ self.assertTrue(async_res._cancel_requested)
+ self.assertTrue(async_res._cancel_button.disabled)
+ mock_update.assert_called_once_with('Cancel requested...')
+ mock_cancel.assert_called_once()
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False)
+ def test_update_display_non_ipython(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ with patch('builtins.print') as mock_print:
+ async_res.update_display('Test Message')
+ mock_print.assert_called_once_with('AsyncCompute: Test Message')
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True)
+ def test_update_display_ipython(self):
+ mock_prog_instance = self.mock_float_progress.return_value
+ mock_btn_instance = self.mock_button.return_value
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+
+ update_call_count = 1
+ self.assertEqual(self.mock_clear_output.call_count, update_call_count)
+
+ # State: Running
+ self.mock_future.done.return_value = False
+ async_res._cancel_requested = False
+ async_res.update_display('Running Test', 0.5)
+ update_call_count += 1
+ self.mock_display.assert_called()
+ self.assertEqual(self.mock_clear_output.call_count, update_call_count)
+ self.assertEqual(mock_prog_instance.value, 0.5)
+ self.assertFalse(mock_btn_instance.disabled)
+ self.mock_html.assert_called_with('<p>Running Test</p>')
+
+ # State: Done Success
+ self.mock_future.done.return_value = True
+ self.mock_future.exception.return_value = None
+ self.mock_future.cancelled.return_value = False
+ async_res.update_display('Done')
+ update_call_count += 1
+ self.assertEqual(self.mock_clear_output.call_count, update_call_count)
+ self.assertTrue(mock_btn_instance.disabled)
+ self.assertEqual(mock_prog_instance.bar_style, 'success')
+ self.assertEqual(mock_prog_instance.description, 'Done')
+
+ # State: Done Failed
+ self.mock_future.exception.return_value = Exception()
+ async_res.update_display('Failed')
+ update_call_count += 1
+ self.assertEqual(self.mock_clear_output.call_count, update_call_count)
+ self.assertEqual(mock_prog_instance.bar_style, 'danger')
+ self.assertEqual(mock_prog_instance.description, 'Failed')
+
+ # State: Done Cancelled
+ self.mock_future.exception.return_value = None
+ self.mock_future.cancelled.return_value = True
+ async_res.update_display('Cancelled')
+ update_call_count += 1
+ self.assertEqual(self.mock_clear_output.call_count, update_call_count)
+ self.assertEqual(mock_prog_instance.bar_style, 'warning')
+ self.assertEqual(mock_prog_instance.description, 'Cancelled')
+
+ # State: Cancelling
+ self.mock_future.done.return_value = False
+ async_res._cancel_requested = True
+ async_res.update_display('Cancelling')
+ update_call_count += 1
+ self.assertEqual(self.mock_clear_output.call_count, update_call_count)
+ self.assertTrue(mock_btn_instance.disabled)
+ self.assertEqual(mock_prog_instance.description, 'Cancelling...')
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False)
+ def test_set_pipeline_result_cancel_requested(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ async_res._cancel_requested = True
+ mock_pipeline_result = MagicMock()
+ with patch.object(async_res, 'cancel') as mock_cancel:
+ async_res.set_pipeline_result(mock_pipeline_result)
+ self.assertIs(async_res._pipeline_result, mock_pipeline_result)
+ mock_cancel.assert_called_once()
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False)
+ def test_exception_timeout(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ self.mock_future.exception.side_effect = TimeoutError
+ self.assertIsNone(async_res.exception(timeout=0.1))
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False)
+ @patch.object(_LOGGER, 'warning')
+ def test_on_done_not_done_state(self, mock_logger_warning):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ mock_pipeline_result = MagicMock()
+ mock_pipeline_result.state = PipelineState.FAILED
+ self.mock_future.result.return_value = mock_pipeline_result
+ self.mock_future.exception.return_value = None
+ self.mock_future.cancelled.return_value = False
+
+ with patch.object(self.env,
+ 'mark_pcollection_computed') as mock_mark_computed:
+ async_res._on_done(self.mock_future)
+ mock_mark_computed.assert_not_called()
+ mock_logger_warning.assert_called_once()
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True)
+ def test_cancel_no_pipeline_result(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ self.mock_future.done.return_value = False
+ self.mock_future.cancel.return_value = True
+ with patch.object(async_res, 'update_display') as mock_update:
+ self.assertTrue(async_res.cancel())
+ mock_update.assert_any_call(
+ 'Pipeline not yet fully started, cancelling future.')
+ self.mock_future.cancel.assert_called_once()
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True)
+ def test_cancel_pipeline_terminal_state(self):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ self.mock_future.done.return_value = False
+ mock_pipeline_result = MagicMock()
+ mock_pipeline_result.state = PipelineState.DONE
+ async_res.set_pipeline_result(mock_pipeline_result)
+
+ with patch.object(async_res, 'update_display') as mock_update:
+ self.assertFalse(async_res.cancel())
+ mock_update.assert_any_call(
+ 'Cannot cancel: Pipeline already in terminal state DONE.')
+ mock_pipeline_result.cancel.assert_not_called()
+
+ @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True)
+ @patch.object(_LOGGER, 'warning')
+ @patch.object(AsyncComputationResult, 'update_display')
+ def test_cancel_pipeline_exception(
+ self, mock_update_display, mock_logger_warning):
+ async_res = AsyncComputationResult(
+ self.mock_future,
+ self.pcolls,
+ self.user_pipeline,
+ self.recording_manager,
+ )
+ self.mock_future.done.return_value = False
+ mock_pipeline_result = MagicMock()
+ mock_pipeline_result.state = PipelineState.RUNNING
+ test_exception = RuntimeError('Cancel Failed')
+ mock_pipeline_result.cancel.side_effect = test_exception
+ async_res.set_pipeline_result(mock_pipeline_result)
+ self.mock_future.cancel.return_value = False
+
+ self.assertFalse(async_res.cancel())
+
+ expected_calls = [
+ call('Initializing...'), # From __init__
+ call('Attempting to cancel...'), # From cancel() start
+ call('Error sending cancel signal: %s',
+ test_exception) # From except block
+ ]
+ mock_update_display.assert_has_calls(expected_calls, any_order=False)
+
+ mock_logger_warning.assert_called_once()
+ self.mock_future.cancel.assert_called_once()
+
+
class MockPipelineResult(beam.runners.runner.PipelineResult):
"""Mock class for controlling a PipelineResult."""
def __init__(self):
@@ -283,6 +667,9 @@ class RecordingTest(unittest.TestCase):
cache_manager.size('full', letters_stream.cache_key))
[email protected](
+ not ie.current_env().is_interactive_ready,
+ '[interactive] dependency is not installed.')
class RecordingManagerTest(unittest.TestCase):
def test_basic_execution(self):
"""A basic pipeline to be used as a smoke test."""
@@ -565,6 +952,119 @@ class RecordingManagerTest(unittest.TestCase):
# Reset cache_root value.
ib.options.cache_root = None
+ def test_compute_async_blocking(self):
+ p = beam.Pipeline(InteractiveRunner())
+ pcoll = p | beam.Create([1, 2, 3])
+ ib.watch(locals())
+ ie.current_env().track_user_pipelines()
+ rm = RecordingManager(p)
+
+ with patch.object(rm, '_execute_pipeline_fragment') as mock_execute:
+ mock_result = MagicMock()
+ mock_result.state = PipelineState.DONE
+ mock_execute.return_value = mock_result
+ res = rm.compute_async({pcoll}, blocking=True)
+ self.assertIsNone(res)
+ mock_execute.assert_called_once()
+ self.assertTrue(pcoll in ie.current_env().computed_pcollections)
+
+ @patch(
+
'apache_beam.runners.interactive.recording_manager.AsyncComputationResult'
+ )
+ @patch(
+ 'apache_beam.runners.interactive.recording_manager.ThreadPoolExecutor.'
+ 'submit')
+ def test_compute_async_non_blocking(self, mock_submit,
mock_async_result_cls):
+ p = beam.Pipeline(InteractiveRunner())
+ pcoll = p | beam.Create([1, 2, 3])
+ ib.watch(locals())
+ ie.current_env().track_user_pipelines()
+ rm = RecordingManager(p)
+ mock_async_res_instance = mock_async_result_cls.return_value
+
+ # Capture the task
+ task_submitted = None
+
+ def capture_task(task):
+ nonlocal task_submitted
+ task_submitted = task
+ # Return a mock future
+ return MagicMock()
+
+ mock_submit.side_effect = capture_task
+
+ with patch.object(
+ rm, '_wait_for_dependencies', return_value=True
+ ), patch.object(
+ rm, '_execute_pipeline_fragment'
+ ) as _, patch.object(
+ ie.current_env(),
+ 'mark_pcollection_computing',
+ wraps=ie.current_env().mark_pcollection_computing,
+ ) as wrapped_mark:
+
+ res = rm.compute_async({pcoll}, blocking=False)
+ wrapped_mark.assert_called_once_with({pcoll})
+
+ # Run the task to trigger the marks
+ self.assertIs(res, mock_async_res_instance)
+ mock_submit.assert_called_once()
+ self.assertIsNotNone(task_submitted)
+
+ with patch.object(
+ rm, '_wait_for_dependencies', return_value=True
+ ), patch.object(
+ rm, '_execute_pipeline_fragment'
+ ) as _:
+ task_submitted()
+
+ self.assertTrue(pcoll in ie.current_env().computing_pcollections)
+
+ def test_get_all_dependencies(self):
+ p = beam.Pipeline(InteractiveRunner())
+ p1 = p | 'C1' >> beam.Create([1])
+ p2 = p | 'C2' >> beam.Create([2])
+ p3 = p1 | 'M1' >> beam.Map(lambda x: x)
+ p4 = (p2, p3) | 'F1' >> beam.Flatten()
+ p5 = p3 | 'M2' >> beam.Map(lambda x: x)
+ ib.watch(locals())
+ ie.current_env().track_user_pipelines()
+ rm = RecordingManager(p)
+ rm.record_pipeline() # Analyze pipeline
+
+ self.assertEqual(rm._get_all_dependencies({p1}), set())
+ self.assertEqual(rm._get_all_dependencies({p3}), {p1})
+ self.assertEqual(rm._get_all_dependencies({p4}), {p1, p2, p3})
+ self.assertEqual(rm._get_all_dependencies({p5}), {p1, p3})
+ self.assertEqual(rm._get_all_dependencies({p4, p5}), {p1, p2, p3})
+
+ @patch(
+
'apache_beam.runners.interactive.recording_manager.AsyncComputationResult'
+ )
+ def test_wait_for_dependencies(self, mock_async_result_cls):
+ p = beam.Pipeline(InteractiveRunner())
+ p1 = p | 'C1' >> beam.Create([1])
+ p2 = p1 | 'M1' >> beam.Map(lambda x: x)
+ ib.watch(locals())
+ ie.current_env().track_user_pipelines()
+ rm = RecordingManager(p)
+ rm.record_pipeline()
+
+ # Scenario 1: No dependencies computing
+ self.assertTrue(rm._wait_for_dependencies({p2}))
+
+ # Scenario 2: Dependency is computing
+ mock_future = MagicMock(spec=Future)
+ mock_async_res = MagicMock(spec=AsyncComputationResult)
+ mock_async_res._future = mock_future
+ mock_async_res._pcolls = {p1}
+ rm._async_computations['dep_id'] = mock_async_res
+ ie.current_env().mark_pcollection_computing({p1})
+
+ self.assertTrue(rm._wait_for_dependencies({p2}))
+ mock_future.result.assert_called_once()
+ ie.current_env().unmark_pcollection_computing({p1})
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/utils_test.py
b/sdks/python/apache_beam/runners/interactive/utils_test.py
index 5fb41df3586..3dba6dfaa3f 100644
--- a/sdks/python/apache_beam/runners/interactive/utils_test.py
+++ b/sdks/python/apache_beam/runners/interactive/utils_test.py
@@ -244,6 +244,9 @@ class IPythonLogHandlerTest(unittest.TestCase):
reason='[interactive] dependency is not installed.')
class ProgressIndicatorTest(unittest.TestCase):
def setUp(self):
+ self.gcs_patcher = patch(
+ 'apache_beam.io.gcp.gcsfilesystem.GCSFileSystem.delete')
+ self.gcs_patcher.start()
ie.new_env()
@patch('IPython.get_ipython', new_callable=mock_get_ipython)
@@ -279,6 +282,9 @@ class ProgressIndicatorTest(unittest.TestCase):
mocked_html.assert_called()
mocked_js.assert_called()
+ def tearDown(self):
+ self.gcs_patcher.stop()
+
@unittest.skipIf(
not ie.current_env().is_interactive_ready,
@@ -287,6 +293,9 @@ class MessagingUtilTest(unittest.TestCase):
SAMPLE_DATA = {'a': [1, 2, 3], 'b': 4, 'c': '5', 'd': {'e': 'f'}}
def setUp(self):
+ self.gcs_patcher = patch(
+ 'apache_beam.io.gcp.gcsfilesystem.GCSFileSystem.delete')
+ self.gcs_patcher.start()
ie.new_env()
def test_as_json_decorator(self):
@@ -298,6 +307,9 @@ class MessagingUtilTest(unittest.TestCase):
# dictionaries remember the order of items inserted.
self.assertEqual(json.loads(dummy()), MessagingUtilTest.SAMPLE_DATA)
+ def tearDown(self):
+ self.gcs_patcher.stop()
+
class GeneralUtilTest(unittest.TestCase):
def test_pcoll_by_name(self):