This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new d545ece  Merge pull request #11128 from [BEAM-9524] Fix for ib.show() 
executing indefinitely
d545ece is described below

commit d545ece09d24f428518a8ebbad09b1559ce5aa23
Author: Sam sam <[email protected]>
AuthorDate: Thu Mar 19 18:31:14 2020 -0700

    Merge pull request #11128 from [BEAM-9524] Fix for ib.show() executing 
indefinitely
    
    * Fix ib.show() spinning forever when rexecuting cells without kernel 
restart
    
    Change-Id: I53aa32a75645086efffa091a53880a076c3a689d
    
    * Add CacheKey class
    
    Change-Id: I1ab6e7036172d7e2d07c774778a50e165df6bdca
    
    * fix dep loop
    
    Change-Id: I247f37cd7acffb6ad796ce0fa8b54b0feff400d1
---
 .../runners/interactive/background_caching_job.py  |  22 ++--
 .../runners/interactive/caching/streaming_cache.py |  15 +--
 .../interactive/caching/streaming_cache_test.py    |  34 +++---
 .../runners/interactive/interactive_environment.py |  35 +++---
 .../runners/interactive/interactive_runner_test.py |   6 +-
 .../runners/interactive/pipeline_instrument.py     |  82 +++++++++++---
 .../interactive/pipeline_instrument_test.py        | 118 +++++++++------------
 7 files changed, 182 insertions(+), 130 deletions(-)

diff --git 
a/sdks/python/apache_beam/runners/interactive/background_caching_job.py 
b/sdks/python/apache_beam/runners/interactive/background_caching_job.py
index e002809..6690fc8 100644
--- a/sdks/python/apache_beam/runners/interactive/background_caching_job.py
+++ b/sdks/python/apache_beam/runners/interactive/background_caching_job.py
@@ -189,6 +189,20 @@ def is_background_caching_job_needed(user_pipeline):
           cache_changed))
 
 
+def is_cache_complete(pipeline_id):
+  # type: (str) -> bool
+
+  """Returns True if the backgrond cache for the given pipeline is done.
+  """
+  user_pipeline = ie.current_env().pipeline_id_to_pipeline(pipeline_id)
+  job = ie.current_env().get_background_caching_job(user_pipeline)
+  is_done = job and job.is_done()
+  cache_changed = is_source_to_cache_changed(
+      user_pipeline, update_cached_source_signature=False)
+
+  return is_done and not cache_changed
+
+
 def has_source_to_cache(user_pipeline):
   """Determines if a user-defined pipeline contains any source that need to be
   cached. If so, also immediately wrap current cache manager held by current
@@ -208,14 +222,6 @@ def has_source_to_cache(user_pipeline):
   if has_cache:
     if not isinstance(ie.current_env().cache_manager(),
                       streaming_cache.StreamingCache):
-      # Wrap the cache manager into a streaming cache manager. Note this
-      # does not invalidate the current cache manager.
-      def is_cache_complete():
-        job = ie.current_env().get_background_caching_job(user_pipeline)
-        is_done = job and job.is_done()
-        cache_changed = is_source_to_cache_changed(
-            user_pipeline, update_cached_source_signature=False)
-        return is_done and not cache_changed
 
       ie.current_env().set_cache_manager(
           streaming_cache.StreamingCache(
diff --git 
a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py 
b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
index 17d8a5f..9b4b20b 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
@@ -153,7 +153,10 @@ class StreamingCacheSource:
     self._coder = coder
     self._labels = labels
     self._is_cache_complete = (
-        is_cache_complete if is_cache_complete else lambda: True)
+        is_cache_complete if is_cache_complete else lambda _: True)
+
+    from apache_beam.runners.interactive.pipeline_instrument import CacheKey
+    self._pipeline_id = CacheKey.from_str(labels[-1]).pipeline_id
 
   def _wait_until_file_exists(self, timeout_secs=30):
     """Blocks until the file exists for a maximum of timeout_secs.
@@ -186,7 +189,7 @@ class StreamingCacheSource:
       # Check if we are at EOF or if we have an incomplete line.
       if not line or (line and line[-1] != b'\n'[0]):
         # Complete reading only when the cache is complete.
-        if self._is_cache_complete():
+        if self._is_cache_complete(self._pipeline_id):
           break
 
         if not tail:
