robertwb commented on a change in pull request #12185:
URL: https://github.com/apache/beam/pull/12185#discussion_r451758742



##########
File path: 
sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
##########
@@ -685,6 +687,264 @@ def fix_side_input_pcoll_coders(stages, pipeline_context):
   return stages
 
 
+def eliminate_common_siblings(stages, context):
+  # type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
+  """Runs common subexpression elimination for common siblings.
+
+  If stages have common input, an identical transform, and one output each,
+  then all but one stages will be eliminated, and the output of the remaining
+  will be connected to the original output PCollections of the eliminated
+  stages. This elimination runs only once, not recursively, and will only
+  eliminate the first stage after a common input, rather than a chain of
+  stages.
+  """
+
+  SiblingKey = collections.namedtuple(
+      'SiblingKey', ['spec_urn', 'spec_payload', 'inputs', 'environment_id'])
+
+  def get_sibling_key(transform):
+    """Returns a key that will be identical for common siblings."""
+    transform_output_keys = list(transform.outputs.keys())
+    # Return None as the sibling key for ineligible transforms.
+    if len(transform_output_keys
+          ) != 1 or transform.spec.urn != common_urns.primitives.PAR_DO.urn:
+      return None
+    return SiblingKey(
+        spec_urn=transform.spec.urn,
+        spec_payload=transform.spec.payload,
+        inputs=tuple(transform.inputs.items()),
+        environment_id=transform.environment_id)
+
+  # Group stages by keys.
+  stages_by_sibling_key = collections.defaultdict(list)
+  for stage in stages:
+    transform = only_transform(stage.transforms)
+    stages_by_sibling_key[get_sibling_key(transform)].append(stage)
+
+  # Eliminate stages and build the output PCollection remapping dictionary.
+  pcoll_id_remap = {}
+  for sibling_key, sibling_stages in stages_by_sibling_key.items():
+    if sibling_key is None or len(sibling_stages) == 1:
+      continue
+    output_pcoll_ids = [
+        only_element(stage.transforms[0].outputs.values())
+        for stage in sibling_stages
+    ]
+    to_delete_pcoll_ids = output_pcoll_ids[1:]
+    for to_delete_pcoll_id in to_delete_pcoll_ids:
+      pcoll_id_remap[to_delete_pcoll_id] = output_pcoll_ids[0]
+      del context.components.pcollections[to_delete_pcoll_id]
+    del sibling_stages[1:]
+
+  # Yield stages while remapping output PCollections if needed.
+  for sibling_key, sibling_stages in stages_by_sibling_key.items():
+    for stage in sibling_stages:
+      input_keys_to_remap = []
+      for input_key, input_pcoll_id in stage.transforms[0].inputs.items():
+        if input_pcoll_id in pcoll_id_remap:
+          input_keys_to_remap.append(input_key)
+      for input_key_to_remap in input_keys_to_remap:
+        stage.transforms[0].inputs[input_key_to_remap] = pcoll_id_remap[
+            stage.transforms[0].inputs[input_key_to_remap]]
+      yield stage
+
+
+def pack_combiners(stages, context):
+  # type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
+  """Packs sibling CombinePerKey stages into a single CombinePerKey.
+
+  If CombinePerKey stages have a common input, one input each, and one output
+  each, pack the stages into a single stage that runs all CombinePerKeys and
+  outputs resulting tuples to a new PCollection. A subsequent stage unpacks
+  tuples from this PCollection and sends them to the original output
+  PCollections.
+  """
+
+  class _UnpackFn(core.DoFn):
+    """A DoFn that unpacks a packed to multiple tagged outputs.
+
+    Example:
+      tags = (T1, T2, ...)
+      input = (K, (V1, V2, ...))
+      output = TaggedOutput(T1, (K, V1)), TaggedOutput(T2, (K, V1)), ...
+    """
+
+    def __init__(self, tags):
+      self._tags = tags
+
+    def process(self, element):
+      key, values = element
+      return [
+          core.pvalue.TaggedOutput(tag, (key, value))
+          for tag, value in zip(self._tags, values)
+      ]
+
+  def _get_fallback_coder_id():
+    return context.add_or_get_coder_id(
+        coders.registry.get_coder(object).to_runner_api(None))
+
+  def _get_component_coder_id_from_kv_coder(coder, index):
+    assert index < 2
+    if coder.spec.urn == common_urns.coders.KV.urn and len(
+        coder.component_coder_ids) == 2:
+      return coder.component_coder_ids[index]
+    return _get_fallback_coder_id()
+
+  def _get_key_coder_id_from_kv_coder(coder):
+    return _get_component_coder_id_from_kv_coder(coder, 0)
+
+  def _get_value_coder_id_from_kv_coder(coder):
+    return _get_component_coder_id_from_kv_coder(coder, 1)
+
+  def _try_fuse_stages(a, b):
+    if a.can_fuse(b, context):
+      return a.fuse(b)
+    else:
+      raise ValueError
+
+  def _try_merge_environments(env1, env2):

