Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 18ce75f78 -> b0d1ad029


Close #54: [HIVEMALL-76][SPARK] Fix worng ranks in top-K funcs


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/b0d1ad02
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/b0d1ad02
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/b0d1ad02

Branch: refs/heads/master
Commit: b0d1ad0295584463ef94c58a62fdc88f7f373a51
Parents: 18ce75f
Author: Takeshi Yamamuro <[email protected]>
Authored: Thu Mar 2 23:48:04 2017 +0900
Committer: myui <[email protected]>
Committed: Thu Mar 2 23:48:04 2017 +0900

----------------------------------------------------------------------
 .../sql/catalyst/expressions/EachTopK.scala     | 27 +++++----
 .../spark/sql/hive/HivemallOpsSuite.scala       | 21 +++++++
 .../sql/catalyst/expressions/EachTopK.scala     | 24 +++++---
 .../joins/ShuffledHashJoinTopKExec.scala        | 22 +++++---
 .../spark/sql/hive/HivemallOpsSuite.scala       | 58 +++++++++++++++++++-
 5 files changed, 122 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b0d1ad02/spark/spark-2.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
 
b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
index 491363d..f1312ba 100644
--- 
a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
+++ 
b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
@@ -77,16 +77,26 @@ case class EachTopK(
         children.map(d => StructField(d.prettyName, d.dataType))
     )
 
+  private def topKRowsForGroup(): Seq[InternalRow] = if (queue.size > 0) {
+    val outputRows = queue.iterator.toSeq.reverse
+    val (headScore, _) = outputRows.head
+    val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, 
prevScore), (score, _)) =>
+      if (prevScore == score) (rank, score) else (rank + 1, score)
+    }
+    outputRows.zip(rankNum.map(_._1)).map { case ((_, row), index) =>
+      new JoinedRow(InternalRow(index), row)
+    }
+  } else {
+    Seq.empty
+  }
+
   override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
     val groupingKey = groupingProjection(input)
     val ret = if (currentGroupingKey != groupingKey) {
-      val part = queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering)
-          .zipWithIndex.map { case ((_, row), index) =>
-        new JoinedRow(InternalRow(1 + index), row)
-      }
+      val topKRows = topKRowsForGroup()
       currentGroupingKey = groupingKey.copy()
       queue.clear()
-      part
+      topKRows
     } else {
       Iterator.empty
     }