@@ -273,8 +276,7 @@ class StreamingCache(CacheManager):
       return iter([]), -1
 
     reader = StreamingCacheSource(
-        self._cache_dir, labels,
-        is_cache_complete=self._is_cache_complete).read(tail=False)
+        self._cache_dir, labels, self._is_cache_complete).read(tail=False)
     header = next(reader)
     return StreamingCache.Reader([header], [reader]).read(), 1
 
@@ -286,9 +288,8 @@ class StreamingCache(CacheManager):
     pipeline runtime which needs to block.
     """
     readers = [
-        StreamingCacheSource(
-            self._cache_dir, l,
-            is_cache_complete=self._is_cache_complete).read(tail=True)
+        StreamingCacheSource(self._cache_dir, l,
+                             self._is_cache_complete).read(tail=True)
         for l in labels
     ]
     headers = [next(r) for r in readers]
diff --git 
a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py 
b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
index 002a05a..c73134e 100644
--- 
a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
+++ 
b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
@@ -29,6 +29,7 @@ from apache_beam.portability.api.beam_interactive_api_pb2 
import TestStreamFileR
 from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
 from apache_beam.runners.interactive.cache_manager import 
SafeFastPrimitivesCoder
 from apache_beam.runners.interactive.caching.streaming_cache import 
StreamingCache
+from apache_beam.runners.interactive.pipeline_instrument import CacheKey
 from apache_beam.runners.interactive.testing.test_cache_manager import 
FileRecordsBuilder
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.test_stream import TestStream
@@ -56,7 +57,7 @@ class StreamingCacheTest(unittest.TestCase):
   def test_single_reader(self):
     """Tests that we expect to see all the correctly emitted 
TestStreamPayloads.
     """
-    CACHED_PCOLLECTION_KEY = 'arbitrary_key'
+    CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
 
     values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
               .add_element(element=0, event_time_secs=0)
@@ -109,9 +110,9 @@ class StreamingCacheTest(unittest.TestCase):
     """Tests that the service advances the clock with multiple outputs.
     """
 
-    CACHED_LETTERS = 'letters'
-    CACHED_NUMBERS = 'numbers'
-    CACHED_LATE = 'late'
+    CACHED_LETTERS = repr(CacheKey('letters', '', '', ''))
+    CACHED_NUMBERS = repr(CacheKey('numbers', '', '', ''))
+    CACHED_LATE = repr(CacheKey('late', '', '', ''))
 
     letters = (FileRecordsBuilder(CACHED_LETTERS)
                .advance_processing_time(1)
@@ -235,13 +236,14 @@ class StreamingCacheTest(unittest.TestCase):
     This ensures that the sink and source speak the same language in terms of
     coders, protos, order, and units.
     """
+    CACHED_RECORDS = repr(CacheKey('records', '', '', ''))
 
     # Units here are in seconds.
     test_stream = (TestStream()
-                   .advance_watermark_to(0, tag='records')
+                   .advance_watermark_to(0, tag=CACHED_RECORDS)
                    .advance_processing_time(5)
-                   .add_elements(['a', 'b', 'c'], tag='records')
-                   .advance_watermark_to(10, tag='records')
+                   .add_elements(['a', 'b', 'c'], tag=CACHED_RECORDS)
+                   .advance_watermark_to(10, tag=CACHED_RECORDS)
                    .advance_processing_time(1)
                    .add_elements(
                        [
@@ -249,7 +251,7 @@ class StreamingCacheTest(unittest.TestCase):
                            TimestampedValue('2', 15),
                            TimestampedValue('3', 15)
                        ],
-                       tag='records')) # yapf: disable
+                       tag=CACHED_RECORDS)) # yapf: disable
 
     coder = SafeFastPrimitivesCoder()
     cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)
@@ -259,9 +261,9 @@ class StreamingCacheTest(unittest.TestCase):
         'passthrough_pcollection_output_ids')
     with TestPipeline(options=options) as p:
       # pylint: disable=expression-not-assigned
-      p | test_stream | cache.sink(['records'])
+      p | test_stream | cache.sink([CACHED_RECORDS])
 
-    reader, _ = cache.read('records')
+    reader, _ = cache.read(CACHED_RECORDS)
     actual_events = list(reader)
 
     # Units here are in microseconds.
