yifanmai commented on a change in pull request #12787:
URL: https://github.com/apache/beam/pull/12787#discussion_r485943641



##########
File path: 
sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
##########
@@ -704,6 +705,64 @@ def fix_side_input_pcoll_coders(stages, pipeline_context):
   return stages
 
 
+def eliminate_common_key_with_none(stages, context):
+  # type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
+
+  """Runs common subexpression elimination for sibling KeyWithNone stages.
+
+  If multiple KeyWithNone stages share a common input, then all but one stages
+  will be eliminated along with their output PCollections. Transforms that
+  read input from the output of the eliminated KeyWithNone stages will be
+  remapped to read input from the output of the remaining KeyWithNone stage.
+  """
+  # Partition stages by whether they are eligible for common KeyWithNone
+  # elimination, and group eligible KeyWithNone stages by parent and
+  # environment.
+  grouped_eligible_stages = collections.defaultdict(list)
+  ineligible_stages = []
+  for stage in stages:
+    is_eligible = False
+    if len(stage.transforms) == 1:
+      transform = only_transform(stage.transforms)
+      if (transform.spec.urn == common_urns.primitives.PAR_DO.urn and
+          len(transform.inputs) == 1 and len(transform.outputs) == 1):
+        pardo_payload = proto_utils.parse_Bytes(
+            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
+        if pardo_payload.do_fn.urn == python_urns.KEY_WITH_NONE_DOFN:
+          is_eligible = True
+
+    if is_eligible:
+      input_pcoll_id = only_element(transform.inputs.values())
+      stage_key = (input_pcoll_id, stage.environment)
+      grouped_eligible_stages[stage_key].append(stage)
+    else:
+      ineligible_stages.append(stage)
+
+  # Eliminate stages and build the PCollection remapping dictionary.
+  pcoll_id_remap = {}
+  for sibling_stages in grouped_eligible_stages.values():
+    output_pcoll_ids = [
+        only_element(stage.transforms[0].outputs.values())
+        for stage in sibling_stages
+    ]
+    for to_delete_pcoll_id in output_pcoll_ids[1:]:
+      pcoll_id_remap[to_delete_pcoll_id] = output_pcoll_ids[0]
+      del context.components.pcollections[to_delete_pcoll_id]
+    del sibling_stages[1:]

Review comment:
       We need to use the loop below to remap the PCollections so I made a new 
flat list.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to