Repository: spark
Updated Branches:
  refs/heads/master 1b829ce13 -> 3e991dbc3


[SPARK-13674] [SQL] Add wholestage codegen support to Sample

JIRA: https://issues.apache.org/jira/browse/SPARK-13674

## What changes were proposed in this pull request?

Sample operator doesn't support wholestage codegen now. This pr is to add 
support to it.

## How was this patch tested?

A test is added into `BenchmarkWholeStageCodegen`. Besides, all tests should be 
passed.

Author: Liang-Chi Hsieh <[email protected]>
Author: Liang-Chi Hsieh <[email protected]>

Closes #11517 from viirya/add-wholestage-sample.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3e991dbc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3e991dbc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3e991dbc

Branch: refs/heads/master
Commit: 3e991dbc310a4a33eec7f3909adce50bf8268d04
Parents: 1b829ce
Author: Liang-Chi Hsieh <[email protected]>
Authored: Fri Apr 1 14:02:32 2016 -0700
Committer: Davies Liu <[email protected]>
Committed: Fri Apr 1 14:02:32 2016 -0700

----------------------------------------------------------------------
 .../spark/util/random/RandomSampler.scala       |  2 +-
 project/MimaExcludes.scala                      |  4 ++
 .../sql/execution/BufferedRowIterator.java      |  4 +-
 .../spark/sql/execution/WholeStageCodegen.scala | 12 ++--
 .../spark/sql/execution/basicOperators.scala    | 72 +++++++++++++++++---
 .../execution/BenchmarkWholeStageCodegen.scala  | 25 +++++++
 6 files changed, 104 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3e991dbc/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala 