@@ -271,7 +273,7 @@ class StreamingCacheTest(unittest.TestCase):
                 advance_duration=5 * 10**6)),
         TestStreamPayload.Event(
             watermark_event=TestStreamPayload.Event.AdvanceWatermark(
-                new_watermark=0, tag='records')),
+                new_watermark=0, tag=CACHED_RECORDS)),
         TestStreamPayload.Event(
             element_event=TestStreamPayload.Event.AddElements(
                 elements=[
@@ -282,13 +284,13 @@ class StreamingCacheTest(unittest.TestCase):
                     TestStreamPayload.TimestampedElement(
                         encoded_element=coder.encode('c'), timestamp=0),
                 ],
-                tag='records')),
+                tag=CACHED_RECORDS)),
         TestStreamPayload.Event(
             
processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
                 advance_duration=1 * 10**6)),
         TestStreamPayload.Event(
             watermark_event=TestStreamPayload.Event.AdvanceWatermark(
-                new_watermark=10 * 10**6, tag='records')),
+                new_watermark=10 * 10**6, tag=CACHED_RECORDS)),
         TestStreamPayload.Event(
             element_event=TestStreamPayload.Event.AddElements(
                 elements=[
@@ -302,7 +304,7 @@ class StreamingCacheTest(unittest.TestCase):
                         encoded_element=coder.encode('3'), timestamp=15 *
                         10**6),
                 ],
-                tag='records')),
+                tag=CACHED_RECORDS)),
     ]
     self.assertEqual(actual_events, expected_events)
 
@@ -312,8 +314,8 @@ class StreamingCacheTest(unittest.TestCase):
     This tests the funcionatlity that the StreamingCache reads from multiple
     files and combines them into a single sorted output.
     """
-    LETTERS_TAG = 'letters'
-    NUMBERS_TAG = 'numbers'
+    LETTERS_TAG = repr(CacheKey('letters', '', '', ''))
+    NUMBERS_TAG = repr(CacheKey('numbers', '', '', ''))
 
     # Units here are in seconds.
     test_stream = (TestStream()
diff --git 
a/sdks/python/apache_beam/runners/interactive/interactive_environment.py 
b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
index c686b30..662e513 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
@@ -134,19 +134,19 @@ class InteractiveEnvironment(object):
     self._watching_set = set()
     # Holds variables list of (Dict[str, object]).
     self._watching_dict_list = []
-    # Holds results of main jobs as Dict[Pipeline, PipelineResult].
+    # Holds results of main jobs as Dict[str, PipelineResult].
     # Each key is a pipeline instance defined by the end user. The
     # InteractiveRunner is responsible for populating this dictionary
     # implicitly.
     self._main_pipeline_results = {}
-    # Holds background caching jobs as Dict[Pipeline, BackgroundCachingJob].
+    # Holds background caching jobs as Dict[str, BackgroundCachingJob].
     # Each key is a pipeline instance defined by the end user. The
     # InteractiveRunner or its enclosing scope is responsible for populating
     # this dictionary implicitly when a background caching jobs is started.
     self._background_caching_jobs = {}
     # Holds TestStreamServiceControllers that controls gRPC servers serving
     # events as test stream of TestStreamPayload.Event.
-    # Dict[Pipeline, TestStreamServiceController]. Each key is a pipeline
+    # Dict[str, TestStreamServiceController]. Each key is a pipeline
     # instance defined by the end user. The InteractiveRunner or its enclosing
     # scope is responsible for populating this dictionary implicitly when a new
     # controller is created to start a new gRPC server. The server stays alive
@@ -301,15 +301,15 @@ class InteractiveEnvironment(object):
     assert issubclass(type(result), runner.PipelineResult), (
         'result must be an instance of '
         'apache_beam.runners.runner.PipelineResult or its subclass')
-    self._main_pipeline_results[pipeline] = result
+    self._main_pipeline_results[str(id(pipeline))] = result
 
   def evict_pipeline_result(self, pipeline):
     """Evicts the tracking of given pipeline run. Noop if absent."""
-    return self._main_pipeline_results.pop(pipeline, None)
+    return self._main_pipeline_results.pop(str(id(pipeline)), None)
 
   def pipeline_result(self, pipeline):
     """Gets the pipeline run result. None if absent."""
-    return self._main_pipeline_results.get(pipeline, None)
+    return self._main_pipeline_results.get(str(id(pipeline)), None)
 
   def set_background_caching_job(self, pipeline, background_caching_job):
     """Sets the background caching job started from the given pipeline."""
@@ -318,32 +318,32 @@ class InteractiveEnvironment(object):
     from apache_beam.runners.interactive.background_caching_job import 
BackgroundCachingJob
     assert isinstance(background_caching_job, BackgroundCachingJob), (
         'background_caching job must be an instance of BackgroundCachingJob')
-    self._background_caching_jobs[pipeline] = background_caching_job
+    self._background_caching_jobs[str(id(pipeline))] = background_caching_job
 
   def get_background_caching_job(self, pipeline):
     """Gets the background caching job started from the given pipeline."""
-    return self._background_caching_jobs.get(pipeline, None)
+    return self._background_caching_jobs.get(str(id(pipeline)), None)
 
   def set_test_stream_service_controller(self, pipeline, controller):
     """Sets the test stream service controller that has started a gRPC server
     serving the test stream for any job started from the given user-defined
     pipeline.
     """
-    self._test_stream_service_controllers[pipeline] = controller
+    self._test_stream_service_controllers[str(id(pipeline))] = controller
 
   def get_test_stream_service_controller(self, pipeline):
     """Gets the test stream service controller that has started a gRPC server
     serving the test stream for any job started from the given user-defined
     pipeline.
     """
-    return self._test_stream_service_controllers.get(pipeline, None)
+    return self._test_stream_service_controllers.get(str(id(pipeline)), None)
 
   def evict_test_stream_service_controller(self, pipeline):
     """Evicts and pops the test stream service controller that has started a
     gRPC server serving the test stream for any job started from the given
     user-defined pipeline.
     """
-    return self._test_stream_service_controllers.pop(pipeline, None)
+    return self._test_stream_service_controllers.pop(str(id(pipeline)), None)
 
   def is_terminated(self, pipeline):
     """Queries if the most recent job (by executing the given pipeline) state
