Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 1801a62c1 -> 70f42038a


Close #31: [HIVEMALL-40][SPARK] Load xgboost-formatted data via Java 
ServiceLoader


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/70f42038
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/70f42038
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/70f42038

Branch: refs/heads/master
Commit: 70f42038a7b7f4c1d358d46ef30300a29dfbcef6
Parents: 1801a62
Author: Takeshi YAMAMURO <[email protected]>
Authored: Fri Jan 27 09:03:53 2017 +0900
Committer: Takeshi YAMAMURO <[email protected]>
Committed: Fri Jan 27 09:03:53 2017 +0900

----------------------------------------------------------------------
 .../main/java/hivemall/xgboost/package.scala    | 32 --------------------
 ....apache.spark.sql.sources.DataSourceRegister |  1 +
 .../apache/spark/sql/hive/XGBoostSuite.scala    | 21 +++++++++----
 3 files changed, 16 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/70f42038/spark/spark-2.0/src/main/java/hivemall/xgboost/package.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/main/java/hivemall/xgboost/package.scala 
b/spark/spark-2.0/src/main/java/hivemall/xgboost/package.scala
deleted file mode 100644
index 2624412..0000000
--- a/spark/spark-2.0/src/main/java/hivemall/xgboost/package.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall
-
-import org.apache.spark.sql.hive.source.XGBoostFileFormat
-
-package object xgboost {
-
-  /**
-   * Model files for libxgboost are loaded as follows;
-   *
-   * import HivemallOps._
-   * val modelDf = 
sparkSession.read.format(xgboostFormat).load(modelDir.getCanonicalPath)
-   */
-  val xgboost = classOf[XGBoostFileFormat].getName
-}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/70f42038/spark/spark-2.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
 
b/spark/spark-2.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
new file mode 100644
index 0000000..b49e20a
--- /dev/null
+++ 
b/spark/spark-2.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -0,0 +1 @@
+org.apache.spark.sql.hive.source.XGBoostFileFormat

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/70f42038/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala 
b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala
index 7c78678..8c9c0c3 100644
--- 
a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala
+++ 
b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala
@@ -23,6 +23,7 @@ import java.io.File
 import hivemall.xgboost._
 
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.hive.HivemallGroupedDataset._
 import org.apache.spark.sql.hive.HivemallOps._
@@ -41,6 +42,14 @@ final class XGBoostSuite extends VectorQueryTest {
   private def countModels(dirPath: String): Int = {
     new 
File(dirPath).listFiles().toSeq.count(_.getName.startsWith("xgbmodel-"))
   }
+
+  test("resolve libxgboost") {
+    def getProvidingClass(name: String): Class[_] =
+      DataSource(sparkSession = null, className = name).providingClass
+    assert(getProvidingClass("libxgboost") ===
+      classOf[org.apache.spark.sql.hive.source.XGBoostFileFormat])
+  }
+
   test("check XGBoost options") {
     assert(s"$defaultOptions" == "-max_depth 4 -num_round 10")
     val errMsg = intercept[IllegalArgumentException] {
@@ -56,13 +65,13 @@ final class XGBoostSuite extends VectorQueryTest {
       // Save built models in persistent storage
       mllibTrainDf.repartition(numModles)
         .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}"))
-        .write.format(xgboost).save(tempDir)
+        .write.format("libxgboost").save(tempDir)
 
       // Check #models generated by XGBoost
       assert(countModels(tempDir) == numModles)
 
       // Load the saved models
-      val model = hiveContext.sparkSession.read.format(xgboost).load(tempDir)
+      val model = 
hiveContext.sparkSession.read.format("libxgboost").load(tempDir)
       val predict = model.join(mllibTestDf)
         .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model")
         .groupBy("rowid").avg()
@@ -82,12 +91,12 @@ final class XGBoostSuite extends VectorQueryTest {
 
       mllibTrainDf.repartition(numModles)
         .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}"))
-        .write.format(xgboost).save(tempDir)
+        .write.format("libxgboost").save(tempDir)
 
       // Check #models generated by XGBoost
       assert(countModels(tempDir) == numModles)
 
-      val model = hiveContext.sparkSession.read.format(xgboost).load(tempDir)
+      val model = 
hiveContext.sparkSession.read.format("libxgboost").load(tempDir)
       val predict = model.join(mllibTestDf)
         .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model")
         .groupBy("rowid").avg()
@@ -110,12 +119,12 @@ final class XGBoostSuite extends VectorQueryTest {
       mllibTrainDf.repartition(numModles)
         .train_xgboost_multiclass_classifier(
           $"features", $"label", lit(s"${defaultOptions.set("num_class", 
"2")}"))
-        .write.format(xgboost).save(tempDir)
+        .write.format("libxgboost").save(tempDir)
 
       // Check #models generated by XGBoost
       assert(countModels(tempDir) == numModles)
 
-      val model = hiveContext.sparkSession.read.format(xgboost).load(tempDir)
+      val model = 
hiveContext.sparkSession.read.format("libxgboost").load(tempDir)
       val predict = model.join(mllibTestDf)
         .xgboost_multiclass_predict($"rowid", $"features", $"model_id", 
$"pred_model")
         .groupBy("rowid").max_label("probability", "label")

Reply via email to