Review comment:
       Was this copied from above? If needed, perhaps refactor? (Similarly for 
try_fuse_stages.)

##########
File path: 
sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
##########
@@ -685,6 +687,264 @@ def fix_side_input_pcoll_coders(stages, pipeline_context):
   return stages
 
 
+def eliminate_common_siblings(stages, context):
+  # type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
+  """Runs common subexpression elimination for common siblings.
+
+  If stages have common input, an identical transform, and one output each,
+  then all but one stages will be eliminated, and the output of the remaining
+  will be connected to the original output PCollections of the eliminated
+  stages. This elimination runs only once, not recursively, and will only
+  eliminate the first stage after a common input, rather than a chain of
+  stages.
+  """
+
+  SiblingKey = collections.namedtuple(
+      'SiblingKey', ['spec_urn', 'spec_payload', 'inputs', 'environment_id'])
+
+  def get_sibling_key(transform):
+    """Returns a key that will be identical for common siblings."""
+    transform_output_keys = list(transform.outputs.keys())
+    # Return None as the sibling key for ineligible transforms.
+    if len(transform_output_keys
+          ) != 1 or transform.spec.urn != common_urns.primitives.PAR_DO.urn:
+      return None
+    return SiblingKey(
+        spec_urn=transform.spec.urn,
+        spec_payload=transform.spec.payload,
+        inputs=tuple(transform.inputs.items()),
+        environment_id=transform.environment_id)
+
+  # Group stages by keys.
+  stages_by_sibling_key = collections.defaultdict(list)
+  for stage in stages:
+    transform = only_transform(stage.transforms)
+    stages_by_sibling_key[get_sibling_key(transform)].append(stage)
+
+  # Eliminate stages and build the output PCollection remapping dictionary.
+  pcoll_id_remap = {}

Review comment:
       (Just a thought) I wonder if this should be a global tracked in the 
context for being able to properly reconstruct composites. 

