This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 437c88f Move TestStream implementation to replacement transform
new 7a4cdec Merge pull request #10892 from rohdesamuel/teststream_merge
437c88f is described below
commit 437c88f1c9754fdac9769d734d08f281318fb343
Author: Sam Rohde <[email protected]>
AuthorDate: Thu Feb 13 16:38:43 2020 -0800
Move TestStream implementation to replacement transform
* This also moves the DirectRunner's TestStream implementation to a
replacement transform. This is because the TestStream relies on getting
the output_tags from the PTransform.
Change-Id: Ibd80b0d25cd8cc5ff5c28e127f7313638e6664da
---
sdks/python/apache_beam/io/iobase.py | 2 +-
sdks/python/apache_beam/pipeline.py | 2 +-
.../apache_beam/runners/direct/direct_runner.py | 70 ++++++----------------
.../apache_beam/runners/direct/test_stream_impl.py | 61 ++++++++++++++++++-
.../runners/portability/expansion_service.py | 3 +-
.../runners/portability/expansion_service_test.py | 29 +++++----
sdks/python/apache_beam/testing/test_stream.py | 53 ++++++++++++----
.../python/apache_beam/testing/test_stream_test.py | 51 ++++++++++++++++
sdks/python/apache_beam/transforms/core.py | 17 +++---
.../apache_beam/transforms/external_it_test.py | 2 +-
sdks/python/apache_beam/transforms/ptransform.py | 16 ++---
sdks/python/apache_beam/transforms/util.py | 3 +-
12 files changed, 214 insertions(+), 95 deletions(-)
diff --git a/sdks/python/apache_beam/io/iobase.py
b/sdks/python/apache_beam/io/iobase.py
index e320c4a..1edcac9 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -925,7 +925,7 @@ class Read(ptransform.PTransform):
beam_runner_api_pb2.IsBounded.UNBOUNDED))
@staticmethod
- def from_runner_api_parameter(parameter, context):
+ def from_runner_api_parameter(unused_ptransform, parameter, context):
# type: (beam_runner_api_pb2.ReadPayload, PipelineContext) -> Read
return Read(SourceBase.from_runner_api(parameter.source, context))
diff --git a/sdks/python/apache_beam/pipeline.py
b/sdks/python/apache_beam/pipeline.py
index 50c141a..22aa2a3 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -1099,7 +1099,7 @@ class AppliedPTransform(object):
id in proto.inputs.items() if is_side_input(tag)
]
side_inputs = [si for _, si in sorted(indexed_side_inputs)]
- transform = ptransform.PTransform.from_runner_api(proto.spec, context)
+ transform = ptransform.PTransform.from_runner_api(proto, context)
result = AppliedPTransform(
parent=None,
transform=transform,
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py
b/sdks/python/apache_beam/runners/direct/direct_runner.py
index 1ddda51..f207779 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -73,60 +73,12 @@ class SwitchingDirectRunner(PipelineRunner):
def is_fnapi_compatible(self):
return BundleBasedDirectRunner.is_fnapi_compatible()
- def apply_TestStream(self, transform, pbegin, options):
- """Expands the TestStream into the DirectRunner implementation.
-
- Takes the TestStream transform and creates a _TestStream -> multiplexer ->
- _WatermarkController.
- """
-
- from apache_beam.runners.direct.test_stream_impl import _TestStream
- from apache_beam.runners.direct.test_stream_impl import
_WatermarkController
- from apache_beam import pvalue
- assert isinstance(pbegin, pvalue.PBegin)
-
- # If there is only one tag there is no need to add the multiplexer.
- if len(transform.output_tags) == 1:
- return (
- pbegin
- | _TestStream(transform.output_tags, events=transform._events)
- | _WatermarkController())
-
- # This multiplexing the multiple output PCollections.
- def mux(event):
- if event.tag:
- yield pvalue.TaggedOutput(event.tag, event)
- else:
- yield event
-
- mux_output = (
- pbegin
- | _TestStream(transform.output_tags, events=transform._events)
- | 'TestStream Multiplexer' >> beam.ParDo(mux).with_outputs())
-
- # Apply a way to control the watermark per output. It is necessary to
- # have an individual _WatermarkController per PCollection because the
- # calculation of the input watermark of a transform is based on the event
- # timestamp of the elements flowing through it. Meaning, it is impossible
- # to control the output watermarks of the individual PCollections solely
- # on the event timestamps.
- outputs = {}
- for tag in transform.output_tags:
- label = '_WatermarkController[{}]'.format(tag)
- outputs[tag] = (mux_output[tag] | label >> _WatermarkController())
-
- return outputs
-
- # We must mark this method as not a test or else its name is a matcher for
- # nosetest tests.
- apply_TestStream.__test__ = False
-
def run_pipeline(self, pipeline, options):
from apache_beam.pipeline import PipelineVisitor
from apache_beam.runners.dataflow.native_io.iobase import NativeSource
from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite
- from apache_beam.runners.direct.test_stream_impl import _TestStream
+ from apache_beam.testing.test_stream import TestStream
class _FnApiRunnerSupportVisitor(PipelineVisitor):
"""Visitor determining if a Pipeline can be run on the FnApiRunner."""
@@ -138,7 +90,7 @@ class SwitchingDirectRunner(PipelineRunner):
def visit_transform(self, applied_ptransform):
transform = applied_ptransform.transform
# The FnApiRunner does not support streaming execution.
- if isinstance(transform, _TestStream):
+ if isinstance(transform, TestStream):
self.supported_by_fnapi_runner = False
# The FnApiRunner does not support reads from NativeSources.
if (isinstance(transform, beam.io.Read) and
@@ -195,7 +147,8 @@ class _StreamingGroupByKeyOnly(_GroupByKeyOnly):
@staticmethod
@PTransform.register_urn(urn, None)
- def from_runner_api_parameter(unused_payload, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_payload, unused_context):
return _StreamingGroupByKeyOnly()
@@ -214,7 +167,7 @@ class _StreamingGroupAlsoByWindow(_GroupAlsoByWindow):
@staticmethod
@PTransform.register_urn(urn, wrappers_pb2.BytesValue)
- def from_runner_api_parameter(payload, context):
+ def from_runner_api_parameter(unused_ptransform, payload, context):
return _StreamingGroupAlsoByWindow(
context.windowing_strategies.get_by_id(payload.value))
@@ -271,10 +224,21 @@ def _get_transform_overrides(pipeline_options):
transform = _StreamingGroupAlsoByWindow(transform.dofn.windowing)
return transform
+ class TestStreamOverride(PTransformOverride):
+ def matches(self, applied_ptransform):
+ from apache_beam.testing.test_stream import TestStream
+ self.applied_ptransform = applied_ptransform
+ return isinstance(applied_ptransform.transform, TestStream)
+
+ def get_replacement_transform(self, transform):
+ from apache_beam.runners.direct.test_stream_impl import
_ExpandableTestStream
+ return _ExpandableTestStream(transform)
+
overrides = [
SplittableParDoOverride(),
ProcessKeyedElementsViaKeyedWorkItemsOverride(),
- CombinePerKeyOverride()
+ CombinePerKeyOverride(),
+ TestStreamOverride(),
]
# Add streaming overrides, if necessary.
diff --git a/sdks/python/apache_beam/runners/direct/test_stream_impl.py
b/sdks/python/apache_beam/runners/direct/test_stream_impl.py
index 9b91154..8cee1fc 100644
--- a/sdks/python/apache_beam/runners/direct/test_stream_impl.py
+++ b/sdks/python/apache_beam/runners/direct/test_stream_impl.py
@@ -27,6 +27,7 @@ tagged PCollection.
from __future__ import absolute_import
+from apache_beam import ParDo
from apache_beam import coders
from apache_beam import pvalue
from apache_beam.testing.test_stream import WatermarkEvent
@@ -45,11 +46,69 @@ class _WatermarkController(PTransform):
- If the instance receives an ElementEvent, it emits all specified elements
to the Global Window with the event time set to the element's timestamp.
"""
+ def __init__(self, output_tag):
+ self.output_tag = output_tag
+
def get_windowing(self, _):
return core.Windowing(window.GlobalWindows())
def expand(self, pcoll):
- return pvalue.PCollection.from_(pcoll)
+ ret = pvalue.PCollection.from_(pcoll)
+ ret.tag = self.output_tag
+ return ret
+
+
+class _ExpandableTestStream(PTransform):
+ def __init__(self, test_stream):
+ self.test_stream = test_stream
+
+ def expand(self, pbegin):
+ """Expands the TestStream into the DirectRunner implementation.
+
+
+ Takes the TestStream transform and creates a _TestStream -> multiplexer ->
+ _WatermarkController.
+ """
+
+ assert isinstance(pbegin, pvalue.PBegin)
+
+ # If there is only one tag there is no need to add the multiplexer.
+ if len(self.test_stream.output_tags) == 1:
+ return (
+ pbegin
+ | _TestStream(
+ self.test_stream.output_tags,
+ events=self.test_stream._events,
+ coder=self.test_stream.coder)
+ | _WatermarkController(list(self.test_stream.output_tags)[0]))
+
+ # Multiplex to the correct PCollection based upon the event tag.
+ def mux(event):
+ if event.tag:
+ yield pvalue.TaggedOutput(event.tag, event)
+ else:
+ yield event
+
+ mux_output = (
+ pbegin
+ | _TestStream(
+ self.test_stream.output_tags,
+ events=self.test_stream._events,
+ coder=self.test_stream.coder)
+ | 'TestStream Multiplexer' >> ParDo(mux).with_outputs())
+
+ # Apply a way to control the watermark per output. It is necessary to
+ # have an individual _WatermarkController per PCollection because the
+ # calculation of the input watermark of a transform is based on the event
+ # timestamp of the elements flowing through it. Meaning, it is impossible
+ # to control the output watermarks of the individual PCollections solely
+ # on the event timestamps.
+ outputs = {}
+ for tag in self.test_stream.output_tags:
+ label = '_WatermarkController[{}]'.format(tag)
+ outputs[tag] = (mux_output[tag] | label >> _WatermarkController(tag))
+
+ return outputs
class _TestStream(PTransform):
diff --git a/sdks/python/apache_beam/runners/portability/expansion_service.py
b/sdks/python/apache_beam/runners/portability/expansion_service.py
index 5a9601d..d2be037 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service.py
@@ -64,8 +64,7 @@ class ExpansionServiceServicer(
pcoll_id in t_proto.outputs.items()
}
transform = with_pipeline(
- ptransform.PTransform.from_runner_api(
- request.transform.spec, context))
+ ptransform.PTransform.from_runner_api(request.transform, context))
inputs = transform._pvaluish_from_dict({
tag:
with_pipeline(context.pcollections.get_by_id(pcoll_id), pcoll_id)
diff --git
a/sdks/python/apache_beam/runners/portability/expansion_service_test.py
b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
index 809a2c4..34acac1 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service_test.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
@@ -62,7 +62,8 @@ class CountPerElementTransform(ptransform.PTransform):
return 'beam:transforms:xlang:count', None
@staticmethod
- def from_runner_api_parameter(unused_parameter, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_parameter, unused_context):
return CountPerElementTransform()
@@ -82,7 +83,7 @@ class FilterLessThanTransform(ptransform.PTransform):
'beam:transforms:xlang:filter_less_than', self._payload.encode('utf8'))
@staticmethod
- def from_runner_api_parameter(payload, unused_context):
+ def from_runner_api_parameter(unused_ptransform, payload, unused_context):
return FilterLessThanTransform(payload.decode('utf8'))
@@ -101,7 +102,7 @@ class PrefixTransform(ptransform.PTransform):
{'data': self._payload}).payload()
@staticmethod
- def from_runner_api_parameter(payload, unused_context):
+ def from_runner_api_parameter(unused_ptransform, payload, unused_context):
return PrefixTransform(parse_string_payload(payload)['data'])
@@ -134,7 +135,8 @@ class GBKTransform(ptransform.PTransform):
return TEST_GBK_URN, None
@staticmethod
- def from_runner_api_parameter(unused_parameter, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_parameter, unused_context):
return GBKTransform()
@@ -155,7 +157,8 @@ class CoGBKTransform(ptransform.PTransform):
return TEST_CGBK_URN, None
@staticmethod
- def from_runner_api_parameter(unused_parameter, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_parameter, unused_context):
return CoGBKTransform()
@@ -169,7 +172,8 @@ class CombineGloballyTransform(ptransform.PTransform):
return TEST_COMGL_URN, None
@staticmethod
- def from_runner_api_parameter(unused_parameter, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_parameter, unused_context):
return CombineGloballyTransform()
@@ -184,7 +188,8 @@ class CombinePerKeyTransform(ptransform.PTransform):
return TEST_COMPK_URN, None
@staticmethod
- def from_runner_api_parameter(unused_parameter, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_parameter, unused_context):
return CombinePerKeyTransform()
@@ -197,7 +202,8 @@ class FlattenTransform(ptransform.PTransform):
return TEST_FLATTEN_URN, None
@staticmethod
- def from_runner_api_parameter(unused_parameter, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_parameter, unused_context):
return FlattenTransform()
@@ -214,7 +220,8 @@ class PartitionTransform(ptransform.PTransform):
return TEST_PARTITION_URN, None
@staticmethod
- def from_runner_api_parameter(unused_parameter, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_parameter, unused_context):
return PartitionTransform()
@@ -230,7 +237,7 @@ class PayloadTransform(ptransform.PTransform):
return b'payload', self._payload.encode('ascii')
@staticmethod
- def from_runner_api_parameter(payload, unused_context):
+ def from_runner_api_parameter(unused_ptransform, payload, unused_context):
return PayloadTransform(payload.decode('ascii'))
@@ -259,7 +266,7 @@ class FibTransform(ptransform.PTransform):
return 'fib', str(self._level).encode('ascii')
@staticmethod
- def from_runner_api_parameter(level, unused_context):
+ def from_runner_api_parameter(unused_ptransform, level, unused_context):
return FibTransform(int(level.decode('ascii')))
diff --git a/sdks/python/apache_beam/testing/test_stream.py
b/sdks/python/apache_beam/testing/test_stream.py
index 406f4da..967cbcf 100644
--- a/sdks/python/apache_beam/testing/test_stream.py
+++ b/sdks/python/apache_beam/testing/test_stream.py
@@ -76,16 +76,21 @@ class Event(with_metaclass(ABCMeta, object)): # type:
ignore[misc]
@staticmethod
def from_runner_api(proto, element_coder):
if proto.HasField('element_event'):
+ event = proto.element_event
+ tag = None if event.tag == 'None' else event.tag
return ElementEvent([
TimestampedValue(
element_coder.decode(tv.encoded_element),
timestamp.Timestamp(micros=1000 * tv.timestamp))
for tv in proto.element_event.elements
- ])
+ ], tag=tag) # yapf: disable
elif proto.HasField('watermark_event'):
+ event = proto.watermark_event
+ tag = None if event.tag == 'None' else event.tag
return WatermarkEvent(
timestamp.Timestamp(
- micros=1000 * proto.watermark_event.new_watermark))
+ micros=1000 * proto.watermark_event.new_watermark),
+ tag=tag)
elif proto.HasField('processing_time_event'):
return ProcessingTimeEvent(
timestamp.Duration(
@@ -113,6 +118,7 @@ class ElementEvent(Event):
return self.timestamped_values < other.timestamped_values
def to_runner_api(self, element_coder):
+ tag = 'None' if self.tag is None else self.tag
return beam_runner_api_pb2.TestStreamPayload.Event(
element_event=beam_runner_api_pb2.TestStreamPayload.Event.AddElements(
elements=[
@@ -120,7 +126,8 @@ class ElementEvent(Event):
encoded_element=element_coder.encode(tv.value),
timestamp=tv.timestamp.micros // 1000)
for tv in self.timestamped_values
- ]))
+ ],
+ tag=tag))
class WatermarkEvent(Event):
@@ -133,15 +140,21 @@ class WatermarkEvent(Event):
return self.new_watermark == other.new_watermark and self.tag == other.tag
def __hash__(self):
- return hash(self.new_watermark)
+ return hash(str(self.new_watermark) + str(self.tag))
def __lt__(self, other):
return self.new_watermark < other.new_watermark
def to_runner_api(self, unused_element_coder):
+ tag = 'None' if self.tag is None else self.tag
+
+ # Assert that no prevision is lost.
+ assert 1000 * (
+ self.new_watermark.micros // 1000) == self.new_watermark.micros
return beam_runner_api_pb2.TestStreamPayload.Event(
watermark_event=beam_runner_api_pb2.TestStreamPayload.Event.
- AdvanceWatermark(new_watermark=self.new_watermark.micros // 1000))
+ AdvanceWatermark(
+ new_watermark=self.new_watermark.micros // 1000, tag=tag))
class ProcessingTimeEvent(Event):
@@ -171,13 +184,20 @@ class TestStream(PTransform):
time. After all of the specified elements are emitted, ceases to produce
output.
"""
- def __init__(self, coder=coders.FastPrimitivesCoder(), events=None):
+ def __init__(
+ self, coder=coders.FastPrimitivesCoder(), events=None, output_tags=None):
super(TestStream, self).__init__()
assert coder is not None
+
self.coder = coder
self.watermarks = {None: timestamp.MIN_TIMESTAMP}
self._events = [] if events is None else list(events)
- self.output_tags = set()
+ self.output_tags = set(output_tags) if output_tags else set()
+
+ event_tags = set(
+ e.tag for e in self._events
+ if isinstance(e, (WatermarkEvent, ElementEvent)))
+ assert event_tags.issubset(self.output_tags)
def get_windowing(self, unused_inputs):
return core.Windowing(window.GlobalWindows())
@@ -188,7 +208,17 @@ class TestStream(PTransform):
def expand(self, pbegin):
assert isinstance(pbegin, pvalue.PBegin)
self.pipeline = pbegin.pipeline
- return pvalue.PCollection(self.pipeline, is_bounded=False)
+ if not self.output_tags:
+ self.output_tags = set([None])
+
+ # For backwards compatibility return a single PCollection.
+ if len(self.output_tags) == 1:
+ return pvalue.PCollection(
+ self.pipeline, is_bounded=False, tag=list(self.output_tags)[0])
+ return {
+ tag: pvalue.PCollection(self.pipeline, is_bounded=False, tag=tag)
+ for tag in self.output_tags
+ }
def _add(self, event):
if isinstance(event, ElementEvent):
@@ -276,8 +306,11 @@ class TestStream(PTransform):
@PTransform.register_urn(
common_urns.primitives.TEST_STREAM.urn,
beam_runner_api_pb2.TestStreamPayload)
- def from_runner_api_parameter(payload, context):
+ def from_runner_api_parameter(ptransform, payload, context):
coder = context.coders.get_by_id(payload.coder_id)
+ output_tags = set(
+ None if k == 'None' else k for k in ptransform.outputs.keys())
return TestStream(
coder=coder,
- events=[Event.from_runner_api(e, coder) for e in payload.events])
+ events=[Event.from_runner_api(e, coder) for e in payload.events],
+ output_tags=output_tags)
diff --git a/sdks/python/apache_beam/testing/test_stream_test.py
b/sdks/python/apache_beam/testing/test_stream_test.py
index 05b1d59..cb98b14 100644
--- a/sdks/python/apache_beam/testing/test_stream_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_test.py
@@ -26,6 +26,7 @@ import unittest
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.portability import common_urns
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import ElementEvent
from apache_beam.testing.test_stream import ProcessingTimeEvent
@@ -528,6 +529,56 @@ class TestStreamTest(unittest.TestCase):
p.run()
+ def test_roundtrip_proto(self):
+ test_stream = (TestStream()
+ .advance_processing_time(1)
+ .advance_watermark_to(2)
+ .add_elements([1, 2, 3])) # yapf: disable
+
+ p = TestPipeline(options=StandardOptions(streaming=True))
+ p | test_stream
+
+ pipeline_proto, context = p.to_runner_api(return_context=True)
+
+ for t in pipeline_proto.components.transforms.values():
+ if t.spec.urn == common_urns.primitives.TEST_STREAM.urn:
+ test_stream_proto = t
+
+ self.assertTrue(test_stream_proto)
+ roundtrip_test_stream = TestStream().from_runner_api(
+ test_stream_proto, context)
+
+ self.assertListEqual(test_stream._events, roundtrip_test_stream._events)
+ self.assertSetEqual(
+ test_stream.output_tags, roundtrip_test_stream.output_tags)
+ self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
+
+ def test_roundtrip_proto_multi(self):
+ test_stream = (TestStream()
+ .advance_processing_time(1)
+ .advance_watermark_to(2, tag='a')
+ .advance_watermark_to(3, tag='b')
+ .add_elements([1, 2, 3], tag='a')
+ .add_elements([4, 5, 6], tag='b')) # yapf: disable
+
+ p = TestPipeline(options=StandardOptions(streaming=True))
+ p | test_stream
+
+ pipeline_proto, context = p.to_runner_api(return_context=True)
+
+ for t in pipeline_proto.components.transforms.values():
+ if t.spec.urn == common_urns.primitives.TEST_STREAM.urn:
+ test_stream_proto = t
+
+ self.assertTrue(test_stream_proto)
+ roundtrip_test_stream = TestStream().from_runner_api(
+ test_stream_proto, context)
+
+ self.assertListEqual(test_stream._events, roundtrip_test_stream._events)
+ self.assertSetEqual(
+ test_stream.output_tags, roundtrip_test_stream.output_tags)
+ self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/transforms/core.py
b/sdks/python/apache_beam/transforms/core.py
index 57660c5..f5ccafa 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -1339,7 +1339,7 @@ class ParDo(PTransformWithSideInputs):
@staticmethod
@PTransform.register_urn(
common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload)
- def from_runner_api_parameter(pardo_payload, context):
+ def from_runner_api_parameter(unused_ptransform, pardo_payload, context):
assert pardo_payload.do_fn.urn == python_urns.PICKLED_DOFN_INFO
fn, args, kwargs, si_tags_and_types, windowing = pickler.loads(
pardo_payload.do_fn.payload)
@@ -1932,7 +1932,7 @@ class CombinePerKey(PTransformWithSideInputs):
@PTransform.register_urn(
common_urns.composites.COMBINE_PER_KEY.urn,
beam_runner_api_pb2.CombinePayload)
- def from_runner_api_parameter(combine_payload, context):
+ def from_runner_api_parameter(unused_ptransform, combine_payload, context):
return CombinePerKey(
CombineFn.from_runner_api(combine_payload.combine_fn, context))
@@ -1975,7 +1975,7 @@ class CombineValues(PTransformWithSideInputs):
@PTransform.register_urn(
common_urns.combine_components.COMBINE_GROUPED_VALUES.urn,
beam_runner_api_pb2.CombinePayload)
- def from_runner_api_parameter(combine_payload, context):
+ def from_runner_api_parameter(unused_ptransform, combine_payload, context):
return CombineValues(
CombineFn.from_runner_api(combine_payload.combine_fn, context))
@@ -2203,7 +2203,8 @@ class GroupByKey(PTransform):
@staticmethod
@PTransform.register_urn(common_urns.primitives.GROUP_BY_KEY.urn, None)
- def from_runner_api_parameter(unused_payload, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_payload, unused_context):
return GroupByKey()
def runner_api_requires_keyed_input(self):
@@ -2494,7 +2495,7 @@ class WindowInto(ParDo):
self.windowing.to_runner_api(context))
@staticmethod
- def from_runner_api_parameter(proto, context):
+ def from_runner_api_parameter(unused_ptransform, proto, context):
windowing = Windowing.from_runner_api(proto, context)
return WindowInto(
windowing.windowfn,
@@ -2568,7 +2569,8 @@ class Flatten(PTransform):
return common_urns.primitives.FLATTEN.urn, None
@staticmethod
- def from_runner_api_parameter(unused_parameter, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_parameter, unused_context):
return Flatten()
@@ -2681,5 +2683,6 @@ class Impulse(PTransform):
@staticmethod
@PTransform.register_urn(common_urns.primitives.IMPULSE.urn, None)
- def from_runner_api_parameter(unused_parameter, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_parameter, unused_context):
return Impulse()
diff --git a/sdks/python/apache_beam/transforms/external_it_test.py
b/sdks/python/apache_beam/transforms/external_it_test.py
index 40ee910..d99c218 100644
--- a/sdks/python/apache_beam/transforms/external_it_test.py
+++ b/sdks/python/apache_beam/transforms/external_it_test.py
@@ -46,7 +46,7 @@ class ExternalTransformIT(unittest.TestCase):
return 'simple', None
@staticmethod
- def from_runner_api_parameter(_1, _2):
+ def from_runner_api_parameter(_0, _1, _2):
return SimpleTransform()
pipeline = TestPipeline(is_integration_test=True)
diff --git a/sdks/python/apache_beam/transforms/ptransform.py
b/sdks/python/apache_beam/transforms/ptransform.py
index 66a953b..5e44ef7 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -669,23 +669,25 @@ class PTransform(WithTypeHints, HasDisplayData):
@classmethod
def from_runner_api(cls,
- proto, # type:
Optional[beam_runner_api_pb2.FunctionSpec]
+ proto, # type: Optional[beam_runner_api_pb2.PTransform]
context # type: PipelineContext
):
# type: (...) -> Optional[PTransform]
- if proto is None or not proto.urn:
+ if proto is None or proto.spec is None or not proto.spec.urn:
return None
- parameter_type, constructor = cls._known_urns[proto.urn]
+ parameter_type, constructor = cls._known_urns[proto.spec.urn]
try:
return constructor(
- proto_utils.parse_Bytes(proto.payload, parameter_type), context)
+ proto,
+ proto_utils.parse_Bytes(proto.spec.payload, parameter_type),
+ context)
except Exception:
if context.allow_proto_holders:
# For external transforms we cannot build a Python ParDo object so
# we build a holder transform instead.
from apache_beam.transforms.core import RunnerAPIPTransformHolder
- return RunnerAPIPTransformHolder(proto, context)
+ return RunnerAPIPTransformHolder(proto.spec, context)
raise
def to_runner_api_parameter(
@@ -707,14 +709,14 @@ class PTransform(WithTypeHints, HasDisplayData):
@PTransform.register_urn(python_urns.GENERIC_COMPOSITE_TRANSFORM, None)
-def _create_transform(payload, unused_context):
+def _create_transform(unused_ptransform, payload, unused_context):
empty_transform = PTransform()
empty_transform._fn_api_payload = payload
return empty_transform
@PTransform.register_urn(python_urns.PICKLED_TRANSFORM, None)
-def _unpickle_transform(pickled_bytes, unused_context):
+def _unpickle_transform(unused_ptransform, pickled_bytes, unused_context):
return pickler.loads(pickled_bytes)
diff --git a/sdks/python/apache_beam/transforms/util.py
b/sdks/python/apache_beam/transforms/util.py
index 3ad99fe..361edbe 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -722,7 +722,8 @@ class Reshuffle(PTransform):
@staticmethod
@PTransform.register_urn(common_urns.composites.RESHUFFLE.urn, None)
- def from_runner_api_parameter(unused_parameter, unused_context):
+ def from_runner_api_parameter(
+ unused_ptransform, unused_parameter, unused_context):
return Reshuffle()