This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b7e5f45  BEAM[8335] TestStreamService integration with DirectRunner
     new 23c4c20  Merge pull request #10994 from [BEAM-8335] TeststreamService 
integration with DirectRunner
b7e5f45 is described below

commit b7e5f454a74e89c033465ba12c3b0f26bca786fd
Author: Sam Rohde <[email protected]>
AuthorDate: Wed Feb 19 14:47:17 2020 -0800

    BEAM[8335] TestStreamService integration with DirectRunner
    
    Change-Id: I83c7c9221e118d49a9005707650803c28630fe2a
---
 .../apache_beam/runners/direct/direct_runner.py    |  14 +--
 .../apache_beam/runners/direct/test_stream_impl.py |  74 ++++++++++--
 .../runners/direct/transform_evaluator.py          |  62 ++++++----
 .../interactive/background_caching_job_test.py     |   4 +-
 sdks/python/apache_beam/testing/test_stream.py     |  50 ++++++--
 .../apache_beam/testing/test_stream_service.py     |  18 ++-
 .../testing/test_stream_service_test.py            |  44 +++++--
 .../python/apache_beam/testing/test_stream_test.py | 126 ++++++++++++++++-----
 8 files changed, 296 insertions(+), 96 deletions(-)

diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py 
b/sdks/python/apache_beam/runners/direct/direct_runner.py
index f207779..5181ad7 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -376,25 +376,25 @@ class BundleBasedDirectRunner(PipelineRunner):
     from apache_beam.runners.direct.executor import Executor
     from apache_beam.runners.direct.transform_evaluator import \
       TransformEvaluatorRegistry
-    from apache_beam.runners.direct.test_stream_impl import _TestStream
-
-    # Performing configured PTransform overrides.
-    pipeline.replace_all(_get_transform_overrides(options))
+    from apache_beam.testing.test_stream import TestStream
 
     # If the TestStream I/O is used, use a mock test clock.
-    class _TestStreamUsageVisitor(PipelineVisitor):
+    class TestStreamUsageVisitor(PipelineVisitor):
       """Visitor determining whether a Pipeline uses a TestStream."""
       def __init__(self):
         self.uses_test_stream = False
 
       def visit_transform(self, applied_ptransform):
-        if isinstance(applied_ptransform.transform, _TestStream):
+        if isinstance(applied_ptransform.transform, TestStream):
           self.uses_test_stream = True
 
-    visitor = _TestStreamUsageVisitor()
+    visitor = TestStreamUsageVisitor()
     pipeline.visit(visitor)
     clock = TestClock() if visitor.uses_test_stream else RealClock()
 
+    # Performing configured PTransform overrides.
+    pipeline.replace_all(_get_transform_overrides(options))
+
     _LOGGER.info('Running pipeline with DirectRunner.')
     self.consumer_tracking_visitor = ConsumerTrackingPipelineVisitor()
     pipeline.visit(self.consumer_tracking_visitor)
diff --git a/sdks/python/apache_beam/runners/direct/test_stream_impl.py 
b/sdks/python/apache_beam/runners/direct/test_stream_impl.py
index 5f325d5..e5ddc0a 100644
--- a/sdks/python/apache_beam/runners/direct/test_stream_impl.py
+++ b/sdks/python/apache_beam/runners/direct/test_stream_impl.py
@@ -27,14 +27,25 @@ tagged PCollection.
 
 from __future__ import absolute_import
 
+import itertools
+
+import grpc
+
 from apache_beam import ParDo
 from apache_beam import coders
 from apache_beam import pvalue
+from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.portability.api import beam_runner_api_pb2_grpc
+from apache_beam.testing.test_stream import ElementEvent
+from apache_beam.testing.test_stream import ProcessingTimeEvent
 from apache_beam.testing.test_stream import WatermarkEvent
 from apache_beam.transforms import PTransform
 from apache_beam.transforms import core
 from apache_beam.transforms import window
+from apache_beam.transforms.window import TimestampedValue
 from apache_beam.utils import timestamp