##########
File path: 
sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
##########
@@ -289,6 +289,8 @@ def create_stages(
         phases=[
             translations.annotate_downstream_side_inputs,
             translations.fix_side_input_pcoll_coders,
+            translations.eliminate_common_siblings,

Review comment:
       This transform is not safe to do ubiquitously (due to the possibility of 
side effects).  

##########
File path: 
sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
##########
@@ -685,6 +687,264 @@ def fix_side_input_pcoll_coders(stages, pipeline_context):
   return stages
 
 
+def eliminate_common_siblings(stages, context):
+  # type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
+  """Runs common subexpression elimination for common siblings.
+
+  If stages have common input, an identical transform, and one output each,
+  then all but one stages will be eliminated, and the output of the remaining
+  will be connected to the original output PCollections of the eliminated
+  stages. This elimination runs only once, not recursively, and will only
+  eliminate the first stage after a common input, rather than a chain of
+  stages.
+  """
+
+  SiblingKey = collections.namedtuple(
+      'SiblingKey', ['spec_urn', 'spec_payload', 'inputs', 'environment_id'])
+
+  def get_sibling_key(transform):
+    """Returns a key that will be identical for common siblings."""
+    transform_output_keys = list(transform.outputs.keys())
+    # Return None as the sibling key for ineligible transforms.
+    if len(transform_output_keys
+          ) != 1 or transform.spec.urn != common_urns.primitives.PAR_DO.urn:
+      return None
+    return SiblingKey(
+        spec_urn=transform.spec.urn,
+        spec_payload=transform.spec.payload,
+        inputs=tuple(transform.inputs.items()),
+        environment_id=transform.environment_id)

Review comment:
       One can omit the environment here--if two specs and payloads are 
identical, they are the same operation.

##########
File path: 
sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
##########
@@ -685,6 +687,264 @@ def fix_side_input_pcoll_coders(stages, pipeline_context):
   return stages
 
 
+def eliminate_common_siblings(stages, context):
+  # type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
+  """Runs common subexpression elimination for common siblings.
+
+  If stages have common input, an identical transform, and one output each,
+  then all but one stages will be eliminated, and the output of the remaining
+  will be connected to the original output PCollections of the eliminated
+  stages. This elimination runs only once, not recursively, and will only
+  eliminate the first stage after a common input, rather than a chain of
+  stages.
+  """
+
+  SiblingKey = collections.namedtuple(
+      'SiblingKey', ['spec_urn', 'spec_payload', 'inputs', 'environment_id'])
+
+  def get_sibling_key(transform):
+    """Returns a key that will be identical for common siblings."""
+    transform_output_keys = list(transform.outputs.keys())
+    # Return None as the sibling key for ineligible transforms.
+    if len(transform_output_keys
+          ) != 1 or transform.spec.urn != common_urns.primitives.PAR_DO.urn:
+      return None
+    return SiblingKey(
+        spec_urn=transform.spec.urn,
+        spec_payload=transform.spec.payload,
+        inputs=tuple(transform.inputs.items()),
+        environment_id=transform.environment_id)
+
+  # Group stages by keys.
+  stages_by_sibling_key = collections.defaultdict(list)
+  for stage in stages:
+    transform = only_transform(stage.transforms)
+    stages_by_sibling_key[get_sibling_key(transform)].append(stage)
+
+  # Eliminate stages and build the output PCollection remapping dictionary.
+  pcoll_id_remap = {}
+  for sibling_key, sibling_stages in stages_by_sibling_key.items():
+    if sibling_key is None or len(sibling_stages) == 1:
+      continue
+    output_pcoll_ids = [
+        only_element(stage.transforms[0].outputs.values())
+        for stage in sibling_stages
+    ]
+    to_delete_pcoll_ids = output_pcoll_ids[1:]
+    for to_delete_pcoll_id in to_delete_pcoll_ids:
+      pcoll_id_remap[to_delete_pcoll_id] = output_pcoll_ids[0]
+      del context.components.pcollections[to_delete_pcoll_id]
+    del sibling_stages[1:]
+
+  # Yield stages while remapping output PCollections if needed.
+  for sibling_key, sibling_stages in stages_by_sibling_key.items():
+    for stage in sibling_stages:
+      input_keys_to_remap = []
+      for input_key, input_pcoll_id in stage.transforms[0].inputs.items():
+        if input_pcoll_id in pcoll_id_remap:
+          input_keys_to_remap.append(input_key)
+      for input_key_to_remap in input_keys_to_remap:
+        stage.transforms[0].inputs[input_key_to_remap] = pcoll_id_remap[
+            stage.transforms[0].inputs[input_key_to_remap]]
+      yield stage
+
+
+def pack_combiners(stages, context):
+  # type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
+  """Packs sibling CombinePerKey stages into a single CombinePerKey.
+
+  If CombinePerKey stages have a common input, one input each, and one output
+  each, pack the stages into a single stage that runs all CombinePerKeys and
+  outputs resulting tuples to a new PCollection. A subsequent stage unpacks
+  tuples from this PCollection and sends them to the original output
+  PCollections.
+  """
+
+  class _UnpackFn(core.DoFn):
+    """A DoFn that unpacks a packed to multiple tagged outputs.
+
+    Example:
+      tags = (T1, T2, ...)
+      input = (K, (V1, V2, ...))
+      output = TaggedOutput(T1, (K, V1)), TaggedOutput(T2, (K, V1)), ...
+    """
+
+    def __init__(self, tags):
+      self._tags = tags
+
+    def process(self, element):
+      key, values = element
+      return [
+          core.pvalue.TaggedOutput(tag, (key, value))
+          for tag, value in zip(self._tags, values)
+      ]
+
+  def _get_fallback_coder_id():
+    return context.add_or_get_coder_id(
+        coders.registry.get_coder(object).to_runner_api(None))
+
+  def _get_component_coder_id_from_kv_coder(coder, index):
+    assert index < 2
+    if coder.spec.urn == common_urns.coders.KV.urn and len(
+        coder.component_coder_ids) == 2:
+      return coder.component_coder_ids[index]
+    return _get_fallback_coder_id()
+
+  def _get_key_coder_id_from_kv_coder(coder):
+    return _get_component_coder_id_from_kv_coder(coder, 0)
+
+  def _get_value_coder_id_from_kv_coder(coder):
+    return _get_component_coder_id_from_kv_coder(coder, 1)
+
+  def _try_fuse_stages(a, b):
+    if a.can_fuse(b, context):
+      return a.fuse(b)
+    else:
+      raise ValueError
+
+  def _try_merge_environments(env1, env2):
+    if env1 is None:
+      return env2
+    elif env2 is None:
+      return env1
+    else:
+      if env1 != env2:
+        raise ValueError
+      return env1
+
+  # Group stages by parent, yielding ineligible stages.
+  combine_stages_by_input_pcoll_id = collections.defaultdict(list)
+  for stage in stages:
+    transform = only_transform(stage.transforms)
+    if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn and 
len(
+        transform.inputs) == 1 and len(transform.outputs) == 1:
+      input_pcoll_id = only_element(transform.inputs.values())
+      combine_stages_by_input_pcoll_id[input_pcoll_id].append(stage)
+    else:
+      yield stage
+
+  for input_pcoll_id, packable_stages in 
combine_stages_by_input_pcoll_id.items(
+  ):
+    # Yield stage and continue if it has no siblings.
+    if len(packable_stages) == 1:
+      yield packable_stages[0]
+      continue
+
+    transforms = [only_transform(stage.transforms) for stage in 
packable_stages]
+
+    # Yield stages and continue if they cannot be packed.
+    try:
+      # Fused stage is used as template and is not yielded.
+      fused_stage = functools.reduce(_try_fuse_stages, packable_stages)
+      merged_transform_environment_id = functools.reduce(
+          _try_merge_environments,
+          [transform.environment_id or None for transform in transforms])
+    except ValueError:
+      for stage in packable_stages:
+        yield stage
+      continue
+
+    output_pcoll_ids = [
+        only_element(transform.outputs.values()) for transform in transforms
+    ]
+    combine_payloads = [
+        proto_utils.parse_Bytes(transform.spec.payload,
+                                beam_runner_api_pb2.CombinePayload)
+        for transform in transforms
+    ]
+
+    # Build accumulator coder for (acc1, acc2, ...)
+    accumulator_coder_ids = [
+        combine_payload.accumulator_coder_id
+        for combine_payload in combine_payloads
+    ]
+    tuple_accumulator_coder_id = context.add_or_get_coder_id(
+        beam_runner_api_pb2.Coder(
+            spec=beam_runner_api_pb2.FunctionSpec(
+                urn=common_urns.coders.KV.urn),
+            component_coder_ids=accumulator_coder_ids))
+
+    # Build packed output coder for (key, (out1, out2, ...))
+    input_kv_coder_id = 
context.components.pcollections[input_pcoll_id].coder_id
+    key_coder_id = _get_key_coder_id_from_kv_coder(
+        context.components.coders[input_kv_coder_id])
+    output_kv_coder_ids = [
+        context.components.pcollections[output_pcoll_id].coder_id
+        for output_pcoll_id in output_pcoll_ids
+    ]
+    output_value_coder_ids = [
+        _get_value_coder_id_from_kv_coder(
+            context.components.coders[output_kv_coder_id])
+        for output_kv_coder_id in output_kv_coder_ids
+    ]
+    pack_output_value_coder = beam_runner_api_pb2.Coder(
+        spec=beam_runner_api_pb2.FunctionSpec(urn=common_urns.coders.KV.urn),
+        component_coder_ids=output_value_coder_ids)
+    pack_output_value_coder_id = context.add_or_get_coder_id(
+        pack_output_value_coder)
+    pack_output_kv_coder = beam_runner_api_pb2.Coder(
+        spec=beam_runner_api_pb2.FunctionSpec(urn=common_urns.coders.KV.urn),
+        component_coder_ids=[key_coder_id, pack_output_value_coder_id])
+    pack_output_kv_coder_id = context.add_or_get_coder_id(pack_output_kv_coder)
+
+    # Set up packed PCollection
+    pack_combine_name = fused_stage.name
+    pack_pcoll_id = unique_name(context.components.pcollections, 'pcollection')
+    input_pcoll = context.components.pcollections[input_pcoll_id]
+    context.components.pcollections[pack_pcoll_id].CopyFrom(
+        beam_runner_api_pb2.PCollection(
+            unique_name=pack_combine_name + '.out',
+            coder_id=pack_output_kv_coder_id,
+            windowing_strategy_id=input_pcoll.windowing_strategy_id,
+            is_bounded=input_pcoll.is_bounded))
+
+    # Set up Pack stage.
+    pack_combine_fn = combiners.SingleInputTupleCombineFn(*[
+        core.CombineFn.from_runner_api(combine_payload.combine_fn, context)
+        for combine_payload in combine_payloads
+    ]).to_runner_api(context)
+    pack_transform = beam_runner_api_pb2.PTransform(
+        unique_name=pack_combine_name + '/Pack',
+        spec=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.composites.COMBINE_PER_KEY.urn,
+            payload=beam_runner_api_pb2.CombinePayload(
+                combine_fn=pack_combine_fn,
+                accumulator_coder_id=tuple_accumulator_coder_id)
+            .SerializeToString()),
+        inputs={'in': input_pcoll_id},
+        outputs={'out': pack_pcoll_id},
+        environment_id=merged_transform_environment_id)
+    pack_stage = Stage(
+        pack_combine_name + '/Pack', [pack_transform],
+        downstream_side_inputs=fused_stage.downstream_side_inputs,
+        must_follow=fused_stage.must_follow,
+        parent=fused_stage,
+        environment=fused_stage.environment)
+    yield pack_stage
+
+    # Set up Unpack stage
+    tags = [str(i) for i in range(len(output_pcoll_ids))]
+    pickled_do_fn_data = pickler.dumps((_UnpackFn(tags), (), {}, [], None))
+    unpack_transform = beam_runner_api_pb2.PTransform(
+        unique_name=pack_combine_name + '/Unpack',
+        spec=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.primitives.PAR_DO.urn,
+            payload=beam_runner_api_pb2.ParDoPayload(
+                do_fn=beam_runner_api_pb2.FunctionSpec(
+                    urn=python_urns.PICKLED_DOFN_INFO,
+                    payload=pickled_do_fn_data)).SerializeToString()),
+        inputs={'in': pack_pcoll_id},
+        outputs=dict(zip(tags, output_pcoll_ids)),
+        environment_id=merged_transform_environment_id)