@@ -354,14 +354,14 @@ class InteractiveEnvironment(object):
     return True
 
   def set_cached_source_signature(self, pipeline, signature):
-    self._cached_source_signature[pipeline] = signature
+    self._cached_source_signature[str(id(pipeline))] = signature
 
   def get_cached_source_signature(self, pipeline):
-    return self._cached_source_signature.get(pipeline, set())
+    return self._cached_source_signature.get(str(id(pipeline)), set())
 
   def evict_cached_source_signature(self, pipeline=None):
     if pipeline:
-      self._cached_source_signature.pop(pipeline, None)
+      self._cached_source_signature.pop(str(id(pipeline)), None)
     else:
       self._cached_source_signature.clear()
 
@@ -395,6 +395,13 @@ class InteractiveEnvironment(object):
   def tracked_user_pipelines(self):
     return self._tracked_user_pipelines
 
+  def pipeline_id_to_pipeline(self, pid):
+    """Converts a pipeline id to a user pipeline.
+    """
+
+    pid_to_pipelines = {str(id(p)): p for p in self._tracked_user_pipelines}
+    return pid_to_pipelines[pid]
+
   def mark_pcollection_computed(self, pcolls):
     """Marks computation completeness for the given pcolls.
 
diff --git 
a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py 
b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
index 5b7090e..4c16ae9 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
@@ -185,11 +185,11 @@ class InteractiveRunnerTest(unittest.TestCase):
     ib.watch(locals())
     result = p.run()
     self.assertTrue(init in ie.current_env().computed_pcollections)
-    self.assertEqual([0, 1, 2, 3, 4], list(result.get(init)))
+    self.assertEqual({0, 1, 2, 3, 4}, set(result.get(init)))
     self.assertTrue(square in ie.current_env().computed_pcollections)
-    self.assertEqual([0, 1, 4, 9, 16], list(result.get(square)))
+    self.assertEqual({0, 1, 4, 9, 16}, set(result.get(square)))
     self.assertTrue(cube in ie.current_env().computed_pcollections)
-    self.assertEqual([0, 1, 8, 27, 64], list(result.get(cube)))
+    self.assertEqual({0, 1, 8, 27, 64}, set(result.get(cube)))
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py 
b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
index b88517c..9b97c50 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
@@ -39,6 +39,48 @@ READ_CACHE = "_ReadCache_"
 WRITE_CACHE = "_WriteCache_"
 
 
+# TODO: turn this into a dataclass object when we finally get off of Python2.
+class Cacheable:
+  def __init__(self, pcoll_id, var, version, pcoll, producer_version):
+    self.pcoll_id = pcoll_id
+    self.var = var
+    self.version = version
+    self.pcoll = pcoll
+    self.producer_version = producer_version
+
+  def __eq__(self, other):
+    return (
+        self.pcoll_id == other.pcoll_id and self.var == other.var and
+        self.version == other.version and self.pcoll == other.pcoll and
+        self.producer_version == other.producer_version)
+
+  def __hash__(self):
+    return hash((
+        self.pcoll_id,
+        self.var,
+        self.version,
+        self.pcoll,
+        self.producer_version))
+
+
+# TODO: turn this into a dataclass object when we finally get off of Python2.
+class CacheKey:
+  def __init__(self, var, version, producer_version, pipeline_id):
+    self.var = var
+    self.version = version
+    self.producer_version = producer_version
+    self.pipeline_id = pipeline_id
+
+  @staticmethod
+  def from_str(r):
+    split = r.split('|')
+    return CacheKey(split[0], split[1], split[2], split[3])
+
+  def __repr__(self):
+    return '|'.join(
+        [self.var, self.version, self.producer_version, self.pipeline_id])
+
+
 class PipelineInstrument(object):
   """A pipeline instrument for pipeline to be executed by interactive runner.
 
