gemini-code-assist[bot] commented on code in PR #39161:
URL: https://github.com/apache/beam/pull/39161#discussion_r3494746941
##########
sdks/python/apache_beam/runners/interactive/recording_manager.py:
##########
@@ -800,10 +799,15 @@ def _wait_for_dependencies(
"""Waits for any dependencies of the given
PCollections that are currently being computed."""
dependencies = self._get_all_dependencies(pcolls)
+ if async_result is None:
+ pcolls_to_check = dependencies.union(pcolls)
+ else:
+ pcolls_to_check = dependencies
computing_deps: dict[beam.pvalue.PCollection, AsyncComputationResult] = {}
- for dep in dependencies:
- if self._env.is_pcollection_computing(dep):
+ for dep in pcolls_to_check:
+ is_computing = self._env.is_pcollection_computing(dep)
+ if is_computing:
Review Comment:

Iterating over `self._async_computations.values()` (line 811) while another
thread concurrently modifies the dictionary (e.g., in `_on_done` via `pop()`)
will raise a `RuntimeError: dictionary changed size during iteration`.
To prevent this, we should protect accesses to `self._async_computations`
using a lock.
Please define a lock in `RecordingManager.__init__`:
```python
self._lock = threading.Lock()
```
And update `_wait_for_dependencies` to copy the values under the lock before
iterating:
```python
with self._lock:
async_computations_copy = list(self._async_computations.values())
for dep in pcolls_to_check:
is_computing = self._env.is_pcollection_computing(dep)
if is_computing:
for comp in async_computations_copy:
if dep in comp._pcolls:
computing_deps[dep] = comp
```
##########
sdks/python/apache_beam/runners/interactive/recording_manager.py:
##########
@@ -151,27 +152,32 @@ def exception(self, timeout=None):
except TimeoutError:
return None
- def _on_done(self, future: Future):
- self._env.unmark_pcollection_computing(self._pcolls)
- self._recording_manager._async_computations.pop(self._display_id, None)
-
- if future.cancelled():
- self.update_display('Computation Cancelled.', 1.0)
- return
+ def wait_for_completion(self):
+ self._completed_event.wait()
- exc = future.exception()
- if exc:
- self.update_display(f'Error: {exc}', 1.0)
- _LOGGER.error('Asynchronous computation failed: %s', exc, exc_info=exc)
- else:
- self.update_display('Computation Finished Successfully.', 1.0)
- res = future.result()
- if res and res.state == PipelineState.DONE:
- self._env.mark_pcollection_computed(self._pcolls)
+ def _on_done(self, future: Future):
+ try:
+ if future.cancelled():
+ self.update_display('Computation Cancelled.', 1.0)
+ return
+
+ exc = future.exception()
+ if exc:
+ self.update_display(f'Error: {exc}', 1.0)
+ _LOGGER.error('Asynchronous computation failed: %s', exc, exc_info=exc)
else:
- _LOGGER.warning(
- 'Async computation finished but state is not DONE: %s',
- res.state if res else 'Unknown')
+ self.update_display('Computation Finished Successfully.', 1.0)
+ res = future.result()
+ if res and res.state == PipelineState.DONE:
+ self._env.mark_pcollection_computed(self._pcolls)
+ else:
+ _LOGGER.warning(
+ 'Async computation finished but state is not DONE: %s',
+ res.state if res else 'Unknown')
+ finally:
+ self._env.unmark_pcollection_computing(self._pcolls)
+ self._recording_manager._async_computations.pop(self._display_id, None)
Review Comment:

To ensure thread safety and prevent `RuntimeError` when
`_wait_for_dependencies` iterates over `_async_computations`, we should pop the
computation from the dictionary under the same lock.
```suggestion
self._env.unmark_pcollection_computing(self._pcolls)
with self._recording_manager._lock:
self._recording_manager._async_computations.pop(self._display_id,
None)
```
##########
sdks/python/apache_beam/runners/interactive/recording_manager.py:
##########
@@ -891,23 +894,32 @@ def record(
'Cannot record because a dependency failed to compute'
' asynchronously.')
- self._clear()
-
- merged_options = pipeline_options.PipelineOptions(
- **{
- **self.user_pipeline.options.get_all_options(
- drop_default=True, retain_unknown_options=True),
- **options.get_all_options(
- drop_default=True, retain_unknown_options=True)
- }) if options else self.user_pipeline.options
-
- cache_path = ie.current_env().options.cache_root
- is_remote_run = cache_path and ie.current_env(
- ).options.cache_root.startswith('gs://')
- pf.PipelineFragment(
- list(uncomputed_pcolls), merged_options,
- runner=runner).run(blocking=is_remote_run)
- result = ie.current_env().pipeline_result(self.user_pipeline)
+ # Recalculate uncomputed PCollections because some may have finished
computing during the wait
+ computed_pcolls = set(
+ pcoll for pcoll in pcolls
+ if pcoll in ie.current_env().computed_pcollections)
+ uncomputed_pcolls = set(pcolls).difference(computed_pcolls)
Review Comment:

This set comprehension and subsequent difference can be simplified into a
single, more efficient, and idiomatic statement using `set.difference` directly
on `ie.current_env().computed_pcollections`.
```python
uncomputed_pcolls = set(pcolls).difference(
ie.current_env().computed_pcollections)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]