This is an automated email from the ASF dual-hosted git repository.
pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new b7e5f45 BEAM[8335] TestStreamService integration with DirectRunner
new 23c4c20 Merge pull request #10994 from [BEAM-8335] TeststreamService
integration with DirectRunner
b7e5f45 is described below
commit b7e5f454a74e89c033465ba12c3b0f26bca786fd
Author: Sam Rohde <[email protected]>
AuthorDate: Wed Feb 19 14:47:17 2020 -0800
BEAM[8335] TestStreamService integration with DirectRunner
Change-Id: I83c7c9221e118d49a9005707650803c28630fe2a
---
.../apache_beam/runners/direct/direct_runner.py | 14 +--
.../apache_beam/runners/direct/test_stream_impl.py | 74 ++++++++++--
.../runners/direct/transform_evaluator.py | 62 ++++++----
.../interactive/background_caching_job_test.py | 4 +-
sdks/python/apache_beam/testing/test_stream.py | 50 ++++++--
.../apache_beam/testing/test_stream_service.py | 18 ++-
.../testing/test_stream_service_test.py | 44 +++++--
.../python/apache_beam/testing/test_stream_test.py | 126 ++++++++++++++++-----
8 files changed, 296 insertions(+), 96 deletions(-)
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py
b/sdks/python/apache_beam/runners/direct/direct_runner.py
index f207779..5181ad7 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -376,25 +376,25 @@ class BundleBasedDirectRunner(PipelineRunner):
from apache_beam.runners.direct.executor import Executor
from apache_beam.runners.direct.transform_evaluator import \
TransformEvaluatorRegistry
- from apache_beam.runners.direct.test_stream_impl import _TestStream
-
- # Performing configured PTransform overrides.
- pipeline.replace_all(_get_transform_overrides(options))
+ from apache_beam.testing.test_stream import TestStream
# If the TestStream I/O is used, use a mock test clock.
- class _TestStreamUsageVisitor(PipelineVisitor):
+ class TestStreamUsageVisitor(PipelineVisitor):
"""Visitor determining whether a Pipeline uses a TestStream."""
def __init__(self):
self.uses_test_stream = False
def visit_transform(self, applied_ptransform):
- if isinstance(applied_ptransform.transform, _TestStream):
+ if isinstance(applied_ptransform.transform, TestStream):
self.uses_test_stream = True
- visitor = _TestStreamUsageVisitor()
+ visitor = TestStreamUsageVisitor()
pipeline.visit(visitor)
clock = TestClock() if visitor.uses_test_stream else RealClock()
+ # Performing configured PTransform overrides.
+ pipeline.replace_all(_get_transform_overrides(options))
+
_LOGGER.info('Running pipeline with DirectRunner.')
self.consumer_tracking_visitor = ConsumerTrackingPipelineVisitor()
pipeline.visit(self.consumer_tracking_visitor)
diff --git a/sdks/python/apache_beam/runners/direct/test_stream_impl.py
b/sdks/python/apache_beam/runners/direct/test_stream_impl.py
index 5f325d5..e5ddc0a 100644
--- a/sdks/python/apache_beam/runners/direct/test_stream_impl.py
+++ b/sdks/python/apache_beam/runners/direct/test_stream_impl.py
@@ -27,14 +27,25 @@ tagged PCollection.
from __future__ import absolute_import
+import itertools
+
+import grpc
+
from apache_beam import ParDo
from apache_beam import coders
from apache_beam import pvalue
+from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.portability.api import beam_runner_api_pb2_grpc
+from apache_beam.testing.test_stream import ElementEvent
+from apache_beam.testing.test_stream import ProcessingTimeEvent
from apache_beam.testing.test_stream import WatermarkEvent
from apache_beam.transforms import PTransform
from apache_beam.transforms import core
from apache_beam.transforms import window
+from apache_beam.transforms.window import TimestampedValue
from apache_beam.utils import timestamp
+from apache_beam.utils.timestamp import Duration
+from apache_beam.utils.timestamp import Timestamp
class _WatermarkController(PTransform):
@@ -79,7 +90,8 @@ class _ExpandableTestStream(PTransform):
| _TestStream(
self.test_stream.output_tags,
events=self.test_stream._events,
- coder=self.test_stream.coder)
+ coder=self.test_stream.coder,
+ endpoint=self.test_stream._endpoint)
| _WatermarkController(list(self.test_stream.output_tags)[0]))
# Multiplex to the correct PCollection based upon the event tag.
@@ -131,12 +143,17 @@ class _TestStream(PTransform):
WATERMARK_CONTROL_TAG = '_TestStream_Watermark'
def __init__(
- self, output_tags, coder=coders.FastPrimitivesCoder(), events=None):
+ self,
+ output_tags,
+ coder=coders.FastPrimitivesCoder(),
+ events=None,
+ endpoint=None):
assert coder is not None
self.coder = coder
self._raw_events = events
self._events = self._add_watermark_advancements(output_tags, events)
self.output_tags = output_tags
+ self.endpoint = endpoint
def _watermark_starts(self, output_tags):
"""Sentinel values to hold the watermark of outputs to -inf.
@@ -226,17 +243,50 @@ class _TestStream(PTransform):
def _infer_output_coder(self, input_type=None, input_coder=None):
return self.coder
- def _events_from_script(self, index):
- yield self._events[index]
+ @staticmethod
+ def events_from_script(events):
+ """Yields the in-memory events.
+ """
+ return itertools.chain(events)
- def events(self, index):
- return self._events_from_script(index)
+ @staticmethod
+ def events_from_rpc(endpoint, output_tags, coder):
+ """Yields the events received from the given endpoint.
+ """
+ stub_channel = grpc.insecure_channel(endpoint)
+ stub = beam_runner_api_pb2_grpc.TestStreamServiceStub(stub_channel)
- def begin(self):
- return 0
+ # Request the PCollections that we are looking for from the service.
+ event_request = beam_runner_api_pb2.EventsRequest(
+ output_ids=[str(tag) for tag in output_tags])
- def end(self, index):
- return index >= len(self._events)
+ event_stream = stub.Events(event_request)
+ for e in event_stream:
+ yield _TestStream.test_stream_payload_to_events(e, coder)
- def next(self, index):
- return index + 1
+ @staticmethod
+ def test_stream_payload_to_events(payload, coder):
+ """Returns a TestStream Python event object from a TestStream event Proto.
+ """
+ if payload.HasField('element_event'):
+ element_event = payload.element_event
+ elements = [
+ TimestampedValue(
+ coder.decode(e.encoded_element), Timestamp(micros=e.timestamp))
+ for e in element_event.elements
+ ]
+ return ElementEvent(timestamped_values=elements, tag=element_event.tag)
+
+ if payload.HasField('watermark_event'):
+ watermark_event = payload.watermark_event
+ return WatermarkEvent(
+ Timestamp(micros=watermark_event.new_watermark),
+ tag=watermark_event.tag)
+
+ if payload.HasField('processing_time_event'):
+ processing_time_event = payload.processing_time_event
+ return ProcessingTimeEvent(
+ Duration(micros=processing_time_event.advance_duration))
+
+ raise RuntimeError(
+ 'Received a proto without the specified fields: {}'.format(payload))
diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
index acec67a..474ad0d 100644
--- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
@@ -61,6 +61,7 @@ from apache_beam.testing.test_stream import PairWithTiming
from apache_beam.testing.test_stream import ProcessingTimeEvent
from apache_beam.testing.test_stream import TimingInfo
from apache_beam.testing.test_stream import WatermarkEvent
+from apache_beam.testing.test_stream import WindowedValueHolder
from apache_beam.transforms import core
from apache_beam.transforms.trigger import InMemoryUnmergedState
from apache_beam.transforms.trigger import TimeDomain
@@ -214,11 +215,19 @@ class _TestStreamRootBundleProvider(RootBundleProvider):
"""
def get_root_bundles(self):
test_stream = self._applied_ptransform.transform
+
+ # If there was an endpoint defined then get the events from the
+ # TestStreamService.
+ if test_stream.endpoint:
+ _TestStreamEvaluator.event_stream = _TestStream.events_from_rpc(
+ test_stream.endpoint, test_stream.output_tags, test_stream.coder)
+ else:
+ _TestStreamEvaluator.event_stream = (
+ _TestStream.events_from_script(test_stream._events))
+
bundle = self._evaluation_context.create_bundle(
pvalue.PBegin(self._applied_ptransform.transform.pipeline))
- bundle.add(
- GlobalWindows.windowed_value(
- test_stream.begin(), timestamp=MIN_TIMESTAMP))
+ bundle.add(GlobalWindows.windowed_value(b'', timestamp=MIN_TIMESTAMP))
bundle.commit(None)
return [bundle]
@@ -424,9 +433,9 @@ class _WatermarkControllerEvaluator(_TransformEvaluator):
bundle = self._evaluation_context.create_bundle(main_output)
for tv in event.timestamped_values:
# Unreify the value into the correct window.
- try:
- bundle.output(WindowedValue(**tv.value))
- except TypeError:
+ if isinstance(tv.value, WindowedValueHolder):
+ bundle.output(tv.value.windowed_value)
+ else:
bundle.output(
GlobalWindows.windowed_value(tv.value, timestamp=tv.timestamp))
self.bundles.append(bundle)
@@ -470,8 +479,11 @@ class _PairWithTimingEvaluator(_TransformEvaluator):
self.timing_info = TimingInfo(now, output_watermark)
def process_element(self, element):
- element.value = (element.value, self.timing_info)
- self.bundle.output(element)
+ result = WindowedValue((element.value, self.timing_info),
+ element.timestamp,
+ element.windows,
+ element.pane_info)
+ self.bundle.output(result)
def finish_bundle(self):
return TransformResult(self, [self.bundle], [], None, {})
@@ -487,6 +499,9 @@ class _TestStreamEvaluator(_TransformEvaluator):
The _WatermarkController is in charge of emitting the elements to the
downstream consumers and setting its own output watermark.
"""
+
+ event_stream = None
+
def __init__(
self,
evaluation_context,
@@ -500,23 +515,16 @@ class _TestStreamEvaluator(_TransformEvaluator):
input_committed_bundle,
side_inputs)
self.test_stream = applied_ptransform.transform
+ self.is_done = False
def start_bundle(self):
- self.current_index = 0
self.bundles = []
self.watermark = MIN_TIMESTAMP
def process_element(self, element):
- # The index into the TestStream list of events.
- self.current_index = element.value
-
# The watermark of the _TestStream transform itself.
self.watermark = element.timestamp
- # We can either have the _TestStream or the _WatermarkController to emit
- # the elements. We chose to emit in the _WatermarkController so that the
- # element is emitted at the correct watermark value.
-
# Set up the correct watermark holds in the Watermark controllers and the
# TestStream so that the watermarks will not automatically advance to +inf
# when elements start streaming. This can happen multiple times in the
first
@@ -527,9 +535,20 @@ class _TestStreamEvaluator(_TransformEvaluator):
for event in self.test_stream._set_up(self.test_stream.output_tags):
events.append(event)
- events += [e for e in self.test_stream.events(self.current_index)]
+ # Retrieve the TestStream's event stream and read from it.
+ try:
+ events.append(next(self.event_stream))
+ except StopIteration:
+ # Advance the watermarks to +inf to cleanly stop the pipeline.
+ self.is_done = True
+ events += ([
+ e for e in self.test_stream._tear_down(self.test_stream.output_tags)
+ ])
for event in events:
+ # We can either have the _TestStream or the _WatermarkController to emit
+ # the elements. We chose to emit in the _WatermarkController so that the
+ # element is emitted at the correct watermark value.
if isinstance(event, (ElementEvent, WatermarkEvent)):
# The WATERMARK_CONTROL_TAG is used to hold the _TestStream's
# watermark to -inf, then +inf-1, then +inf. This watermark progression
@@ -550,12 +569,15 @@ class _TestStreamEvaluator(_TransformEvaluator):
def finish_bundle(self):
unprocessed_bundles = []
- next_index = self.test_stream.next(self.current_index)
- if not self.test_stream.end(next_index):
+
+ # Continue to send its own state to itself via an unprocessed bundle. This
+ # acts as a heartbeat, where each element will read the next event from the
+ # event stream.
+ if not self.is_done:
unprocessed_bundle = self._evaluation_context.create_bundle(
pvalue.PBegin(self._applied_ptransform.transform.pipeline))
unprocessed_bundle.add(
- GlobalWindows.windowed_value(next_index, timestamp=self.watermark))
+ GlobalWindows.windowed_value(b'', timestamp=self.watermark))
unprocessed_bundles.append(unprocessed_bundle)
# Returning the watermark in the dict here is used as a watermark hold.
diff --git
a/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py
b/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py
index 0f76874..e5b6b52 100644
--- a/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py
+++ b/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py
@@ -267,14 +267,14 @@ class BackgroundCachingJobTest(unittest.TestCase):
def test_determine_a_test_stream_service_running(self):
pipeline = _build_an_empty_stream_pipeline()
- test_stream_service = TestStreamServiceController(events=iter([]))
+ test_stream_service = TestStreamServiceController(reader=None)
ie.current_env().set_test_stream_service_controller(
pipeline, test_stream_service)
self.assertTrue(bcj.is_a_test_stream_service_running(pipeline))
def test_stop_a_running_test_stream_service(self):
pipeline = _build_an_empty_stream_pipeline()
- test_stream_service = TestStreamServiceController(events=iter([]))
+ test_stream_service = TestStreamServiceController(reader=None)
test_stream_service.start()
ie.current_env().set_test_stream_service_controller(
pipeline, test_stream_service)
diff --git a/sdks/python/apache_beam/testing/test_stream.py
b/sdks/python/apache_beam/testing/test_stream.py
index 6d95ede..7719b7e 100644
--- a/sdks/python/apache_beam/testing/test_stream.py
+++ b/sdks/python/apache_beam/testing/test_stream.py
@@ -39,6 +39,7 @@ from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.portability.api.beam_interactive_api_pb2 import
TestStreamFileHeader
from apache_beam.portability.api.beam_interactive_api_pb2 import
TestStreamFileRecord
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.portability.api.endpoints_pb2 import ApiServiceDescriptor
from apache_beam.transforms import PTransform
from apache_beam.transforms import core
from apache_beam.transforms import window
@@ -92,16 +93,14 @@ class Event(with_metaclass(ABCMeta, object)): # type:
ignore[misc]
return ElementEvent([
TimestampedValue(
element_coder.decode(tv.encoded_element),
- timestamp.Timestamp(micros=1000 * tv.timestamp))
+ Timestamp(micros=1000 * tv.timestamp))
for tv in proto.element_event.elements
], tag=tag) # yapf: disable
elif proto.HasField('watermark_event'):
event = proto.watermark_event
tag = None if event.tag == 'None' else event.tag
return WatermarkEvent(
- timestamp.Timestamp(
- micros=1000 * proto.watermark_event.new_watermark),
- tag=tag)
+ Timestamp(micros=1000 * proto.watermark_event.new_watermark),
tag=tag)
elif proto.HasField('processing_time_event'):
return ProcessingTimeEvent(
timestamp.Duration(
@@ -155,7 +154,7 @@ class ElementEvent(Event):
class WatermarkEvent(Event):
"""Watermark-advancing test stream event."""
def __init__(self, new_watermark, tag=None):
- self.new_watermark = timestamp.Timestamp.of(new_watermark)
+ self.new_watermark = Timestamp.of(new_watermark)
self.tag = tag
def __eq__(self, other):
@@ -217,6 +216,17 @@ class ProcessingTimeEvent(Event):
return 'ProcessingTimeEvent: <{}>'.format(self.advance_by)
+class WindowedValueHolder:
+ """A class that holds a WindowedValue.
+
+ This is a special class that can be used by the runner that implements the
+ TestStream as a signal that the underlying value should be unreified to the
+ specified window.
+ """
+ def __init__(self, windowed_value):
+ self.windowed_value = windowed_value
+
+
class TestStream(PTransform):
"""Test stream that generates events on an unbounded PCollection of elements.
@@ -236,20 +246,36 @@ class TestStream(PTransform):
"""
def __init__(
- self, coder=coders.FastPrimitivesCoder(), events=None, output_tags=None):
+ self,
+ coder=coders.FastPrimitivesCoder(),
+ events=None,
+ output_tags=None,
+ endpoint=None):
+ """TestStream constructor.
+
+ Args:
+ coder: (apache_beam.Coder) the element coder for any ElementEvents.
+ events: (List[Event]) a list of instructions for the TestStream to
+ execute. This must be a subset of the given output_tags.
+ output_tags: (List[str]) a list of PCollection output tags.
+ endpoint: (str) a URL locating a TestStreamService.
+ """
super(TestStream, self).__init__()
assert coder is not None
self.coder = coder
self.watermarks = {None: timestamp.MIN_TIMESTAMP}
- self._events = [] if events is None else list(events)
self.output_tags = set(output_tags) if output_tags else set()
+ self._events = [] if events is None else list(events)
+ self._endpoint = endpoint
event_tags = set(
e.tag for e in self._events
if isinstance(e, (WatermarkEvent, ElementEvent)))
assert event_tags.issubset(self.output_tags), \
'{} is not a subset of {}'.format(event_tags, output_tags)
+ assert not (self._events and self._endpoint), \
+ 'Only either events or an endpoint can be given at once.'
def get_windowing(self, unused_inputs):
return core.Windowing(window.GlobalWindows())
@@ -354,7 +380,8 @@ class TestStream(PTransform):
common_urns.primitives.TEST_STREAM.urn,
beam_runner_api_pb2.TestStreamPayload(
coder_id=context.coders.get_id(self.coder),
- events=[e.to_runner_api(self.coder) for e in self._events]))
+ events=[e.to_runner_api(self.coder) for e in self._events],
+ endpoint=ApiServiceDescriptor(url=self._endpoint)))
@staticmethod
@PTransform.register_urn(
@@ -367,13 +394,14 @@ class TestStream(PTransform):
return TestStream(
coder=coder,
events=[Event.from_runner_api(e, coder) for e in payload.events],
- output_tags=output_tags)
+ output_tags=output_tags,
+ endpoint=payload.endpoint.url)
class TimingInfo(object):
def __init__(self, processing_time, watermark):
- self._processing_time = timestamp.Timestamp.of(processing_time)
- self._watermark = timestamp.Timestamp.of(watermark)
+ self._processing_time = Timestamp.of(processing_time)
+ self._watermark = Timestamp.of(watermark)
@property
def processing_time(self):
diff --git a/sdks/python/apache_beam/testing/test_stream_service.py
b/sdks/python/apache_beam/testing/test_stream_service.py
index 28ec9ca..694f61b 100644
--- a/sdks/python/apache_beam/testing/test_stream_service.py
+++ b/sdks/python/apache_beam/testing/test_stream_service.py
@@ -28,19 +28,23 @@ from apache_beam.portability.api.beam_runner_api_pb2_grpc
import TestStreamServi
class TestStreamServiceController(TestStreamServiceServicer):
- def __init__(self, events, endpoint=None):
+ """A server that streams TestStreamPayload.Events from a single EventRequest.
+
+ This server is used as a way for TestStreams to receive events from file.
+ """
+ def __init__(self, reader, endpoint=None):
self._server = grpc.server(ThreadPoolExecutor(max_workers=10))
if endpoint:
self.endpoint = endpoint
self._server.add_insecure_port(self.endpoint)
else:
- port = self._server.add_insecure_port('[::]:0')
- self.endpoint = '[::]:{}'.format(port)
+ port = self._server.add_insecure_port('localhost:0')
+ self.endpoint = 'localhost:{}'.format(port)
beam_runner_api_pb2_grpc.add_TestStreamServiceServicer_to_server(
self, self._server)
- self._events = events
+ self._reader = reader
def start(self):
self._server.start()
@@ -52,5 +56,9 @@ class TestStreamServiceController(TestStreamServiceServicer):
def Events(self, request, context):
"""Streams back all of the events from the streaming cache."""
- for e in self._events:
+ # TODO(srohde): Once we get rid of the CacheManager, get rid of this 'full'
+ # label.
+ tags = [None if tag == 'None' else tag for tag in request.output_ids]
+ reader = self._reader.read_multiple([('full', tag) for tag in tags])
+ for e in reader:
yield e
diff --git a/sdks/python/apache_beam/testing/test_stream_service_test.py
b/sdks/python/apache_beam/testing/test_stream_service_test.py
index e1aebb4..01b16a1 100644
--- a/sdks/python/apache_beam/testing/test_stream_service_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_service_test.py
@@ -38,18 +38,31 @@ TestStreamFileHeader.__test__ = False # type:
ignore[attr-defined]
TestStreamFileRecord.__test__ = False # type: ignore[attr-defined]
-class TestStreamServiceTest(unittest.TestCase):
- def events(self):
- events = []
+class EventsReader:
+ def __init__(self, expected_key):
+ self._expected_key = expected_key
+
+ def read_multiple(self, keys):
+ if keys != self._expected_key:
+ raise ValueError(
+ 'Expected key ({}) is not argument({})'.format(
+ self._expected_key, keys))
+
for i in range(10):
e = TestStreamPayload.Event()
e.element_event.elements.append(
TestStreamPayload.TimestampedElement(timestamp=i))
- events.append(e)
- return events
+ yield e
+
+EXPECTED_KEY = 'key'
+EXPECTED_KEYS = [EXPECTED_KEY]
+
+
+class TestStreamServiceTest(unittest.TestCase):
def setUp(self):
- self.controller = TestStreamServiceController(self.events())
+ self.controller = TestStreamServiceController(
+ EventsReader(expected_key=[('full', EXPECTED_KEY)]))
self.controller.start()
channel = grpc.insecure_channel(self.controller.endpoint)
@@ -59,15 +72,21 @@ class TestStreamServiceTest(unittest.TestCase):
self.controller.stop()
def test_normal_run(self):
- r = self.stub.Events(beam_runner_api_pb2.EventsRequest())
+ r = self.stub.Events(
+ beam_runner_api_pb2.EventsRequest(output_ids=EXPECTED_KEYS))
events = [e for e in r]
- expected_events = [e for e in self.events()]
+ expected_events = [
+ e for e in EventsReader(
+ expected_key=[EXPECTED_KEYS]).read_multiple([EXPECTED_KEYS])
+ ]
self.assertEqual(events, expected_events)
def test_multiple_sessions(self):
- resp_a = self.stub.Events(beam_runner_api_pb2.EventsRequest())
- resp_b = self.stub.Events(beam_runner_api_pb2.EventsRequest())
+ resp_a = self.stub.Events(
+ beam_runner_api_pb2.EventsRequest(output_ids=EXPECTED_KEYS))
+ resp_b = self.stub.Events(
+ beam_runner_api_pb2.EventsRequest(output_ids=EXPECTED_KEYS))
events_a = []
events_b = []
@@ -88,7 +107,10 @@ class TestStreamServiceTest(unittest.TestCase):
done = a_is_done and b_is_done
- expected_events = [e for e in self.events()]
+ expected_events = [
+ e for e in EventsReader(
+ expected_key=[EXPECTED_KEYS]).read_multiple([EXPECTED_KEYS])
+ ]
self.assertEqual(events_a, expected_events)
self.assertEqual(events_b, expected_events)
diff --git a/sdks/python/apache_beam/testing/test_stream_test.py
b/sdks/python/apache_beam/testing/test_stream_test.py
index 650953f..e0ee91b 100644
--- a/sdks/python/apache_beam/testing/test_stream_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_test.py
@@ -38,6 +38,8 @@ from apache_beam.testing.test_stream import
ProcessingTimeEvent
from apache_beam.testing.test_stream import ReverseTestStream
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.test_stream import WatermarkEvent
+from apache_beam.testing.test_stream import WindowedValueHolder
+from apache_beam.testing.test_stream_service import TestStreamServiceController
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.testing.util import equal_to_per_window
@@ -305,16 +307,15 @@ class TestStreamTest(unittest.TestCase):
]))
def test_windowed_values_interpreted_correctly(self):
- windowed_value_args = {
- 'value': 'a',
- 'timestamp': Timestamp(5),
- 'windows': [beam.window.IntervalWindow(5, 10)],
- 'pane_info': PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0)
- }
+ windowed_value = WindowedValueHolder(
+ WindowedValue(
+ 'a',
+ Timestamp(5), [beam.window.IntervalWindow(5, 10)],
+ PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0)))
test_stream = (TestStream()
.advance_processing_time(10)
.advance_watermark_to(10)
- .add_elements([windowed_value_args])
+ .add_elements([windowed_value])
.advance_watermark_to_infinity()) # yapf: disable
class RecordFn(beam.DoFn):
@@ -491,20 +492,23 @@ class TestStreamTest(unittest.TestCase):
def test_basic_execution_sideinputs(self):
options = PipelineOptions()
+ options.view_as(DebugOptions).add_experiment(
+ 'passthrough_pcollection_output_ids')
options.view_as(StandardOptions).streaming = True
with TestPipeline(options=options) as p:
- main_stream = (
- p
- | 'main TestStream' >>
- TestStream().advance_watermark_to(10).add_elements(['e']))
- side_stream = (
- p
- | 'side TestStream' >> TestStream().add_elements([
- window.TimestampedValue(2, 2)
- ]).add_elements([window.TimestampedValue(1, 1)]).add_elements([
- window.TimestampedValue(7, 7)
- ]).add_elements([window.TimestampedValue(4, 4)]))
+ test_stream = (p | TestStream()
+ .advance_watermark_to(0, tag='side')
+ .advance_watermark_to(10, tag='main')
+ .add_elements(['e'], tag='main')
+ .add_elements([window.TimestampedValue(2, 2)], tag='side')
+ .add_elements([window.TimestampedValue(1, 1)], tag='side')
+ .add_elements([window.TimestampedValue(7, 7)], tag='side')
+ .add_elements([window.TimestampedValue(4, 4)], tag='side')
+ ) # yapf: disable
+
+ main_stream = test_stream['main']
+ side_stream = test_stream['side']
class RecordFn(beam.DoFn):
def process(
@@ -564,22 +568,30 @@ class TestStreamTest(unittest.TestCase):
def test_basic_execution_sideinputs_fixed_windows(self):
options = PipelineOptions()
+ options.view_as(DebugOptions).add_experiment(
+ 'passthrough_pcollection_output_ids')
options.view_as(StandardOptions).streaming = True
p = TestPipeline(options=options)
+ test_stream = (p | TestStream()
+ .advance_watermark_to(12, tag='side')
+ .add_elements([window.TimestampedValue('s1', 10)], tag='side')
+ .advance_watermark_to(20, tag='side')
+ .add_elements([window.TimestampedValue('s2', 20)], tag='side')
+
+ .advance_watermark_to(9, tag='main')
+ .add_elements(['a1', 'a2', 'a3', 'a4'], tag='main')
+ .add_elements(['b'], tag='main')
+ .advance_watermark_to(18, tag='main')
+ .add_elements('c', tag='main')
+ ) # yapf: disable
+
main_stream = (
- p
- |
- 'main TestStream' >>
TestStream().advance_watermark_to(9).add_elements([
- 'a1', 'a2', 'a3', 'a4'
- ]).add_elements(['b']).advance_watermark_to(18).add_elements('c')
+ test_stream['main']
| 'main windowInto' >> beam.WindowInto(window.FixedWindows(1)))
+
side_stream = (
- p
- |
- 'side TestStream' >>
TestStream().advance_watermark_to(12).add_elements(
- [window.TimestampedValue('s1', 10)]).advance_watermark_to(
- 20).add_elements([window.TimestampedValue('s2', 20)])
+ test_stream['side']
| 'side windowInto' >> beam.WindowInto(window.FixedWindows(3)))
class RecordFn(beam.DoFn):
@@ -664,6 +676,64 @@ class TestStreamTest(unittest.TestCase):
test_stream.output_tags, roundtrip_test_stream.output_tags)
self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
+ def test_basic_execution_with_service(self):
+ """Tests that the TestStream can correctly read from an RPC service.
+ """
+ coder = beam.coders.FastPrimitivesCoder()
+
+ test_stream_events = (TestStream(coder=coder)
+ .advance_watermark_to(10000)
+ .add_elements(['a', 'b', 'c'])
+ .advance_watermark_to(20000)
+ .add_elements(['d'])
+ .add_elements(['e'])
+ .advance_processing_time(10)
+ .advance_watermark_to(300000)
+ .add_elements([TimestampedValue('late', 12000)])
+ .add_elements([TimestampedValue('last', 310000)])
+ .advance_watermark_to_infinity())._events # yapf: disable
+
+ test_stream_proto_events = [
+ e.to_runner_api(coder) for e in test_stream_events
+ ]
+
+ class InMemoryEventReader:
+ def read_multiple(self, unused_keys):
+ for e in test_stream_proto_events:
+ yield e
+
+ service = TestStreamServiceController(reader=InMemoryEventReader())
+ service.start()
+
+ test_stream = TestStream(coder=coder, endpoint=service.endpoint)
+
+ class RecordFn(beam.DoFn):
+ def process(
+ self,
+ element=beam.DoFn.ElementParam,
+ timestamp=beam.DoFn.TimestampParam):
+ yield (element, timestamp)
+
+ options = StandardOptions(streaming=True)
+
+ p = TestPipeline(options=options)
+ my_record_fn = RecordFn()
+ records = p | test_stream | beam.ParDo(my_record_fn)
+
+ assert_that(
+ records,
+ equal_to([
+ ('a', timestamp.Timestamp(10)),
+ ('b', timestamp.Timestamp(10)),
+ ('c', timestamp.Timestamp(10)),
+ ('d', timestamp.Timestamp(20)),
+ ('e', timestamp.Timestamp(20)),
+ ('late', timestamp.Timestamp(12)),
+ ('last', timestamp.Timestamp(310)),
+ ]))
+
+ p.run()
+
class ReverseTestStreamTest(unittest.TestCase):
def test_basic_execution(self):