Repository: incubator-hivemall
Updated Branches:
  refs/heads/master fdb4dd869 -> 52f05f43d


Close #37: [HIVEMALL-47][SPARK] Support codegen for top-K joins


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/52f05f43
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/52f05f43
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/52f05f43

Branch: refs/heads/master
Commit: 52f05f43d25256bb52f2201976b624fe5425e3da
Parents: fdb4dd8
Author: Takeshi Yamamuro <[email protected]>
Authored: Tue Feb 7 00:51:25 2017 +0900
Committer: Takeshi Yamamuro <[email protected]>
Committed: Tue Feb 7 00:51:25 2017 +0900

----------------------------------------------------------------------
 NOTICE                                          |   1 +
 .../joins/ShuffledHashJoinTopKExec.scala        | 307 +++++++++++++++++--
 .../sql/execution/benchmark/BenchmarkBase.scala |  55 ++++
 .../spark/sql/hive/HivemallOpsSuite.scala       | 105 ++++---
 .../sql/hive/benchmark/MiscBenchmark.scala      |  61 +++-
 .../hive/test/HivemallFeatureQueryTest.scala    |   6 +-
 6 files changed, 446 insertions(+), 89 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/NOTICE
----------------------------------------------------------------------
diff --git a/NOTICE b/NOTICE
index 699b055..0911f50 100644
--- a/NOTICE
+++ b/NOTICE
@@ -61,6 +61,7 @@ o 
hivemall/spark/spark-1.6/extra-src/hive/src/main/scala/org/apache/spark/sql/hi
   
hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
   
hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
   
hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+  
hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala
 
     Copyright (C) 2014-2017 The Apache Software Foundation.
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
 
b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
index c52cea1..caad646 100644
--- 
a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
+++ 
b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
@@ -25,11 +25,18 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric._
 import org.apache.spark.sql.types._
 
