Repository: incubator-hivemall Updated Branches: refs/heads/master f6c5a5f2c -> 1801a62c1 (forced update)
Close #27: [HIVEMALL-36] Refactor each_top_k Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/1801a62c Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/1801a62c Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/1801a62c Branch: refs/heads/master Commit: 1801a62c15f4466c331456f6e2d7102b715d4a63 Parents: e86c8a0 Author: Takeshi YAMAMURO <[email protected]> Authored: Thu Jan 26 19:21:29 2017 +0900 Committer: Takeshi YAMAMURO <[email protected]> Committed: Thu Jan 26 19:24:52 2017 +0900 ---------------------------------------------------------------------- .../org/apache/spark/sql/hive/HivemallOps.scala | 45 ++++++++------------ .../spark/sql/hive/HivemallOpsSuite.scala | 32 ++++---------- 2 files changed, 26 insertions(+), 51 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1801a62c/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala index 8fa4831..e3e20ee 100644 --- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala +++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{EachTopK, Expression, Literal, NamedExpression, UserDefinedGenerator} import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.types._ @@ -56,6 +55,9 @@ import org.apache.spark.unsafe.types.UTF8String */ final class HivemallOps(df: DataFrame) extends Logging { + private[this] val _sparkSession = df.sparkSession + private[this] val _analyzer = _sparkSession.sessionState.analyzer + /** * @see hivemall.regression.AdaDeltaUDTF * @group regression @@ -788,37 +790,24 @@ final class HivemallOps(df: DataFrame) extends Logging { /** * Returns `top-k` records for each `group`. * @group misc - * @since 0.5.0 - */ - def each_top_k(k: Int, group: String, score: String, args: String*) - : DataFrame = withTypedPlan { - val clusterDf = df.repartition(df(group)).sortWithinPartitions(group) - val childrenAttributes = clusterDf.logicalPlan.output - val generator = Generate( - EachTopK( - k, - clusterDf.resolve(group), - clusterDf.resolve(score), - childrenAttributes - ), - join = false, outer = false, None, - (Seq("rank") ++ childrenAttributes.map(_.name)).map(UnresolvedAttribute(_)), - clusterDf.logicalPlan) - val attributes = generator.generatedSet - val projectList = (Seq("rank") ++ args).map(s => attributes.find(_.name == s).get) - Project(projectList, generator) - } - - @deprecated("use each_top_k(Int, String, String, String*) instead", "0.5.0") - def each_top_k(k: Column, group: Column, value: Column, args: Column*): DataFrame = { + */ + def each_top_k(k: Column, group: Column, score: Column): DataFrame = withTypedPlan { val kInt = k.expr match { case Literal(v: Any, IntegerType) => v.asInstanceOf[Int] case e => throw new AnalysisException("`k` must be integer, however " + e) } - val groupStr = usePrettyExpression(group.expr).sql - val valueStr = usePrettyExpression(value.expr).sql - val argStrs = args.map(c => usePrettyExpression(c.expr).sql) - each_top_k(kInt, groupStr, valueStr, argStrs: _*) + val clusterDf = df.repartition(group).sortWithinPartitions(group) + val child = clusterDf.logicalPlan + val logicalPlan = Project(group.named +: score.named +: child.output, child) + _analyzer.execute(logicalPlan) match { + case Project(group :: score :: origCols, c) => + Generate( + EachTopK(kInt, group, score, c.output), + join = false, outer = false, None, + (Seq("rank") ++ origCols.map(_.name)).map(UnresolvedAttribute(_)), + clusterDf.logicalPlan + ) + } } /** http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1801a62c/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala index 61af8d1..49773cc 100644 --- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala +++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@ -305,40 +305,26 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { // Compute top-1 rows for each group checkAnswer( - testDf.each_top_k(1, "key", "score", "key", "value"), - Row(1, "a", "3") :: - Row(1, "b", "4") :: - Row(1, "c", "6") :: - Nil - ) - checkAnswer( - testDf.each_top_k(lit(1), $"key", $"score", $"key", $"value"), - Row(1, "a", "3") :: - Row(1, "b", "4") :: - Row(1, "c", "6") :: + testDf.each_top_k(lit(1), $"key", $"score"), + Row(1, "a", "3", 0.8, Array(2, 5)) :: + Row(1, "b", "4", 0.3, Array(2)) :: + Row(1, "c", "6", 0.3, Array(1, 3)) :: Nil ) // Compute reverse top-1 rows for each group checkAnswer( - testDf.each_top_k(-1, "key", "score", "key", "value"), - Row(1, "a", "1") :: - Row(1, "b", "5") :: - Row(1, "c", "6") :: - Nil - ) - checkAnswer( - testDf.each_top_k(lit(-1), $"key", $"score", $"key", $"value"), - Row(1, "a", "1") :: - Row(1, "b", "5") :: - Row(1, "c", "6") :: + testDf.each_top_k(lit(-1), $"key", $"score"), + Row(1, "a", "1", 0.5, Array(0, 1, 2)) :: + Row(1, "b", "5", 0.1, Array(3)) :: + Row(1, "c", "6", 0.3, Array(1, 3)) :: Nil ) // Check if some exceptions thrown in case of some conditions assert(intercept[AnalysisException] { testDf.each_top_k(lit(0.1), $"key", $"score") } .getMessage contains "`k` must be integer, however") - assert(intercept[AnalysisException] { testDf.each_top_k(1, "key", "data") } + assert(intercept[AnalysisException] { testDf.each_top_k(lit(1), $"key", $"data") } .getMessage contains "must have a comparable type") }