b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 2921b93..d397cca 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -186,7 +186,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) 
extends RandomSampler[T, T
  * @tparam T item type
  */
 @DeveloperApi
-class PoissonSampler[T: ClassTag](
+class PoissonSampler[T](
     fraction: Double,
     useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3e991dbc/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index ff11775..2be490b 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -597,6 +597,10 @@ object MimaExcludes {
         // for multilayer perceptron.
         // This class is marked as `private`.
         
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction")
+      ) ++ Seq(
+        // [SPARK-13674][SQL] Add wholestage codegen support to Sample
+        
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"),
+        
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this")
       )
     case v if v.startsWith("1.6") =>
       Seq(

http://git-wip-us.apache.org/repos/asf/spark/blob/3e991dbc/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
index dbea852..c2633a9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -36,6 +36,8 @@ public abstract class BufferedRowIterator {
   protected UnsafeRow unsafeRow = new UnsafeRow(0);
   private long startTimeNs = System.nanoTime();
 
+  protected int partitionIndex = -1;
+
   public boolean hasNext() throws IOException {
     if (currentRows.isEmpty()) {
       processNext();
@@ -58,7 +60,7 @@ public abstract class BufferedRowIterator {
   /**
    * Initializes from array of iterators of InternalRow.
    */
-  public abstract void init(Iterator<InternalRow> iters[]);
+  public abstract void init(int index, Iterator<InternalRow> iters[]);
 
   /**
    * Append a row to currentRows.

http://git-wip-us.apache.org/repos/asf/spark/blob/3e991dbc/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 6a779ab..9bdf611 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.broadcast
+import org.apache.spark.{broadcast, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -323,7 +323,8 @@ case class WholeStageCodegen(child: SparkPlan) extends 
UnaryNode with CodegenSup
           this.references = references;
         }
 
-        public void init(scala.collection.Iterator inputs[]) {
+        public void init(int index, scala.collection.Iterator inputs[]) {
+          partitionIndex = index;
           ${ctx.initMutableStates()}
         }
 
@@ -351,10 +352,10 @@ case class WholeStageCodegen(child: SparkPlan) extends 
UnaryNode with CodegenSup
     val rdds = child.asInstanceOf[CodegenSupport].upstreams()
     assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
     if (rdds.length == 1) {
-      rdds.head.mapPartitions { iter =>
+      rdds.head.mapPartitionsWithIndex { (index, iter) =>
         val clazz = CodeGenerator.compile(cleanedSource)
         val buffer = 
clazz.generate(references).asInstanceOf[BufferedRowIterator]
-        buffer.init(Array(iter))
+        buffer.init(index, Array(iter))
         new Iterator[InternalRow] {
           override def hasNext: Boolean = {
             val v = buffer.hasNext
@@ -367,9 +368,10 @@ case class WholeStageCodegen(child: SparkPlan) extends 
UnaryNode with CodegenSup
     } else {
       // Right now, we support up to two upstreams.
       rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
+        val partitionIndex = TaskContext.getPartitionId()
         val clazz = CodeGenerator.compile(cleanedSource)
         val buffer = 
clazz.generate(references).asInstanceOf[BufferedRowIterator]
-        buffer.init(Array(leftIter, rightIter))
+        buffer.init(partitionIndex, Array(leftIter, rightIter))
         new Iterator[InternalRow] {
           override def hasNext: Boolean = {
             val v = buffer.hasNext

http://git-wip-us.apache.org/repos/asf/spark/blob/3e991dbc/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index fca6627..a6a14df 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,11 +20,11 @@ package org.apache.spark.sql.execution
 import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode, ExpressionCanonicalizer}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode, ExpressionCanonicalizer, GenerateUnsafeProjection}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.LongType
-import org.apache.spark.util.random.PoissonSampler
+import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
 
 case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
   extends UnaryNode with CodegenSupport {
@@ -223,9 +223,12 @@ case class Sample(
     upperBound: Double,
     withReplacement: Boolean,
     seed: Long,
-    child: SparkPlan) extends UnaryNode {
+    child: SparkPlan) extends UnaryNode with CodegenSupport {
   override def output: Seq[Attribute] = child.output
 
+  private[sql] override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of 
output rows"))
+
   protected override def doExecute(): RDD[InternalRow] = {
     if (withReplacement) {
       // Disable gap sampling since the gap sampling method buffers two rows 
internally,
@@ -239,6 +242,63 @@ case class Sample(
       child.execute().randomSampleWithRange(lowerBound, upperBound, seed)
     }
   }
+
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].upstreams()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
+    val numOutput = metricTerm(ctx, "numOutputRows")
+    val sampler = ctx.freshName("sampler")
+
+    if (withReplacement) {
+      val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
+      val initSampler = ctx.freshName("initSampler")
+      ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
+        s"$initSampler();")
+
+      ctx.addNewFunction(initSampler,
+        s"""
+          | private void $initSampler() {
+          |   $sampler = new $samplerClass<UnsafeRow>($upperBound - 
$lowerBound, false);
+          |   java.util.Random random = new java.util.Random(${seed}L);
+          |   long randomSeed = random.nextLong();
+          |   int loopCount = 0;
+          |   while (loopCount < partitionIndex) {
+          |     randomSeed = random.nextLong();
+          |     loopCount += 1;
+          |   }
+          |   $sampler.setSeed(randomSeed);
+          | }
+         """.stripMargin.trim)
+
+      val samplingCount = ctx.freshName("samplingCount")
+      s"""
+         | int $samplingCount = $sampler.sample();
+         | while ($samplingCount-- > 0) {
+         |   $numOutput.add(1);
+         |   ${consume(ctx, input)}
+         | }
+       """.stripMargin.trim
+    } else {
+      val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName
+      ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
+        s"""
+          | $sampler = new $samplerClass<UnsafeRow>($lowerBound, $upperBound, 
false);
+          | $sampler.setSeed(${seed}L + partitionIndex);
+         """.stripMargin.trim)
+
+      s"""
+         | if ($sampler.sample() == 0) continue;
+         | $numOutput.add(1);
+         | ${consume(ctx, input)}
+       """.stripMargin.trim
+    }
+  }
 }
 
 case class Range(
@@ -320,11 +380,7 @@ case class Range(
       | // initialize Range
       | if (!$initTerm) {
       |   $initTerm = true;
-      |   if ($input.hasNext()) {
-      |     initRange(((InternalRow) $input.next()).getInt(0));
-      |   } else {
-      |     return;
-      |   }
+      |   initRange(partitionIndex);
       | }
       |
       | while (!$overflow && $checkEnd) {

http://git-wip-us.apache.org/repos/asf/spark/blob/3e991dbc/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 003d3e0..5590679 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -85,6 +85,31 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     */
   }
 
+  ignore("range/sample/sum") {
+    val N = 500 << 20
+    runBenchmark("range/sample/sum", N) {
+      sqlContext.range(N).sample(true, 0.01).groupBy().sum().collect()
+    }
+    /*
+    Westmere E56xx/L56xx/X56xx (Nehalem-C)
+    range/sample/sum:                   Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)   Relative
+    
-------------------------------------------------------------------------------------------
+    range/sample/sum codegen=false         53888 / 56592          9.7         
102.8       1.0X
+    range/sample/sum codegen=true          41614 / 42607         12.6          
79.4       1.3X
+    */
+
+    runBenchmark("range/sample/sum", N) {
+      sqlContext.range(N).sample(false, 0.01).groupBy().sum().collect()
+    }
+    /*
+    Westmere E56xx/L56xx/X56xx (Nehalem-C)
+    range/sample/sum:                   Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)   Relative
+    
-------------------------------------------------------------------------------------------
+    range/sample/sum codegen=false         12982 / 13384         40.4          
24.8       1.0X
+    range/sample/sum codegen=true            7074 / 7383         74.1          
13.5       1.8X
+    */
+  }
+
   ignore("stat functions") {
     val N = 100L << 20
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to