johnyangk closed pull request #142: [NEMO-257] Local combining only for 
BinaryCombineFn
URL: https://github.com/apache/incubator-nemo/pull/142
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java
 
b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java
index 7dc7af65a..d596f9f1a 100644
--- 
a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java
+++ 
b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java
@@ -58,7 +58,6 @@
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiFunction;
 import java.util.stream.Collectors;
-import java.util.stream.Stream;
 
 /**
  * Converts DAG of Beam root to Nemo IR DAG.
@@ -314,23 +313,12 @@ private static void topologicalTranslator(final 
TranslationContext ctx,
   private static void combineTranslator(final TranslationContext ctx,
                                         final CompositeTransformVertex 
transformVertex,
                                         final PTransform<?, ?> transform) {
-    // No optimization for BeamSQL that handles Beam 'Row's.
-    final boolean handlesBeamRow = Stream
-      .concat(transformVertex.getNode().getInputs().values().stream(),
-        transformVertex.getNode().getOutputs().values().stream())
-      .map(pValue -> (KvCoder) getCoder(pValue, ctx.root)) // Input and output 
of combine should be KV
-      .map(kvCoder -> kvCoder.getValueCoder().getEncodedTypeDescriptor()) // 
We're interested in the 'Value' of KV
-      .anyMatch(valueTypeDescriptor -> 
TypeDescriptor.of(Row.class).equals(valueTypeDescriptor));
-    if (handlesBeamRow) {
-      transformVertex.getDAG().topologicalDo(ctx::translate);
-      return; // return early and give up optimization - TODO #209: Enable 
Local Combiner for BeamSQL
-    }
-
-    // Local combiner optimization
     final List<TransformVertex> topologicalOrdering = 
transformVertex.getDAG().getTopologicalSort();
     final TransformVertex groupByKeyBeamTransform = topologicalOrdering.get(0);
     final TransformVertex last = 
topologicalOrdering.get(topologicalOrdering.size() - 1);
-    if (groupByKeyBeamTransform.getNode().getTransform() instanceof 
GroupByKey) {
+    if (groupByKeyBeamTransform.getNode().getTransform() instanceof GroupByKey
+      && isBinaryCombine(extractCombineFn(transform))) {
+      // Local combiner optimization (only for binary combiners that are 
guaranteed to be commutative/associative)
       // Translate the given CompositeTransform under OneToOneEdge-enforced 
context.
       final TranslationContext oneToOneEdgeContext = new 
TranslationContext(ctx,
           OneToOneCommunicationPatternSelector.INSTANCE);
@@ -349,10 +337,30 @@ private static void combineTranslator(final 
TranslationContext ctx,
       // Translate the remaining vertices.
       topologicalOrdering.stream().skip(1).forEach(ctx::translate);
     } else {
+      // No optimization
       transformVertex.getDAG().topologicalDo(ctx::translate);
     }
   }
 
+  private static boolean isBinaryCombine(final CombineFnBase.GlobalCombineFn 
combineFn) {
+    return combineFn instanceof Combine.BinaryCombineFn
+      || combineFn instanceof Combine.BinaryCombineDoubleFn
+      || combineFn instanceof Combine.BinaryCombineIntegerFn
+      || combineFn instanceof Combine.BinaryCombineLongFn;
+  }
+
+  private static CombineFnBase.GlobalCombineFn extractCombineFn(final 
PTransform<?, ?> combine) {
+    if (combine instanceof Combine.Globally) {
+      return ((Combine.Globally) combine).getFn();
+    } else if (combine instanceof Combine.PerKey) {
+      return ((Combine.PerKey) combine).getFn();
+    } else if (combine instanceof Combine.GroupedValues) {
+      return ((Combine.GroupedValues) combine).getFn();
+    } else {
+      throw new IllegalStateException(combine.toString());
+    }
+  }
+
   /**
    * Pushes the loop vertex to the stack before translating the inner DAG, and 
pops it after the translation.
    *


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to