[ 
https://issues.apache.org/jira/browse/BEAM-2927?focusedWorklogId=82553&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-82553
 ]

ASF GitHub Bot logged work on BEAM-2927:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 21/Mar/18 00:05
            Start Date: 21/Mar/18 00:05
    Worklog Time Spent: 10m 
      Work Description: robertwb closed pull request #4781: [BEAM-2927] Python 
support for portable side inputs over Fn API
URL: https://github.com/apache/beam/pull/4781
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py 
b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index 7a5b884b1af..82130d6e2b7 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -211,36 +211,8 @@ def visit_transform(self, transform_node):
         from apache_beam.transforms.core import GroupByKey, _GroupByKeyOnly
         if isinstance(transform_node.transform, (GroupByKey, _GroupByKeyOnly)):
           pcoll = transform_node.inputs[0]
-          input_type = pcoll.element_type
-          # If input_type is not specified, then treat it as `Any`.
-          if not input_type:
-            input_type = typehints.Any
-
-          def coerce_to_kv_type(element_type):
-            if isinstance(element_type, typehints.TupleHint.TupleConstraint):
-              if len(element_type.tuple_types) == 2:
-                return element_type
-              else:
-                raise ValueError(
-                    "Tuple input to GroupByKey must be have two components. "
-                    "Found %s for %s" % (element_type, pcoll))
-            elif isinstance(input_type, typehints.AnyTypeConstraint):
-              # `Any` type needs to be replaced with a KV[Any, Any] to
-              # force a KV coder as the main output coder for the pcollection
-              # preceding a GroupByKey.
-              return typehints.KV[typehints.Any, typehints.Any]
-            elif isinstance(element_type, typehints.UnionConstraint):
-              union_types = [
-                  coerce_to_kv_type(t) for t in element_type.union_types]
-              return typehints.KV[
-                  typehints.Union[tuple(t.tuple_types[0] for t in 
union_types)],
-                  typehints.Union[tuple(t.tuple_types[1] for t in 
union_types)]]
-            else:
-              # TODO: Possibly handle other valid types.
-              raise ValueError(
-                  "Input to GroupByKey must be of Tuple or Any type. "
-                  "Found %s for %s" % (element_type, pcoll))
-          pcoll.element_type = coerce_to_kv_type(input_type)
+          pcoll.element_type = typehints.coerce_to_kv_type(
+              pcoll.element_type, transform_node.full_label)
           key_type, value_type = pcoll.element_type.tuple_types
           if transform_node.outputs:
             transform_node.outputs[None].element_type = typehints.KV[
@@ -248,6 +220,59 @@ def coerce_to_kv_type(element_type):
 
     return GroupByKeyInputVisitor()
 
+  @staticmethod
+  def side_input_visitor():
+    # Imported here to avoid circular dependencies.
+    # pylint: disable=wrong-import-order, wrong-import-position
+    from apache_beam.pipeline import PipelineVisitor
+    from apache_beam.transforms.core import ParDo
+
+    class SideInputVisitor(PipelineVisitor):
+      """Ensures input `PCollection` used as a side inputs has a `KV` type.
+
+      TODO(BEAM-115): Once Python SDK is compatible with the new Runner API,
+      we could directly replace the coder instead of mutating the element type.
+      """
+      def visit_transform(self, transform_node):
+        if isinstance(transform_node.transform, ParDo):
+          new_side_inputs = []
+          for ix, side_input in enumerate(transform_node.side_inputs):
+            access_pattern = side_input._side_input_data().access_pattern
+            if access_pattern == common_urns.ITERABLE_SIDE_INPUT:
+              # Add a map to ('', value) as Dataflow currently only handles
+              # keyed side inputs.
+              pipeline = side_input.pvalue.pipeline
+              new_side_input = _DataflowIterableSideInput(side_input)
+              new_side_input.pvalue = beam.pvalue.PCollection(
+                  pipeline,
+                  element_type=typehints.KV[
+                      str, side_input.pvalue.element_type])
+              parent = transform_node.parent or pipeline._root_transform()
+              map_to_void_key = beam.pipeline.AppliedPTransform(
+                  pipeline,
+                  beam.Map(lambda x: ('', x)),
+                  transform_node.full_label + '/MapToVoidKey%s' % ix,
+                  (side_input.pvalue,))
+              new_side_input.pvalue.producer = map_to_void_key
+              map_to_void_key.add_output(new_side_input.pvalue)
+              parent.add_part(map_to_void_key)
+              transform_node.update_input_refcounts()
+            elif access_pattern == common_urns.MULTIMAP_SIDE_INPUT:
+              # Ensure the input coder is a KV coder and patch up the
+              # access pattern to appease Dataflow.
+              side_input.pvalue.element_type = typehints.coerce_to_kv_type(
+                  side_input.pvalue.element_type, transform_node.full_label)
+              new_side_input = _DataflowMultimapSideInput(side_input)
+            else:
+              raise ValueError(
+                  'Unsupported access pattern for %r: %r' %
+                  (transform_node.full_label, access_pattern))
+            new_side_inputs.append(new_side_input)
+          transform_node.side_inputs = new_side_inputs
+          transform_node.transform.side_inputs = new_side_inputs
+
+    return SideInputVisitor()
+
   @staticmethod
   def flatten_input_visitor():
     # Imported here to avoid circular dependencies.
@@ -280,10 +305,20 @@ def run_pipeline(self, pipeline):
           'Google Cloud Dataflow runner not available, '
           'please install apache_beam[gcp]')
 