+from apache_beam.utils.timestamp import Duration
+from apache_beam.utils.timestamp import Timestamp
 
 
 class _WatermarkController(PTransform):
@@ -79,7 +90,8 @@ class _ExpandableTestStream(PTransform):
           | _TestStream(
               self.test_stream.output_tags,
               events=self.test_stream._events,
-              coder=self.test_stream.coder)
+              coder=self.test_stream.coder,
+              endpoint=self.test_stream._endpoint)
           | _WatermarkController(list(self.test_stream.output_tags)[0]))
 
     # Multiplex to the correct PCollection based upon the event tag.
@@ -131,12 +143,17 @@ class _TestStream(PTransform):
   WATERMARK_CONTROL_TAG = '_TestStream_Watermark'
 
   def __init__(
-      self, output_tags, coder=coders.FastPrimitivesCoder(), events=None):
+      self,
+      output_tags,
+      coder=coders.FastPrimitivesCoder(),
+      events=None,
+      endpoint=None):
     assert coder is not None
     self.coder = coder
     self._raw_events = events
     self._events = self._add_watermark_advancements(output_tags, events)
     self.output_tags = output_tags
+    self.endpoint = endpoint
 
   def _watermark_starts(self, output_tags):
     """Sentinel values to hold the watermark of outputs to -inf.
@@ -226,17 +243,50 @@ class _TestStream(PTransform):
   def _infer_output_coder(self, input_type=None, input_coder=None):
     return self.coder
 
-  def _events_from_script(self, index):
-    yield self._events[index]
+  @staticmethod
+  def events_from_script(events):
+    """Yields the in-memory events.
+    """
+    return itertools.chain(events)
 
-  def events(self, index):
-    return self._events_from_script(index)
+  @staticmethod
+  def events_from_rpc(endpoint, output_tags, coder):
+    """Yields the events received from the given endpoint.
+    """
+    stub_channel = grpc.insecure_channel(endpoint)
+    stub = beam_runner_api_pb2_grpc.TestStreamServiceStub(stub_channel)
 
-  def begin(self):
-    return 0
+    # Request the PCollections that we are looking for from the service.
+    event_request = beam_runner_api_pb2.EventsRequest(
+        output_ids=[str(tag) for tag in output_tags])
 
-  def end(self, index):
-    return index >= len(self._events)
+    event_stream = stub.Events(event_request)
+    for e in event_stream:
+      yield _TestStream.test_stream_payload_to_events(e, coder)
 
-  def next(self, index):
-    return index + 1
+  @staticmethod
+  def test_stream_payload_to_events(payload, coder):
+    """Returns a TestStream Python event object from a TestStream event Proto.
+    """
+    if payload.HasField('element_event'):
+      element_event = payload.element_event
+      elements = [
+          TimestampedValue(
+              coder.decode(e.encoded_element), Timestamp(micros=e.timestamp))
+          for e in element_event.elements
+      ]
+      return ElementEvent(timestamped_values=elements, tag=element_event.tag)
+
+    if payload.HasField('watermark_event'):
+      watermark_event = payload.watermark_event
+      return WatermarkEvent(
+          Timestamp(micros=watermark_event.new_watermark),
+          tag=watermark_event.tag)
+
+    if payload.HasField('processing_time_event'):
+      processing_time_event = payload.processing_time_event
+      return ProcessingTimeEvent(
+          Duration(micros=processing_time_event.advance_duration))
+
+    raise RuntimeError(
+        'Received a proto without the specified fields: {}'.format(payload))
diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py 
b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
index acec67a..474ad0d 100644
--- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
@@ -61,6 +61,7 @@ from apache_beam.testing.test_stream import PairWithTiming
 from apache_beam.testing.test_stream import ProcessingTimeEvent
 from apache_beam.testing.test_stream import TimingInfo
 from apache_beam.testing.test_stream import WatermarkEvent
+from apache_beam.testing.test_stream import WindowedValueHolder
 from apache_beam.transforms import core
 from apache_beam.transforms.trigger import InMemoryUnmergedState
 from apache_beam.transforms.trigger import TimeDomain
@@ -214,11 +215,19 @@ class _TestStreamRootBundleProvider(RootBundleProvider):
   """
   def get_root_bundles(self):
     test_stream = self._applied_ptransform.transform