@@ -96,12 +106,9 @@ case class EachTopK(
 
   override def terminate(): TraversableOnce[InternalRow] = {
     if (queue.size > 0) {
-      val part = queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering)
-          .zipWithIndex.map { case ((_, row), index) =>
-        new JoinedRow(InternalRow(1 + index), row)
-      }
+      val topKRows = topKRowsForGroup()
       queue.clear()
-      part
+      topKRows
     } else {
       Iterator.empty
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b0d1ad02/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
 
b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 1430f09..e9ccac8 100644
--- 
a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ 
b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -336,6 +336,27 @@ final class HivemallOpsWithFeatureSuite extends 
HivemallFeatureQueryTest {
       .getMessage contains "must have a comparable type")
   }
 
+  test("HIVEMALL-76 top-K funcs must assign the same rank with the rows having 
the same scores") {
+    import hiveContext.implicits._
+    val testDf = Seq(
+      ("a", "1", 0.1),
+      ("b", "5", 0.1),
+      ("a", "3", 0.1),
+      ("b", "4", 0.1),
+      ("a", "2", 0.0)
+    ).toDF("key", "value", "score")
+
+    // Compute top-1 rows for each group
+    checkAnswer(
+      testDf.each_top_k(lit(2), $"key", $"score"),
+      Row(1, "a", "3", 0.1) ::
+      Row(1, "a", "1", 0.1) ::
+      Row(1, "b", "4", 0.1) ::
+      Row(1, "b", "5", 0.1) ::
+      Nil
+    )
+  }
+
   /**
    * This test fails because;
    *

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b0d1ad02/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
 
b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
index 7acb107..6e53e66 100644
--- 
a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
+++ 
b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
@@ -83,18 +83,24 @@ case class EachTopK(
     }
   }
 
-  private def topKRowsForGroup(): Seq[InternalRow] = {
+  private def topKRowsForGroup(): Seq[InternalRow] = if (queue.size > 0) {
+    val outputRows = queue.iterator.toSeq.reverse
+    val (headScore, _) = outputRows.head
+    val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, 
prevScore), (score, _)) =>
+      if (prevScore == score) (rank, score) else (rank + 1, score)
+    }
     val topKRow = new UnsafeRow(1)
     val bufferHolder = new BufferHolder(topKRow)
     val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1)
-    queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering)
-      .zipWithIndex.map { case ((_, row), index) =>
-        // Writes to an UnsafeRow directly
-        bufferHolder.reset()
-        unsafeRowWriter.write(0, 1 + index)
-        topKRow.setTotalSize(bufferHolder.totalSize())
-        new JoinedRow(topKRow, row)
-      }
+    outputRows.zip(rankNum.map(_._1)).map { case ((_, row), index) =>
+      // Writes to an UnsafeRow directly
+      bufferHolder.reset()
+      unsafeRowWriter.write(0, index)
+      topKRow.setTotalSize(bufferHolder.totalSize())
+      new JoinedRow(topKRow, row)
+    }
+  } else {
+    Seq.empty
   }
 
   override def eval(input: InternalRow): TraversableOnce[InternalRow] = {

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b0d1ad02/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
 
b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
index caad646..a799a07 100644
--- 
a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
+++ 
b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
@@ -76,19 +76,23 @@ case class ShuffledHashJoinTopKExec(
     }
 
     override def get(): Iterator[InternalRow] = {
+      val outputRows = queue.iterator.toSeq.reverse
+      val (headScore, _) = outputRows.head
+      val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, 
prevScore), (score, _)) =>
+        if (prevScore == score) (rank, score) else (rank + 1, score)
+      }
       val topKRow = new UnsafeRow(2)
       val bufferHolder = new BufferHolder(topKRow)
       val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2)
       val scoreWriter = ScoreWriter(unsafeRowWriter, 1)
-      q.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering).zipWithIndex.map {
-        case ((score, row), index) =>
-          // Writes to an UnsafeRow directly
-          bufferHolder.reset()
-          unsafeRowWriter.write(0, 1 + index)
-          scoreWriter.write(score)
-          topKRow.setTotalSize(bufferHolder.totalSize())
-          joinedRow.apply(topKRow, row)
-        }.iterator
+      outputRows.zip(rankNum.map(_._1)).map { case ((score, row), index) =>
+        // Writes to an UnsafeRow directly
+        bufferHolder.reset()
+        unsafeRowWriter.write(0, index)
+        scoreWriter.write(score)
+        topKRow.setTotalSize(bufferHolder.totalSize())
+        joinedRow.apply(topKRow, row)
+      }.iterator
     }
 
     override def clear(): Unit = q.clear()

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b0d1ad02/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 15879e0..ed56bc3 100644
--- 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -315,7 +315,7 @@ final class HivemallOpsWithFeatureSuite extends 
HivemallFeatureQueryTest {
 
     // Compute top-1 rows for each group
     val distance = sqrt(inputDf("x") * inputDf("x") + inputDf("y") * 
inputDf("y")).as("score")
-    val top1Df = inputDf.each_top_k(lit(1), distance, $"key")
+    val top1Df = inputDf.each_top_k(lit(1), distance, $"key".as("group"))
     assert(top1Df.schema.toSet === Set(
       StructField("rank", IntegerType, nullable = true),
       StructField("score", DoubleType, nullable = true),
@@ -334,7 +334,7 @@ final class HivemallOpsWithFeatureSuite extends 
HivemallFeatureQueryTest {
     )
 
     // Compute reverse top-1 rows for each group
-    val bottom1Df = inputDf.each_top_k(lit(-1), distance, $"key")
+    val bottom1Df = inputDf.each_top_k(lit(-1), distance, $"key".as("group"))
     checkAnswer(
       bottom1Df.select($"rank", $"key", $"value", $"data"),
       Row(1, "a", "1", Array(0, 1, 2)) ::
@@ -407,6 +407,60 @@ final class HivemallOpsWithFeatureSuite extends 
HivemallFeatureQueryTest {
     }
   }
 
+  test("HIVEMALL-76 top-K funcs must assign the same rank with the rows having 
the same scores") {
+    import hiveContext.implicits._
+    val inputDf = Seq(
+      ("a", "1", 0.1),
+      ("b", "5", 0.1),
+      ("a", "3", 0.1),
+      ("b", "4", 0.1),
+      ("a", "2", 0.0)
+    ).toDF("key", "value", "x")
+
+    // Compute top-2 rows for each group
+    val top2Df = inputDf.each_top_k(lit(2), $"x".as("score"), 
$"key".as("group"))
+    checkAnswer(
+      top2Df.select($"rank", $"score", $"key", $"value"),
+      Row(1, 0.1, "a", "3") ::
+      Row(1, 0.1, "a", "1") ::
+      Row(1, 0.1, "b", "4") ::
+      Row(1, 0.1, "b", "5") ::
+      Nil
+    )
+    Seq("true", "false").map { flag =>
+      withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) {
+        val inputDf = Seq(
+          ("user1", 1, 0.3, 0.5),
+          ("user2", 2, 0.1, 0.1)
+        ).toDF("userId", "group", "x", "y")
+
+        val masterDf = Seq(
+          (1, "pos1-1", 0.5, 0.1),
+          (1, "pos1-2", 0.5, 0.1),
+          (1, "pos1-3", 0.3, 0.4),
+          (2, "pos2-1", 0.8, 0.2),
+          (2, "pos2-2", 0.8, 0.2)
+        ).toDF("group", "position", "x", "y")
+
+        // Compute top-2 rows for each group
+        val distance = sqrt(
+          pow(inputDf("x") - masterDf("x"), lit(2.0)) +
+            pow(inputDf("y") - masterDf("y"), lit(2.0))
+        ).as("score")
+        val top2Df = inputDf.top_k_join(
+          lit(2), masterDf, inputDf("group") === masterDf("group"), distance)
+        checkAnswer(
+          top2Df.select($"rank", inputDf("group"), $"userId", $"position"),
+          Row(1, 1, "user1", "pos1-1") ::
+          Row(1, 1, "user1", "pos1-2") ::
+          Row(1, 2, "user2", "pos2-1") ::
+          Row(1, 2, "user2", "pos2-2") ::
+          Nil
+        )
+      }
+    }
+  }
+
   /**
    * This test fails because;
    *

Reply via email to