+    # Convert all side inputs into a form acceptable to Dataflow.
+    pipeline.visit(self.side_input_visitor())
+
     # Snapshot the pipeline in a portable proto before mutating it
     proto_pipeline, self.proto_context = pipeline.to_runner_api(
         return_context=True)
 
+    # TODO(BEAM-2717): Remove once Coders are already in proto.
+    for pcoll in proto_pipeline.components.pcollections.values():
+      if pcoll.coder_id not in self.proto_context.coders:
+        coder = coders.registry.get_coder(pickler.loads(pcoll.coder_id))
+        pcoll.coder_id = self.proto_context.coders.get_id(coder)
+    self.proto_context.coders.populate_map(proto_pipeline.components.coders)
+
     # Performing configured PTransform overrides.
     pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES)
 
@@ -543,10 +578,10 @@ def run_ParDo(self, transform_node):
     si_labels = {}
     full_label_counts = defaultdict(int)
     lookup_label = lambda side_pval: si_labels[side_pval]
-    for side_pval in transform_node.side_inputs:
+    for ix, side_pval in enumerate(transform_node.side_inputs):
       assert isinstance(side_pval, AsSideInput)
-      step_number = self._get_unique_step_name()
-      si_label = 'SideInput-' + step_number
+      step_name = 'SideInput-' + self._get_unique_step_name()
+      si_label = 'side%d' % ix
       pcollection_label = '%s.%s' % (
           side_pval.pvalue.producer.full_label.split('/')[-1],
           side_pval.pvalue.tag if side_pval.pvalue.tag else 'out')
@@ -560,11 +595,11 @@ def run_ParDo(self, transform_node):
       full_label_counts[pcollection_label] += 1
 
       self._add_singleton_step(
-          si_label, si_full_label, side_pval.pvalue.tag,
+          step_name, si_full_label, side_pval.pvalue.tag,
           self._cache.get_pvalue(side_pval.pvalue))
       si_dict[si_label] = {
           '@type': 'OutputReference',
-          PropertyNames.STEP_NAME: si_label,
+          PropertyNames.STEP_NAME: step_name,
           PropertyNames.OUTPUT_NAME: PropertyNames.OUT}
       si_labels[side_pval] = si_label
 
@@ -776,8 +811,9 @@ def run_Read(self, transform_node):
     # PickleCoder because GlobalWindowCoder is known coder.
     # TODO(robertwb): Query the collection for the windowfn to extract the
     # correct coder.