+
+    # If there was an endpoint defined then get the events from the
+    # TestStreamService.
+    if test_stream.endpoint:
+      _TestStreamEvaluator.event_stream = _TestStream.events_from_rpc(
+          test_stream.endpoint, test_stream.output_tags, test_stream.coder)
+    else:
+      _TestStreamEvaluator.event_stream = (
+          _TestStream.events_from_script(test_stream._events))
+
     bundle = self._evaluation_context.create_bundle(
         pvalue.PBegin(self._applied_ptransform.transform.pipeline))
-    bundle.add(
-        GlobalWindows.windowed_value(
-            test_stream.begin(), timestamp=MIN_TIMESTAMP))
+    bundle.add(GlobalWindows.windowed_value(b'', timestamp=MIN_TIMESTAMP))
     bundle.commit(None)
     return [bundle]
 
@@ -424,9 +433,9 @@ class _WatermarkControllerEvaluator(_TransformEvaluator):
       bundle = self._evaluation_context.create_bundle(main_output)
       for tv in event.timestamped_values:
         # Unreify the value into the correct window.
-        try:
-          bundle.output(WindowedValue(**tv.value))
-        except TypeError:
+        if isinstance(tv.value, WindowedValueHolder):
+          bundle.output(tv.value.windowed_value)
+        else:
           bundle.output(
               GlobalWindows.windowed_value(tv.value, timestamp=tv.timestamp))
       self.bundles.append(bundle)
@@ -470,8 +479,11 @@ class _PairWithTimingEvaluator(_TransformEvaluator):
     self.timing_info = TimingInfo(now, output_watermark)
 
   def process_element(self, element):
-    element.value = (element.value, self.timing_info)
-    self.bundle.output(element)
+    result = WindowedValue((element.value, self.timing_info),
+                           element.timestamp,
+                           element.windows,
+                           element.pane_info)
+    self.bundle.output(result)
 
   def finish_bundle(self):
     return TransformResult(self, [self.bundle], [], None, {})
