Repository: incubator-hivemall Updated Branches: refs/heads/master fdb4dd869 -> 52f05f43d
Close #37: [HIVEMALL-47][SPARK] Support codegen for top-K joins Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/52f05f43 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/52f05f43 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/52f05f43 Branch: refs/heads/master Commit: 52f05f43d25256bb52f2201976b624fe5425e3da Parents: fdb4dd8 Author: Takeshi Yamamuro <[email protected]> Authored: Tue Feb 7 00:51:25 2017 +0900 Committer: Takeshi Yamamuro <[email protected]> Committed: Tue Feb 7 00:51:25 2017 +0900 ---------------------------------------------------------------------- NOTICE | 1 + .../joins/ShuffledHashJoinTopKExec.scala | 307 +++++++++++++++++-- .../sql/execution/benchmark/BenchmarkBase.scala | 55 ++++ .../spark/sql/hive/HivemallOpsSuite.scala | 105 ++++--- .../sql/hive/benchmark/MiscBenchmark.scala | 61 +++- .../hive/test/HivemallFeatureQueryTest.scala | 6 +- 6 files changed, 446 insertions(+), 89 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/NOTICE ---------------------------------------------------------------------- diff --git a/NOTICE b/NOTICE index 699b055..0911f50 100644 --- a/NOTICE +++ b/NOTICE @@ -61,6 +61,7 @@ o hivemall/spark/spark-1.6/extra-src/hive/src/main/scala/org/apache/spark/sql/hi hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala + hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala Copyright (C) 2014-2017 The Apache Software Foundation. http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala index c52cea1..caad646 100644 --- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala @@ -25,11 +25,18 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric._ import org.apache.spark.sql.types._ -// TODO: Need to support codegen +abstract class PriorityQueueShim { + + def insert(score: Any, row: InternalRow): Unit + def get(): Iterator[InternalRow] + def clear(): Unit +} + case class ShuffledHashJoinTopKExec( k: Int, leftKeys: Seq[Expression], @@ -39,7 +46,10 @@ case class ShuffledHashJoinTopKExec( right: SparkPlan)( scoreExpr: NamedExpression, rankAttr: Seq[Attribute]) - extends BinaryExecNode with TopKHelper with HashJoin { + extends BinaryExecNode with TopKHelper with HashJoin with CodegenSupport { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override val scoreType: DataType = scoreExpr.dataType override val joinType: JoinType = Inner @@ -56,6 +66,34 @@ case class ShuffledHashJoinTopKExec( private lazy val topKAttr = rankAttr :+ scoreExpr.toAttribute + private lazy val _priorityQueue = new PriorityQueueShim { + + private val q: InternalRowPriorityQueue = queue + private val joinedRow = new JoinedRow + + override def insert(score: Any, row: InternalRow): Unit = { + q += Tuple2(score, row) + } + + override def get(): Iterator[InternalRow] = { + val topKRow = new UnsafeRow(2) + val bufferHolder = new BufferHolder(topKRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2) + val scoreWriter = ScoreWriter(unsafeRowWriter, 1) + q.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering).zipWithIndex.map { + case ((score, row), index) => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, 1 + index) + scoreWriter.write(score) + topKRow.setTotalSize(bufferHolder.totalSize()) + joinedRow.apply(topKRow, row) + }.iterator + } + + override def clear(): Unit = q.clear() + } + override def output: Seq[Attribute] = joinType match { case Inner => topKAttr ++ left.output ++ right.output } @@ -67,7 +105,7 @@ case class ShuffledHashJoinTopKExec( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { + def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val context = TaskContext.get() val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) context.addTaskCompletionListener(_ => relation.close()) @@ -94,28 +132,16 @@ case class ShuffledHashJoinTopKExec( val matches = hashedRelation.get(joinKeys) if (matches != null) { matches.map(joinRow.withRight).filter(boundCondition).foreach { resultRow => - queue += Tuple2(scoreProjection(resultRow).get(0, scoreType), resultRow) + _priorityQueue.insert(scoreProjection(resultRow).get(0, scoreType), resultRow) } - val topKRow = new UnsafeRow(2) - val bufferHolder = new BufferHolder(topKRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2) - val scoreWriter = ScoreWriter(unsafeRowWriter, 1) - val iter = queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering) - .zipWithIndex.map { case ((score, row), index) => - // Writes to an UnsafeRow directly - bufferHolder.reset() - unsafeRowWriter.write(0, 1 + index) - scoreWriter.write(score) - topKRow.setTotalSize(bufferHolder.totalSize()) - new JoinedRow(topKRow, row) - } - queue.clear + val iter = _priorityQueue.get() + _priorityQueue.clear() iter } else { Seq.empty } } - val resultProj = createResultProjection + val resultProj = createResultProjection() (joinedIter ++ queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering) .map(_._2)).map { r => resultProj(r) @@ -128,4 +154,247 @@ case class ShuffledHashJoinTopKExec( InnerJoin(streamIter, hashed, null) } } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + left.execute() :: right.execute() :: Nil + } + + // Accessor for generated code + def priorityQueue(): PriorityQueueShim = _priorityQueue + + /** + * Add a state of HashedRelation and return the variable name for it. + */ + private def prepareHashedRelation(ctx: CodegenContext): String = { + // create a name for HashedRelation + val joinExec = ctx.addReferenceObj("joinExec", this) + val relationTerm = ctx.freshName("relation") + val clsName = HashedRelation.getClass.getName.replace("$", "") + ctx.addMutableState(clsName, relationTerm, + s""" + | $relationTerm = ($clsName) $joinExec.buildHashedRelation(inputs[1]); + | incPeakExecutionMemory($relationTerm.estimatedSize()); + """.stripMargin) + relationTerm + } + + /** + * Creates variables for left part of result row. + * + * In order to defer the access after condition and also only access once in the loop, + * the variables should be declared separately from accessing the columns, we can't use the + * codegen of BoundReference here. + */ + private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = leftRow + left.output.zipWithIndex.map { case (a, i) => + val value = ctx.freshName("value") + val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) + // declare it as class member, so we can access the column before or in the loop. + ctx.addMutableState(ctx.javaType(a.dataType), value, "") + if (a.nullable) { + val isNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", isNull, "") + val code = + s""" + |$isNull = $leftRow.isNullAt($i); + |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$value = $valueCode;", "false", value) + } + } + } + + /** + * Creates the variables for right part of result row, using BoundReference, since the right + * part are accessed inside the loop. + */ + private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = rightRow + right.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + } + + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + private def genStreamSideJoinKey(ctx: CodegenContext, leftRow: String): (ExprCode, String) = { + ctx.INPUT_ROW = leftRow + if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { + // generate the join key as Long + val ev = streamedKeys.head.genCode(ctx) + (ev, ev.isNull) + } else { + // generate the join key as UnsafeRow + val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) + (ev, s"${ev.value}.anyNull()") + } + } + + private def createScoreVar(ctx: CodegenContext, row: String): ExprCode = { + ctx.INPUT_ROW = row + BindReferences.bindReference(scoreExpr, left.output ++ right.output).genCode(ctx) + } + + private def createResultVars(ctx: CodegenContext, resultRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = resultRow + output.zipWithIndex.map { case (a, i) => + val value = ctx.freshName("value") + val valueCode = ctx.getValue(resultRow, a.dataType, i.toString) + // declare it as class member, so we can access the column before or in the loop. + ctx.addMutableState(ctx.javaType(a.dataType), value, "") + if (a.nullable) { + val isNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", isNull, "") + val code = + s""" + |$isNull = $resultRow.isNullAt($i); + |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$value = $valueCode;", "false", value) + } + } + } + + /** + * Splits variables based on whether it's used by condition or not, returns the code to create + * these variables before the condition and after the condition. + * + * Only a few columns are used by condition, then we can skip the accessing of those columns + * that are not used by condition also filtered out by condition. + */ + private def splitVarsByCondition( + attributes: Seq[Attribute], + variables: Seq[ExprCode]): (String, String) = { + if (condition.isDefined) { + val condRefs = condition.get.references + val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => + condRefs.contains(a) + } + val beforeCond = evaluateVariables(used.map(_._2)) + val afterCond = evaluateVariables(notUsed.map(_._2)) + (beforeCond, afterCond) + } else { + (evaluateVariables(variables), "") + } + } + + override def doProduce(ctx: CodegenContext): String = { + ctx.copyResult = true + + val topKJoin = ctx.addReferenceObj("topKJoin", this) + + // Prepare a priority queue for top-K computing + val pQueue = ctx.freshName("queue") + ctx.addMutableState(classOf[PriorityQueueShim].getName, pQueue, + s"$pQueue = $topKJoin.priorityQueue();") + + // Prepare variables for a left side + val leftIter = ctx.freshName("leftIter") + ctx.addMutableState("scala.collection.Iterator", leftIter, s"$leftIter = inputs[0];") + val leftRow = ctx.freshName("leftRow") + ctx.addMutableState("InternalRow", leftRow, "") + val leftVars = createLeftVars(ctx, leftRow) + + // Prepare variables for a right side + val rightRow = ctx.freshName("rightRow") + val rightVars = createRightVar(ctx, rightRow) + + // Build a hashed relation from a right side + val buildRelation = prepareHashedRelation(ctx) + + // Project join keys from a left side + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, leftRow) + + // Prepare variables for joined rows + val joinedRow = ctx.freshName("joinedRow") + val joinedRowCls = classOf[JoinedRow].getName + ctx.addMutableState(joinedRowCls, joinedRow, s"$joinedRow = new $joinedRowCls();") + + // Project score values from joined rows + val scoreVar = createScoreVar(ctx, joinedRow) + + // Prepare variables for output rows + val resultRow = ctx.freshName("resultRow") + val resultVars = createResultVars(ctx, resultRow) + + val (beforeLoop, condCheck) = if (condition.isDefined) { + // Split the code of creating variables based on whether it's used by condition or not. + val loaded = ctx.freshName("loaded") + val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) + val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + // Generate code for condition + ctx.currentVars = leftVars ++ rightVars + val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) + // evaluate the columns those used by condition before loop + val before = s""" + |boolean $loaded = false; + |$leftBefore + """.stripMargin + + val checking = s""" + |$rightBefore + |${cond.code} + |if (${cond.isNull} || !${cond.value}) continue; + |if (!$loaded) { + | $loaded = true; + | $leftAfter + |} + |$rightAfter + """.stripMargin + (before, checking) + } else { + (evaluateVariables(leftVars), "") + } + + val numOutput = metricTerm(ctx, "numOutputRows") + + val matches = ctx.freshName("matches") + val topKRows = ctx.freshName("topKRows") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + + s""" + |$leftRow = null; + |while ($leftIter.hasNext()) { + | $leftRow = (InternalRow) $leftIter.next(); + | + | // Generate join key for stream side + | ${keyEv.code} + | + | // Find matches from HashedRelation + | $iteratorCls $matches = $anyNull? null : ($iteratorCls)$buildRelation.get(${keyEv.value}); + | if ($matches == null) continue; + | + | // Join top-K right rows + | while ($matches.hasNext()) { + | ${beforeLoop.trim} + | InternalRow $rightRow = (InternalRow) $matches.next(); + | ${condCheck.trim} + | InternalRow row = $joinedRow.apply($leftRow, $rightRow); + | // Compute a score for the `row` + | ${scoreVar.code} + | $pQueue.insert(${scoreVar.value}, row); + | } + | + | // Get top-K rows + | $iteratorCls $topKRows = $pQueue.get(); + | $pQueue.clear(); + | + | // Output top-K rows + | while ($topKRows.hasNext()) { + | InternalRow $resultRow = (InternalRow) $topKRows.next(); + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | + | if (shouldStop()) return; + |} + """.stripMargin + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/spark/spark-2.1/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala new file mode 100644 index 0000000..5bb7fbe --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.util.Benchmark + +/** + * Common base trait for micro benchmarks that are supposed to run standalone (i.e. not together + * with other test suites). + */ +private[sql] trait BenchmarkBase extends SparkFunSuite { + + lazy val sparkSession = SparkSession.builder + .master("local[1]") + .appName("microbenchmark") + .config("spark.sql.shuffle.partitions", 1) + .config("spark.sql.autoBroadcastJoinThreshold", 1) + .getOrCreate() + + /** Runs function `f` with whole stage codegen on and off. */ + def runBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { + val benchmark = new Benchmark(name, cardinality) + + benchmark.addCase(s"$name wholestage off", numIters = 2) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", value = false) + f + } + + benchmark.addCase(s"$name wholestage on", numIters = 5) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) + f + } + + benchmark.run() + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala index e5775b1..76195fd 100644 --- a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.hive.HivemallGroupedDataset._ import org.apache.spark.sql.hive.HivemallOps._ import org.apache.spark.sql.hive.HivemallUtils._ import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.VectorQueryTest import org.apache.spark.sql.types._ import org.apache.spark.test.TestFPWrapper._ @@ -349,56 +350,60 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { } test("misc - join_top_k") { - import hiveContext.implicits._ - val inputDf = Seq( - ("user1", 1, 0.3, 0.5), - ("user2", 2, 0.1, 0.1), - ("user3", 3, 0.8, 0.0), - ("user4", 1, 0.9, 0.9), - ("user5", 3, 0.7, 0.2), - ("user6", 1, 0.5, 0.4), - ("user7", 2, 0.6, 0.8) - ).toDF("userId", "group", "x", "y") - - val masterDf = Seq( - (1, "pos1-1", 0.5, 0.1), - (1, "pos1-2", 0.0, 0.0), - (1, "pos1-3", 0.3, 0.3), - (2, "pos2-3", 0.1, 0.3), - (2, "pos2-3", 0.8, 0.8), - (3, "pos3-1", 0.1, 0.7), - (3, "pos3-1", 0.7, 0.1), - (3, "pos3-1", 0.9, 0.0), - (3, "pos3-1", 0.1, 0.3) - ).toDF("group", "position", "x", "y") - - // Compute top-1 rows for each group - val distance = sqrt( - pow(inputDf("x") - masterDf("x"), lit(2.0)) + - pow(inputDf("y") - masterDf("y"), lit(2.0)) - ).as("score") - val top1Df = inputDf.top_k_join( - lit(1), masterDf, inputDf("group") === masterDf("group"), distance) - assert(top1Df.schema.toSet === Set( - StructField("rank", IntegerType, nullable = true), - StructField("score", DoubleType, nullable = true), - StructField("group", IntegerType, nullable = false), - StructField("userId", StringType, nullable = true), - StructField("position", StringType, nullable = true), - StructField("x", DoubleType, nullable = false), - StructField("y", DoubleType, nullable = false) - )) - checkAnswer( - top1Df.select($"rank", inputDf("group"), $"userId", $"position"), - Row(1, 1, "user1", "pos1-2") :: - Row(1, 2, "user2", "pos2-3") :: - Row(1, 3, "user3", "pos3-1") :: - Row(1, 1, "user4", "pos1-2") :: - Row(1, 3, "user5", "pos3-1") :: - Row(1, 1, "user6", "pos1-2") :: - Row(1, 2, "user7", "pos2-3") :: - Nil - ) + Seq("true", "false").map { flag => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) { + import hiveContext.implicits._ + val inputDf = Seq( + ("user1", 1, 0.3, 0.5), + ("user2", 2, 0.1, 0.1), + ("user3", 3, 0.8, 0.0), + ("user4", 1, 0.9, 0.9), + ("user5", 3, 0.7, 0.2), + ("user6", 1, 0.5, 0.4), + ("user7", 2, 0.6, 0.8) + ).toDF("userId", "group", "x", "y") + + val masterDf = Seq( + (1, "pos1-1", 0.5, 0.1), + (1, "pos1-2", 0.0, 0.0), + (1, "pos1-3", 0.3, 0.3), + (2, "pos2-3", 0.1, 0.3), + (2, "pos2-3", 0.8, 0.8), + (3, "pos3-1", 0.1, 0.7), + (3, "pos3-1", 0.7, 0.1), + (3, "pos3-1", 0.9, 0.0), + (3, "pos3-1", 0.1, 0.3) + ).toDF("group", "position", "x", "y") + + // Compute top-1 rows for each group + val distance = sqrt( + pow(inputDf("x") - masterDf("x"), lit(2.0)) + + pow(inputDf("y") - masterDf("y"), lit(2.0)) + ).as("score") + val top1Df = inputDf.top_k_join( + lit(1), masterDf, inputDf("group") === masterDf("group"), distance) + assert(top1Df.schema.toSet === Set( + StructField("rank", IntegerType, nullable = true), + StructField("score", DoubleType, nullable = true), + StructField("group", IntegerType, nullable = false), + StructField("userId", StringType, nullable = true), + StructField("position", StringType, nullable = true), + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType, nullable = false) + )) + checkAnswer( + top1Df.select($"rank", inputDf("group"), $"userId", $"position"), + Row(1, 1, "user1", "pos1-2") :: + Row(1, 2, "user2", "pos2-3") :: + Row(1, 3, "user3", "pos3-1") :: + Row(1, 1, "user4", "pos1-2") :: + Row(1, 3, "user5", "pos3-1") :: + Row(1, 1, "user6", "pos1-2") :: + Row(1, 2, "user7", "pos2-3") :: + Nil + ) + } + } } /** http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala index 2f6f9b9..8da0776 100644 --- a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala @@ -18,12 +18,12 @@ */ package org.apache.spark.sql.hive.benchmark -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.benchmark.BenchmarkBase import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HivemallOps._ @@ -79,23 +79,14 @@ object TestFuncWrapper { @inline private def withExpr(expr: Expression): Column = Column(expr) } -class MiscBenchmark extends SparkFunSuite { - - lazy val sparkSession = SparkSession.builder - .master("local[1]") - .appName("microbenchmark") - .config("spark.sql.shuffle.partitions", 1) - .config("spark.sql.codegen.wholeStage", true) - .getOrCreate() +class MiscBenchmark extends BenchmarkBase { val numIters = 10 private def addBenchmarkCase(name: String, df: DataFrame)(implicit benchmark: Benchmark): Unit = { - // TODO: This query below failed in `each_top_k` - // benchmark.addCase(name, numIters) { - // _ => df.queryExecution.executedPlan.execute().foreach(x => {}) - // } - benchmark.addCase(name, numIters) { _ => df.count } + benchmark.addCase(name, numIters) { + _ => df.queryExecution.executedPlan.execute().foreach(x => {}) + } } TestUtils.benchmark("closure/exprs/spark-udf/hive-udf") { @@ -180,7 +171,7 @@ class MiscBenchmark extends SparkFunSuite { ) addBenchmarkCase( "each_top_k (exprs)", - testDf.each_top_k(lit(topK), $"x".as("score"), $"key") + testDf.each_top_k(lit(topK), $"x".as("score"), $"key".as("group")) ) benchmark.run() } @@ -229,7 +220,7 @@ class MiscBenchmark extends SparkFunSuite { addBenchmarkCase( "join + each_top_k", inputDf.join(masterDf, inputDf("group") === masterDf("group")) - .each_top_k(lit(topK), distance, inputDf("group")) + .each_top_k(lit(topK), distance, inputDf("group").as("group")) ) addBenchmarkCase( "top_k_join", @@ -237,4 +228,40 @@ class MiscBenchmark extends SparkFunSuite { ) benchmark.run() } + + TestUtils.benchmark("codegen top-k join") { + /** + * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + * + * top_k_join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + * ----------------------------------------------------------------------------------- + * top_k_join wholestage off 3 / 5 2751.9 0.4 1.0X + * top_k_join wholestage on 1 / 1 6494.4 0.2 2.4X + */ + val topK = 3 + val N = 1L << 23 + val M = 1L << 22 + val numGroup = 3 + val inputDf = sparkSession.range(N).selectExpr( + s"CAST(rand() * $numGroup AS INT) AS group", "id AS userId", "rand() AS x", "rand() AS y" + ).cache + val masterDf = sparkSession.range(M).selectExpr( + s"id % $numGroup AS group", "id AS posId", "rand() AS x", "rand() AS y" + ).cache + + // First, cache data + inputDf.count + masterDf.count + + // Define a score column + val distance = sqrt( + pow(inputDf("x") - masterDf("x"), lit(2.0)) + + pow(inputDf("y") - masterDf("y"), lit(2.0)) + ) + runBenchmark("top_k_join", N) { + inputDf.top_k_join(lit(topK), masterDf, inputDf("group") === masterDf("group"), + distance.as("score")) + } + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/52f05f43/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala index a4733f5..eef6e63 100644 --- a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala @@ -23,15 +23,15 @@ import scala.reflect.runtime.universe.TypeTag import hivemall.tools.RegressionDatagen -import org.apache.spark.sql.Column +import org.apache.spark.sql.{Column, QueryTest} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SQLTestUtils /** * Base class for tests with Hivemall features. */ -abstract class HivemallFeatureQueryTest extends QueryTest with TestHiveSingleton { +abstract class HivemallFeatureQueryTest extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext.implicits._