@@ -445,7 +487,7 @@ class PipelineInstrument(object):
     # Write cache for all cacheables.
     for _, cacheable in self.cacheables.items():
       self._write_cache(
-          self._pipeline, cacheable['pcoll'], ignore_unbounded_reads=True)
+          self._pipeline, cacheable.pcoll, ignore_unbounded_reads=True)
 
     # Instrument the background caching pipeline if we can.
     if self.has_unbounded_sources:
@@ -509,7 +551,7 @@ class PipelineInstrument(object):
         pcoll_id = self._pin.pcolls_to_pcoll_id.get(str(pcoll), '')
         if pcoll_id in self._pin._pcoll_version_map:
           cacheable_key = self._pin._cacheable_key(pcoll)
-          user_pcoll = self._pin.cacheables[cacheable_key]['pcoll']
+          user_pcoll = self._pin.cacheables[cacheable_key].pcoll
           if (cacheable_key in self._pin.cacheables and user_pcoll != pcoll):
             if not self._pin._user_pipeline:
               # Retrieve a reference to the user defined pipeline instance.
@@ -523,7 +565,7 @@ class PipelineInstrument(object):
                   self._pin._user_pipeline):
                 self._pin._cache_manager = ie.current_env().cache_manager()
             self._pin._runner_pcoll_to_user_pcoll[pcoll] = user_pcoll
-            self._pin.cacheables[cacheable_key]['pcoll'] = pcoll
+            self._pin.cacheables[cacheable_key].pcoll = pcoll
 
     v = PreprocessVisitor(self)
     self._pipeline.visit(v)