-// TODO: Need to support codegen
+abstract class PriorityQueueShim {
+
+  def insert(score: Any, row: InternalRow): Unit
+  def get(): Iterator[InternalRow]
+  def clear(): Unit
+}
+
 case class ShuffledHashJoinTopKExec(
     k: Int,
     leftKeys: Seq[Expression],
@@ -39,7 +46,10 @@ case class ShuffledHashJoinTopKExec(
     right: SparkPlan)(
     scoreExpr: NamedExpression,
     rankAttr: Seq[Attribute])
-  extends BinaryExecNode with TopKHelper with HashJoin {
+  extends BinaryExecNode with TopKHelper with HashJoin with CodegenSupport {
+
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
 
   override val scoreType: DataType = scoreExpr.dataType
   override val joinType: JoinType = Inner
@@ -56,6 +66,34 @@ case class ShuffledHashJoinTopKExec(
 
   private lazy val topKAttr = rankAttr :+ scoreExpr.toAttribute
 
+  private lazy val _priorityQueue = new PriorityQueueShim {
+
+    private val q: InternalRowPriorityQueue = queue
+    private val joinedRow = new JoinedRow
+
+    override def insert(score: Any, row: InternalRow): Unit = {
+      q += Tuple2(score, row)
+    }
+
+    override def get(): Iterator[InternalRow] = {
+      val topKRow = new UnsafeRow(2)
+      val bufferHolder = new BufferHolder(topKRow)
+      val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2)
+      val scoreWriter = ScoreWriter(unsafeRowWriter, 1)
+      q.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering).zipWithIndex.map {
+        case ((score, row), index) =>
+          // Writes to an UnsafeRow directly
+          bufferHolder.reset()
+          unsafeRowWriter.write(0, 1 + index)
+          scoreWriter.write(score)
+          topKRow.setTotalSize(bufferHolder.totalSize())
+          joinedRow.apply(topKRow, row)
+        }.iterator
+    }
+
+    override def clear(): Unit = q.clear()
+  }
+
   override def output: Seq[Attribute] = joinType match {
     case Inner => topKAttr ++ left.output ++ right.output
   }
@@ -67,7 +105,7 @@ case class ShuffledHashJoinTopKExec(
   override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
-  private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation 
= {
+  def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
     val context = TaskContext.get()
     val relation = HashedRelation(iter, buildKeys, taskMemoryManager = 
context.taskMemoryManager())
     context.addTaskCompletionListener(_ => relation.close())
@@ -94,28 +132,16 @@ case class ShuffledHashJoinTopKExec(
       val matches = hashedRelation.get(joinKeys)
       if (matches != null) {
         matches.map(joinRow.withRight).filter(boundCondition).foreach { 
resultRow =>
-          queue += Tuple2(scoreProjection(resultRow).get(0, scoreType), 
resultRow)
+          _priorityQueue.insert(scoreProjection(resultRow).get(0, scoreType), 
resultRow)
         }
-        val topKRow = new UnsafeRow(2)
-        val bufferHolder = new BufferHolder(topKRow)
-        val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2)
-        val scoreWriter = ScoreWriter(unsafeRowWriter, 1)
-        val iter = queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering)
-          .zipWithIndex.map { case ((score, row), index) =>
-            // Writes to an UnsafeRow directly
-            bufferHolder.reset()
-            unsafeRowWriter.write(0, 1 + index)
-            scoreWriter.write(score)
-            topKRow.setTotalSize(bufferHolder.totalSize())
-            new JoinedRow(topKRow, row)
-          }
-        queue.clear
+        val iter = _priorityQueue.get()
+        _priorityQueue.clear()
         iter
       } else {
         Seq.empty
       }
     }
-    val resultProj = createResultProjection
+    val resultProj = createResultProjection()
     (joinedIter ++ queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering)
       .map(_._2)).map { r =>
       resultProj(r)
@@ -128,4 +154,247 @@ case class ShuffledHashJoinTopKExec(
       InnerJoin(streamIter, hashed, null)
     }
   }
+
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    left.execute() :: right.execute() :: Nil
+  }
+
+  // Accessor for generated code
+  def priorityQueue(): PriorityQueueShim = _priorityQueue
+
+  /**
+   * Add a state of HashedRelation and return the variable name for it.
+   */
+  private def prepareHashedRelation(ctx: CodegenContext): String = {
+    // create a name for HashedRelation
+    val joinExec = ctx.addReferenceObj("joinExec", this)
+    val relationTerm = ctx.freshName("relation")
+    val clsName = HashedRelation.getClass.getName.replace("$", "")
+    ctx.addMutableState(clsName, relationTerm,
+      s"""
+         | $relationTerm = ($clsName) $joinExec.buildHashedRelation(inputs[1]);
+         | incPeakExecutionMemory($relationTerm.estimatedSize());
+       """.stripMargin)
+    relationTerm
+  }
+
+  /**
+   * Creates variables for left part of result row.
+   *
+   * In order to defer the access after condition and also only access once in 
the loop,
+   * the variables should be declared separately from accessing the columns, 
we can't use the
+   * codegen of BoundReference here.
+   */
+  private def createLeftVars(ctx: CodegenContext, leftRow: String): 
Seq[ExprCode] = {
+    ctx.INPUT_ROW = leftRow
+    left.output.zipWithIndex.map { case (a, i) =>
+      val value = ctx.freshName("value")
+      val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
+      // declare it as class member, so we can access the column before or in 
the loop.
+      ctx.addMutableState(ctx.javaType(a.dataType), value, "")
+      if (a.nullable) {
+        val isNull = ctx.freshName("isNull")
+        ctx.addMutableState("boolean", isNull, "")
+        val code =
+          s"""
+             |$isNull = $leftRow.isNullAt($i);
+             |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : 
($valueCode);
+           """.stripMargin
+        ExprCode(code, isNull, value)
+      } else {
+        ExprCode(s"$value = $valueCode;", "false", value)
+      }
+    }
+  }
+
+  /**
+   * Creates the variables for right part of result row, using BoundReference, 
since the right
+   * part are accessed inside the loop.
+   */
+  private def createRightVar(ctx: CodegenContext, rightRow: String): 
Seq[ExprCode] = {
+    ctx.INPUT_ROW = rightRow
+    right.output.zipWithIndex.map { case (a, i) =>
+      BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+    }
+  }
+
+  /**
+   * Returns the code for generating join key for stream side, and expression 
of whether the key
+   * has any null in it or not.
+   */
+  private def genStreamSideJoinKey(ctx: CodegenContext, leftRow: String): 
(ExprCode, String) = {
+    ctx.INPUT_ROW = leftRow
+    if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
+      // generate the join key as Long
+      val ev = streamedKeys.head.genCode(ctx)
+      (ev, ev.isNull)
+    } else {
+      // generate the join key as UnsafeRow
+      val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
+      (ev, s"${ev.value}.anyNull()")
+    }
+  }
+
+  private def createScoreVar(ctx: CodegenContext, row: String): ExprCode = {
+    ctx.INPUT_ROW = row
+    BindReferences.bindReference(scoreExpr, left.output ++ 
right.output).genCode(ctx)
+  }
+
+  private def createResultVars(ctx: CodegenContext, resultRow: String): 
Seq[ExprCode] = {
+    ctx.INPUT_ROW = resultRow
+    output.zipWithIndex.map { case (a, i) =>
+      val value = ctx.freshName("value")
+      val valueCode = ctx.getValue(resultRow, a.dataType, i.toString)
+      // declare it as class member, so we can access the column before or in 
the loop.
+      ctx.addMutableState(ctx.javaType(a.dataType), value, "")
+      if (a.nullable) {
+        val isNull = ctx.freshName("isNull")
+        ctx.addMutableState("boolean", isNull, "")
+        val code =
+          s"""
+             |$isNull = $resultRow.isNullAt($i);
+             |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : 
($valueCode);
+           """.stripMargin
+        ExprCode(code, isNull, value)
+      } else {
+        ExprCode(s"$value = $valueCode;", "false", value)
+      }
+    }
+  }
+
+  /**
+   * Splits variables based on whether it's used by condition or not, returns 
the code to create
+   * these variables before the condition and after the condition.
+   *
+   * Only a few columns are used by condition, then we can skip the accessing 
of those columns
+   * that are not used by condition also filtered out by condition.
+   */
+  private def splitVarsByCondition(
+      attributes: Seq[Attribute],
+      variables: Seq[ExprCode]): (String, String) = {
+    if (condition.isDefined) {
+      val condRefs = condition.get.references
+      val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) 
=>
+        condRefs.contains(a)
+      }
+      val beforeCond = evaluateVariables(used.map(_._2))
+      val afterCond = evaluateVariables(notUsed.map(_._2))
+      (beforeCond, afterCond)
+    } else {
+      (evaluateVariables(variables), "")
+    }
+  }
+
+  override def doProduce(ctx: CodegenContext): String = {
+    ctx.copyResult = true
+
+    val topKJoin = ctx.addReferenceObj("topKJoin", this)
+
+    // Prepare a priority queue for top-K computing
+    val pQueue = ctx.freshName("queue")
+    ctx.addMutableState(classOf[PriorityQueueShim].getName, pQueue,
+      s"$pQueue = $topKJoin.priorityQueue();")
+
+    // Prepare variables for a left side
+    val leftIter = ctx.freshName("leftIter")
+    ctx.addMutableState("scala.collection.Iterator", leftIter, s"$leftIter = 
inputs[0];")
+    val leftRow = ctx.freshName("leftRow")
+    ctx.addMutableState("InternalRow", leftRow, "")
+    val leftVars = createLeftVars(ctx, leftRow)
+
+    // Prepare variables for a right side
+    val rightRow = ctx.freshName("rightRow")
+    val rightVars = createRightVar(ctx, rightRow)
+
+    // Build a hashed relation from a right side
+    val buildRelation = prepareHashedRelation(ctx)
+
+    // Project join keys from a left side
+    val (keyEv, anyNull) = genStreamSideJoinKey(ctx, leftRow)
+
+    // Prepare variables for joined rows
+    val joinedRow = ctx.freshName("joinedRow")
+    val joinedRowCls = classOf[JoinedRow].getName
+    ctx.addMutableState(joinedRowCls, joinedRow, s"$joinedRow = new 
$joinedRowCls();")
+
+    // Project score values from joined rows
+    val scoreVar = createScoreVar(ctx, joinedRow)
+
+    // Prepare variables for output rows
+    val resultRow = ctx.freshName("resultRow")
+    val resultVars = createResultVars(ctx, resultRow)
+
+    val (beforeLoop, condCheck) = if (condition.isDefined) {
+      // Split the code of creating variables based on whether it's used by 
condition or not.
+      val loaded = ctx.freshName("loaded")
+      val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
+      val (rightBefore, rightAfter) = splitVarsByCondition(right.output, 
rightVars)
+      // Generate code for condition
+      ctx.currentVars = leftVars ++ rightVars
+      val cond = BindReferences.bindReference(condition.get, 
output).genCode(ctx)
+      // evaluate the columns those used by condition before loop
+      val before = s"""
+           |boolean $loaded = false;
+           |$leftBefore
+         """.stripMargin
+
+      val checking = s"""
+         |$rightBefore
+         |${cond.code}
+         |if (${cond.isNull} || !${cond.value}) continue;
+         |if (!$loaded) {
+         |  $loaded = true;
+         |  $leftAfter
+         |}
+         |$rightAfter
+     """.stripMargin
+      (before, checking)
+    } else {
+      (evaluateVariables(leftVars), "")
+    }
+
+    val numOutput = metricTerm(ctx, "numOutputRows")
+
+    val matches = ctx.freshName("matches")
+    val topKRows = ctx.freshName("topKRows")
+    val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+
+    s"""
+       |$leftRow = null;
+       |while ($leftIter.hasNext()) {
+       |  $leftRow = (InternalRow) $leftIter.next();
+       |
+       |  // Generate join key for stream side
+       |  ${keyEv.code}
+       |
+       |  // Find matches from HashedRelation
+       |  $iteratorCls $matches = $anyNull? null : 
($iteratorCls)$buildRelation.get(${keyEv.value});
+       |  if ($matches == null) continue;
+       |
+       |  // Join top-K right rows
+       |  while ($matches.hasNext()) {
+       |    ${beforeLoop.trim}
+       |    InternalRow $rightRow = (InternalRow) $matches.next();
+       |    ${condCheck.trim}
+       |    InternalRow row = $joinedRow.apply($leftRow, $rightRow);
+       |    // Compute a score for the `row`
+       |    ${scoreVar.code}
+       |    $pQueue.insert(${scoreVar.value}, row);
+       |  }
+       |
+       |  // Get top-K rows
+       |  $iteratorCls $topKRows = $pQueue.get();
+       |  $pQueue.clear();
+       |
+       |  // Output top-K rows
+       |  while ($topKRows.hasNext()) {
+       |    InternalRow $resultRow = (InternalRow) $topKRows.next();
+       |    $numOutput.add(1);
+       |    ${consume(ctx, resultVars)}
+       |  }
+       |
+       |  if (shouldStop()) return;
+       |}
+     """.stripMargin
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/spark/spark-2.1/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala
 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala
new file mode 100644
index 0000000..5bb7fbe
--- /dev/null
+++ 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.execution.benchmark
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.Benchmark
+
+/**
+ * Common base trait for micro benchmarks that are supposed to run standalone 
(i.e. not together
+ * with other test suites).
+ */
+private[sql] trait BenchmarkBase extends SparkFunSuite {
+
+  lazy val sparkSession = SparkSession.builder
+    .master("local[1]")
+    .appName("microbenchmark")
+    .config("spark.sql.shuffle.partitions", 1)
+    .config("spark.sql.autoBroadcastJoinThreshold", 1)
+    .getOrCreate()
+
+  /** Runs function `f` with whole stage codegen on and off. */
+  def runBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = {
+    val benchmark = new Benchmark(name, cardinality)
+
+    benchmark.addCase(s"$name wholestage off", numIters = 2) { iter =>
+      sparkSession.conf.set("spark.sql.codegen.wholeStage", value = false)
+      f
+    }
+
+    benchmark.addCase(s"$name wholestage on", numIters = 5) { iter =>
+      sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true)
+      f
+    }
+
+    benchmark.run()
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index e5775b1..76195fd 100644
--- 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.hive.HivemallGroupedDataset._
 import org.apache.spark.sql.hive.HivemallOps._
 import org.apache.spark.sql.hive.HivemallUtils._
 import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.VectorQueryTest
 import org.apache.spark.sql.types._
 import org.apache.spark.test.TestFPWrapper._
@@ -349,56 +350,60 @@ final class HivemallOpsWithFeatureSuite extends 
HivemallFeatureQueryTest {
   }
 
   test("misc - join_top_k") {
-    import hiveContext.implicits._
-    val inputDf = Seq(
-      ("user1", 1, 0.3, 0.5),
-      ("user2", 2, 0.1, 0.1),
-      ("user3", 3, 0.8, 0.0),
-      ("user4", 1, 0.9, 0.9),
-      ("user5", 3, 0.7, 0.2),
-      ("user6", 1, 0.5, 0.4),
-      ("user7", 2, 0.6, 0.8)
-    ).toDF("userId", "group", "x", "y")
-
-    val masterDf = Seq(
-      (1, "pos1-1", 0.5, 0.1),
-      (1, "pos1-2", 0.0, 0.0),
-      (1, "pos1-3", 0.3, 0.3),
-      (2, "pos2-3", 0.1, 0.3),
-      (2, "pos2-3", 0.8, 0.8),
-      (3, "pos3-1", 0.1, 0.7),
-      (3, "pos3-1", 0.7, 0.1),
-      (3, "pos3-1", 0.9, 0.0),
-      (3, "pos3-1", 0.1, 0.3)
-    ).toDF("group", "position", "x", "y")
-
-    // Compute top-1 rows for each group
-    val distance = sqrt(
-      pow(inputDf("x") - masterDf("x"), lit(2.0)) +
-      pow(inputDf("y") - masterDf("y"), lit(2.0))
-    ).as("score")
-    val top1Df = inputDf.top_k_join(
-      lit(1), masterDf, inputDf("group") === masterDf("group"), distance)
-    assert(top1Df.schema.toSet === Set(
-      StructField("rank", IntegerType, nullable = true),
-      StructField("score", DoubleType, nullable = true),
-      StructField("group", IntegerType, nullable = false),
-      StructField("userId", StringType, nullable = true),
-      StructField("position", StringType, nullable = true),
-      StructField("x", DoubleType, nullable = false),
-      StructField("y", DoubleType, nullable = false)
-    ))
-    checkAnswer(
-      top1Df.select($"rank", inputDf("group"), $"userId", $"position"),
-      Row(1, 1, "user1", "pos1-2") ::
-      Row(1, 2, "user2", "pos2-3") ::
-      Row(1, 3, "user3", "pos3-1") ::
-      Row(1, 1, "user4", "pos1-2") ::
-      Row(1, 3, "user5", "pos3-1") ::
-      Row(1, 1, "user6", "pos1-2") ::
-      Row(1, 2, "user7", "pos2-3") ::
-      Nil
-    )
+    Seq("true", "false").map { flag =>
+      withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) {
+        import hiveContext.implicits._
+        val inputDf = Seq(
+          ("user1", 1, 0.3, 0.5),
+          ("user2", 2, 0.1, 0.1),
+          ("user3", 3, 0.8, 0.0),
+          ("user4", 1, 0.9, 0.9),
+          ("user5", 3, 0.7, 0.2),
+          ("user6", 1, 0.5, 0.4),
+          ("user7", 2, 0.6, 0.8)
+        ).toDF("userId", "group", "x", "y")
+
+        val masterDf = Seq(
+          (1, "pos1-1", 0.5, 0.1),
+          (1, "pos1-2", 0.0, 0.0),
+          (1, "pos1-3", 0.3, 0.3),
+          (2, "pos2-3", 0.1, 0.3),
+          (2, "pos2-3", 0.8, 0.8),
+          (3, "pos3-1", 0.1, 0.7),
+          (3, "pos3-1", 0.7, 0.1),
+          (3, "pos3-1", 0.9, 0.0),
+          (3, "pos3-1", 0.1, 0.3)
+        ).toDF("group", "position", "x", "y")
+
+        // Compute top-1 rows for each group
+        val distance = sqrt(
+          pow(inputDf("x") - masterDf("x"), lit(2.0)) +
+            pow(inputDf("y") - masterDf("y"), lit(2.0))
+        ).as("score")
+        val top1Df = inputDf.top_k_join(
+          lit(1), masterDf, inputDf("group") === masterDf("group"), distance)
+        assert(top1Df.schema.toSet === Set(
+          StructField("rank", IntegerType, nullable = true),
+          StructField("score", DoubleType, nullable = true),
+          StructField("group", IntegerType, nullable = false),
+          StructField("userId", StringType, nullable = true),
+          StructField("position", StringType, nullable = true),
+          StructField("x", DoubleType, nullable = false),
+          StructField("y", DoubleType, nullable = false)
+        ))
+        checkAnswer(
+          top1Df.select($"rank", inputDf("group"), $"userId", $"position"),
+          Row(1, 1, "user1", "pos1-2") ::
+          Row(1, 2, "user2", "pos2-3") ::
+          Row(1, 3, "user3", "pos3-1") ::
+          Row(1, 1, "user4", "pos1-2") ::
+          Row(1, 3, "user5", "pos3-1") ::
+          Row(1, 1, "user6", "pos1-2") ::
+          Row(1, 2, "user7", "pos2-3") ::
+          Nil
+        )
+      }
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala
 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala
index 2f6f9b9..8da0776 100644
--- 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala
+++ 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala
@@ -18,12 +18,12 @@
  */
 package org.apache.spark.sql.hive.benchmark
 
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.benchmark.BenchmarkBase
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.hive.HivemallOps._
@@ -79,23 +79,14 @@ object TestFuncWrapper {
   @inline private def withExpr(expr: Expression): Column = Column(expr)
 }
 
-class MiscBenchmark extends SparkFunSuite {
-
-  lazy val sparkSession = SparkSession.builder
-    .master("local[1]")
-    .appName("microbenchmark")
-    .config("spark.sql.shuffle.partitions", 1)
-    .config("spark.sql.codegen.wholeStage", true)
-    .getOrCreate()
+class MiscBenchmark extends BenchmarkBase {
 
   val numIters = 10
 
   private def addBenchmarkCase(name: String, df: DataFrame)(implicit 
benchmark: Benchmark): Unit = {
-    // TODO: This query below failed in `each_top_k`
-    // benchmark.addCase(name, numIters) {
-    //   _ => df.queryExecution.executedPlan.execute().foreach(x => {})
-    // }
-    benchmark.addCase(name, numIters) { _ => df.count }
+    benchmark.addCase(name, numIters) {
+      _ => df.queryExecution.executedPlan.execute().foreach(x => {})
+    }
   }
 
   TestUtils.benchmark("closure/exprs/spark-udf/hive-udf") {
@@ -180,7 +171,7 @@ class MiscBenchmark extends SparkFunSuite {
     )
     addBenchmarkCase(
       "each_top_k (exprs)",
-      testDf.each_top_k(lit(topK), $"x".as("score"), $"key")
+      testDf.each_top_k(lit(topK), $"x".as("score"), $"key".as("group"))
     )
     benchmark.run()
   }
@@ -229,7 +220,7 @@ class MiscBenchmark extends SparkFunSuite {
     addBenchmarkCase(
       "join + each_top_k",
       inputDf.join(masterDf, inputDf("group") === masterDf("group"))
-        .each_top_k(lit(topK), distance, inputDf("group"))
+        .each_top_k(lit(topK), distance, inputDf("group").as("group"))
     )
     addBenchmarkCase(
       "top_k_join",
@@ -237,4 +228,40 @@ class MiscBenchmark extends SparkFunSuite {
     )
     benchmark.run()
   }
+
+  TestUtils.benchmark("codegen top-k join") {
+    /**
+     * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2
+     * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz
+     *
+     * top_k_join:                 Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)   Relative
+     * 
-----------------------------------------------------------------------------------
+     * top_k_join wholestage off           3 /    5       2751.9           0.4 
      1.0X
+     * top_k_join wholestage on            1 /    1       6494.4           0.2 
      2.4X
+     */
+    val topK = 3
+    val N = 1L << 23
+    val M = 1L << 22
+    val numGroup = 3
+    val inputDf = sparkSession.range(N).selectExpr(
+      s"CAST(rand() * $numGroup AS INT) AS group", "id AS userId", "rand() AS 
x", "rand() AS y"
+    ).cache
+    val masterDf = sparkSession.range(M).selectExpr(
+      s"id % $numGroup AS group", "id AS posId", "rand() AS x", "rand() AS y"
+    ).cache
+
+    // First, cache data
+    inputDf.count
+    masterDf.count
+
+    // Define a score column
+    val distance = sqrt(
+      pow(inputDf("x") - masterDf("x"), lit(2.0)) +
+      pow(inputDf("y") - masterDf("y"), lit(2.0))
+    )
+    runBenchmark("top_k_join", N) {
+      inputDf.top_k_join(lit(topK), masterDf, inputDf("group") === 
masterDf("group"),
+        distance.as("score"))
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala
 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala
index a4733f5..eef6e63 100644
--- 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala
+++ 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala
@@ -23,15 +23,15 @@ import scala.reflect.runtime.universe.TypeTag
 
 import hivemall.tools.RegressionDatagen
 
-import org.apache.spark.sql.Column
+import org.apache.spark.sql.{Column, QueryTest}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
 import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.test.SQLTestUtils
 
 /**
  * Base class for tests with Hivemall features.
  */
-abstract class HivemallFeatureQueryTest extends QueryTest with 
TestHiveSingleton {
+abstract class HivemallFeatureQueryTest extends QueryTest with SQLTestUtils 
with TestHiveSingleton {
 
   import hiveContext.implicits._
 

Reply via email to