-    coder = coders.WindowedValueCoder(transform._infer_output_coder(),
-                                      coders.coders.GlobalWindowCoder())  # 
pylint: disable=protected-access
+    coder = coders.WindowedValueCoder(
+        coders.registry.get_coder(transform_node.outputs[None].element_type),
+        coders.coders.GlobalWindowCoder())
 
     step.encoding = self._get_cloud_encoding(coder)
     step.add_property(
@@ -895,6 +931,54 @@ def json_string_to_byte_array(encoded_string):
     return urllib.unquote(encoded_string)
 
 
+class _DataflowSideInput(beam.pvalue.AsSideInput):
+  """Wraps a side input as a dataflow-compatible side input."""
+
+  # Dataflow does not yet accept the shared urn definition for access.
+  DATAFLOW_MULTIMAP_URN = 'urn:beam:sideinput:materialization:multimap:0.1'
+
+  def _view_options(self):
+    return {
+        'data': self._data,
+    }
+
+  def _side_input_data(self):
+    return self._data
+
+
+class _DataflowIterableSideInput(_DataflowSideInput):
+  """Wraps an iterable side input as dataflow-compatible side input."""
+
+  def __init__(self, iterable_side_input):
+    # pylint: disable=protected-access
+    side_input_data = iterable_side_input._side_input_data()
+    assert side_input_data.access_pattern == common_urns.ITERABLE_SIDE_INPUT
+    iterable_view_fn = side_input_data.view_fn
+    self._data = beam.pvalue.SideInputData(
+        self.DATAFLOW_MULTIMAP_URN,
+        side_input_data.window_mapping_fn,
+        lambda multimap: iterable_view_fn(multimap['']),
+        coders.WindowedValueCoder(
+            coders.TupleCoder((coders.BytesCoder(),
+                               side_input_data.coder.wrapped_value_coder)),
+            side_input_data.coder.window_coder))
+
+
+class _DataflowMultimapSideInput(_DataflowSideInput):
+  """Wraps a multimap side input as dataflow-compatible side input."""
+
+  def __init__(self, side_input):
+    # pylint: disable=protected-access
+    self.pvalue = side_input.pvalue
+    side_input_data = side_input._side_input_data()
+    assert side_input_data.access_pattern == common_urns.MULTIMAP_SIDE_INPUT
+    self._data = beam.pvalue.SideInputData(
+        self.DATAFLOW_MULTIMAP_URN,
+        side_input_data.window_mapping_fn,
+        side_input_data.view_fn,
+        self._input_element_coder())
+
+
 class DataflowPipelineResult(PipelineResult):
   """Represents the state of a pipeline run on the Dataflow service."""
 
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py 
b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
index b5300a4a9f6..c8790824bed 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
@@ -32,6 +32,7 @@
 from apache_beam.runners import DataflowRunner
 from apache_beam.runners import TestDataflowRunner
 from apache_beam.runners import create_runner
+from apache_beam.runners.dataflow import dataflow_runner
 from apache_beam.runners.dataflow.dataflow_runner import DataflowPipelineResult
 from apache_beam.runners.dataflow.dataflow_runner import 
DataflowRuntimeException
 from apache_beam.runners.dataflow.internal.clients import dataflow as 
dataflow_api
@@ -285,9 +286,11 @@ def 
test_group_by_key_input_visitor_with_invalid_inputs(self):
     pcoll1 = PCollection(p)
     pcoll2 = PCollection(p)
     for transform in [_GroupByKeyOnly(), beam.GroupByKey()]:
-      pcoll1.element_type = typehints.TupleSequenceConstraint
+      pcoll1.element_type = str
       pcoll2.element_type = typehints.Set
-      err_msg = "Input to GroupByKey must be of Tuple or Any type"
+      err_msg = (
+          r"Input to 'label' must be compatible with KV\[Any, Any\]. "
+          "Found .*")
       for pcoll in [pcoll1, pcoll2]:
         with self.assertRaisesRegexp(ValueError, err_msg):
           DataflowRunner.group_by_key_input_visitor().visit_transform(
@@ -356,6 +359,22 @@ def test_serialize_windowing_strategy(self):
         DataflowRunner.deserialize_windowing_strategy(
             DataflowRunner.serialize_windowing_strategy(strategy)))
 
+  def test_side_input_visitor(self):
+    p = TestPipeline()
+    pc = p | beam.Create([])
+
+    transform = beam.Map(
+        lambda x, y, z: (x, y, z),
+        beam.pvalue.AsSingleton(pc),
+        beam.pvalue.AsMultiMap(pc))
+    applied_transform = AppliedPTransform(None, transform, "label", [pc])
+    DataflowRunner.side_input_visitor().visit_transform(applied_transform)
+    self.assertEqual(2, len(applied_transform.side_inputs))
+    for side_input in applied_transform.side_inputs:
+      self.assertEqual(
+          dataflow_runner._DataflowSideInput.DATAFLOW_MULTIMAP_URN,
+          side_input._side_input_data().access_pattern)
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 5c1b1de12aa..9576550f0ee 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -18,6 +18,7 @@
 """A PipelineRunner using the SDK harness.
 """
 import collections
+import contextlib
 import copy
 import logging
 import Queue as queue
@@ -903,7 +904,7 @@ def extract_endpoints(stage):
                 side_input_id=tag,
                 window=window,
                 key=key))
