[ 
https://issues.apache.org/jira/browse/BEAM-3914?focusedWorklogId=95324&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-95324
 ]

ASF GitHub Bot logged work on BEAM-3914:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 26/Apr/18 02:04
            Start Date: 26/Apr/18 02:04
    Worklog Time Spent: 10m 
      Work Description: tgroh closed pull request #4977: [BEAM-3914] 
Deduplicate Unzipped Flattens after Pipeline Fusion
URL: https://github.com/apache/beam/pull/4977
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/FusedPipeline.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/FusedPipeline.java
index 68da5c3961b..ddc03355a90 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/FusedPipeline.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/FusedPipeline.java
@@ -98,28 +98,16 @@ static FusedPipeline of(
   private Map<String, PTransform> getEnvironmentExecutedTransforms() {
     Map<String, PTransform> topLevelTransforms = new HashMap<>();
     for (ExecutableStage stage : getFusedStages()) {
+      String baseName =
+          String.format(
+              "%s/%s",
+              stage.getInputPCollection().getPCollection().getUniqueName(),
+              stage.getEnvironment().getUrl());
+      Set<String> usedNames =
+          Sets.union(topLevelTransforms.keySet(), 
getComponents().getTransformsMap().keySet());
       topLevelTransforms.put(
-          generateStageId(
-              stage,
-              Sets.union(getComponents().getTransformsMap().keySet(), 
topLevelTransforms.keySet())),
-          stage.toPTransform());
+          SyntheticNodes.uniqueId(baseName, usedNames::contains), 
stage.toPTransform());
     }
     return topLevelTransforms;
   }
-
-  private String generateStageId(ExecutableStage stage, Set<String> 
existingIds) {
-    int i = 0;
-    String name;
-    do {
-      // Instead this could include the name of the root transforms
-      name =
-          String.format(
-              "%s/%s.%s",
-              stage.getInputPCollection().getPCollection().getUniqueName(),
-              stage.getEnvironment().getUrl(),
-              i);
-      i++;
-    } while (existingIds.contains(name));
-    return name;
-  }
 }
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java
index abbba2a91a9..aaa44b67c00 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java
@@ -27,6 +27,7 @@
 import com.google.common.collect.Multimap;
 import com.google.common.collect.Sets;
 import java.util.ArrayDeque;
+import java.util.Collection;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -42,6 +43,7 @@
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import 
org.apache.beam.runners.core.construction.graph.OutputDeduplicator.DeduplicationResult;
 import 
org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode;
 import 
org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
 import org.slf4j.Logger;
@@ -54,19 +56,19 @@
   private static final Logger LOG = 
LoggerFactory.getLogger(GreedyPipelineFuser.class);
 
   private final QueryablePipeline pipeline;
-  private final Map<CollectionConsumer, ExecutableStage> 
consumedCollectionsAndTransforms =
-      new HashMap<>();
-  private final Set<PTransformNode> unfusedTransforms = new LinkedHashSet<>();
-  private final Set<ExecutableStage> stages = new LinkedHashSet<>();
+  private final FusedPipeline fusedPipeline;
 
   private GreedyPipelineFuser(Pipeline p) {
     this.pipeline = QueryablePipeline.forPrimitivesIn(p.getComponents());
+    Set<PTransformNode> unfusedRootNodes = new LinkedHashSet<>();
     NavigableSet<CollectionConsumer> rootConsumers = new TreeSet<>();
     for (PTransformNode pTransformNode : pipeline.getRootTransforms()) {
       // This will usually be a single node, the downstream of an Impulse, but 
may be of any size
-      rootConsumers.addAll(getRootEnvTransforms(pTransformNode));
+      DescendantConsumers descendants = getRootConsumers(pTransformNode);
+      unfusedRootNodes.addAll(descendants.getUnfusedNodes());
+      rootConsumers.addAll(descendants.getFusibleConsumers());
     }
-    fusePipeline(groupSiblings(rootConsumers));
+    this.fusedPipeline = fusePipeline(unfusedRootNodes, 
groupSiblings(rootConsumers));
   }
 
   /**
@@ -79,8 +81,7 @@ private GreedyPipelineFuser(Pipeline p) {
    * bounded pipelines using the Read primitive.
    */
   public static FusedPipeline fuse(Pipeline p) {
-    GreedyPipelineFuser fuser = new GreedyPipelineFuser(p);
-    return FusedPipeline.of(p.getComponents(), fuser.stages, 
fuser.unfusedTransforms);
+    return new GreedyPipelineFuser(p).fusedPipeline;
   }
 
   /**
@@ -102,17 +103,22 @@ public static FusedPipeline fuse(Pipeline p) {
    *       {@link PTransformNode} may only be present in a single stage rooted 
at a single {@link
    *       PCollectionNode}, otherwise it will process elements of that {@link 
PCollectionNode}
    *       multiple times.
-   *   <li>Create a {@link GreedyStageFuser} with those siblings as the initial
-   *       consuming transforms of the stage
+   *   <li>Create a {@link GreedyStageFuser} with those siblings as the 
initial consuming transforms
+   *       of the stage
    *   <li>For each materialized {@link PCollectionNode}, find all of the 
descendant in-environment
-   *       consumers. See {@link 
#getDescendantConsumersInEnv(PCollectionNode)} for details.
+   *       consumers. See {@link #getDescendantConsumers(PCollectionNode)} for 
details.
    *   <li>Construct all of the sibling sets from the descendant 
in-environment consumers, and add
    *       them to the queue of sibling sets.
    * </ul>
    */