@@ -487,6 +499,9 @@ class _TestStreamEvaluator(_TransformEvaluator):
   The _WatermarkController is in charge of emitting the elements to the
   downstream consumers and setting its own output watermark.
   """
+
+  event_stream = None
+
   def __init__(
       self,
       evaluation_context,
@@ -500,23 +515,16 @@ class _TestStreamEvaluator(_TransformEvaluator):
         input_committed_bundle,
         side_inputs)
     self.test_stream = applied_ptransform.transform
+    self.is_done = False
 
   def start_bundle(self):
-    self.current_index = 0
     self.bundles = []
     self.watermark = MIN_TIMESTAMP
 
   def process_element(self, element):
-    # The index into the TestStream list of events.
-    self.current_index = element.value
-
     # The watermark of the _TestStream transform itself.
     self.watermark = element.timestamp
 
-    # We can either have the _TestStream or the _WatermarkController to emit
-    # the elements. We chose to emit in the _WatermarkController so that the
-    # element is emitted at the correct watermark value.
-
     # Set up the correct watermark holds in the Watermark controllers and the
     # TestStream so that the watermarks will not automatically advance to +inf
     # when elements start streaming. This can happen multiple times in the 
first
@@ -527,9 +535,20 @@ class _TestStreamEvaluator(_TransformEvaluator):
       for event in self.test_stream._set_up(self.test_stream.output_tags):
         events.append(event)
 
-    events += [e for e in self.test_stream.events(self.current_index)]
+    # Retrieve the TestStream's event stream and read from it.
+    try:
+      events.append(next(self.event_stream))
+    except StopIteration:
+      # Advance the watermarks to +inf to cleanly stop the pipeline.
+      self.is_done = True
+      events += ([
+          e for e in self.test_stream._tear_down(self.test_stream.output_tags)
+      ])
 
     for event in events:
+      # We can either have the _TestStream or the _WatermarkController to emit
+      # the elements. We chose to emit in the _WatermarkController so that the
+      # element is emitted at the correct watermark value.
       if isinstance(event, (ElementEvent, WatermarkEvent)):
         # The WATERMARK_CONTROL_TAG is used to hold the _TestStream's
         # watermark to -inf, then +inf-1, then +inf. This watermark progression
@@ -550,12 +569,15 @@ class _TestStreamEvaluator(_TransformEvaluator):
 
   def finish_bundle(self):
     unprocessed_bundles = []
-    next_index = self.test_stream.next(self.current_index)
-    if not self.test_stream.end(next_index):
+
+    # Continue to send its own state to itself via an unprocessed bundle. This
+    # acts as a heartbeat, where each element will read the next event from the
+    # event stream.
+    if not self.is_done:
       unprocessed_bundle = self._evaluation_context.create_bundle(
           pvalue.PBegin(self._applied_ptransform.transform.pipeline))
       unprocessed_bundle.add(
-          GlobalWindows.windowed_value(next_index, timestamp=self.watermark))
+          GlobalWindows.windowed_value(b'', timestamp=self.watermark))
       unprocessed_bundles.append(unprocessed_bundle)
 
     # Returning the watermark in the dict here is used as a watermark hold.
diff --git 
a/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py 
b/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py
index 0f76874..e5b6b52 100644
--- a/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py
+++ b/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py
@@ -267,14 +267,14 @@ class BackgroundCachingJobTest(unittest.TestCase):
 
   def test_determine_a_test_stream_service_running(self):
     pipeline = _build_an_empty_stream_pipeline()
-    test_stream_service = TestStreamServiceController(events=iter([]))
+    test_stream_service = TestStreamServiceController(reader=None)
     ie.current_env().set_test_stream_service_controller(
         pipeline, test_stream_service)
     self.assertTrue(bcj.is_a_test_stream_service_running(pipeline))
 
   def test_stop_a_running_test_stream_service(self):
     pipeline = _build_an_empty_stream_pipeline()
-    test_stream_service = TestStreamServiceController(events=iter([]))
+    test_stream_service = TestStreamServiceController(reader=None)
     test_stream_service.start()
     ie.current_env().set_test_stream_service_controller(
         pipeline, test_stream_service)
diff --git a/sdks/python/apache_beam/testing/test_stream.py 
b/sdks/python/apache_beam/testing/test_stream.py
index 6d95ede..7719b7e 100644
--- a/sdks/python/apache_beam/testing/test_stream.py
+++ b/sdks/python/apache_beam/testing/test_stream.py
@@ -39,6 +39,7 @@ from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.portability.api.beam_interactive_api_pb2 import 
TestStreamFileHeader
 from apache_beam.portability.api.beam_interactive_api_pb2 import 
TestStreamFileRecord
 from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.portability.api.endpoints_pb2 import ApiServiceDescriptor
 from apache_beam.transforms import PTransform
 from apache_beam.transforms import core
 from apache_beam.transforms import window
@@ -92,16 +93,14 @@ class Event(with_metaclass(ABCMeta, object)):  # type: 
ignore[misc]
       return ElementEvent([
           TimestampedValue(
               element_coder.decode(tv.encoded_element),
-              timestamp.Timestamp(micros=1000 * tv.timestamp))
+              Timestamp(micros=1000 * tv.timestamp))
           for tv in proto.element_event.elements
       ], tag=tag) # yapf: disable
     elif proto.HasField('watermark_event'):
       event = proto.watermark_event
       tag = None if event.tag == 'None' else event.tag
       return WatermarkEvent(
-          timestamp.Timestamp(
-              micros=1000 * proto.watermark_event.new_watermark),
-          tag=tag)
+          Timestamp(micros=1000 * proto.watermark_event.new_watermark), 
tag=tag)
     elif proto.HasField('processing_time_event'):
       return ProcessingTimeEvent(
           timestamp.Duration(
@@ -155,7 +154,7 @@ class ElementEvent(Event):
 class WatermarkEvent(Event):
   """Watermark-advancing test stream event."""
   def __init__(self, new_watermark, tag=None):
-    self.new_watermark = timestamp.Timestamp.of(new_watermark)
+    self.new_watermark = Timestamp.of(new_watermark)
     self.tag = tag
 
   def __eq__(self, other):
@@ -217,6 +216,17 @@ class ProcessingTimeEvent(Event):
     return 'ProcessingTimeEvent: <{}>'.format(self.advance_by)
 
 
+class WindowedValueHolder:
+  """A class that holds a WindowedValue.
+
+  This is a special class that can be used by the runner that implements the
+  TestStream as a signal that the underlying value should be unreified to the
+  specified window.
+  """
+  def __init__(self, windowed_value):
+    self.windowed_value = windowed_value
+
+
 class TestStream(PTransform):
   """Test stream that generates events on an unbounded PCollection of elements.
 