-        controller.state_handler.blocking_append(state_key, elements_data, 
None)
+        controller.state_handler.blocking_append(state_key, elements_data)
 
     def get_buffer(pcoll_id):
       if pcoll_id.startswith('materialize:'):
@@ -945,15 +946,19 @@ def __init__(self):
       self._lock = threading.Lock()
       self._state = collections.defaultdict(list)
 
-    def blocking_get(self, state_key, instruction_reference=None):
+    @contextlib.contextmanager
+    def process_instruction_id(self, unused_instruction_id):
+      yield
+
+    def blocking_get(self, state_key):
       with self._lock:
         return ''.join(self._state[self._to_key(state_key)])
 
-    def blocking_append(self, state_key, data, instruction_reference=None):
+    def blocking_append(self, state_key, data):
       with self._lock:
         self._state[self._to_key(state_key)].append(data)
 
-    def blocking_clear(self, state_key, instruction_reference=None):
+    def blocking_clear(self, state_key):
       with self._lock:
         del self._state[self._to_key(state_key)]
 
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index a05f22d1d37..14b25a6035e 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -36,6 +36,7 @@
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.runners import pipeline_context
+from apache_beam.runners.dataflow import dataflow_runner
 from apache_beam.runners.worker import operation_specs
 from apache_beam.runners.worker import operations
 from apache_beam.runners.worker import statesampler
@@ -147,7 +148,7 @@ def __init__(self, state_key, coder):
         def __iter__(self):
           # TODO(robertwb): Support pagination.
           input_stream = coder_impl.create_InputStream(
-              state_handler.blocking_get(self._state_key, None))
+              state_handler.blocking_get(self._state_key))
           while input_stream.size() > 0:
             yield self._coder_impl.decode_from_stream(input_stream, True)
 
@@ -157,7 +158,9 @@ def __reduce__(self):
       if access_pattern == common_urns.ITERABLE_SIDE_INPUT:
         raw_view = AllElements(state_key, self._element_coder)
 
-      elif access_pattern == common_urns.MULTIMAP_SIDE_INPUT:
+      elif (access_pattern == common_urns.MULTIMAP_SIDE_INPUT or
+            access_pattern ==
+            dataflow_runner._DataflowSideInput.DATAFLOW_MULTIMAP_URN):
         cache = {}
         key_coder_impl = self._element_coder.key_coder().get_impl()
         value_coder = self._element_coder.value_coder()
