Repository: incubator-hivemall Updated Branches: refs/heads/master b342dabfc -> f15379878
Close #113: [HIVEMALL-136][SPARK] Support train_classifier and train_regressor for Spark Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f1537987 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f1537987 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f1537987 Branch: refs/heads/master Commit: f15379878d25f587e06e0053344be024ec9b1049 Parents: b342dab Author: Takeshi Yamamuro <[email protected]> Authored: Wed Sep 13 21:22:49 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Wed Sep 13 21:22:49 2017 +0900 ---------------------------------------------------------------------- .../org/apache/spark/sql/hive/HivemallOps.scala | 30 ++++++++++++++++++++ .../spark/sql/hive/HivemallOpsSuite.scala | 4 +++ 2 files changed, 34 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f1537987/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala index 9350a81..20850a7 100644 --- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -64,6 +64,21 @@ final class HivemallOps(df: DataFrame) extends Logging { private[this] lazy val _strategy = new UserProvidedPlanner(_sparkSession.sqlContext.conf) /** + * @see [[hivemall.regression.GeneralRegressorUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_regressor(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.GeneralRegressorUDTF", + "train_regressor", + toHivemallFeatures(exprs), + Seq("feature", "weight") + ) + } + + /** * @see [[hivemall.regression.AdaDeltaUDTF]] * @group regression */ @@ -229,6 +244,21 @@ final class HivemallOps(df: DataFrame) extends Logging { } /** + * @see [[hivemall.classifier.GeneralClassifierUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_classifier(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.GeneralClassifierUDTF", + "train_classifier", + toHivemallFeatures(exprs), + Seq("feature", "weight") + ) + } + + /** * @see [[hivemall.classifier.PerceptronUDTF]] * @group classifier */ http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f1537987/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala index f634f9b..835438d 100644 --- a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@ -566,6 +566,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { test("invoke regression functions") { import hiveContext.implicits._ Seq( + "train_regressor", "train_adadelta", "train_adagrad", "train_arow_regr", @@ -585,6 +586,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { test("invoke classifier functions") { import hiveContext.implicits._ Seq( + "train_classifier", "train_perceptron", "train_pa", "train_pa1", @@ -729,6 +731,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { ignore("check regression precision") { Seq( + "train_regressor", "train_adadelta", "train_adagrad", "train_arow_regr", @@ -746,6 +749,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { ignore("check classifier precision") { Seq( + "train_classifier", "train_perceptron", "train_pa", "train_pa1",