Review comment:
       Similarly, this environment must be a Python environment for this to 
work.

##########
File path: 
sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
##########
@@ -685,6 +687,264 @@ def fix_side_input_pcoll_coders(stages, pipeline_context):
   return stages
 
 
+def eliminate_common_siblings(stages, context):
+  # type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
+  """Runs common subexpression elimination for common siblings.
+
+  If stages have common input, an identical transform, and one output each,
+  then all but one stages will be eliminated, and the output of the remaining
+  will be connected to the original output PCollections of the eliminated
+  stages. This elimination runs only once, not recursively, and will only
+  eliminate the first stage after a common input, rather than a chain of
+  stages.
+  """
+
+  SiblingKey = collections.namedtuple(
+      'SiblingKey', ['spec_urn', 'spec_payload', 'inputs', 'environment_id'])
+
+  def get_sibling_key(transform):
+    """Returns a key that will be identical for common siblings."""
+    transform_output_keys = list(transform.outputs.keys())
+    # Return None as the sibling key for ineligible transforms.
+    if len(transform_output_keys
+          ) != 1 or transform.spec.urn != common_urns.primitives.PAR_DO.urn:
+      return None
+    return SiblingKey(
+        spec_urn=transform.spec.urn,
+        spec_payload=transform.spec.payload,
+        inputs=tuple(transform.inputs.items()),
+        environment_id=transform.environment_id)
+
+  # Group stages by keys.
+  stages_by_sibling_key = collections.defaultdict(list)
+  for stage in stages:
+    transform = only_transform(stage.transforms)
+    stages_by_sibling_key[get_sibling_key(transform)].append(stage)
+
+  # Eliminate stages and build the output PCollection remapping dictionary.
+  pcoll_id_remap = {}
+  for sibling_key, sibling_stages in stages_by_sibling_key.items():
+    if sibling_key is None or len(sibling_stages) == 1:
+      continue
+    output_pcoll_ids = [
+        only_element(stage.transforms[0].outputs.values())
+        for stage in sibling_stages
+    ]
+    to_delete_pcoll_ids = output_pcoll_ids[1:]
+    for to_delete_pcoll_id in to_delete_pcoll_ids:
+      pcoll_id_remap[to_delete_pcoll_id] = output_pcoll_ids[0]
+      del context.components.pcollections[to_delete_pcoll_id]
+    del sibling_stages[1:]
+
+  # Yield stages while remapping output PCollections if needed.
+  for sibling_key, sibling_stages in stages_by_sibling_key.items():
+    for stage in sibling_stages:
+      input_keys_to_remap = []
+      for input_key, input_pcoll_id in stage.transforms[0].inputs.items():
+        if input_pcoll_id in pcoll_id_remap:
+          input_keys_to_remap.append(input_key)
+      for input_key_to_remap in input_keys_to_remap:
+        stage.transforms[0].inputs[input_key_to_remap] = pcoll_id_remap[
+            stage.transforms[0].inputs[input_key_to_remap]]
+      yield stage
+
+
+def pack_combiners(stages, context):
+  # type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
+  """Packs sibling CombinePerKey stages into a single CombinePerKey.
+
+  If CombinePerKey stages have a common input, one input each, and one output
+  each, pack the stages into a single stage that runs all CombinePerKeys and
+  outputs resulting tuples to a new PCollection. A subsequent stage unpacks
+  tuples from this PCollection and sends them to the original output
+  PCollections.
+  """
+
+  class _UnpackFn(core.DoFn):
+    """A DoFn that unpacks a packed to multiple tagged outputs.
+
+    Example:
+      tags = (T1, T2, ...)
+      input = (K, (V1, V2, ...))
+      output = TaggedOutput(T1, (K, V1)), TaggedOutput(T2, (K, V1)), ...
+    """
+
+    def __init__(self, tags):
+      self._tags = tags
+
+    def process(self, element):
+      key, values = element
+      return [
+          core.pvalue.TaggedOutput(tag, (key, value))
+          for tag, value in zip(self._tags, values)
+      ]
+
+  def _get_fallback_coder_id():
+    return context.add_or_get_coder_id(
+        coders.registry.get_coder(object).to_runner_api(None))
+
+  def _get_component_coder_id_from_kv_coder(coder, index):
+    assert index < 2
+    if coder.spec.urn == common_urns.coders.KV.urn and len(
+        coder.component_coder_ids) == 2:
+      return coder.component_coder_ids[index]
+    return _get_fallback_coder_id()
+
+  def _get_key_coder_id_from_kv_coder(coder):
+    return _get_component_coder_id_from_kv_coder(coder, 0)
+
+  def _get_value_coder_id_from_kv_coder(coder):
+    return _get_component_coder_id_from_kv_coder(coder, 1)
+
+  def _try_fuse_stages(a, b):
+    if a.can_fuse(b, context):
+      return a.fuse(b)
+    else:
+      raise ValueError
+
+  def _try_merge_environments(env1, env2):
+    if env1 is None:
+      return env2
+    elif env2 is None:
+      return env1
+    else:
+      if env1 != env2:
+        raise ValueError
+      return env1
+
+  # Group stages by parent, yielding ineligible stages.
+  combine_stages_by_input_pcoll_id = collections.defaultdict(list)
+  for stage in stages:
+    transform = only_transform(stage.transforms)
+    if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn and 
len(
+        transform.inputs) == 1 and len(transform.outputs) == 1:
+      input_pcoll_id = only_element(transform.inputs.values())
+      combine_stages_by_input_pcoll_id[input_pcoll_id].append(stage)
+    else:
+      yield stage
+
+  for input_pcoll_id, packable_stages in 
combine_stages_by_input_pcoll_id.items(
+  ):
+    # Yield stage and continue if it has no siblings.
+    if len(packable_stages) == 1:
+      yield packable_stages[0]
+      continue
+
+    transforms = [only_transform(stage.transforms) for stage in 
packable_stages]
+
+    # Yield stages and continue if they cannot be packed.
+    try:
+      # Fused stage is used as template and is not yielded.
+      fused_stage = functools.reduce(_try_fuse_stages, packable_stages)
+      merged_transform_environment_id = functools.reduce(
+          _try_merge_environments,
+          [transform.environment_id or None for transform in transforms])
+    except ValueError:
+      for stage in packable_stages:
+        yield stage
+      continue
+
+    output_pcoll_ids = [
+        only_element(transform.outputs.values()) for transform in transforms
+    ]
+    combine_payloads = [
+        proto_utils.parse_Bytes(transform.spec.payload,
+                                beam_runner_api_pb2.CombinePayload)
+        for transform in transforms
+    ]
+
+    # Build accumulator coder for (acc1, acc2, ...)
+    accumulator_coder_ids = [
+        combine_payload.accumulator_coder_id
+        for combine_payload in combine_payloads
+    ]
+    tuple_accumulator_coder_id = context.add_or_get_coder_id(
+        beam_runner_api_pb2.Coder(
+            spec=beam_runner_api_pb2.FunctionSpec(
+                urn=common_urns.coders.KV.urn),
+            component_coder_ids=accumulator_coder_ids))
+
+    # Build packed output coder for (key, (out1, out2, ...))
+    input_kv_coder_id = 
context.components.pcollections[input_pcoll_id].coder_id
+    key_coder_id = _get_key_coder_id_from_kv_coder(
+        context.components.coders[input_kv_coder_id])
+    output_kv_coder_ids = [
+        context.components.pcollections[output_pcoll_id].coder_id
+        for output_pcoll_id in output_pcoll_ids
+    ]
+    output_value_coder_ids = [
+        _get_value_coder_id_from_kv_coder(
+            context.components.coders[output_kv_coder_id])
+        for output_kv_coder_id in output_kv_coder_ids
+    ]
+    pack_output_value_coder = beam_runner_api_pb2.Coder(
+        spec=beam_runner_api_pb2.FunctionSpec(urn=common_urns.coders.KV.urn),
+        component_coder_ids=output_value_coder_ids)
+    pack_output_value_coder_id = context.add_or_get_coder_id(
+        pack_output_value_coder)
+    pack_output_kv_coder = beam_runner_api_pb2.Coder(
+        spec=beam_runner_api_pb2.FunctionSpec(urn=common_urns.coders.KV.urn),
+        component_coder_ids=[key_coder_id, pack_output_value_coder_id])
+    pack_output_kv_coder_id = context.add_or_get_coder_id(pack_output_kv_coder)
+
+    # Set up packed PCollection
+    pack_combine_name = fused_stage.name
+    pack_pcoll_id = unique_name(context.components.pcollections, 'pcollection')
+    input_pcoll = context.components.pcollections[input_pcoll_id]
+    context.components.pcollections[pack_pcoll_id].CopyFrom(
+        beam_runner_api_pb2.PCollection(
+            unique_name=pack_combine_name + '.out',
+            coder_id=pack_output_kv_coder_id,
+            windowing_strategy_id=input_pcoll.windowing_strategy_id,
+            is_bounded=input_pcoll.is_bounded))
+
+    # Set up Pack stage.
+    pack_combine_fn = combiners.SingleInputTupleCombineFn(*[
+        core.CombineFn.from_runner_api(combine_payload.combine_fn, context)

Review comment:
       This will only work for Python combine fns (and should be guarded 
against that). 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to