@@ -236,20 +246,36 @@ class TestStream(PTransform):
 
   """
   def __init__(
-      self, coder=coders.FastPrimitivesCoder(), events=None, output_tags=None):
+      self,
+      coder=coders.FastPrimitivesCoder(),
+      events=None,
+      output_tags=None,
+      endpoint=None):
+    """TestStream constructor.
+
+    Args:
+      coder: (apache_beam.Coder) the element coder for any ElementEvents.
+      events: (List[Event]) a list of instructions for the TestStream to
+        execute. This must be a subset of the given output_tags.
+      output_tags: (List[str]) a list of PCollection output tags.
+      endpoint: (str) a URL locating a TestStreamService.
+    """
     super(TestStream, self).__init__()
     assert coder is not None
 
     self.coder = coder
     self.watermarks = {None: timestamp.MIN_TIMESTAMP}
-    self._events = [] if events is None else list(events)
     self.output_tags = set(output_tags) if output_tags else set()
+    self._events = [] if events is None else list(events)
+    self._endpoint = endpoint
 
     event_tags = set(
         e.tag for e in self._events
         if isinstance(e, (WatermarkEvent, ElementEvent)))
     assert event_tags.issubset(self.output_tags), \
         '{} is not a subset of {}'.format(event_tags, output_tags)
+    assert not (self._events and self._endpoint), \
+        'Only either events or an endpoint can be given at once.'
 
   def get_windowing(self, unused_inputs):
     return core.Windowing(window.GlobalWindows())
@@ -354,7 +380,8 @@ class TestStream(PTransform):
         common_urns.primitives.TEST_STREAM.urn,
         beam_runner_api_pb2.TestStreamPayload(
             coder_id=context.coders.get_id(self.coder),
-            events=[e.to_runner_api(self.coder) for e in self._events]))
+            events=[e.to_runner_api(self.coder) for e in self._events],
+            endpoint=ApiServiceDescriptor(url=self._endpoint)))
 
   @staticmethod
   @PTransform.register_urn(
@@ -367,13 +394,14 @@ class TestStream(PTransform):
     return TestStream(
         coder=coder,
         events=[Event.from_runner_api(e, coder) for e in payload.events],
-        output_tags=output_tags)
+        output_tags=output_tags,
+        endpoint=payload.endpoint.url)
 
 
 class TimingInfo(object):
   def __init__(self, processing_time, watermark):
-    self._processing_time = timestamp.Timestamp.of(processing_time)
-    self._watermark = timestamp.Timestamp.of(watermark)
+    self._processing_time = Timestamp.of(processing_time)
+    self._watermark = Timestamp.of(watermark)
 
   @property
   def processing_time(self):
diff --git a/sdks/python/apache_beam/testing/test_stream_service.py 
b/sdks/python/apache_beam/testing/test_stream_service.py
index 28ec9ca..694f61b 100644
--- a/sdks/python/apache_beam/testing/test_stream_service.py
+++ b/sdks/python/apache_beam/testing/test_stream_service.py
@@ -28,19 +28,23 @@ from apache_beam.portability.api.beam_runner_api_pb2_grpc 
import TestStreamServi
 
 
 class TestStreamServiceController(TestStreamServiceServicer):
-  def __init__(self, events, endpoint=None):
+  """A server that streams TestStreamPayload.Events from a single EventRequest.
+
+  This server is used as a way for TestStreams to receive events from file.
+  """
+  def __init__(self, reader, endpoint=None):
     self._server = grpc.server(ThreadPoolExecutor(max_workers=10))
 
     if endpoint:
       self.endpoint = endpoint
       self._server.add_insecure_port(self.endpoint)
     else:
-      port = self._server.add_insecure_port('[::]:0')
-      self.endpoint = '[::]:{}'.format(port)
+      port = self._server.add_insecure_port('localhost:0')
+      self.endpoint = 'localhost:{}'.format(port)
 
     beam_runner_api_pb2_grpc.add_TestStreamServiceServicer_to_server(
         self, self._server)
-    self._events = events
+    self._reader = reader
 
   def start(self):
     self._server.start()
@@ -52,5 +56,9 @@ class TestStreamServiceController(TestStreamServiceServicer):
   def Events(self, request, context):
     """Streams back all of the events from the streaming cache."""
 
-    for e in self._events:
+    # TODO(srohde): Once we get rid of the CacheManager, get rid of this 'full'
+    # label.
+    tags = [None if tag == 'None' else tag for tag in request.output_ids]
+    reader = self._reader.read_multiple([('full', tag) for tag in tags])
+    for e in reader:
       yield e
diff --git a/sdks/python/apache_beam/testing/test_stream_service_test.py 
b/sdks/python/apache_beam/testing/test_stream_service_test.py
index e1aebb4..01b16a1 100644
--- a/sdks/python/apache_beam/testing/test_stream_service_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_service_test.py
@@ -38,18 +38,31 @@ TestStreamFileHeader.__test__ = False  # type: 
ignore[attr-defined]
 TestStreamFileRecord.__test__ = False  # type: ignore[attr-defined]
 
 
-class TestStreamServiceTest(unittest.TestCase):
-  def events(self):
-    events = []
+class EventsReader:
+  def __init__(self, expected_key):
+    self._expected_key = expected_key
+
+  def read_multiple(self, keys):
+    if keys != self._expected_key:
+      raise ValueError(
+          'Expected key ({}) is not argument({})'.format(
+              self._expected_key, keys))
+
     for i in range(10):
       e = TestStreamPayload.Event()
       e.element_event.elements.append(
           TestStreamPayload.TimestampedElement(timestamp=i))
-      events.append(e)
-    return events
+      yield e
+
 
+EXPECTED_KEY = 'key'
+EXPECTED_KEYS = [EXPECTED_KEY]
+
+
+class TestStreamServiceTest(unittest.TestCase):
   def setUp(self):
-    self.controller = TestStreamServiceController(self.events())
+    self.controller = TestStreamServiceController(
+        EventsReader(expected_key=[('full', EXPECTED_KEY)]))
     self.controller.start()
 
     channel = grpc.insecure_channel(self.controller.endpoint)
@@ -59,15 +72,21 @@ class TestStreamServiceTest(unittest.TestCase):
     self.controller.stop()
 
   def test_normal_run(self):
-    r = self.stub.Events(beam_runner_api_pb2.EventsRequest())
+    r = self.stub.Events(
+        beam_runner_api_pb2.EventsRequest(output_ids=EXPECTED_KEYS))
     events = [e for e in r]
-    expected_events = [e for e in self.events()]
+    expected_events = [
+        e for e in EventsReader(
+            expected_key=[EXPECTED_KEYS]).read_multiple([EXPECTED_KEYS])
+    ]
 
     self.assertEqual(events, expected_events)
 
   def test_multiple_sessions(self):
-    resp_a = self.stub.Events(beam_runner_api_pb2.EventsRequest())
-    resp_b = self.stub.Events(beam_runner_api_pb2.EventsRequest())
+    resp_a = self.stub.Events(
+        beam_runner_api_pb2.EventsRequest(output_ids=EXPECTED_KEYS))
+    resp_b = self.stub.Events(
+        beam_runner_api_pb2.EventsRequest(output_ids=EXPECTED_KEYS))
 
     events_a = []
     events_b = []
@@ -88,7 +107,10 @@ class TestStreamServiceTest(unittest.TestCase):
 
       done = a_is_done and b_is_done
 
-    expected_events = [e for e in self.events()]
+    expected_events = [
+        e for e in EventsReader(
+            expected_key=[EXPECTED_KEYS]).read_multiple([EXPECTED_KEYS])
+    ]
 
     self.assertEqual(events_a, expected_events)
     self.assertEqual(events_b, expected_events)
diff --git a/sdks/python/apache_beam/testing/test_stream_test.py 
b/sdks/python/apache_beam/testing/test_stream_test.py
index 650953f..e0ee91b 100644
--- a/sdks/python/apache_beam/testing/test_stream_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_test.py
@@ -38,6 +38,8 @@ from apache_beam.testing.test_stream import 
ProcessingTimeEvent
 from apache_beam.testing.test_stream import ReverseTestStream
 from apache_beam.testing.test_stream import TestStream
 from apache_beam.testing.test_stream import WatermarkEvent
+from apache_beam.testing.test_stream import WindowedValueHolder
+from apache_beam.testing.test_stream_service import TestStreamServiceController
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 from apache_beam.testing.util import equal_to_per_window
@@ -305,16 +307,15 @@ class TestStreamTest(unittest.TestCase):
           ]))
 
   def test_windowed_values_interpreted_correctly(self):
-    windowed_value_args = {
-        'value': 'a',
-        'timestamp': Timestamp(5),
-        'windows': [beam.window.IntervalWindow(5, 10)],
-        'pane_info': PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0)
-    }
+    windowed_value = WindowedValueHolder(
+        WindowedValue(
+            'a',
+            Timestamp(5), [beam.window.IntervalWindow(5, 10)],
+            PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0)))
     test_stream = (TestStream()
                    .advance_processing_time(10)
                    .advance_watermark_to(10)
-                   .add_elements([windowed_value_args])
+                   .add_elements([windowed_value])
                    .advance_watermark_to_infinity())  # yapf: disable
 
     class RecordFn(beam.DoFn):
@@ -491,20 +492,23 @@ class TestStreamTest(unittest.TestCase):
 
   def test_basic_execution_sideinputs(self):
     options = PipelineOptions()
+    options.view_as(DebugOptions).add_experiment(
+        'passthrough_pcollection_output_ids')
     options.view_as(StandardOptions).streaming = True
     with TestPipeline(options=options) as p:
 
-      main_stream = (
-          p
-          | 'main TestStream' >>
-          TestStream().advance_watermark_to(10).add_elements(['e']))
-      side_stream = (
-          p
-          | 'side TestStream' >> TestStream().add_elements([
-              window.TimestampedValue(2, 2)
-          ]).add_elements([window.TimestampedValue(1, 1)]).add_elements([
-              window.TimestampedValue(7, 7)
-          ]).add_elements([window.TimestampedValue(4, 4)]))
+      test_stream = (p | TestStream()
+          .advance_watermark_to(0, tag='side')
+          .advance_watermark_to(10, tag='main')
+          .add_elements(['e'], tag='main')
+          .add_elements([window.TimestampedValue(2, 2)], tag='side')
+          .add_elements([window.TimestampedValue(1, 1)], tag='side')
+          .add_elements([window.TimestampedValue(7, 7)], tag='side')
+          .add_elements([window.TimestampedValue(4, 4)], tag='side')
+          ) # yapf: disable
+
+      main_stream = test_stream['main']
+      side_stream = test_stream['side']
 
       class RecordFn(beam.DoFn):
         def process(
@@ -564,22 +568,30 @@ class TestStreamTest(unittest.TestCase):
 
   def test_basic_execution_sideinputs_fixed_windows(self):
     options = PipelineOptions()
+    options.view_as(DebugOptions).add_experiment(
+        'passthrough_pcollection_output_ids')
     options.view_as(StandardOptions).streaming = True
     p = TestPipeline(options=options)
 
+    test_stream = (p | TestStream()
+        .advance_watermark_to(12, tag='side')
+        .add_elements([window.TimestampedValue('s1', 10)], tag='side')
+        .advance_watermark_to(20, tag='side')
+        .add_elements([window.TimestampedValue('s2', 20)], tag='side')
+
+        .advance_watermark_to(9, tag='main')
+        .add_elements(['a1', 'a2', 'a3', 'a4'], tag='main')
+        .add_elements(['b'], tag='main')
+        .advance_watermark_to(18, tag='main')
+        .add_elements('c', tag='main')
+        ) # yapf: disable
+
     main_stream = (
-        p
-        |
-        'main TestStream' >> 
TestStream().advance_watermark_to(9).add_elements([
-            'a1', 'a2', 'a3', 'a4'
-        ]).add_elements(['b']).advance_watermark_to(18).add_elements('c')
+        test_stream['main']
         | 'main windowInto' >> beam.WindowInto(window.FixedWindows(1)))
+
     side_stream = (
-        p
-        |
-        'side TestStream' >> 
TestStream().advance_watermark_to(12).add_elements(
-            [window.TimestampedValue('s1', 10)]).advance_watermark_to(
-                20).add_elements([window.TimestampedValue('s2', 20)])
+        test_stream['side']
         | 'side windowInto' >> beam.WindowInto(window.FixedWindows(3)))
 
     class RecordFn(beam.DoFn):
@@ -664,6 +676,64 @@ class TestStreamTest(unittest.TestCase):
         test_stream.output_tags, roundtrip_test_stream.output_tags)
     self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
 
+  def test_basic_execution_with_service(self):
+    """Tests that the TestStream can correctly read from an RPC service.
+    """
+    coder = beam.coders.FastPrimitivesCoder()
+
+    test_stream_events = (TestStream(coder=coder)
+        .advance_watermark_to(10000)
+        .add_elements(['a', 'b', 'c'])
+        .advance_watermark_to(20000)
+        .add_elements(['d'])
+        .add_elements(['e'])
+        .advance_processing_time(10)
+        .advance_watermark_to(300000)
+        .add_elements([TimestampedValue('late', 12000)])
+        .add_elements([TimestampedValue('last', 310000)])
+        .advance_watermark_to_infinity())._events  # yapf: disable
+
+    test_stream_proto_events = [
+        e.to_runner_api(coder) for e in test_stream_events
+    ]
+
+    class InMemoryEventReader:
+      def read_multiple(self, unused_keys):
+        for e in test_stream_proto_events:
+          yield e
+
+    service = TestStreamServiceController(reader=InMemoryEventReader())
+    service.start()
+
+    test_stream = TestStream(coder=coder, endpoint=service.endpoint)
+
+    class RecordFn(beam.DoFn):
+      def process(
+          self,
+          element=beam.DoFn.ElementParam,
+          timestamp=beam.DoFn.TimestampParam):
+        yield (element, timestamp)
+
+    options = StandardOptions(streaming=True)
+
+    p = TestPipeline(options=options)
+    my_record_fn = RecordFn()
+    records = p | test_stream | beam.ParDo(my_record_fn)
+
+    assert_that(
+        records,
+        equal_to([
+            ('a', timestamp.Timestamp(10)),
+            ('b', timestamp.Timestamp(10)),
+            ('c', timestamp.Timestamp(10)),
+            ('d', timestamp.Timestamp(20)),
+            ('e', timestamp.Timestamp(20)),
+            ('late', timestamp.Timestamp(12)),
+            ('last', timestamp.Timestamp(310)),
+        ]))
+
+    p.run()
+
 
 class ReverseTestStreamTest(unittest.TestCase):
   def test_basic_execution(self):

Reply via email to