Repository: incubator-hivemall Updated Branches: refs/heads/master 1801a62c1 -> 70f42038a
Close #31: [HIVEMALL-40][SPARK] Load xgboost-formatted data via Java ServiceLoader Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/70f42038 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/70f42038 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/70f42038 Branch: refs/heads/master Commit: 70f42038a7b7f4c1d358d46ef30300a29dfbcef6 Parents: 1801a62 Author: Takeshi YAMAMURO <[email protected]> Authored: Fri Jan 27 09:03:53 2017 +0900 Committer: Takeshi YAMAMURO <[email protected]> Committed: Fri Jan 27 09:03:53 2017 +0900 ---------------------------------------------------------------------- .../main/java/hivemall/xgboost/package.scala | 32 -------------------- ....apache.spark.sql.sources.DataSourceRegister | 1 + .../apache/spark/sql/hive/XGBoostSuite.scala | 21 +++++++++---- 3 files changed, 16 insertions(+), 38 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/70f42038/spark/spark-2.0/src/main/java/hivemall/xgboost/package.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/java/hivemall/xgboost/package.scala b/spark/spark-2.0/src/main/java/hivemall/xgboost/package.scala deleted file mode 100644 index 2624412..0000000 --- a/spark/spark-2.0/src/main/java/hivemall/xgboost/package.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall - -import org.apache.spark.sql.hive.source.XGBoostFileFormat - -package object xgboost { - - /** - * Model files for libxgboost are loaded as follows; - * - * import HivemallOps._ - * val modelDf = sparkSession.read.format(xgboostFormat).load(modelDir.getCanonicalPath) - */ - val xgboost = classOf[XGBoostFileFormat].getName -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/70f42038/spark/spark-2.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-2.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000..b49e20a --- /dev/null +++ b/spark/spark-2.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.source.XGBoostFileFormat http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/70f42038/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala index 7c78678..8c9c0c3 100644 --- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala +++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala @@ -23,6 +23,7 @@ import java.io.File import hivemall.xgboost._ import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HivemallGroupedDataset._ import org.apache.spark.sql.hive.HivemallOps._ @@ -41,6 +42,14 @@ final class XGBoostSuite extends VectorQueryTest { private def countModels(dirPath: String): Int = { new File(dirPath).listFiles().toSeq.count(_.getName.startsWith("xgbmodel-")) } + + test("resolve libxgboost") { + def getProvidingClass(name: String): Class[_] = + DataSource(sparkSession = null, className = name).providingClass + assert(getProvidingClass("libxgboost") === + classOf[org.apache.spark.sql.hive.source.XGBoostFileFormat]) + } + test("check XGBoost options") { assert(s"$defaultOptions" == "-max_depth 4 -num_round 10") val errMsg = intercept[IllegalArgumentException] { @@ -56,13 +65,13 @@ final class XGBoostSuite extends VectorQueryTest { // Save built models in persistent storage mllibTrainDf.repartition(numModles) .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}")) - .write.format(xgboost).save(tempDir) + .write.format("libxgboost").save(tempDir) // Check #models generated by XGBoost assert(countModels(tempDir) == numModles) // Load the saved models - val model = hiveContext.sparkSession.read.format(xgboost).load(tempDir) + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) val predict = model.join(mllibTestDf) .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model") .groupBy("rowid").avg() @@ -82,12 +91,12 @@ final class XGBoostSuite extends VectorQueryTest { mllibTrainDf.repartition(numModles) .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}")) - .write.format(xgboost).save(tempDir) + .write.format("libxgboost").save(tempDir) // Check #models generated by XGBoost assert(countModels(tempDir) == numModles) - val model = hiveContext.sparkSession.read.format(xgboost).load(tempDir) + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) val predict = model.join(mllibTestDf) .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model") .groupBy("rowid").avg() @@ -110,12 +119,12 @@ final class XGBoostSuite extends VectorQueryTest { mllibTrainDf.repartition(numModles) .train_xgboost_multiclass_classifier( $"features", $"label", lit(s"${defaultOptions.set("num_class", "2")}")) - .write.format(xgboost).save(tempDir) + .write.format("libxgboost").save(tempDir) // Check #models generated by XGBoost assert(countModels(tempDir) == numModles) - val model = hiveContext.sparkSession.read.format(xgboost).load(tempDir) + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) val predict = model.join(mllibTestDf) .xgboost_multiclass_predict($"rowid", $"features", $"model_id", $"pred_model") .groupBy("rowid").max_label("probability", "label")
