Repository: incubator-hivemall
Updated Branches:
  refs/heads/master f6c5a5f2c -> 1801a62c1 (forced update)


Close #27: [HIVEMALL-36] Refactor each_top_k


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/1801a62c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/1801a62c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/1801a62c

Branch: refs/heads/master
Commit: 1801a62c15f4466c331456f6e2d7102b715d4a63
Parents: e86c8a0
Author: Takeshi YAMAMURO <[email protected]>
Authored: Thu Jan 26 19:21:29 2017 +0900
Committer: Takeshi YAMAMURO <[email protected]>
Committed: Thu Jan 26 19:24:52 2017 +0900

----------------------------------------------------------------------
 .../org/apache/spark/sql/hive/HivemallOps.scala | 45 ++++++++------------
 .../spark/sql/hive/HivemallOpsSuite.scala       | 32 ++++----------
 2 files changed, 26 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1801a62c/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala 
b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index 8fa4831..e3e20ee 100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -30,7 +30,6 @@ import 
org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.{EachTopK, Expression, 
Literal, NamedExpression, UserDefinedGenerator}
 import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, 
Project}
-import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
 import org.apache.spark.sql.types._
@@ -56,6 +55,9 @@ import org.apache.spark.unsafe.types.UTF8String
  */
 final class HivemallOps(df: DataFrame) extends Logging {
 
+  private[this] val _sparkSession = df.sparkSession
+  private[this] val _analyzer = _sparkSession.sessionState.analyzer
+
   /**
    * @see hivemall.regression.AdaDeltaUDTF
    * @group regression
@@ -788,37 +790,24 @@ final class HivemallOps(df: DataFrame) extends Logging {
   /**
    * Returns `top-k` records for each `group`.
    * @group misc
-   * @since 0.5.0
-   */
-  def each_top_k(k: Int, group: String, score: String, args: String*)
-    : DataFrame = withTypedPlan {
-    val clusterDf = df.repartition(df(group)).sortWithinPartitions(group)
-    val childrenAttributes = clusterDf.logicalPlan.output
-    val generator = Generate(
-      EachTopK(
-        k,
-        clusterDf.resolve(group),
-        clusterDf.resolve(score),
-        childrenAttributes
-      ),
-      join = false, outer = false, None,
-      (Seq("rank") ++ 
childrenAttributes.map(_.name)).map(UnresolvedAttribute(_)),
-      clusterDf.logicalPlan)
-    val attributes = generator.generatedSet
-    val projectList = (Seq("rank") ++ args).map(s => attributes.find(_.name == 
s).get)
-    Project(projectList, generator)
-  }
-
-  @deprecated("use each_top_k(Int, String, String, String*) instead", "0.5.0")
-  def each_top_k(k: Column, group: Column, value: Column, args: Column*): 
DataFrame = {
+   */
+  def each_top_k(k: Column, group: Column, score: Column): DataFrame = 
withTypedPlan {
     val kInt = k.expr match {
       case Literal(v: Any, IntegerType) => v.asInstanceOf[Int]
       case e => throw new AnalysisException("`k` must be integer, however " + 
e)
     }
-    val groupStr = usePrettyExpression(group.expr).sql
-    val valueStr = usePrettyExpression(value.expr).sql
-    val argStrs = args.map(c => usePrettyExpression(c.expr).sql)
-    each_top_k(kInt, groupStr, valueStr, argStrs: _*)
+    val clusterDf = df.repartition(group).sortWithinPartitions(group)
+    val child = clusterDf.logicalPlan
+    val logicalPlan = Project(group.named +: score.named +: child.output, 
child)
+    _analyzer.execute(logicalPlan) match {
+      case Project(group :: score :: origCols, c) =>
+        Generate(
+          EachTopK(kInt, group, score, c.output),
+          join = false, outer = false, None,
+          (Seq("rank") ++ origCols.map(_.name)).map(UnresolvedAttribute(_)),
+          clusterDf.logicalPlan
+        )
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1801a62c/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
 
b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 61af8d1..49773cc 100644
--- 
a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ 
b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -305,40 +305,26 @@ final class HivemallOpsWithFeatureSuite extends 
HivemallFeatureQueryTest {
 
     // Compute top-1 rows for each group
     checkAnswer(
-      testDf.each_top_k(1, "key", "score", "key", "value"),
-      Row(1, "a", "3") ::
-      Row(1, "b", "4") ::
-      Row(1, "c", "6") ::
-      Nil
-    )
-    checkAnswer(
-      testDf.each_top_k(lit(1), $"key", $"score", $"key", $"value"),
-      Row(1, "a", "3") ::
-      Row(1, "b", "4") ::
-      Row(1, "c", "6") ::
+      testDf.each_top_k(lit(1), $"key", $"score"),
+      Row(1, "a", "3", 0.8, Array(2, 5)) ::
+      Row(1, "b", "4", 0.3, Array(2)) ::
+      Row(1, "c", "6", 0.3, Array(1, 3)) ::
       Nil
     )
 
     // Compute reverse top-1 rows for each group
     checkAnswer(
-      testDf.each_top_k(-1, "key", "score", "key", "value"),
-      Row(1, "a", "1") ::
-      Row(1, "b", "5") ::
-      Row(1, "c", "6") ::
-      Nil
-    )
-    checkAnswer(
-      testDf.each_top_k(lit(-1), $"key", $"score", $"key", $"value"),
-      Row(1, "a", "1") ::
-      Row(1, "b", "5") ::
-      Row(1, "c", "6") ::
+      testDf.each_top_k(lit(-1), $"key", $"score"),
+      Row(1, "a", "1", 0.5, Array(0, 1, 2)) ::
+      Row(1, "b", "5", 0.1, Array(3)) ::
+      Row(1, "c", "6", 0.3, Array(1, 3)) ::
       Nil
     )
 
     // Check if some exceptions thrown in case of some conditions
     assert(intercept[AnalysisException] { testDf.each_top_k(lit(0.1), $"key", 
$"score") }
       .getMessage contains "`k` must be integer, however")
-    assert(intercept[AnalysisException] { testDf.each_top_k(1, "key", "data") }
+    assert(intercept[AnalysisException] { testDf.each_top_k(lit(1), $"key", 
$"data") }
       .getMessage contains "must have a comparable type")
   }
 

Reply via email to