Repository: incubator-hivemall Updated Branches: refs/heads/master 18ce75f78 -> b0d1ad029
Close #54: [HIVEMALL-76][SPARK] Fix worng ranks in top-K funcs Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/b0d1ad02 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/b0d1ad02 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/b0d1ad02 Branch: refs/heads/master Commit: b0d1ad0295584463ef94c58a62fdc88f7f373a51 Parents: 18ce75f Author: Takeshi Yamamuro <[email protected]> Authored: Thu Mar 2 23:48:04 2017 +0900 Committer: myui <[email protected]> Committed: Thu Mar 2 23:48:04 2017 +0900 ---------------------------------------------------------------------- .../sql/catalyst/expressions/EachTopK.scala | 27 +++++---- .../spark/sql/hive/HivemallOpsSuite.scala | 21 +++++++ .../sql/catalyst/expressions/EachTopK.scala | 24 +++++--- .../joins/ShuffledHashJoinTopKExec.scala | 22 +++++--- .../spark/sql/hive/HivemallOpsSuite.scala | 58 +++++++++++++++++++- 5 files changed, 122 insertions(+), 30 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b0d1ad02/spark/spark-2.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala index 491363d..f1312ba 100644 --- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala +++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala @@ -77,16 +77,26 @@ case class EachTopK( children.map(d => StructField(d.prettyName, d.dataType)) ) + private def topKRowsForGroup(): Seq[InternalRow] = if (queue.size > 0) { + val outputRows = queue.iterator.toSeq.reverse + val (headScore, _) = outputRows.head + val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) => + if (prevScore == score) (rank, score) else (rank + 1, score) + } + outputRows.zip(rankNum.map(_._1)).map { case ((_, row), index) => + new JoinedRow(InternalRow(index), row) + } + } else { + Seq.empty + } + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { val groupingKey = groupingProjection(input) val ret = if (currentGroupingKey != groupingKey) { - val part = queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering) - .zipWithIndex.map { case ((_, row), index) => - new JoinedRow(InternalRow(1 + index), row) - } + val topKRows = topKRowsForGroup() currentGroupingKey = groupingKey.copy() queue.clear() - part + topKRows } else { Iterator.empty } @@ -96,12 +106,9 @@ case class EachTopK( override def terminate(): TraversableOnce[InternalRow] = { if (queue.size > 0) { - val part = queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering) - .zipWithIndex.map { case ((_, row), index) => - new JoinedRow(InternalRow(1 + index), row) - } + val topKRows = topKRowsForGroup() queue.clear() - part + topKRows } else { Iterator.empty } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b0d1ad02/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala index 1430f09..e9ccac8 100644 --- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala +++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@ -336,6 +336,27 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { .getMessage contains "must have a comparable type") } + test("HIVEMALL-76 top-K funcs must assign the same rank with the rows having the same scores") { + import hiveContext.implicits._ + val testDf = Seq( + ("a", "1", 0.1), + ("b", "5", 0.1), + ("a", "3", 0.1), + ("b", "4", 0.1), + ("a", "2", 0.0) + ).toDF("key", "value", "score") + + // Compute top-1 rows for each group + checkAnswer( + testDf.each_top_k(lit(2), $"key", $"score"), + Row(1, "a", "3", 0.1) :: + Row(1, "a", "1", 0.1) :: + Row(1, "b", "4", 0.1) :: + Row(1, "b", "5", 0.1) :: + Nil + ) + } + /** * This test fails because; * http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b0d1ad02/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala index 7acb107..6e53e66 100644 --- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala @@ -83,18 +83,24 @@ case class EachTopK( } } - private def topKRowsForGroup(): Seq[InternalRow] = { + private def topKRowsForGroup(): Seq[InternalRow] = if (queue.size > 0) { + val outputRows = queue.iterator.toSeq.reverse + val (headScore, _) = outputRows.head + val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) => + if (prevScore == score) (rank, score) else (rank + 1, score) + } val topKRow = new UnsafeRow(1) val bufferHolder = new BufferHolder(topKRow) val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) - queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering) - .zipWithIndex.map { case ((_, row), index) => - // Writes to an UnsafeRow directly - bufferHolder.reset() - unsafeRowWriter.write(0, 1 + index) - topKRow.setTotalSize(bufferHolder.totalSize()) - new JoinedRow(topKRow, row) - } + outputRows.zip(rankNum.map(_._1)).map { case ((_, row), index) => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, index) + topKRow.setTotalSize(bufferHolder.totalSize()) + new JoinedRow(topKRow, row) + } + } else { + Seq.empty } override def eval(input: InternalRow): TraversableOnce[InternalRow] = { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b0d1ad02/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala index caad646..a799a07 100644 --- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala @@ -76,19 +76,23 @@ case class ShuffledHashJoinTopKExec( } override def get(): Iterator[InternalRow] = { + val outputRows = queue.iterator.toSeq.reverse + val (headScore, _) = outputRows.head + val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) => + if (prevScore == score) (rank, score) else (rank + 1, score) + } val topKRow = new UnsafeRow(2) val bufferHolder = new BufferHolder(topKRow) val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2) val scoreWriter = ScoreWriter(unsafeRowWriter, 1) - q.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering).zipWithIndex.map { - case ((score, row), index) => - // Writes to an UnsafeRow directly - bufferHolder.reset() - unsafeRowWriter.write(0, 1 + index) - scoreWriter.write(score) - topKRow.setTotalSize(bufferHolder.totalSize()) - joinedRow.apply(topKRow, row) - }.iterator + outputRows.zip(rankNum.map(_._1)).map { case ((score, row), index) => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, index) + scoreWriter.write(score) + topKRow.setTotalSize(bufferHolder.totalSize()) + joinedRow.apply(topKRow, row) + }.iterator } override def clear(): Unit = q.clear() http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b0d1ad02/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala index 15879e0..ed56bc3 100644 --- a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@ -315,7 +315,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { // Compute top-1 rows for each group val distance = sqrt(inputDf("x") * inputDf("x") + inputDf("y") * inputDf("y")).as("score") - val top1Df = inputDf.each_top_k(lit(1), distance, $"key") + val top1Df = inputDf.each_top_k(lit(1), distance, $"key".as("group")) assert(top1Df.schema.toSet === Set( StructField("rank", IntegerType, nullable = true), StructField("score", DoubleType, nullable = true), @@ -334,7 +334,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { ) // Compute reverse top-1 rows for each group - val bottom1Df = inputDf.each_top_k(lit(-1), distance, $"key") + val bottom1Df = inputDf.each_top_k(lit(-1), distance, $"key".as("group")) checkAnswer( bottom1Df.select($"rank", $"key", $"value", $"data"), Row(1, "a", "1", Array(0, 1, 2)) :: @@ -407,6 +407,60 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { } } + test("HIVEMALL-76 top-K funcs must assign the same rank with the rows having the same scores") { + import hiveContext.implicits._ + val inputDf = Seq( + ("a", "1", 0.1), + ("b", "5", 0.1), + ("a", "3", 0.1), + ("b", "4", 0.1), + ("a", "2", 0.0) + ).toDF("key", "value", "x") + + // Compute top-2 rows for each group + val top2Df = inputDf.each_top_k(lit(2), $"x".as("score"), $"key".as("group")) + checkAnswer( + top2Df.select($"rank", $"score", $"key", $"value"), + Row(1, 0.1, "a", "3") :: + Row(1, 0.1, "a", "1") :: + Row(1, 0.1, "b", "4") :: + Row(1, 0.1, "b", "5") :: + Nil + ) + Seq("true", "false").map { flag => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) { + val inputDf = Seq( + ("user1", 1, 0.3, 0.5), + ("user2", 2, 0.1, 0.1) + ).toDF("userId", "group", "x", "y") + + val masterDf = Seq( + (1, "pos1-1", 0.5, 0.1), + (1, "pos1-2", 0.5, 0.1), + (1, "pos1-3", 0.3, 0.4), + (2, "pos2-1", 0.8, 0.2), + (2, "pos2-2", 0.8, 0.2) + ).toDF("group", "position", "x", "y") + + // Compute top-2 rows for each group + val distance = sqrt( + pow(inputDf("x") - masterDf("x"), lit(2.0)) + + pow(inputDf("y") - masterDf("y"), lit(2.0)) + ).as("score") + val top2Df = inputDf.top_k_join( + lit(2), masterDf, inputDf("group") === masterDf("group"), distance) + checkAnswer( + top2Df.select($"rank", inputDf("group"), $"userId", $"position"), + Row(1, 1, "user1", "pos1-1") :: + Row(1, 1, "user1", "pos1-2") :: + Row(1, 2, "user2", "pos2-1") :: + Row(1, 2, "user2", "pos2-2") :: + Nil + ) + } + } + } + /** * This test fails because; *