@@ -757,9 +799,17 @@ class PipelineInstrument(object):
     """
     cacheable = self.cacheables.get(self._cacheable_key(pcoll), None)
     if cacheable:
-      return '_'.join((
-          cacheable['var'], cacheable['version'],
-          cacheable['producer_version']))
+      if cacheable.pcoll in self.runner_pcoll_to_user_pcoll:
+        user_pcoll = self.runner_pcoll_to_user_pcoll[cacheable.pcoll]
+      else:
+        user_pcoll = cacheable.pcoll
+
+      return repr(
+          CacheKey(
+              cacheable.var,
+              cacheable.version,
+              cacheable.producer_version,
+              str(id(user_pcoll.pipeline))))
     return ''
 
   def cacheable_var_by_pcoll_id(self, pcoll_id):
@@ -829,18 +879,22 @@ def cacheables(pcolls_to_pcoll_id):
       # TODO(BEAM-8288): cleanup the attribute check when py2 is not supported.
       if hasattr(val, '__class__') and isinstance(val, 
beam.pvalue.PCollection):
         cacheable = {}
-        cacheable['pcoll_id'] = pcolls_to_pcoll_id.get(str(val), None)
+
+        pcoll_id = pcolls_to_pcoll_id.get(str(val), None)
         # It's highly possible that PCollection str is not unique across
         # multiple pipelines, further check during instrument is needed.
-        if not cacheable['pcoll_id']:
+        if not pcoll_id:
           continue
-        cacheable['var'] = key
-        cacheable['version'] = str(id(val))
-        cacheable['pcoll'] = val
-        cacheable['producer_version'] = str(id(val.producer))
-        pcoll_version_map[cacheable['pcoll_id']] = cacheable['version']
+
+        cacheable = Cacheable(
+            pcoll_id=pcoll_id,
+            var=key,
+            version=str(id(val)),
+            pcoll=val,
+            producer_version=str(id(val.producer)))
+        pcoll_version_map[cacheable.pcoll_id] = cacheable.version
         cacheables[cacheable_key(val, pcolls_to_pcoll_id)] = cacheable
-        cacheable_var_by_pcoll_id[cacheable['pcoll_id']] = key
+        cacheable_var_by_pcoll_id[cacheable.pcoll_id] = key
 
   return pcoll_version_map, cacheables, cacheable_var_by_pcoll_id
 
diff --git 
a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py 
b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
index f10e98d..8ef6f8c 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
@@ -45,6 +45,14 @@ class PipelineInstrumentTest(unittest.TestCase):
   def setUp(self):
     ie.new_env(cache_manager=InMemoryCache())
 
+  def cache_key_of(self, name, pcoll):
+    return repr(
+        instr.CacheKey(
+            name,
+            str(id(pcoll)),
+            str(id(pcoll.producer)),
+            str(id(pcoll.pipeline))))
+
   def test_pcolls_to_pcoll_id(self):
     p = beam.Pipeline(interactive_runner.InteractiveRunner())
     # pylint: disable=range-builtin-not-iterating
@@ -101,14 +109,12 @@ class PipelineInstrumentTest(unittest.TestCase):
     pipeline_instrument = instr.build_pipeline_instrument(p)
     self.assertEqual(
         pipeline_instrument.cache_key(init_pcoll),
-        'init_pcoll_' + str(id(init_pcoll)) + '_' +
-        str(id(init_pcoll.producer)))
+        self.cache_key_of('init_pcoll', init_pcoll))
     self.assertEqual(
         pipeline_instrument.cache_key(squares),
-        'squares_' + str(id(squares)) + '_' + str(id(squares.producer)))
+        self.cache_key_of('squares', squares))
     self.assertEqual(
-        pipeline_instrument.cache_key(cubes),
-        'cubes_' + str(id(cubes)) + '_' + str(id(cubes.producer)))
+        pipeline_instrument.cache_key(cubes), self.cache_key_of('cubes', 
cubes))
 
   def test_cacheables(self):
     p = beam.Pipeline(interactive_runner.InteractiveRunner())
@@ -122,27 +128,24 @@ class PipelineInstrumentTest(unittest.TestCase):
     self.assertEqual(
         pipeline_instrument.cacheables,
         {
-            pipeline_instrument._cacheable_key(init_pcoll): {
-                'var': 'init_pcoll',
-                'version': str(id(init_pcoll)),
-                'pcoll_id': 'ref_PCollection_PCollection_10',
-                'producer_version': str(id(init_pcoll.producer)),
-                'pcoll': init_pcoll
-            },
-            pipeline_instrument._cacheable_key(squares): {
-                'var': 'squares',
-                'version': str(id(squares)),
-                'pcoll_id': 'ref_PCollection_PCollection_11',
-                'producer_version': str(id(squares.producer)),
-                'pcoll': squares
-            },
-            pipeline_instrument._cacheable_key(cubes): {
-                'var': 'cubes',
-                'version': str(id(cubes)),
-                'pcoll_id': 'ref_PCollection_PCollection_12',
-                'producer_version': str(id(cubes.producer)),
-                'pcoll': cubes
-            }
+            pipeline_instrument._cacheable_key(init_pcoll): instr.Cacheable(
+                var='init_pcoll',
+                version=str(id(init_pcoll)),
+                pcoll_id='ref_PCollection_PCollection_10',
+                producer_version=str(id(init_pcoll.producer)),
+                pcoll=init_pcoll),
+            pipeline_instrument._cacheable_key(squares): instr.Cacheable(
+                var='squares',
+                version=str(id(squares)),
+                pcoll_id='ref_PCollection_PCollection_11',
+                producer_version=str(id(squares.producer)),
+                pcoll=squares),
+            pipeline_instrument._cacheable_key(cubes): instr.Cacheable(
+                var='cubes',
+                version=str(id(cubes)),
+                pcoll_id='ref_PCollection_PCollection_12',
+                producer_version=str(id(cubes.producer)),
+                pcoll=cubes)
         })
 
   def test_has_unbounded_source(self):
@@ -251,11 +254,9 @@ class PipelineInstrumentTest(unittest.TestCase):
     p_copy, _, _ = self._example_pipeline(False)
 
     # Mock as if cacheable PCollections are cached.
-    init_pcoll_cache_key = 'init_pcoll_' + str(id(init_pcoll)) + '_' + str(
-        id(init_pcoll.producer))
+    init_pcoll_cache_key = self.cache_key_of('init_pcoll', init_pcoll)
     self._mock_write_cache([b'1', b'2', b'3'], init_pcoll_cache_key)
-    second_pcoll_cache_key = 'second_pcoll_' + str(
-        id(second_pcoll)) + '_' + str(id(second_pcoll.producer))
+    second_pcoll_cache_key = self.cache_key_of('second_pcoll', second_pcoll)
     self._mock_write_cache([b'1', b'4', b'9'], second_pcoll_cache_key)
     # Mark the completeness of PCollections from the original(user) pipeline.
     ie.current_env().mark_pcollection_computed(
@@ -320,13 +321,10 @@ class PipelineInstrumentTest(unittest.TestCase):
     # Mock as if cacheable PCollections are cached.
     ib.watch(locals())
 
-    def cache_key_of(name, pcoll):
-      return name + '_' + str(id(pcoll)) + '_' + str(id(pcoll.producer))
-
     for name, pcoll in locals().items():
       if not isinstance(pcoll, beam.pvalue.PCollection):
         continue
-      cache_key = cache_key_of(name, pcoll)
+      cache_key = self.cache_key_of(name, pcoll)
       self._mock_write_cache([b''], cache_key)
 
     # Instrument the original pipeline to create the pipeline the user will 
see.
@@ -338,11 +336,11 @@ class PipelineInstrumentTest(unittest.TestCase):
 
     # Now, build the expected pipeline which replaces the unbounded source with
     # a TestStream.
-    source_1_cache_key = cache_key_of('source_1', source_1)
+    source_1_cache_key = self.cache_key_of('source_1', source_1)
     p_expected = beam.Pipeline()
     test_stream = (
         p_expected
-        | TestStream(output_tags=[cache_key_of('source_1', source_1)]))
+        | TestStream(output_tags=[self.cache_key_of('source_1', source_1)]))
     # pylint: disable=expression-not-assigned
     test_stream | 'square1' >> beam.Map(lambda x: x * x)
 
@@ -396,9 +394,6 @@ class PipelineInstrumentTest(unittest.TestCase):
     # Watch but do not cache the PCollections.
     ib.watch(locals())
 
-    def cache_key_of(name, pcoll):
-      return name + '_' + str(id(pcoll)) + '_' + str(id(pcoll.producer))
-
     # Make sure that sources without a user reference are still cached.
     instr.watch_sources(p_original)
 
@@ -424,7 +419,7 @@ class PipelineInstrumentTest(unittest.TestCase):
     # Now, build the expected pipeline which replaces the unbounded source with
     # a TestStream.
     intermediate_source_pcoll_cache_key = \
-        cache_key_of('synthetic_var_' + str(id(intermediate_source_pcoll)),
+        self.cache_key_of('synthetic_var_' + 
str(id(intermediate_source_pcoll)),
                      intermediate_source_pcoll)
     p_expected = beam.Pipeline()
 
@@ -488,10 +483,7 @@ class PipelineInstrumentTest(unittest.TestCase):
     # Watch but do not cache the PCollections.
     ib.watch(locals())
 
-    def cache_key_of(name, pcoll):
-      return name + '_' + str(id(pcoll)) + '_' + str(id(pcoll.producer))
-
-    self._mock_write_cache([b''], cache_key_of('source_2', source_2))
+    self._mock_write_cache([b''], self.cache_key_of('source_2', source_2))
     ie.current_env().mark_pcollection_computed([source_2])
 
     # Instrument the original pipeline to create the pipeline the user will 
see.
@@ -507,8 +499,8 @@ class PipelineInstrumentTest(unittest.TestCase):
 
     # Now, build the expected pipeline which replaces the unbounded source with
     # a TestStream.
-    source_1_cache_key = cache_key_of('source_1', source_1)
-    source_2_cache_key = cache_key_of('source_2', source_2)
+    source_1_cache_key = self.cache_key_of('source_1', source_1)
+    source_2_cache_key = self.cache_key_of('source_2', source_2)
     p_expected = beam.Pipeline()
 
     test_stream = (
@@ -516,8 +508,8 @@ class PipelineInstrumentTest(unittest.TestCase):
         | TestStream(output_tags=[source_1_cache_key, source_2_cache_key]))
     # pylint: disable=expression-not-assigned
     ((
-        test_stream[cache_key_of('source_1', source_1)],
-        test_stream[cache_key_of('source_2', source_2)])
+        test_stream[self.cache_key_of('source_1', source_1)],
+        test_stream[self.cache_key_of('source_2', source_2)])
      | beam.Flatten()
      | 'square1' >> beam.Map(lambda x: x * x)
      | 'reify' >> beam.Map(lambda _: _)
@@ -565,9 +557,6 @@ class PipelineInstrumentTest(unittest.TestCase):
     # Watch but do not cache the PCollections.
     ib.watch(locals())
 
-    def cache_key_of(name, pcoll):
-      return name + '_' + str(id(pcoll)) + '_' + str(id(pcoll.producer))
-
     # Instrument the original pipeline to create the pipeline the user will 
see.
     p_copy = beam.Pipeline.from_runner_api(
         p_original.to_runner_api(),
@@ -581,13 +570,13 @@ class PipelineInstrumentTest(unittest.TestCase):
 
     # Now, build the expected pipeline which replaces the unbounded source with
     # a TestStream.
-    source_1_cache_key = cache_key_of('source_1', source_1)
+    source_1_cache_key = self.cache_key_of('source_1', source_1)
     p_expected = beam.Pipeline()
 
     # pylint: disable=unused-variable
     test_stream = (
         p_expected
-        | TestStream(output_tags=[cache_key_of('source_1', source_1)]))
+        | TestStream(output_tags=[self.cache_key_of('source_1', source_1)]))
 
     # Test that the TestStream is outputting to the correct PCollection.
     class TestStreamVisitor(PipelineVisitor):
@@ -632,9 +621,6 @@ class PipelineInstrumentTest(unittest.TestCase):
     # Watch but do not cache the PCollections.
     ib.watch(locals())
 
-    def cache_key_of(name, pcoll):
-      return name + '_' + str(id(pcoll)) + '_' + str(id(pcoll.producer))
-
     # Instrument the original pipeline to create the pipeline the user will 
see.
     p_copy = beam.Pipeline.from_runner_api(
         p_original.to_runner_api(),
@@ -648,11 +634,11 @@ class PipelineInstrumentTest(unittest.TestCase):
 
     # Now, build the expected pipeline which replaces the unbounded source with
     # a TestStream.
-    source_1_cache_key = cache_key_of('source_1', source_1)
+    source_1_cache_key = self.cache_key_of('source_1', source_1)
     p_expected = beam.Pipeline()
     test_stream = (
         p_expected
-        | TestStream(output_tags=[cache_key_of('source_1', source_1)]))
+        | TestStream(output_tags=[self.cache_key_of('source_1', source_1)]))
     # pylint: disable=expression-not-assigned
     (
         test_stream
@@ -705,13 +691,10 @@ class PipelineInstrumentTest(unittest.TestCase):
     # Mock as if cacheable PCollections are cached.
     ib.watch(locals())
 
-    def cache_key_of(name, pcoll):
-      return name + '_' + str(id(pcoll)) + '_' + str(id(pcoll.producer))
-
     for name, pcoll in locals().items():
       if not isinstance(pcoll, beam.pvalue.PCollection):
         continue
-      cache_key = cache_key_of(name, pcoll)
+      cache_key = self.cache_key_of(name, pcoll)
       self._mock_write_cache([b''], cache_key)
 
     # Instrument the original pipeline to create the pipeline the user will 
see.
@@ -723,15 +706,15 @@ class PipelineInstrumentTest(unittest.TestCase):
 
     # Now, build the expected pipeline which replaces the unbounded source with
     # a TestStream.
-    source_1_cache_key = cache_key_of('source_1', source_1)
-    source_2_cache_key = cache_key_of('source_2', source_2)
+    source_1_cache_key = self.cache_key_of('source_1', source_1)
+    source_2_cache_key = self.cache_key_of('source_2', source_2)
     p_expected = beam.Pipeline()
     test_stream = (
         p_expected
         | TestStream(
             output_tags=[
-                cache_key_of('source_1', source_1),
-                cache_key_of('source_2', source_2)
+                self.cache_key_of('source_1', source_1),
+                self.cache_key_of('source_2', source_2)
             ]))
     # pylint: disable=expression-not-assigned
     test_stream[source_1_cache_key] | 'square1' >> beam.Map(lambda x: x * x)
@@ -771,8 +754,7 @@ class PipelineInstrumentTest(unittest.TestCase):
         None)
 
     # Mock as if init_pcoll is cached.
-    init_pcoll_cache_key = 'init_pcoll_' + str(id(init_pcoll)) + '_' + str(
-        id(init_pcoll.producer))
+    init_pcoll_cache_key = self.cache_key_of('init_pcoll', init_pcoll)
     self._mock_write_cache([b'1', b'2', b'3'], init_pcoll_cache_key)
     ie.current_env().mark_pcollection_computed([init_pcoll])
     # Build an instrument from the runner pipeline.

Reply via email to