-  private void fusePipeline(NavigableSet<NavigableSet<CollectionConsumer>> 
initialConsumers) {
-    Queue<Set<CollectionConsumer>> pendingSiblingSets = new ArrayDeque<>();
-    pendingSiblingSets.addAll(initialConsumers);
+  private FusedPipeline fusePipeline(
+      Collection<PTransformNode> initialUnfusedTransforms,
+      NavigableSet<NavigableSet<CollectionConsumer>> initialConsumers) {
+    Map<CollectionConsumer, ExecutableStage> consumedCollectionsAndTransforms 
= new HashMap<>();
+    Set<ExecutableStage> stages = new LinkedHashSet<>();
+    Set<PTransformNode> unfusedTransforms = new 
LinkedHashSet<>(initialUnfusedTransforms);
+
+    Queue<Set<CollectionConsumer>> pendingSiblingSets = new 
ArrayDeque<>(initialConsumers);
     while (!pendingSiblingSets.isEmpty()) {
       // Only introduce new PCollection consumers. Not performing this 
introduces potential
       // duplicate paths through the pipeline.
@@ -139,21 +145,39 @@ private void 
fusePipeline(NavigableSet<NavigableSet<CollectionConsumer>> initial
       for (PCollectionNode materializedOutput : stage.getOutputPCollections()) 
{
         // Get all of the descendant consumers of each materialized 
PCollection, and add them to the
         // queue of pending siblings.
-        NavigableSet<CollectionConsumer> materializedConsumers =
-            getDescendantConsumersInEnv(materializedOutput);
+        DescendantConsumers descendantConsumers = 
getDescendantConsumers(materializedOutput);
+        unfusedTransforms.addAll(descendantConsumers.getUnfusedNodes());
         NavigableSet<NavigableSet<CollectionConsumer>> siblings =
-            groupSiblings(materializedConsumers);
+            groupSiblings(descendantConsumers.getFusibleConsumers());
 
         pendingSiblingSets.addAll(siblings);
       }
     }
+    // TODO: Figure out where to store this.
+    DeduplicationResult deduplicated =
+        OutputDeduplicator.ensureSingleProducer(pipeline, stages, 
unfusedTransforms);
     // TODO: Stages can be fused with each other, if doing so does not 
introduce duplicate paths
     // for an element to take through the Pipeline. Compatible siblings can 
generally be fused,
     // as can compatible producers/consumers if a PCollection is only 
materialized once.
+    return FusedPipeline.of(
+        deduplicated.getDeduplicatedComponents(),
+        stages
+            .stream()
+            .map(stage -> 
deduplicated.getDeduplicatedStages().getOrDefault(stage, stage))
+            .collect(Collectors.toSet()),
+        Sets.union(
+            deduplicated.getIntroducedTransforms(),
+            unfusedTransforms
+                .stream()
+                .map(
+                    transform ->
+                        deduplicated
+                            .getDeduplicatedTransforms()
+                            .getOrDefault(transform.getId(), transform))
+                .collect(Collectors.toSet())));
   }
 
-  private Set<CollectionConsumer> getRootEnvTransforms(
-      PTransformNode rootNode) {
+  private DescendantConsumers getRootConsumers(PTransformNode rootNode) {
     checkArgument(
         rootNode.getTransform().getInputsCount() == 0,
         "%s is not at the root of the graph (consumes %s)",
@@ -164,13 +188,16 @@ private void 
fusePipeline(NavigableSet<NavigableSet<CollectionConsumer>> initial
         "%s requires all root nodes to be runner-implemented %s primitives",
         GreedyPipelineFuser.class.getSimpleName(),
         PTransformTranslation.IMPULSE_TRANSFORM_URN);
-    unfusedTransforms.add(rootNode);
-    Set<CollectionConsumer> environmentNodes = new HashSet<>();
+    Set<PTransformNode> unfused = new HashSet<>();
+    unfused.add(rootNode);
+    NavigableSet<CollectionConsumer> environmentNodes = new TreeSet<>();
     // Walk down until the first environments are found, and fuse them as 
appropriate.
     for (PCollectionNode output : pipeline.getOutputPCollections(rootNode)) {
-      environmentNodes.addAll(getDescendantConsumersInEnv(output));
+      DescendantConsumers descendants = getDescendantConsumers(output);
+      unfused.addAll(descendants.getUnfusedNodes());
+      environmentNodes.addAll(descendants.getFusibleConsumers());
     }
-    return environmentNodes;
+    return DescendantConsumers.of(unfused, environmentNodes);
   }
 
   /**
@@ -197,8 +224,8 @@ private void 
fusePipeline(NavigableSet<NavigableSet<CollectionConsumer>> initial
    * reachable only via a path including that node as an intermediate node 
cannot be returned as a
    * descendant consumer of the original {@link PCollectionNode}.
    */
-  private NavigableSet<CollectionConsumer> getDescendantConsumersInEnv(
-      PCollectionNode inputPCollection) {
+  private DescendantConsumers getDescendantConsumers(PCollectionNode 
inputPCollection) {
+    Set<PTransformNode> unfused = new HashSet<>();
     NavigableSet<CollectionConsumer> downstreamConsumers = new TreeSet<>();
     for (PTransformNode consumer : 
pipeline.getPerElementConsumers(inputPCollection)) {
       if (pipeline.getEnvironment(consumer).isPresent()) {
@@ -209,14 +236,28 @@ private void 
fusePipeline(NavigableSet<NavigableSet<CollectionConsumer>> initial
             "Adding {} {} to the set of runner-executed transforms",
             PTransformNode.class.getSimpleName(),
             consumer.getId());
-        unfusedTransforms.add(consumer);
+        unfused.add(consumer);
         for (PCollectionNode output : 
pipeline.getOutputPCollections(consumer)) {
           // Recurse to all of the ouput PCollections of this PTransform.
-          downstreamConsumers.addAll(getDescendantConsumersInEnv(output));
+          DescendantConsumers descendants = getDescendantConsumers(output);
+          unfused.addAll(descendants.getUnfusedNodes());
+          downstreamConsumers.addAll(descendants.getFusibleConsumers());
         }
       }
     }
-    return downstreamConsumers;
+    return DescendantConsumers.of(unfused, downstreamConsumers);
+  }
+
+  @AutoValue
+  abstract static class DescendantConsumers {
+    static DescendantConsumers of(
+        Set<PTransformNode> unfusible, NavigableSet<CollectionConsumer> 
fusible) {
+      return new AutoValue_GreedyPipelineFuser_DescendantConsumers(unfusible, 
fusible);
+    }
+
+    abstract Set<PTransformNode> getUnfusedNodes();
+
+    abstract NavigableSet<CollectionConsumer> getFusibleConsumers();
   }
 
   /**
@@ -228,6 +269,7 @@ private void 
fusePipeline(NavigableSet<NavigableSet<CollectionConsumer>> initial
   @AutoValue
   abstract static class SiblingKey {
     abstract PCollectionNode getInputCollection();
+
     abstract Environment getEnv();
   }
 
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicator.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicator.java
new file mode 100644
index 00000000000..4419787ede1
--- /dev/null
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicator.java
@@ -0,0 +1,345 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction.graph;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.auto.value.AutoValue;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
+import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import 
org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode;
+import 
org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
+
+/**
+ * Utilities to insert synthetic {@link PCollectionNode PCollections} for 
{@link PCollection
+ * PCollections} which are produced by multiple independently executable 
stages.
+ */
+class OutputDeduplicator {
+
+  /**
+   * Ensure that no {@link PCollection} output by any of the {@code stages} or 
{@code
+   * unfusedTransforms} is produced by more than one of those stages or 
transforms.
+   *
+   * <p>For each {@link PCollection} output by multiple stages and/or 
transforms, each producer is
+   * rewritten to produce a partial {@link PCollection}, which are then 
flattened together via an
+   * introduced Flatten node which produces the original output.
+   */
+  static DeduplicationResult ensureSingleProducer(
+      QueryablePipeline pipeline,
+      Collection<ExecutableStage> stages,
+      Collection<PTransformNode> unfusedTransforms) {
+    RunnerApi.Components.Builder unzippedComponents = 
pipeline.getComponents().toBuilder();
+
+    Multimap<PCollectionNode, StageOrTransform> pcollectionProducers =
+        getProducers(pipeline, stages, unfusedTransforms);
+    Multimap<StageOrTransform, PCollectionNode> requiresNewOutput = 
HashMultimap.create();
+    // Create a synthetic PCollection for each of these nodes. The transforms 
in the runner
+    // portion of the graph that creates them should be replaced in the result 
components. The
+    // ExecutableStage must also be rewritten to have updated outputs and 
transforms.
+    for (Map.Entry<PCollectionNode, Collection<StageOrTransform>> 
collectionProducer :
+        pcollectionProducers.asMap().entrySet()) {
+      if (collectionProducer.getValue().size() > 1) {
+        for (StageOrTransform producer : collectionProducer.getValue()) {
+          requiresNewOutput.put(producer, collectionProducer.getKey());
+        }
+      }
+    }
+
+    Map<ExecutableStage, ExecutableStage> updatedStages = new 
LinkedHashMap<>();
+    Map<String, PTransformNode> updatedTransforms = new LinkedHashMap<>();
+    Multimap<String, PCollectionNode> originalToPartial = 
HashMultimap.create();
+    for (Map.Entry<StageOrTransform, Collection<PCollectionNode>> 
deduplicationTargets :
+        requiresNewOutput.asMap().entrySet()) {
+      if (deduplicationTargets.getKey().getStage() != null) {
+        StageDeduplication deduplication =
+            deduplicatePCollections(
+                deduplicationTargets.getKey().getStage(),
+                deduplicationTargets.getValue(),
+                unzippedComponents::containsPcollections);
+        for (Entry<String, PCollectionNode> originalToPartialReplacement :
+            deduplication.getOriginalToPartialPCollections().entrySet()) {
+          originalToPartial.put(
+              originalToPartialReplacement.getKey(), 
originalToPartialReplacement.getValue());
+          unzippedComponents.putPcollections(
+              originalToPartialReplacement.getValue().getId(),
+              originalToPartialReplacement.getValue().getPCollection());
+        }
+        updatedStages.put(
+            deduplicationTargets.getKey().getStage(), 
deduplication.getUpdatedStage());
+      } else if (deduplicationTargets.getKey().getTransform() != null) {
+        PTransformDeduplication deduplication =
+            deduplicatePCollections(
+                deduplicationTargets.getKey().getTransform(),
+                deduplicationTargets.getValue(),
+                unzippedComponents::containsPcollections);
+        for (Entry<String, PCollectionNode> originalToPartialReplacement :
+            deduplication.getOriginalToPartialPCollections().entrySet()) {
+          originalToPartial.put(
+              originalToPartialReplacement.getKey(), 
originalToPartialReplacement.getValue());
+          unzippedComponents.putPcollections(
+              originalToPartialReplacement.getValue().getId(),
+              originalToPartialReplacement.getValue().getPCollection());
+        }
+        updatedTransforms.put(
+            deduplicationTargets.getKey().getTransform().getId(),
+            deduplication.getUpdatedTransform());
+      } else {
+        throw new IllegalStateException(
+            String.format(
+                "%s with no %s or %s",
+                StageOrTransform.class.getSimpleName(),
+                ExecutableStage.class.getSimpleName(),
+                PTransformNode.class.getSimpleName()));
+      }
+    }
+
+    Set<PTransformNode> introducedFlattens = new LinkedHashSet<>();
+    for (Map.Entry<String, Collection<PCollectionNode>> partialFlattenTargets :
+        originalToPartial.asMap().entrySet()) {
+      PTransform flattenPartialPCollections =
+          createFlattenOfPartials(partialFlattenTargets.getKey(), 
partialFlattenTargets.getValue());
+      String flattenId =
+          SyntheticNodes.uniqueId("unzipped_flatten", 
unzippedComponents::containsTransforms);
+      unzippedComponents.putTransforms(flattenId, flattenPartialPCollections);
+      introducedFlattens.add(PipelineNode.pTransform(flattenId, 
flattenPartialPCollections));
+    }
+
+    Components components = unzippedComponents.build();
+    return DeduplicationResult.of(components, introducedFlattens, 
updatedStages, updatedTransforms);
+  }
+
+  @AutoValue
+  abstract static class DeduplicationResult {
+    private static DeduplicationResult of(
+        RunnerApi.Components components,
+        Set<PTransformNode> introducedTransforms,
+        Map<ExecutableStage, ExecutableStage> stages,
+        Map<String, PTransformNode> unfused) {
+      return new AutoValue_OutputDeduplicator_DeduplicationResult(
+          components, introducedTransforms, stages, unfused);
+    }
+
+    abstract RunnerApi.Components getDeduplicatedComponents();
+
+    abstract Set<PTransformNode> getIntroducedTransforms();
+
+    abstract Map<ExecutableStage, ExecutableStage> getDeduplicatedStages();
+
+    abstract Map<String, PTransformNode> getDeduplicatedTransforms();
+  }
+
+  private static PTransform createFlattenOfPartials(
+      String outputId, Collection<PCollectionNode> generatedInputs) {
+    PTransform.Builder newFlattenBuilder = PTransform.newBuilder();
+    int i = 0;
+    for (PCollectionNode generatedInput : generatedInputs) {
+      String localInputId = String.format("input_%s", i);
+      i++;
+      newFlattenBuilder.putInputs(localInputId, generatedInput.getId());
+    }
+    // Flatten all of the new partial nodes together.
+    return newFlattenBuilder
+        .putOutputs("output", outputId)
+        
.setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN))
+        .build();
+  }
+
+  /**
+   * Returns the map from each {@link PCollectionNode} produced by any of the 
{@link ExecutableStage
+   * stages} or {@link PTransformNode transforms} to all of the {@link 
ExecutableStage stages} or
+   * {@link PTransformNode transforms} that produce it.
+   */
+  private static Multimap<PCollectionNode, StageOrTransform> getProducers(
+      QueryablePipeline pipeline,
+      Iterable<ExecutableStage> stages,
+      Iterable<PTransformNode> unfusedTransforms) {
+    Multimap<PCollectionNode, StageOrTransform> pcollectionProducers = 
HashMultimap.create();
+    for (ExecutableStage stage : stages) {
+      for (PCollectionNode output : stage.getOutputPCollections()) {
+        pcollectionProducers.put(output, StageOrTransform.stage(stage));
+      }
+    }
+    for (PTransformNode unfused : unfusedTransforms) {
+      for (PCollectionNode output : pipeline.getOutputPCollections(unfused)) {
+        pcollectionProducers.put(output, StageOrTransform.transform(unfused));
+      }
+    }
+    return pcollectionProducers;
+  }
+
+  private static PTransformDeduplication deduplicatePCollections(
+      PTransformNode transform,
+      Collection<PCollectionNode> duplicates,
+      Predicate<String> existingPCollectionIds) {
+    Map<String, PCollectionNode> unzippedOutputs =
+        createPartialPCollections(duplicates, existingPCollectionIds);
+    PTransform pTransform = updateOutputs(transform.getTransform(), 
unzippedOutputs);
+    return PTransformDeduplication.of(
+        PipelineNode.pTransform(transform.getId(), pTransform), 
unzippedOutputs);
+  }
+
+  @AutoValue
+  abstract static class PTransformDeduplication {
+    public static PTransformDeduplication of(
+        PTransformNode updatedTransform, Map<String, PCollectionNode> 
originalToPartial) {
+      return new AutoValue_OutputDeduplicator_PTransformDeduplication(
+          updatedTransform, originalToPartial);
+    }
+
+    abstract PTransformNode getUpdatedTransform();
+
+    abstract Map<String, PCollectionNode> getOriginalToPartialPCollections();
+  }
+
+  private static StageDeduplication deduplicatePCollections(
+      ExecutableStage stage,
+      Collection<PCollectionNode> duplicates,
+      Predicate<String> existingPCollectionIds) {
+    Map<String, PCollectionNode> unzippedOutputs =
+        createPartialPCollections(duplicates, existingPCollectionIds);
+    ExecutableStage updatedStage = deduplicateStageOutput(stage, 
unzippedOutputs);
+    return StageDeduplication.of(updatedStage, unzippedOutputs);
+  }
+
+  @AutoValue
+  abstract static class StageDeduplication {
+    public static StageDeduplication of(
+        ExecutableStage updatedStage, Map<String, PCollectionNode> 
originalToPartial) {
+      return new AutoValue_OutputDeduplicator_StageDeduplication(updatedStage, 
originalToPartial);
+    }
+
+    abstract ExecutableStage getUpdatedStage();
+
+    abstract Map<String, PCollectionNode> getOriginalToPartialPCollections();
+  }
+
+  /**
+   * Returns a {@link Map} from the ID of a {@link PCollectionNode 
PCollection} to a {@link
+   * PCollectionNode} that contains part of that {@link PCollectionNode 
PCollection}.
+   */
+  private static Map<String, PCollectionNode> createPartialPCollections(
+      Collection<PCollectionNode> duplicates, Predicate<String> 
existingPCollectionIds) {
+    Map<String, PCollectionNode> unzippedOutputs = new LinkedHashMap<>();
+    Predicate<String> existingOrNewIds =
+        existingPCollectionIds.or(
+            id ->
+                
unzippedOutputs.values().stream().map(PCollectionNode::getId).anyMatch(id::equals));
+    for (PCollectionNode duplicateOutput : duplicates) {
+      String id = SyntheticNodes.uniqueId(duplicateOutput.getId(), 
existingOrNewIds);
+      PCollection partial = 
duplicateOutput.getPCollection().toBuilder().setUniqueName(id).build();
+      // Check to make sure there is only one duplicated output with the same 
id - which ensures we
+      // only introduce one 'partial output' per producer of that output.
+      PCollectionNode alreadyDeduplicated =
+          unzippedOutputs.put(duplicateOutput.getId(), 
PipelineNode.pCollection(id, partial));
+      checkArgument(alreadyDeduplicated == null, "a duplicate should only 
appear once per stage");
+    }
+    return unzippedOutputs;
+  }
+
+  /**
+   * Returns an {@link ExecutableStage} where all of the {@link 
PCollectionNode PCollections}
+   * matching the original are replaced with the introduced partial {@link 
PCollection} in all
+   * references made within the {@link ExecutableStage}.
+   */
+  private static ExecutableStage deduplicateStageOutput(
+      ExecutableStage stage, Map<String, PCollectionNode> originalToPartial) {
+    Collection<PTransformNode> updatedTransforms = new ArrayList<>();
+    for (PTransformNode transform : stage.getTransforms()) {
+      PTransform updatedTransform = updateOutputs(transform.getTransform(), 
originalToPartial);
+      updatedTransforms.add(PipelineNode.pTransform(transform.getId(), 
updatedTransform));
+    }
+    Collection<PCollectionNode> updatedOutputs = new ArrayList<>();
+    for (PCollectionNode output : stage.getOutputPCollections()) {
+      updatedOutputs.add(originalToPartial.getOrDefault(output.getId(), 
output));
+    }
+    RunnerApi.Components updatedStageComponents =
+        stage
+            .getComponents()
+            .toBuilder()
+            .clearTransforms()
+            .putAllTransforms(
+                updatedTransforms
+                    .stream()
+                    .collect(Collectors.toMap(PTransformNode::getId, 
PTransformNode::getTransform)))
+            .putAllPcollections(
+                originalToPartial
+                    .values()
+                    .stream()
+                    .collect(
+                        Collectors.toMap(PCollectionNode::getId, 
PCollectionNode::getPCollection)))
+            .build();
+    return ImmutableExecutableStage.of(
+        updatedStageComponents,
+        stage.getEnvironment(),
+        stage.getInputPCollection(),
+        stage.getSideInputs(),
+        updatedTransforms,
+        updatedOutputs);
+  }
+
+  /**
+   * Returns a {@link PTransform} like the input {@link PTransform}, but with 
each output to {@code
+   * originalPCollection} replaced with an output (with the same local name) 
to {@code
+   * newPCollection}.
+   */
+  private static PTransform updateOutputs(
+      PTransform transform, Map<String, PCollectionNode> originalToPartial) {
+    PTransform.Builder updatedTransformBuilder = transform.toBuilder();
+    for (Map.Entry<String, String> output : 
transform.getOutputsMap().entrySet()) {
+      if (originalToPartial.containsKey(output.getValue())) {
+        updatedTransformBuilder.putOutputs(
+            output.getKey(), originalToPartial.get(output.getValue()).getId());
+      }
+    }
+    return updatedTransformBuilder.build();
+  }
+
+  @AutoValue
+  abstract static class StageOrTransform {
+    public static StageOrTransform stage(ExecutableStage stage) {
+      return new AutoValue_OutputDeduplicator_StageOrTransform(stage, null);
+    }
+
+    public static StageOrTransform transform(PTransformNode transform) {
+      return new AutoValue_OutputDeduplicator_StageOrTransform(null, 
transform);
+    }
+
+    @Nullable
+    abstract ExecutableStage getStage();
+
+    @Nullable
+    abstract PTransformNode getTransform();
+  }
+}
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/SyntheticNodes.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/SyntheticNodes.java
new file mode 100644
index 00000000000..fc2cb3dc562
--- /dev/null
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/SyntheticNodes.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction.graph;
+
+import java.util.function.Predicate;
+
+/**
+ * A utility class to interact with synthetic {@link PipelineNode Pipeline 
Nodes}.
+ */
+class SyntheticNodes {
+  private SyntheticNodes() {}
+
+  /**
+   * Generate an ID which does not collide with any existing ID, as determined 
by the input
+   * predicate.
+   *
+   * <p>The returned ID will be in the form "${baseName}:${number}".
+   */
+  public static String uniqueId(String baseName, Predicate<String> 
existingIds) {
+    int i = 0;
+    String name;
+    do {
+      name = String.format("%s:%s", baseName, i);
+      i++;
+    } while (existingIds.test(name));
+    return name;
+  }
+}
diff --git 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ExecutableStageMatcher.java
 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ExecutableStageMatcher.java
index b072c9c83cc..aed8a9fc7ef 100644
--- 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ExecutableStageMatcher.java
+++ 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ExecutableStageMatcher.java
@@ -30,6 +30,7 @@
 import 
org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
 import org.hamcrest.Description;
 import org.hamcrest.Matcher;
+import org.hamcrest.Matchers;
 import org.hamcrest.TypeSafeMatcher;
 
 /**
@@ -39,13 +40,13 @@
 public class ExecutableStageMatcher extends TypeSafeMatcher<ExecutableStage> {
   private final String inputPCollectionId;
   private final Collection<SideInputId> sideInputIds;
-  private final Collection<String> materializedPCollection;
+  private final Matcher<Iterable<? extends String>> materializedPCollection;
   private final Collection<String> fusedTransforms;
 
   private ExecutableStageMatcher(
       String inputPCollectionId,
       Collection<SideInputId> sideInputIds,
-      Collection<String> materializedPCollection,
+      Matcher<Iterable<? extends String>> materializedPCollection,
       Collection<String> fusedTransforms) {
     this.inputPCollectionId = inputPCollectionId;
     this.sideInputIds = sideInputIds;
@@ -55,7 +56,7 @@ private ExecutableStageMatcher(
 
   public static ExecutableStageMatcher withInput(String inputId) {
     return new ExecutableStageMatcher(
-        inputId, ImmutableList.of(), ImmutableList.of(), ImmutableList.of());
+        inputId, ImmutableList.of(), Matchers.emptyIterable(), 
ImmutableList.of());
   }
 
   public ExecutableStageMatcher withSideInputs(SideInputId... sideInputs) {
@@ -68,12 +69,28 @@ public ExecutableStageMatcher withSideInputs(SideInputId... 
sideInputs) {
 
   public ExecutableStageMatcher withNoOutputs() {
     return new ExecutableStageMatcher(
-        inputPCollectionId, sideInputIds, ImmutableList.of(), fusedTransforms);
+        inputPCollectionId, sideInputIds, Matchers.emptyIterable(), 
fusedTransforms);
+  }
+
+  public ExecutableStageMatcher withOutputs(Matcher<String>... pCollections) {
+    return new ExecutableStageMatcher(
+        inputPCollectionId,
+        sideInputIds,
+        Matchers.containsInAnyOrder(pCollections),
+        fusedTransforms);
+  }
+
+  public ExecutableStageMatcher withOutputs(Matcher<Iterable<? extends 
String>> pCollections) {
+    return new ExecutableStageMatcher(
+        inputPCollectionId, sideInputIds, pCollections, fusedTransforms);
   }
 
   public ExecutableStageMatcher withOutputs(String... pCollections) {
     return new ExecutableStageMatcher(
-        inputPCollectionId, sideInputIds, ImmutableList.copyOf(pCollections), 
fusedTransforms);
+        inputPCollectionId,
+        sideInputIds,
+        Matchers.containsInAnyOrder(pCollections),
+        fusedTransforms);
   }
 
   public ExecutableStageMatcher withTransforms(String... transforms) {
@@ -98,12 +115,11 @@ protected boolean matchesSafely(ExecutableStage item) {
                                 .setLocalName(ref.localName())
                                 .build())
                     .collect(Collectors.toSet()))
-        && containsInAnyOrder(materializedPCollection.toArray(new String[0]))
-            .matches(
-                item.getOutputPCollections()
-                    .stream()
-                    .map(PCollectionNode::getId)
-                    .collect(Collectors.toSet()))
+        && materializedPCollection.matches(
+            item.getOutputPCollections()
+                .stream()
+                .map(PCollectionNode::getId)
+                .collect(Collectors.toSet()))
         && containsInAnyOrder(fusedTransforms.toArray(new String[0]))
             .matches(
                 item.getTransforms()
@@ -121,7 +137,7 @@ public void describeTo(Description description) {
                 ExecutableStage.class.getSimpleName(), 
PCollection.class.getSimpleName()))
         .appendText(inputPCollectionId)
         .appendText(String.format(", output %ss ", 
PCollection.class.getSimpleName()))
-        .appendValueList("[", ", ", "]", materializedPCollection)
+        .appendDescriptionOf(materializedPCollection)
         .appendText(String.format(" and fused %ss ", 
PTransform.class.getSimpleName()))
         .appendValueList("[", ", ", "]", fusedTransforms);
   }
diff --git 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java
 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java
index cb494cbfcbb..91dcd0d2ebb 100644
--- 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java
+++ 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java
@@ -18,12 +18,22 @@
 
 package org.apache.beam.runners.core.construction.graph;
 
+import static com.google.common.collect.Iterables.getOnlyElement;
+import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.emptyIterable;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasItems;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
 import static org.junit.Assert.assertThat;
 
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Collectors;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
@@ -38,6 +48,10 @@
 import org.apache.beam.model.pipeline.v1.RunnerApi.TimerSpec;
 import org.apache.beam.model.pipeline.v1.RunnerApi.WindowIntoPayload;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import 
org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode;
+import 
org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
+import org.hamcrest.Matchers;
+import org.hamcrest.core.AnyOf;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -242,7 +256,7 @@ public void 
singleEnvironmentAcrossGroupByKeyMultipleStages() {
 
     assertThat(
         fused.getRunnerExecutedTransforms(),
-        contains(
+        containsInAnyOrder(
             PipelineNode.pTransform("impulse", 
components.getTransformsOrThrow("impulse")),
             PipelineNode.pTransform("groupByKey", 
components.getTransformsOrThrow("groupByKey"))));
     assertThat(
@@ -342,8 +356,9 @@ public void multipleEnvironmentsBecomesMultipleStages() {
    * pyImpulse -> .out -> pyRead -> .out /                    -> pyParDo -> 
.out
    *
    * becomes
-   * (goImpulse.out) -> goRead -> goRead.out -> flatten -> (flatten.out)
-   * (pyImpulse.out) -> pyRead -> pyRead.out -> flatten -> (flatten.out)
+   * (goImpulse.out) -> goRead -> goRead.out -> flatten -> 
(flatten.out_synthetic0)
+   * (pyImpulse.out) -> pyRead -> pyRead.out -> flatten -> 
(flatten.out_synthetic1)
+   * flatten.out_synthetic0 & flatten.out_synthetic1 -> synthetic_flatten -> 
flatten.out
    * (flatten.out) -> goParDo
    * (flatten.out) -> pyParDo
    */
@@ -453,19 +468,39 @@ public void 
flattenWithHeterogenousInputsAndOutputsEntirelyMaterialized() {
     FusedPipeline fused =
         
GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build());
 
+    assertThat(fused.getRunnerExecutedTransforms(), hasSize(3));
     assertThat(
+        "The runner should include the impulses for both languages, plus an 
introduced flatten",
         fused.getRunnerExecutedTransforms(),
-        containsInAnyOrder(
+        hasItems(
             PipelineNode.pTransform("pyImpulse", 
components.getTransformsOrThrow("pyImpulse")),
             PipelineNode.pTransform("goImpulse", 
components.getTransformsOrThrow("goImpulse"))));
+
+    PTransformNode flattenNode = null;
+    for (PTransformNode runnerTransform : fused.getRunnerExecutedTransforms()) 
{
+      if 
(getOnlyElement(runnerTransform.getTransform().getOutputsMap().values())
+          .equals("flatten.out")) {
+        flattenNode = runnerTransform;
+      }
+    }
+
+    assertThat(flattenNode, not(nullValue()));
+    assertThat(
+        flattenNode.getTransform().getSpec().getUrn(),
+        equalTo(PTransformTranslation.FLATTEN_TRANSFORM_URN));
+    assertThat(new 
HashSet<>(flattenNode.getTransform().getInputsMap().values()), hasSize(2));
+
+    Collection<String> introducedOutputs = 
flattenNode.getTransform().getInputsMap().values();
+    AnyOf<String> anyIntroducedPCollection =
+        
anyOf(introducedOutputs.stream().map(Matchers::equalTo).collect(Collectors.toSet()));
     assertThat(
         fused.getFusedStages(),
         containsInAnyOrder(
             ExecutableStageMatcher.withInput("goImpulse.out")
-                .withOutputs("flatten.out")
+                .withOutputs(anyIntroducedPCollection)
                 .withTransforms("goRead", "flatten"),
             ExecutableStageMatcher.withInput("pyImpulse.out")
-                .withOutputs("flatten.out")
+                .withOutputs(anyIntroducedPCollection)
                 .withTransforms("pyRead", "flatten"),
             ExecutableStageMatcher.withInput("flatten.out")
                 .withNoOutputs()
@@ -473,6 +508,19 @@ public void 
flattenWithHeterogenousInputsAndOutputsEntirelyMaterialized() {
             ExecutableStageMatcher.withInput("flatten.out")
                 .withNoOutputs()
                 .withTransforms("pyParDo")));
+    Set<String> materializedStageOutputs =
+        fused
+            .getFusedStages()
+            .stream()
+            .flatMap(executableStage -> 
executableStage.getOutputPCollections().stream())
+            .map(PCollectionNode::getId)
+            .collect(Collectors.toSet());
+
+    assertThat(
+        "All materialized stage outputs should be flattened, and no more",
+        materializedStageOutputs,
+        containsInAnyOrder(
+            flattenNode.getTransform().getInputsMap().values().toArray(new 
String[0])));
   }
 
   /*
diff --git 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicatorTest.java
 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicatorTest.java
new file mode 100644
index 00000000000..ea5b3ae9969
--- /dev/null
+++ 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicatorTest.java
@@ -0,0 +1,509 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction.graph;
+
+import static com.google.common.collect.Iterables.getOnlyElement;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasEntry;
+import static org.hamcrest.Matchers.hasItems;
+import static org.hamcrest.Matchers.hasSize;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
+import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import 
org.apache.beam.runners.core.construction.graph.OutputDeduplicator.DeduplicationResult;
+import 
org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode;
+import 
org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link OutputDeduplicator}. */
+@RunWith(JUnit4.class)
+public class OutputDeduplicatorTest {
+  @Test
+  public void unchangedWithNoDuplicates() {
+    /* When all the PCollections are produced by only one transform or stage, 
the result should be
+     * empty/identical to the input.
+     *
+     * Pipeline:
+     *              /-> one -> .out \
+     * red -> .out ->                -> blue -> .out
+     *              \-> two -> .out /
+     */
+    PCollection redOut = 
PCollection.newBuilder().setUniqueName("red.out").build();
+    PTransform red = PTransform.newBuilder().putOutputs("out", 
redOut.getUniqueName()).build();
+    PCollection oneOut = 
PCollection.newBuilder().setUniqueName("one.out").build();
+    PTransform one =
+        PTransform.newBuilder()
+            .putInputs("in", redOut.getUniqueName())
+            .putOutputs("out", oneOut.getUniqueName())
+            .build();
+    PCollection twoOut = 
PCollection.newBuilder().setUniqueName("two.out").build();
+    PTransform two =
+        PTransform.newBuilder()
+            .putInputs("in", redOut.getUniqueName())
+            .putOutputs("out", twoOut.getUniqueName())
+            .build();
+    PCollection blueOut = 
PCollection.newBuilder().setUniqueName("blue.out").build();
+    PTransform blue =
+        PTransform.newBuilder()
+            .putInputs("one", oneOut.getUniqueName())
+            .putInputs("two", twoOut.getUniqueName())
+            .putOutputs("out", blueOut.getUniqueName())
+            .build();
+    RunnerApi.Components components =
+        Components.newBuilder()
+            .putTransforms("one", one)
+            .putPcollections(oneOut.getUniqueName(), oneOut)
+            .putTransforms("two", two)
+            .putPcollections(twoOut.getUniqueName(), twoOut)
+            .putTransforms("red", red)
+            .putPcollections(redOut.getUniqueName(), redOut)
+            .putTransforms("blue", blue)
+            .putPcollections(blueOut.getUniqueName(), blueOut)
+            .build();
+    ExecutableStage oneStage =
+        ImmutableExecutableStage.of(
+            components,
+            Environment.getDefaultInstance(),
+            PipelineNode.pCollection(redOut.getUniqueName(), redOut),
+            ImmutableList.of(),
+            ImmutableList.of(PipelineNode.pTransform("one", one)),
+            ImmutableList.of(PipelineNode.pCollection(oneOut.getUniqueName(), 
oneOut)));
+    ExecutableStage twoStage =
+        ImmutableExecutableStage.of(
+            components,
+            Environment.getDefaultInstance(),
+            PipelineNode.pCollection(redOut.getUniqueName(), redOut),
+            ImmutableList.of(),
+            ImmutableList.of(PipelineNode.pTransform("two", two)),
+            ImmutableList.of(PipelineNode.pCollection(twoOut.getUniqueName(), 
twoOut)));
+    PTransformNode redTransform = PipelineNode.pTransform("red", red);
+    PTransformNode blueTransform = PipelineNode.pTransform("blue", blue);
+    QueryablePipeline pipeline = QueryablePipeline.forPrimitivesIn(components);
+    DeduplicationResult result =
+        OutputDeduplicator.ensureSingleProducer(
+            pipeline,
+            ImmutableList.of(oneStage, twoStage),
+            ImmutableList.of(redTransform, blueTransform));
+
+    assertThat(result.getDeduplicatedComponents(), equalTo(components));
+    assertThat(result.getDeduplicatedStages().keySet(), empty());
+    assertThat(result.getDeduplicatedTransforms().keySet(), empty());
+    assertThat(result.getIntroducedTransforms(), empty());
+  }
+
+  @Test
+  public void duplicateOverStages() {
+    /* When multiple stages and a runner-executed transform produce a 
PCollection, all should be
+     * replaced with synthetic flattens.
+     * original graph:
+     *             --> one -> .out \
+     * red -> .out |                -> shared -> .out -> blue -> .out
+     *             --> two -> .out /
+     *
+     * fused graph:
+     *             --> [one -> .out -> shared ->] .out
+     * red -> .out |                                   (shared.out) -> blue -> 
.out
+     *             --> [two -> .out -> shared ->] .out
+     *
+     * deduplicated graph:
+     *             --> [one -> .out -> shared ->] .out:0 \
+     * red -> .out |                                      -> shared -> .out -> 
blue ->.out
+     *             --> [two -> .out -> shared ->] .out:1 /
+     */
+    PCollection redOut = 
PCollection.newBuilder().setUniqueName("red.out").build();
+    PTransform red = PTransform.newBuilder().putOutputs("out", 
redOut.getUniqueName()).build();
+    PCollection oneOut = 
PCollection.newBuilder().setUniqueName("one.out").build();
+    PTransform one =
+        PTransform.newBuilder()
+            .putInputs("in", redOut.getUniqueName())
+            .putOutputs("out", oneOut.getUniqueName())
+            .build();
+    PCollection twoOut = 
PCollection.newBuilder().setUniqueName("two.out").build();
+    PTransform two =
+        PTransform.newBuilder()
+            .putInputs("in", redOut.getUniqueName())
+            .putOutputs("out", twoOut.getUniqueName())
+            .build();
+    PCollection sharedOut = 
PCollection.newBuilder().setUniqueName("shared.out").build();
+    PTransform shared =
+        PTransform.newBuilder()
+            .putInputs("one", oneOut.getUniqueName())
+            .putInputs("two", twoOut.getUniqueName())
+            .putOutputs("shared", sharedOut.getUniqueName())
+            .build();
+    PCollection blueOut = 
PCollection.newBuilder().setUniqueName("blue.out").build();
+    PTransform blue =
+        PTransform.newBuilder()
+            .putInputs("in", sharedOut.getUniqueName())
+            .putOutputs("out", blueOut.getUniqueName())
+            .build();
+    RunnerApi.Components components =
+        Components.newBuilder()
+            .putTransforms("one", one)
+            .putPcollections(oneOut.getUniqueName(), oneOut)
+            .putTransforms("two", two)
+            .putPcollections(twoOut.getUniqueName(), twoOut)
+            .putTransforms("shared", shared)
+            .putPcollections(sharedOut.getUniqueName(), sharedOut)
+            .putTransforms("red", red)
+            .putPcollections(redOut.getUniqueName(), redOut)
+            .putTransforms("blue", blue)
+            .putPcollections(blueOut.getUniqueName(), blueOut)
+            .build();
+    ExecutableStage oneStage =
+        ImmutableExecutableStage.of(
+            components,
+            Environment.getDefaultInstance(),
+            PipelineNode.pCollection(redOut.getUniqueName(), redOut),
+            ImmutableList.of(),
+            ImmutableList.of(
+                PipelineNode.pTransform("one", one), 
PipelineNode.pTransform("shared", shared)),
+            
ImmutableList.of(PipelineNode.pCollection(sharedOut.getUniqueName(), 
sharedOut)));
+    ExecutableStage twoStage =
+        ImmutableExecutableStage.of(
+            components,
+            Environment.getDefaultInstance(),
+            PipelineNode.pCollection(redOut.getUniqueName(), redOut),
+            ImmutableList.of(),
+            ImmutableList.of(
+                PipelineNode.pTransform("two", two), 
PipelineNode.pTransform("shared", shared)),
+            
ImmutableList.of(PipelineNode.pCollection(sharedOut.getUniqueName(), 
sharedOut)));
+    PTransformNode redTransform = PipelineNode.pTransform("red", red);
+    PTransformNode blueTransform = PipelineNode.pTransform("blue", blue);
+    QueryablePipeline pipeline = QueryablePipeline.forPrimitivesIn(components);
+    DeduplicationResult result =
+        OutputDeduplicator.ensureSingleProducer(
+            pipeline,
+            ImmutableList.of(oneStage, twoStage),
+            ImmutableList.of(redTransform, blueTransform));
+
+    assertThat(result.getIntroducedTransforms(), hasSize(1));
+    PTransformNode introduced = 
getOnlyElement(result.getIntroducedTransforms());
+    assertThat(introduced.getTransform().getOutputsMap().size(), equalTo(1));
+    assertThat(
+        getOnlyElement(introduced.getTransform().getOutputsMap().values()),
+        equalTo(sharedOut.getUniqueName()));
+
+    assertThat(
+        result.getDeduplicatedComponents().getPcollectionsMap().keySet(),
+        hasItems(introduced.getTransform().getInputsMap().values().toArray(new 
String[0])));
+
+    assertThat(result.getDeduplicatedStages().keySet(), hasSize(2));
+    List<String> stageOutputs =
+        result
+            .getDeduplicatedStages()
+            .values()
+            .stream()
+            .flatMap(stage -> 
stage.getOutputPCollections().stream().map(PCollectionNode::getId))
+            .collect(Collectors.toList());
+    assertThat(
+        stageOutputs,
+        
containsInAnyOrder(introduced.getTransform().getInputsMap().values().toArray()));
+    assertThat(result.getDeduplicatedTransforms().keySet(), empty());
+
+    assertThat(
+        result.getDeduplicatedComponents().getPcollectionsMap().keySet(),
+        hasItems(stageOutputs.toArray(new String[0])));
+    assertThat(
+        result.getDeduplicatedComponents().getTransformsMap(),
+        hasEntry(introduced.getId(), introduced.getTransform()));
+  }
+
+  @Test
+  public void duplicateOverStagesAndTransforms() {
+    /* When both a stage and a runner-executed transform produce a 
PCollection, all should be
+     * replaced with synthetic flattens.
+     * original graph:
+     *             --> one -> .out \
+     * red -> .out |                -> shared -> .out
+     *             --------------> /
+     *
+     * fused graph:
+     *             --> [one -> .out -> shared ->] .out
+     * red -> .out |
+     *             ------------------> shared --> .out
+     *
+     * deduplicated graph:
+     *             --> [one -> .out -> shared ->] .out:0 \
+     * red -> .out |                                      -> shared -> .out
+     *             -----------------> shared:0 -> .out:1 /
+     */
+    PCollection redOut = 
PCollection.newBuilder().setUniqueName("red.out").build();
+    PTransform red = PTransform.newBuilder().putOutputs("out", 
redOut.getUniqueName()).build();
+    PCollection oneOut = 
PCollection.newBuilder().setUniqueName("one.out").build();
+    PTransform one =
+        PTransform.newBuilder()
+            .putInputs("in", redOut.getUniqueName())
+            .putOutputs("out", oneOut.getUniqueName())
+            .build();
+    PCollection sharedOut = 
PCollection.newBuilder().setUniqueName("shared.out").build();
+    PTransform shared =
+        PTransform.newBuilder()
+            .putInputs("one", oneOut.getUniqueName())
+            .putInputs("red", redOut.getUniqueName())
+            .putOutputs("shared", sharedOut.getUniqueName())
+            .build();
+    PCollection blueOut = 
PCollection.newBuilder().setUniqueName("blue.out").build();
+    PTransform blue =
+        PTransform.newBuilder()
+            .putInputs("in", sharedOut.getUniqueName())
+            .putOutputs("out", blueOut.getUniqueName())
+            .build();
+    RunnerApi.Components components =
+        Components.newBuilder()
+            .putTransforms("one", one)
+            .putPcollections(oneOut.getUniqueName(), oneOut)
+            .putTransforms("red", red)
+            .putPcollections(redOut.getUniqueName(), redOut)
+            .putTransforms("shared", shared)
+            .putPcollections(sharedOut.getUniqueName(), sharedOut)
+            .putTransforms("blue", blue)
+            .putPcollections(blueOut.getUniqueName(), blueOut)
+            .build();
+    PTransformNode sharedTransform = PipelineNode.pTransform("shared", shared);
+    ExecutableStage oneStage =
+        ImmutableExecutableStage.of(
+            components,
+            Environment.getDefaultInstance(),
+            PipelineNode.pCollection(redOut.getUniqueName(), redOut),
+            ImmutableList.of(),
+            ImmutableList.of(PipelineNode.pTransform("one", one), 
sharedTransform),
+            
ImmutableList.of(PipelineNode.pCollection(sharedOut.getUniqueName(), 
sharedOut)));
+    PTransformNode redTransform = PipelineNode.pTransform("red", red);
+    PTransformNode blueTransform = PipelineNode.pTransform("blue", blue);
+    QueryablePipeline pipeline = QueryablePipeline.forPrimitivesIn(components);
+    DeduplicationResult result =
+        OutputDeduplicator.ensureSingleProducer(
+            pipeline,
+            ImmutableList.of(oneStage),
+            ImmutableList.of(redTransform, blueTransform, sharedTransform));
+
+    assertThat(result.getIntroducedTransforms(), hasSize(1));
+    PTransformNode introduced = 
getOnlyElement(result.getIntroducedTransforms());
+    assertThat(introduced.getTransform().getOutputsMap().size(), equalTo(1));
+    assertThat(
+        getOnlyElement(introduced.getTransform().getOutputsMap().values()),
+        equalTo(sharedOut.getUniqueName()));
+
+    assertThat(
+        result.getDeduplicatedComponents().getPcollectionsMap().keySet(),
+        hasItems(introduced.getTransform().getInputsMap().values().toArray(new 
String[0])));
+
+    assertThat(result.getDeduplicatedStages().keySet(), hasSize(1));
+    assertThat(result.getDeduplicatedTransforms().keySet(), 
containsInAnyOrder("shared"));
+
+    List<String> introducedOutputs = new ArrayList<>();
+    introducedOutputs.addAll(
+        
result.getDeduplicatedTransforms().get("shared").getTransform().getOutputsMap().values());
+    introducedOutputs.addAll(
+        result
+            .getDeduplicatedStages()
+            .get(oneStage)
+            .getOutputPCollections()
+            .stream()
+            .map(PCollectionNode::getId)
+            .collect(Collectors.toList()));
+    assertThat(
+        introduced.getTransform().getInputsMap().values(),
+        containsInAnyOrder(introducedOutputs.toArray(new String[0])));
+    assertThat(
+        result.getDeduplicatedComponents().getPcollectionsMap().keySet(),
+        hasItems(introducedOutputs.toArray(new String[0])));
+    assertThat(
+        result.getDeduplicatedComponents().getTransformsMap(),
+        hasEntry(introduced.getId(), introduced.getTransform()));
+  }
+
+  @Test
+  public void multipleDuplicatesInStages() {
+    /* A stage that produces multiple duplicates should have them all 
synthesized.
+     *
+     * Original Pipeline:
+     * red -> .out ---> one -> .out -----\
+     *             \                      -> shared.out
+     *              \--> two -> .out ----|
+     *               \                    -> otherShared -> .out
+     *                \-> three --> .out /
+     *
+     * Fused Pipeline:
+     *      -> .out [-> one -> .out -> shared -> .out] \
+     *     /                                            -> blue -> .out
+     *     |                        -> shared -> .out] /
+     * red -> .out [-> two -> .out |
+     *     |                        -> otherShared -> .out]
+     *     \
+     *      -> .out [-> three -> .out -> otherShared -> .out]
+     *
+     * Deduplicated Pipeline:
+     *           [-> one -> .out -> shared -> .out:0] --\
+     *           |                                       -> shared -> .out -> 
blue -> .out
+     *           |                 -> shared -> .out:1] /
+     * red -> .out [-> two -> .out |
+     *           |                  -> otherShared -> .out:0] --\
+     *           |                                               -> 
otherShared -> .out
+     *           [-> three -> .out -> otherShared -> .out:1] ---/
+     */
+    PCollection redOut = 
PCollection.newBuilder().setUniqueName("red.out").build();
+    PTransform red = PTransform.newBuilder().putOutputs("out", 
redOut.getUniqueName()).build();
+    PCollection threeOut = 
PCollection.newBuilder().setUniqueName("three.out").build();
+    PTransform three =
+        PTransform.newBuilder()
+            .putInputs("in", redOut.getUniqueName())
+            .putOutputs("out", threeOut.getUniqueName())
+            .build();
+    PCollection oneOut = 
PCollection.newBuilder().setUniqueName("one.out").build();
+    PTransform one =
+        PTransform.newBuilder()
+            .putInputs("in", redOut.getUniqueName())
+            .putOutputs("out", oneOut.getUniqueName())
+            .build();
+    PCollection twoOut = 
PCollection.newBuilder().setUniqueName("two.out").build();
+    PTransform two =
+        PTransform.newBuilder()
+            .putInputs("in", redOut.getUniqueName())
+            .putOutputs("out", twoOut.getUniqueName())
+            .build();
+    PCollection sharedOut = 
PCollection.newBuilder().setUniqueName("shared.out").build();
+    PTransform shared =
+        PTransform.newBuilder()
+            .putInputs("one", oneOut.getUniqueName())
+            .putInputs("two", twoOut.getUniqueName())
+            .putOutputs("shared", sharedOut.getUniqueName())
+            .build();
+    PCollection otherSharedOut = 
PCollection.newBuilder().setUniqueName("shared.out2").build();
+    PTransform otherShared =
+        PTransform.newBuilder()
+            .putInputs("multi", threeOut.getUniqueName())
+            .putInputs("two", twoOut.getUniqueName())
+            .putOutputs("out", otherSharedOut.getUniqueName())
+            .build();
+    PCollection blueOut = 
PCollection.newBuilder().setUniqueName("blue.out").build();
+    PTransform blue =
+        PTransform.newBuilder()
+            .putInputs("in", sharedOut.getUniqueName())
+            .putOutputs("out", blueOut.getUniqueName())
+            .build();
+    RunnerApi.Components components =
+        Components.newBuilder()
+            .putTransforms("one", one)
+            .putPcollections(oneOut.getUniqueName(), oneOut)
+            .putTransforms("two", two)
+            .putPcollections(twoOut.getUniqueName(), twoOut)
+            .putTransforms("multi", three)
+            .putPcollections(threeOut.getUniqueName(), threeOut)
+            .putTransforms("shared", shared)
+            .putPcollections(sharedOut.getUniqueName(), sharedOut)
+            .putTransforms("otherShared", otherShared)
+            .putPcollections(otherSharedOut.getUniqueName(), otherSharedOut)
+            .putTransforms("red", red)
+            .putPcollections(redOut.getUniqueName(), redOut)
+            .putTransforms("blue", blue)
+            .putPcollections(blueOut.getUniqueName(), blueOut)
+            .build();
+    ExecutableStage multiStage =
+        ImmutableExecutableStage.of(
+            components,
+            Environment.getDefaultInstance(),
+            PipelineNode.pCollection(redOut.getUniqueName(), redOut),
+            ImmutableList.of(),
+            ImmutableList.of(
+                PipelineNode.pTransform("multi", three),
+                PipelineNode.pTransform("shared", shared),
+                PipelineNode.pTransform("otherShared", otherShared)),
+            ImmutableList.of(
+                PipelineNode.pCollection(sharedOut.getUniqueName(), sharedOut),
+                PipelineNode.pCollection(otherSharedOut.getUniqueName(), 
otherSharedOut)));
+    ExecutableStage oneStage =
+        ImmutableExecutableStage.of(
+            components,
+            Environment.getDefaultInstance(),
+            PipelineNode.pCollection(redOut.getUniqueName(), redOut),
+            ImmutableList.of(),
+            ImmutableList.of(
+                PipelineNode.pTransform("one", one), 
PipelineNode.pTransform("shared", shared)),
+            
ImmutableList.of(PipelineNode.pCollection(sharedOut.getUniqueName(), 
sharedOut)));
+    ExecutableStage twoStage =
+        ImmutableExecutableStage.of(
+            components,
+            Environment.getDefaultInstance(),
+            PipelineNode.pCollection(redOut.getUniqueName(), redOut),
+            ImmutableList.of(),
+            ImmutableList.of(
+                PipelineNode.pTransform("two", two),
+                PipelineNode.pTransform("otherShared", otherShared)),
+            ImmutableList.of(
+                PipelineNode.pCollection(otherSharedOut.getUniqueName(), 
otherSharedOut)));
+    PTransformNode redTransform = PipelineNode.pTransform("red", red);
+    PTransformNode blueTransform = PipelineNode.pTransform("blue", blue);
+    QueryablePipeline pipeline = QueryablePipeline.forPrimitivesIn(components);
+    DeduplicationResult result =
+        OutputDeduplicator.ensureSingleProducer(
+            pipeline,
+            ImmutableList.of(oneStage, twoStage, multiStage),
+            ImmutableList.of(redTransform, blueTransform));
+
+    assertThat(result.getIntroducedTransforms(), hasSize(2));
+    assertThat(
+        result.getDeduplicatedStages().keySet(),
+        containsInAnyOrder(multiStage, oneStage, twoStage));
+    assertThat(result.getDeduplicatedTransforms().keySet(), empty());
+
+    Collection<String> introducedIds =
+        result
+            .getIntroducedTransforms()
+            .stream()
+            .flatMap(pt -> pt.getTransform().getInputsMap().values().stream())
+            .collect(Collectors.toList());
+    String[] stageOutputs =
+        result
+            .getDeduplicatedStages()
+            .values()
+            .stream()
+            .flatMap(s -> 
s.getOutputPCollections().stream().map(PCollectionNode::getId))
+            .toArray(String[]::new);
+    assertThat(introducedIds, containsInAnyOrder(stageOutputs));
+
+    assertThat(
+        result.getDeduplicatedComponents().getPcollectionsMap().keySet(),
+        hasItems(introducedIds.toArray(new String[0])));
+    assertThat(
+        result.getDeduplicatedComponents().getTransformsMap().entrySet(),
+        hasItems(
+            result
+                .getIntroducedTransforms()
+                .stream()
+                .collect(Collectors.toMap(PTransformNode::getId, 
PTransformNode::getTransform))
+                .entrySet()
+                .toArray(new Map.Entry[0])));
+  }
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


Issue Time Tracking
-------------------

    Worklog Id:     (was: 95324)
    Time Spent: 4h  (was: 3h 50m)

> 'Unzip' flattens before performing fusion
> -----------------------------------------
>
>                 Key: BEAM-3914
>                 URL: https://issues.apache.org/jira/browse/BEAM-3914
>             Project: Beam
>          Issue Type: Improvement
>          Components: runner-core
>            Reporter: Thomas Groh
>            Assignee: Thomas Groh
>            Priority: Major
>              Labels: portability
>          Time Spent: 4h
>  Remaining Estimate: 0h
>
> This consists of duplicating nodes downstream of a flatten that exist within 
> an environment, and reintroducing the flatten immediately upstream of a 
> runner-executed transform (the flatten should be executed within the runner)



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to