@@ -433,7 +436,7 @@ def create(factory, transform_id, transform_proto, 
parameter, consumers):
   source = pickler.loads(base64.b64encode(parameter))
   spec = operation_specs.WorkerRead(
       iobase.SourceBundle(1.0, source, None, None),
-      [WindowedValueCoder(source.default_output_coder())])
+      [factory.get_only_output_coder(transform_proto)])
   return factory.augment_oldstyle_op(
       operations.ReadOperation(
           transform_proto.unique_name,
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 1988490013c..772f85e40f8 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -20,6 +20,7 @@
 from __future__ import division
 from __future__ import print_function
 
+import contextlib
 import logging
 import Queue as queue
 import sys
@@ -213,7 +214,8 @@ def process_bundle(self, request, instruction_id):
             self.fns[request.process_bundle_descriptor_reference],
             self.state_handler, self.data_channel_factory)
     try:
-      processor.process_bundle(instruction_id)
+      with self.state_handler.process_instruction_id(instruction_id):
+        processor.process_bundle(instruction_id)
     finally:
       del self.bundle_processors[instruction_id]
 
@@ -242,6 +244,18 @@ def __init__(self, state_stub):
     self._responses_by_id = {}
     self._last_id = 0
     self._exc_info = None
+    self._context = threading.local()
+
+  @contextlib.contextmanager
+  def process_instruction_id(self, bundle_id):
+    if getattr(self._context, 'process_instruction_id', None) is not None:
+      raise RuntimeError(
+          'Already bound to %r' % self._context.process_instruction_id)
+    self._context.process_instruction_id = bundle_id
+    try:
+      yield
+    finally:
+      self._context.process_instruction_id = None
 
   def start(self):
     self._done = False
@@ -273,32 +287,30 @@ def done(self):
     self._done = True
     self._requests.put(self._DONE)
 
-  def blocking_get(self, state_key, instruction_reference):
+  def blocking_get(self, state_key):
     response = self._blocking_request(
         beam_fn_api_pb2.StateRequest(
-            instruction_reference=instruction_reference,
             state_key=state_key,
             get=beam_fn_api_pb2.StateGetRequest()))
     if response.get.continuation_token:
       raise NotImplementedError
     return response.get.data
 
-  def blocking_append(self, state_key, data, instruction_reference):
+  def blocking_append(self, state_key, data):
     self._blocking_request(
         beam_fn_api_pb2.StateRequest(
-            instruction_reference=instruction_reference,
             state_key=state_key,
             append=beam_fn_api_pb2.StateAppendRequest(data=data)))
 
-  def blocking_clear(self, state_key, instruction_reference):
+  def blocking_clear(self, state_key):
     self._blocking_request(
         beam_fn_api_pb2.StateRequest(
-            instruction_reference=instruction_reference,
             state_key=state_key,
             clear=beam_fn_api_pb2.StateClearRequest()))
 
   def _blocking_request(self, request):
     request.id = self._next_id()
+    request.instruction_reference = self._context.process_instruction_id
     self._responses_by_id[request.id] = future = _Future()
     self._requests.put(request)
     while not future.wait(timeout=1):
@@ -308,7 +320,11 @@ def _blocking_request(self, request):
       elif self._done:
         raise RuntimeError()
     del self._responses_by_id[request.id]
-    return future.get()
+    response = future.get()
+    if response.error:
+      raise RuntimeError(response.error)
+    else:
+      return response
 
   def _next_id(self):
     self._last_id += 1
diff --git a/sdks/python/apache_beam/typehints/typehints.py 
b/sdks/python/apache_beam/typehints/typehints.py
index 3455672e7a8..543641526f5 100644
--- a/sdks/python/apache_beam/typehints/typehints.py
+++ b/sdks/python/apache_beam/typehints/typehints.py
@@ -1098,3 +1098,36 @@ def is_consistent_with(sub, base):
     # Nothing but object lives above any type constraints.
     return base == object
   return issubclass(sub, base)
+
+
+def coerce_to_kv_type(element_type, label=None):
+  """Attempts to coerce element_type to a compatible kv type.
+
+  Raises an error on failure.
+  """
+  # If element_type is not specified, then treat it as `Any`.
+  if not element_type:
+    return KV[Any, Any]
+  elif isinstance(element_type, TupleHint.TupleConstraint):
+    if len(element_type.tuple_types) == 2:
+      return element_type
+    else:
+      raise ValueError(
+          "Tuple input to %r must be have two components. "
+          "Found %s." % (label, element_type))
+  elif isinstance(element_type, AnyTypeConstraint):
+    # `Any` type needs to be replaced with a KV[Any, Any] to
+    # satisfy the KV form.
+    return KV[Any, Any]
+  elif isinstance(element_type, UnionConstraint):
+    union_types = [
+        coerce_to_kv_type(t) for t in element_type.union_types]
+    return KV[
+        Union[tuple(t.tuple_types[0] for t in union_types)],
+        Union[tuple(t.tuple_types[1] for t in union_types)]]
+  else:
+    # TODO: Possibly handle other valid types.
+    print "element_type", element_type
+    raise ValueError(
+        "Input to %r must be compatible with KV[Any, Any]. "
+        "Found %s." % (label, element_type))


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Issue Time Tracking
-------------------

    Worklog Id:     (was: 82553)
    Time Spent: 1h 40m  (was: 1.5h)

> Python SDK support for portable side input
> ------------------------------------------
>
>                 Key: BEAM-2927
>                 URL: https://issues.apache.org/jira/browse/BEAM-2927
>             Project: Beam
>          Issue Type: Sub-task
>          Components: sdk-py-core
>            Reporter: Henning Rohde
>            Assignee: Robert Bradshaw
>            Priority: Major
>              Labels: portability
>          Time Spent: 1h 40m
>  Remaining Estimate: 0h
>




--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to