http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala new file mode 100644 index 0000000..81390e0 --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive + +import org.apache.spark.sql.Row +import org.apache.spark.sql.hive.HivemallUtils._ +import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest +import org.apache.spark.sql.test.VectorQueryTest + +final class HiveUdfWithFeatureSuite extends HivemallFeatureQueryTest { + import hiveContext.implicits._ + import hiveContext._ + + test("hivemall_version") { + sql(s""" + | CREATE TEMPORARY FUNCTION hivemall_version + | AS '${classOf[hivemall.HivemallVersionUDF].getName}' + """.stripMargin) + + checkAnswer( + sql(s"SELECT DISTINCT hivemall_version()"), + Row("0.4.2-rc.2") + ) + + // sql("DROP TEMPORARY FUNCTION IF EXISTS hivemall_version") + // reset() + } + + test("train_logregr") { + TinyTrainData.createOrReplaceTempView("TinyTrainData") + sql(s""" + | CREATE TEMPORARY FUNCTION train_logregr + | AS '${classOf[hivemall.regression.LogressUDTF].getName}' + """.stripMargin) + sql(s""" + | CREATE TEMPORARY FUNCTION add_bias + | AS '${classOf[hivemall.ftvec.AddBiasUDFWrapper].getName}' + """.stripMargin) + + val model = sql( + s""" + | SELECT feature, AVG(weight) AS weight + | FROM ( + | SELECT train_logregr(add_bias(features), label) AS (feature, weight) + | FROM TinyTrainData + | ) t + | GROUP BY feature + """.stripMargin) + + checkAnswer( + model.select($"feature"), + Seq(Row("0"), Row("1"), Row("2")) + ) + + // TODO: Why 'train_logregr' is not registered in HiveMetaStore? + // ERROR RetryingHMSHandler: MetaException(message:NoSuchObjectException + // (message:Function default.train_logregr does not exist)) + // + // hiveContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_logregr") + // hiveContext.reset() + } + + test("each_top_k") { + val testDf = Seq( + ("a", "1", 0.5, Array(0, 1, 2)), + ("b", "5", 0.1, Array(3)), + ("a", "3", 0.8, Array(2, 5)), + ("c", "6", 0.3, Array(1, 3)), + ("b", "4", 0.3, Array(2)), + ("a", "2", 0.6, Array(1)) + ).toDF("key", "value", "score", "data") + + import testDf.sqlContext.implicits._ + testDf.repartition($"key").sortWithinPartitions($"key").createOrReplaceTempView("TestData") + sql(s""" + | CREATE TEMPORARY FUNCTION each_top_k + | AS '${classOf[hivemall.tools.EachTopKUDTF].getName}' + """.stripMargin) + + // Compute top-1 rows for each group + checkAnswer( + sql("SELECT each_top_k(1, key, score, key, value) FROM TestData"), + Row(1, 0.8, "a", "3") :: + Row(1, 0.3, "b", "4") :: + Row(1, 0.3, "c", "6") :: + Nil + ) + + // Compute reverse top-1 rows for each group + checkAnswer( + sql("SELECT each_top_k(-1, key, score, key, value) FROM TestData"), + Row(1, 0.5, "a", "1") :: + Row(1, 0.1, "b", "5") :: + Row(1, 0.3, "c", "6") :: + Nil + ) + } +} + +final class HiveUdfWithVectorSuite extends VectorQueryTest { + import hiveContext._ + + test("to_hivemall_features") { + mllibTrainDf.createOrReplaceTempView("mllibTrainDf") + hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func) + checkAnswer( + sql( + s""" + | SELECT to_hivemall_features(features) + | FROM mllibTrainDf + """.stripMargin), + Seq( + Row(Seq("0:1.0", "2:2.0", "4:3.0")), + Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2")), + Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0")), + Row(Seq("1:4.0", "3:5.0", "5:6.0")) + ) + ) + } + + test("append_bias") { + mllibTrainDf.createOrReplaceTempView("mllibTrainDf") + hiveContext.udf.register("append_bias", append_bias_func) + hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func) + checkAnswer( + sql( + s""" + | SELECT to_hivemall_features(append_bias(features)) + | FROM mllibTrainDF + """.stripMargin), + Seq( + Row(Seq("0:1.0", "2:2.0", "4:3.0", "7:1.0")), + Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2", "7:1.0")), + Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0", "7:1.0")), + Row(Seq("1:4.0", "3:5.0", "5:6.0", "7:1.0")) + ) + ) + } + + ignore("explode_vector") { + // TODO: Spark-2.0 does not support use-defined generator function in + // `org.apache.spark.sql.UDFRegistration`. + } +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala new file mode 100644 index 0000000..f65b451 --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@ -0,0 +1,784 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallGroupedDataset._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.HivemallUtils._ +import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest +import org.apache.spark.sql.test.VectorQueryTest +import org.apache.spark.sql.types._ +import org.apache.spark.test.TestFPWrapper._ +import org.apache.spark.test.TestUtils + +final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { + + test("anomaly") { + import hiveContext.implicits._ + val df = spark.range(1000).selectExpr("id AS time", "rand() AS x") + // TODO: Test results more exactly + assert(df.sort($"time".asc).select(sst($"x", lit("-th 0.005"))).count === 1000) + } + + test("knn.similarity") { + val df1 = DummyInputData.select(cosine_sim(lit2(Seq(1, 2, 3, 4)), lit2(Seq(3, 4, 5, 6)))) + assert(df1.collect.apply(0).getFloat(0) ~== 0.500f) + + val df2 = DummyInputData.select(jaccard(lit(5), lit(6))) + assert(df2.collect.apply(0).getFloat(0) ~== 0.96875f) + + val df3 = DummyInputData.select(angular_similarity(lit2(Seq(1, 2, 3)), lit2(Seq(4, 5, 6)))) + assert(df3.collect.apply(0).getFloat(0) ~== 0.500f) + + val df4 = DummyInputData.select(euclid_similarity(lit2(Seq(5, 3, 1)), lit2(Seq(2, 8, 3)))) + assert(df4.collect.apply(0).getFloat(0) ~== 0.33333334f) + + val df5 = DummyInputData.select(distance2similarity(lit(1.0))) + assert(df5.collect.apply(0).getFloat(0) ~== 0.5f) + } + + test("knn.distance") { + val df1 = DummyInputData.select(hamming_distance(lit(1), lit(3))) + checkAnswer(df1, Row(1) :: Nil) + + val df2 = DummyInputData.select(popcnt(lit(1))) + checkAnswer(df2, Row(1) :: Nil) + + val df3 = DummyInputData.select(kld(lit(0.1), lit(0.5), lit(0.2), lit(0.5))) + assert(df3.collect.apply(0).getDouble(0) ~== 0.01) + + val df4 = DummyInputData.select( + euclid_distance(lit2(Seq("0.1", "0.5")), lit2(Seq("0.2", "0.5")))) + assert(df4.collect.apply(0).getFloat(0) ~== 1.4142135f) + + val df5 = DummyInputData.select( + cosine_distance(lit2(Seq("0.8", "0.3")), lit2(Seq("0.4", "0.6")))) + assert(df5.collect.apply(0).getFloat(0) ~== 1.0f) + + val df6 = DummyInputData.select( + angular_distance(lit2(Seq("0.1", "0.1")), lit2(Seq("0.3", "0.8")))) + assert(df6.collect.apply(0).getFloat(0) ~== 0.50f) + + val df7 = DummyInputData.select( + manhattan_distance(lit2(Seq("0.7", "0.8")), lit2(Seq("0.5", "0.6")))) + assert(df7.collect.apply(0).getFloat(0) ~== 4.0f) + + val df8 = DummyInputData.select( + minkowski_distance(lit2(Seq("0.1", "0.2")), lit2(Seq("0.2", "0.2")), lit2(1.0))) + assert(df8.collect.apply(0).getFloat(0) ~== 2.0f) + } + + test("knn.lsh") { + import hiveContext.implicits._ + assert(IntList2Data.minhash(lit(1), $"target").count() > 0) + + assert(DummyInputData.select(bbit_minhash(lit2(Seq("1:0.1", "2:0.5")), lit(false))).count + == DummyInputData.count) + assert(DummyInputData.select(minhashes(lit2(Seq("1:0.1", "2:0.5")), lit(false))).count + == DummyInputData.count) + } + + test("ftvec - add_bias") { + import hiveContext.implicits._ + checkAnswer(TinyTrainData.select(add_bias($"features")), + Row(Seq("1:0.8", "2:0.2", "0:1.0")) :: + Row(Seq("2:0.7", "0:1.0")) :: + Row(Seq("1:0.9", "0:1.0")) :: + Nil + ) + } + + test("ftvec - extract_feature") { + val df = DummyInputData.select(extract_feature(lit("1:0.8"))) + checkAnswer(df, Row("1") :: Nil) + } + + test("ftvec - extract_weight") { + val df = DummyInputData.select(extract_weight(lit("3:0.1"))) + assert(df.collect.apply(0).getDouble(0) ~== 0.1) + } + + test("ftvec - explode_array") { + import hiveContext.implicits._ + val df = TinyTrainData.explode_array($"features").select($"feature") + checkAnswer(df, Row("1:0.8") :: Row("2:0.2") :: Row("2:0.7") :: Row("1:0.9") :: Nil) + } + + test("ftvec - add_feature_index") { + import hiveContext.implicits._ + val doubleListData = Seq(Array(0.8, 0.5), Array(0.3, 0.1), Array(0.2)).toDF("data") + checkAnswer( + doubleListData.select(add_feature_index($"data")), + Row(Seq("1:0.8", "2:0.5")) :: + Row(Seq("1:0.3", "2:0.1")) :: + Row(Seq("1:0.2")) :: + Nil + ) + } + + test("ftvec - sort_by_feature") { + // import hiveContext.implicits._ + val intFloatMapData = { + // TODO: Use `toDF` + val rowRdd = hiveContext.sparkContext.parallelize( + Row(Map(1 -> 0.3f, 2 -> 0.1f, 3 -> 0.5f)) :: + Row(Map(2 -> 0.4f, 1 -> 0.2f)) :: + Row(Map(2 -> 0.4f, 3 -> 0.2f, 1 -> 0.1f, 4 -> 0.6f)) :: + Nil + ) + hiveContext.createDataFrame( + rowRdd, + StructType( + StructField("data", MapType(IntegerType, FloatType), true) :: + Nil) + ) + } + val sortedKeys = intFloatMapData.select(sort_by_feature(intFloatMapData.col("data"))) + .collect.map { + case Row(m: Map[Int, Float]) => m.keysIterator.toSeq + } + assert(sortedKeys.toSet === Set(Seq(1, 2, 3), Seq(1, 2), Seq(1, 2, 3, 4))) + } + + test("ftvec.hash") { + assert(DummyInputData.select(mhash(lit("test"))).count == DummyInputData.count) + assert(DummyInputData.select(org.apache.spark.sql.hive.HivemallOps.sha1(lit("test"))).count == + DummyInputData.count) + // TODO: The tests below failed because: + // org.apache.spark.sql.AnalysisException: List type in java is unsupported because JVM type + // erasure makes spark fail to catch a component type in List<>; + // + // assert(DummyInputData.select(array_hash_values(lit2(Seq("aaa", "bbb")))).count + // == DummyInputData.count) + // assert(DummyInputData.select( + // prefixed_hash_values(lit2(Seq("ccc", "ddd")), lit("prefix"))).count + // == DummyInputData.count) + } + + test("ftvec.scaling") { + val df1 = TinyTrainData.select(rescale(lit(2.0f), lit(1.0), lit(5.0))) + assert(df1.collect.apply(0).getFloat(0) === 0.25f) + val df2 = TinyTrainData.select(zscore(lit(1.0f), lit(0.5), lit(0.5))) + assert(df2.collect.apply(0).getFloat(0) === 1.0f) + val df3 = TinyTrainData.select(normalize(TinyTrainData.col("features"))) + checkAnswer( + df3, + Row(Seq("1:0.9701425", "2:0.24253562")) :: + Row(Seq("2:1.0")) :: + Row(Seq("1:1.0")) :: + Nil) + } + + test("ftvec.selection - chi2") { + import hiveContext.implicits._ + + // See also hivemall.ftvec.selection.ChiSquareUDFTest + val df = Seq( + Seq( + Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996), + Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3), + Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998) + ) -> Seq( + Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589), + Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589), + Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589))) + .toDF("arg0", "arg1") + + val result = df.select(chi2(df("arg0"), df("arg1"))).collect + assert(result.length == 1) + val chi2Val = result.head.getAs[Row](0).getAs[Seq[Double]](0) + val pVal = result.head.getAs[Row](0).getAs[Seq[Double]](1) + + (chi2Val, Seq(10.81782088, 3.59449902, 116.16984746, 67.24482759)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + + (pVal, Seq(4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + } + + test("ftvec.conv - quantify") { + import hiveContext.implicits._ + val testDf = Seq((1, "aaa", true), (2, "bbb", false), (3, "aaa", false)).toDF + // This test is done in a single partition because `HivemallOps#quantify` assigns identifiers + // for non-numerical values in each partition. + checkAnswer( + testDf.coalesce(1).quantify(lit(true) +: testDf.cols: _*), + Row(1, 0, 0) :: Row(2, 1, 1) :: Row(3, 0, 1) :: Nil) + } + + test("ftvec.amplify") { + import hiveContext.implicits._ + assert(TinyTrainData.amplify(lit(3), $"label", $"features").count() == 9) + assert(TinyTrainData.part_amplify(lit(3)).count() == 9) + // TODO: The test below failed because: + // java.lang.RuntimeException: Unsupported literal type class scala.Tuple3 + // (-buf 128,label,features) + // + // assert(TinyTrainData.rand_amplify(lit(3), lit("-buf 8", $"label", $"features")).count() == 9) + } + + ignore("ftvec.conv") { + import hiveContext.implicits._ + + val df1 = Seq((0.0, "1:0.1" :: "3:0.3" :: Nil), (1, 0, "2:0.2" :: Nil)).toDF("a", "b") + checkAnswer( + df1.select(to_dense_features(df1("b"), lit(3))), + Row(Array(0.1f, 0.0f, 0.3f)) :: Row(Array(0.0f, 0.2f, 0.0f)) :: Nil + ) + val df2 = Seq((0.1, 0.2, 0.3), (0.2, 0.5, 0.4)).toDF("a", "b", "c") + checkAnswer( + df2.select(to_sparse_features(df2("a"), df2("b"), df2("c"))), + Row(Seq("1:0.1", "2:0.2", "3:0.3")) :: Row(Seq("1:0.2", "2:0.5", "3:0.4")) :: Nil + ) + } + + test("ftvec.trans") { + import hiveContext.implicits._ + + val df1 = Seq((1, -3, 1), (2, -2, 1)).toDF("a", "b", "c") + checkAnswer( + df1.binarize_label($"a", $"b", $"c"), + Row(1, 1) :: Row(1, 1) :: Row(1, 1) :: Nil + ) + + val df2 = Seq((0.1f, 0.2f), (0.5f, 0.3f)).toDF("a", "b") + checkAnswer( + df2.select(vectorize_features(lit2(Seq("a", "b")), df2("a"), df2("b"))), + Row(Seq("a:0.1", "b:0.2")) :: Row(Seq("a:0.5", "b:0.3")) :: Nil + ) + + val df3 = Seq(("c11", "c12"), ("c21", "c22")).toDF("a", "b") + checkAnswer( + df3.select(categorical_features(lit2(Seq("a", "b")), df3("a"), df3("b"))), + Row(Seq("a#c11", "b#c12")) :: Row(Seq("a#c21", "b#c22")) :: Nil + ) + + val df4 = Seq((0.1, 0.2, 0.3), (0.2, 0.5, 0.4)).toDF("a", "b", "c") + checkAnswer( + df4.select(indexed_features(df4("a"), df4("b"), df4("c"))), + Row(Seq("1:0.1", "2:0.2", "3:0.3")) :: Row(Seq("1:0.2", "2:0.5", "3:0.4")) :: Nil + ) + + val df5 = Seq(("xxx", "yyy", 0), ("zzz", "yyy", 1)).toDF("a", "b", "c").coalesce(1) + checkAnswer( + df5.quantified_features(lit(true), df5("a"), df5("b"), df5("c")), + Row(Seq(0.0, 0.0, 0.0)) :: Row(Seq(1.0, 0.0, 1.0)) :: Nil + ) + + val df6 = Seq((0.1, 0.2), (0.5, 0.3)).toDF("a", "b") + checkAnswer( + df6.select(quantitative_features(lit2(Seq("a", "b")), df6("a"), df6("b"))), + Row(Seq("a:0.1", "b:0.2")) :: Row(Seq("a:0.5", "b:0.3")) :: Nil + ) + } + + test("misc - hivemall_version") { + checkAnswer(DummyInputData.select(hivemall_version()), Row("0.4.2-rc.2")) + } + + test("misc - rowid") { + assert(DummyInputData.select(rowid()).distinct.count == DummyInputData.count) + } + + test("misc - each_top_k") { + import hiveContext.implicits._ + val testDf = Seq( + ("a", "1", 0.5, Array(0, 1, 2)), + ("b", "5", 0.1, Array(3)), + ("a", "3", 0.8, Array(2, 5)), + ("c", "6", 0.3, Array(1, 3)), + ("b", "4", 0.3, Array(2)), + ("a", "2", 0.6, Array(1)) + ).toDF("key", "value", "score", "data") + + // Compute top-1 rows for each group + checkAnswer( + testDf.each_top_k(lit(1), $"key", $"score"), + Row(1, "a", "3", 0.8, Array(2, 5)) :: + Row(1, "b", "4", 0.3, Array(2)) :: + Row(1, "c", "6", 0.3, Array(1, 3)) :: + Nil + ) + + // Compute reverse top-1 rows for each group + checkAnswer( + testDf.each_top_k(lit(-1), $"key", $"score"), + Row(1, "a", "1", 0.5, Array(0, 1, 2)) :: + Row(1, "b", "5", 0.1, Array(3)) :: + Row(1, "c", "6", 0.3, Array(1, 3)) :: + Nil + ) + + // Check if some exceptions thrown in case of some conditions + assert(intercept[AnalysisException] { testDf.each_top_k(lit(0.1), $"key", $"score") } + .getMessage contains "`k` must be integer, however") + assert(intercept[AnalysisException] { testDf.each_top_k(lit(1), $"key", $"data") } + .getMessage contains "must have a comparable type") + } + + /** + * This test fails because; + * + * Cause: java.lang.OutOfMemoryError: Java heap space + * at hivemall.smile.tools.RandomForestEnsembleUDAF$Result.<init> + * (RandomForestEnsembleUDAF.java:128) + * at hivemall.smile.tools.RandomForestEnsembleUDAF$RandomForestPredictUDAFEvaluator + * .terminate(RandomForestEnsembleUDAF.java:91) + * at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) + */ + ignore("misc - tree_predict") { + import hiveContext.implicits._ + + val model = Seq((0.0, 0.1 :: 0.1 :: Nil), (1.0, 0.2 :: 0.3 :: 0.2 :: Nil)) + .toDF("label", "features") + .train_randomforest_regr($"features", $"label") + + val testData = Seq((0.0, 0.1 :: 0.0 :: Nil), (1.0, 0.3 :: 0.5 :: 0.4 :: Nil)) + .toDF("label", "features") + .select(rowid(), $"label", $"features") + + val predicted = model + .join(testData).coalesce(1) + .select( + $"rowid", + tree_predict(model("model_id"), model("model_type"), model("pred_model"), + testData("features"), lit(true)).as("predicted") + ) + .groupBy($"rowid") + .rf_ensemble("predicted").toDF("rowid", "predicted") + .select($"predicted.label") + + checkAnswer(predicted, Seq(Row(0), Row(1))) + } + + test("tools.array - select_k_best") { + import hiveContext.implicits._ + + val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9)) + val df = data.map(d => (d, Seq(3, 1, 2))).toDF("features", "importance_list") + val k = 2 + + checkAnswer( + df.select(select_k_best(df("features"), df("importance_list"), lit(k))), + Row(Seq(0.0, 3.0)) :: Row(Seq(2.0, 1.0)) :: Row(Seq(5.0, 9.0)) :: Nil + ) + } + + test("misc - sigmoid") { + import hiveContext.implicits._ + assert(DummyInputData.select(sigmoid($"c0")).collect.apply(0).getDouble(0) ~== 0.500) + } + + test("misc - lr_datagen") { + assert(TinyTrainData.lr_datagen(lit("-n_examples 100 -n_features 10 -seed 100")).count >= 100) + } + + test("invoke regression functions") { + import hiveContext.implicits._ + Seq( + "train_adadelta", + "train_adagrad", + "train_arow_regr", + "train_arowe_regr", + "train_arowe2_regr", + "train_logregr", + "train_pa1_regr", + "train_pa1a_regr", + "train_pa2_regr", + "train_pa2a_regr" + ).map { func => + TestUtils.invokeFunc(new HivemallOps(TinyTrainData), func, Seq($"features", $"label")) + .foreach(_ => {}) // Just call it + } + } + + test("invoke classifier functions") { + import hiveContext.implicits._ + Seq( + "train_perceptron", + "train_pa", + "train_pa1", + "train_pa2", + "train_cw", + "train_arow", + "train_arowh", + "train_scw", + "train_scw2", + "train_adagrad_rda" + ).map { func => + TestUtils.invokeFunc(new HivemallOps(TinyTrainData), func, Seq($"features", $"label")) + .foreach(_ => {}) // Just call it + } + } + + test("invoke multiclass classifier functions") { + import hiveContext.implicits._ + Seq( + "train_multiclass_perceptron", + "train_multiclass_pa", + "train_multiclass_pa1", + "train_multiclass_pa2", + "train_multiclass_cw", + "train_multiclass_arow", + "train_multiclass_scw", + "train_multiclass_scw2" + ).map { func => + // TODO: Why is a label type [Int|Text] only in multiclass classifiers? + TestUtils.invokeFunc( + new HivemallOps(TinyTrainData), func, Seq($"features", $"label".cast(IntegerType))) + .foreach(_ => {}) // Just call it + } + } + + test("invoke random forest functions") { + import hiveContext.implicits._ + val testDf = Seq( + (Array(0.3, 0.1, 0.2), 1), + (Array(0.3, 0.1, 0.2), 0), + (Array(0.3, 0.1, 0.2), 0)).toDF("features", "label") + Seq( + "train_randomforest_regr", + "train_randomforest_classifier" + ).map { func => + TestUtils.invokeFunc(new HivemallOps(testDf.coalesce(1)), func, Seq($"features", $"label")) + .foreach(_ => {}) // Just call it + } + } + + protected def checkRegrPrecision(func: String): Unit = { + import hiveContext.implicits._ + + // Build a model + val model = { + val res = TestUtils.invokeFunc(new HivemallOps(LargeRegrTrainData), + func, Seq(add_bias($"features"), $"label")) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = LargeRegrTrainData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .groupBy() + .agg(Map("target" -> "avg", "predicted" -> "avg")) + .toDF("target", "predicted") + + val diff = eval.map { + case Row(target: Double, predicted: Double) => + Math.abs(target - predicted) + }.first + + TestUtils.expectResult(diff > 0.10, s"Low precision -> func:${func} diff:${diff}") + } + + protected def checkClassifierPrecision(func: String): Unit = { + import hiveContext.implicits._ + + // Build a model + val model = { + val res = TestUtils.invokeFunc(new HivemallOps(LargeClassifierTrainData), + func, Seq(add_bias($"features"), $"label")) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = LargeClassifierTestData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + /** + * TODO: This sentence throws an exception below: + * + * WARN Column: Constructing trivially true equals predicate, 'rowid#1323 = rowid#1323'. + * Perhaps you need to use aliases. + */ + .select($"rowid", when(sigmoid($"sum(value)") > 0.50, 1.0).otherwise(0.0)) + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .where($"target" === $"predicted") + + val precision = (eval.count + 0.0) / predict.count + + TestUtils.expectResult(precision < 0.70, s"Low precision -> func:${func} value:${precision}") + } + + ignore("check regression precision") { + Seq( + "train_adadelta", + "train_adagrad", + "train_arow_regr", + "train_arowe_regr", + "train_arowe2_regr", + "train_logregr", + "train_pa1_regr", + "train_pa1a_regr", + "train_pa2_regr", + "train_pa2a_regr" + ).map { func => + checkRegrPrecision(func) + } + } + + ignore("check classifier precision") { + Seq( + "train_perceptron", + "train_pa", + "train_pa1", + "train_pa2", + "train_cw", + "train_arow", + "train_arowh", + "train_scw", + "train_scw2", + "train_adagrad_rda" + ).map { func => + checkClassifierPrecision(func) + } + } + + test("user-defined aggregators for ensembles") { + import hiveContext.implicits._ + + val df1 = Seq((1, 0.1f), (1, 0.2f), (2, 0.1f)).toDF("c0", "c1") + val row1 = df1.groupBy($"c0").voted_avg("c1").collect + assert(row1(0).getDouble(1) ~== 0.15) + assert(row1(1).getDouble(1) ~== 0.10) + + val df3 = Seq((1, 0.2f), (1, 0.8f), (2, 0.3f)).toDF("c0", "c1") + val row3 = df3.groupBy($"c0").weight_voted_avg("c1").collect + assert(row3(0).getDouble(1) ~== 0.50) + assert(row3(1).getDouble(1) ~== 0.30) + + val df5 = Seq((1, 0.2f, 0.1f), (1, 0.4f, 0.2f), (2, 0.8f, 0.9f)).toDF("c0", "c1", "c2") + val row5 = df5.groupBy($"c0").argmin_kld("c1", "c2").collect + assert(row5(0).getFloat(1) ~== 0.266666666) + assert(row5(1).getFloat(1) ~== 0.80) + + val df6 = Seq((1, "id-0", 0.2f), (1, "id-1", 0.4f), (1, "id-2", 0.1f)).toDF("c0", "c1", "c2") + val row6 = df6.groupBy($"c0").max_label("c2", "c1").collect + assert(row6(0).getString(1) == "id-1") + + val df7 = Seq((1, "id-0", 0.5f), (1, "id-1", 0.1f), (1, "id-2", 0.2f)).toDF("c0", "c1", "c2") + val row7 = df7.groupBy($"c0").maxrow("c2", "c1").toDF("c0", "c1").select($"c1.col1").collect + assert(row7(0).getString(0) == "id-0") + + val df8 = Seq((1, 1), (1, 2), (2, 1), (1, 5)).toDF("c0", "c1") + val row8 = df8.groupBy($"c0").rf_ensemble("c1").toDF("c0", "c1") + .select("c1.probability").collect + assert(row8(0).getDouble(0) ~== 0.3333333333) + assert(row8(1).getDouble(0) ~== 1.0) + } + + test("user-defined aggregators for evaluation") { + import hiveContext.implicits._ + + val df1 = Seq((1, 1.0f, 0.5f), (1, 0.3f, 0.5f), (1, 0.1f, 0.2f)).toDF("c0", "c1", "c2") + val row1 = df1.groupBy($"c0").mae("c1", "c2").collect + assert(row1(0).getDouble(1) ~== 0.26666666) + + val df2 = Seq((1, 0.3f, 0.8f), (1, 1.2f, 2.0f), (1, 0.2f, 0.3f)).toDF("c0", "c1", "c2") + val row2 = df2.groupBy($"c0").mse("c1", "c2").collect + assert(row2(0).getDouble(1) ~== 0.29999999) + + val df3 = Seq((1, 0.3f, 0.8f), (1, 1.2f, 2.0f), (1, 0.2f, 0.3f)).toDF("c0", "c1", "c2") + val row3 = df3.groupBy($"c0").rmse("c1", "c2").collect + assert(row3(0).getDouble(1) ~== 0.54772253) + + val df4 = Seq((1, Array(1, 2), Array(2, 3)), (1, Array(3, 8), Array(5, 4))).toDF + .toDF("c0", "c1", "c2") + val row4 = df4.groupBy($"c0").f1score("c1", "c2").collect + assert(row4(0).getDouble(1) ~== 0.25) + } + + test("user-defined aggregators for ftvec.trans") { + import hiveContext.implicits._ + + val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10), + (1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9), + (1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9)) + .toDF("col0", "cat1", "cat2", "cat3") + val row00 = df0.groupBy($"col0").onehot_encoding("cat1") + val row01 = df0.groupBy($"col0").onehot_encoding("cat1", "cat2", "cat3") + + val result000 = row00.collect()(0).getAs[Row](1).getAs[Map[String, Int]](0) + val result01 = row01.collect()(0).getAs[Row](1) + val result010 = result01.getAs[Map[String, Int]](0) + val result011 = result01.getAs[Map[String, Int]](1) + val result012 = result01.getAs[Map[String, Int]](2) + + assert(result000.keySet === Set("seahawk", "cat", "human", "wasp", "dog")) + assert(result000.values.toSet === Set(1, 2, 3, 4, 5)) + assert(result010.keySet === Set("seahawk", "cat", "human", "wasp", "dog")) + assert(result010.values.toSet === Set(1, 2, 3, 4, 5)) + assert(result011.keySet === Set("bird", "insect", "mammal")) + assert(result011.values.toSet === Set(6, 7, 8)) + assert(result012.keySet === Set(1, 3, 9, 10, 101)) + assert(result012.values.toSet === Set(9, 10, 11, 12, 13)) + } + + test("user-defined aggregators for ftvec.selection") { + import hiveContext.implicits._ + + // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest + // binary class + // +-----------------+-------+ + // | features | class | + // +-----------------+-------+ + // | 5.1,3.5,1.4,0.2 | 0 | + // | 4.9,3.0,1.4,0.2 | 0 | + // | 4.7,3.2,1.3,0.2 | 0 | + // | 7.0,3.2,4.7,1.4 | 1 | + // | 6.4,3.2,4.5,1.5 | 1 | + // | 6.9,3.1,4.9,1.5 | 1 | + // +-----------------+-------+ + val df0 = Seq( + (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)), + (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)), + (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1))) + .toDF("c0", "arg0", "arg1") + val row0 = df0.groupBy($"c0").snr("arg0", "arg1").collect + (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + + // multiple class + // +-----------------+-------+ + // | features | class | + // +-----------------+-------+ + // | 5.1,3.5,1.4,0.2 | 0 | + // | 4.9,3.0,1.4,0.2 | 0 | + // | 7.0,3.2,4.7,1.4 | 1 | + // | 6.4,3.2,4.5,1.5 | 1 | + // | 6.3,3.3,6.0,2.5 | 2 | + // | 5.8,2.7,5.1,1.9 | 2 | + // +-----------------+-------+ + val df1 = Seq( + (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)), + (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)), + (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1))) + .toDF("c0", "arg0", "arg1") + val row1 = df1.groupBy($"c0").snr("arg0", "arg1").collect + (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + } + + test("user-defined aggregators for tools.matrix") { + import hiveContext.implicits._ + + // | 1 2 3 |T | 5 6 7 | + // | 3 4 5 | * | 7 8 9 | + val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9))) + .toDF("c0", "arg0", "arg1") + + checkAnswer(df0.groupBy($"c0").transpose_and_dot("arg0", "arg1"), + Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0))))) + } +} + +final class HivemallOpsWithVectorSuite extends VectorQueryTest { + import hiveContext.implicits._ + + test("to_hivemall_features") { + checkAnswer( + mllibTrainDf.select(to_hivemall_features($"features")), + Seq( + Row(Seq("0:1.0", "2:2.0", "4:3.0")), + Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2")), + Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0")), + Row(Seq("1:4.0", "3:5.0", "5:6.0")) + ) + ) + } + + test("append_bias") { + checkAnswer( + mllibTrainDf.select(to_hivemall_features(append_bias($"features"))), + Seq( + Row(Seq("0:1.0", "2:2.0", "4:3.0", "7:1.0")), + Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2", "7:1.0")), + Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0", "7:1.0")), + Row(Seq("1:4.0", "3:5.0", "5:6.0", "7:1.0")) + ) + ) + } + + test("explode_vector") { + checkAnswer( + mllibTrainDf.explode_vector($"features").select($"feature", $"weight"), + Seq( + Row("0", 1.0), Row("0", 1.0), Row("0", 1.1), + Row("1", 4.0), + Row("2", 2.0), + Row("3", 1.0), Row("3", 1.5), Row("3", 5.0), + Row("4", 2.1), Row("4", 2.3), Row("4", 3.0), + Row("5", 6.0), + Row("6", 1.0), Row("6", 1.2) + ) + ) + } + + test("train_logregr") { + checkAnswer( + mllibTrainDf.train_logregr($"features", $"label") + .groupBy("feature").agg("weight" -> "avg") + .select($"feature"), + Seq(0, 1, 2, 3, 4, 5, 6).map(v => Row(s"$v")) + ) + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala new file mode 100644 index 0000000..06a4dc0 --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive + +import java.io.{BufferedInputStream, BufferedReader, InputStream, InputStreamReader} +import java.net.URL +import java.util.UUID +import java.util.concurrent.{Executors, ExecutorService} + +import hivemall.mix.server.MixServer +import hivemall.utils.lang.CommandLineUtils +import hivemall.utils.net.NetUtils +import org.apache.commons.cli.Options +import org.apache.commons.compress.compressors.CompressorStreamFactory +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.HivemallLabeledPoint +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallGroupedDataset._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.test.TestUtils + +final class ModelMixingSuite extends SparkFunSuite with BeforeAndAfter { + + // Load A9a training and test data + val a9aLineParser = (line: String) => { + val elements = line.split(" ") + val (label, features) = (elements.head, elements.tail) + HivemallLabeledPoint(if (label == "+1") 1.0f else 0.0f, features) + } + + lazy val trainA9aData: DataFrame = + getDataFromURI( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a").openStream, + a9aLineParser) + + lazy val testA9aData: DataFrame = + getDataFromURI( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a.t").openStream, + a9aLineParser) + + // Load A9a training and test data + val kdd2010aLineParser = (line: String) => { + val elements = line.split(" ") + val (label, features) = (elements.head, elements.tail) + HivemallLabeledPoint(if (label == "1") 1.0f else 0.0f, features) + } + + lazy val trainKdd2010aData: DataFrame = + getDataFromURI( + new CompressorStreamFactory().createCompressorInputStream( + new BufferedInputStream( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.bz2") + .openStream + ) + ), + kdd2010aLineParser, + 8) + + lazy val testKdd2010aData: DataFrame = + getDataFromURI( + new CompressorStreamFactory().createCompressorInputStream( + new BufferedInputStream( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2") + .openStream + ) + ), + kdd2010aLineParser, + 8) + + // Placeholder for a mix server + var mixServExec: ExecutorService = _ + var assignedPort: Int = _ + + private def getDataFromURI( + in: InputStream, lineParseFunc: String => HivemallLabeledPoint, numPart: Int = 2) + : DataFrame = { + val reader = new BufferedReader(new InputStreamReader(in)) + try { + // Cache all data because stream closed soon + val lines = FileIterator(reader.readLine()).toSeq + val rdd = TestHive.sparkContext.parallelize(lines, numPart).map(lineParseFunc) + val df = rdd.toDF.cache + df.foreach(_ => {}) + df + } finally { + reader.close() + } + } + + before { + assert(mixServExec == null) + + // Launch a MIX server as thread + assignedPort = NetUtils.getAvailablePort + val method = classOf[MixServer].getDeclaredMethod("getOptions") + method.setAccessible(true) + val options = method.invoke(null).asInstanceOf[Options] + val cl = CommandLineUtils.parseOptions( + Array( + "-port", Integer.toString(assignedPort), + "-sync_threshold", "1" + ), + options + ) + val server = new MixServer(cl) + mixServExec = Executors.newSingleThreadExecutor() + mixServExec.submit(server) + var retry = 0 + while (server.getState() != MixServer.ServerState.RUNNING && retry < 32) { + Thread.sleep(100L) + retry += 1 + } + assert(MixServer.ServerState.RUNNING == server.getState) + } + + after { + mixServExec.shutdownNow() + mixServExec = null + } + + TestUtils.benchmark("model mixing test w/ regression") { + Seq( + "train_adadelta", + "train_adagrad", + "train_arow_regr", + "train_arowe_regr", + "train_arowe2_regr", + "train_logregr", + "train_pa1_regr", + "train_pa1a_regr", + "train_pa2_regr", + "train_pa2a_regr" + ).map { func => + // Build a model + val model = { + val groupId = s"${TestHive.sparkContext.applicationId}-${UUID.randomUUID}" + val res = TestUtils.invokeFunc( + new HivemallOps(trainA9aData.part_amplify(lit(1))), + func, + Seq[Column]( + add_bias($"features"), + $"label", + lit(s"-mix localhost:${assignedPort} -mix_session ${groupId} -mix_threshold 2 " + + "-mix_cancel") + ) + ) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = testA9aData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .groupBy() + .agg(Map("target" -> "avg", "predicted" -> "avg")) + .toDF("target", "predicted") + + val (target, predicted) = eval.map { + case Row(target: Double, predicted: Double) => (target, predicted) + }.first + + // scalastyle:off println + println(s"func:${func} target:${target} predicted:${predicted} " + + s"diff:${Math.abs(target - predicted)}") + + testDf.unpersist() + } + } + + TestUtils.benchmark("model mixing test w/ binary classification") { + Seq( + "train_perceptron", + "train_pa", + "train_pa1", + "train_pa2", + "train_cw", + "train_arow", + "train_arowh", + "train_scw", + "train_scw2", + "train_adagrad_rda" + ).map { func => + // Build a model + val model = { + val groupId = s"${TestHive.sparkContext.applicationId}-${UUID.randomUUID}" + val res = TestUtils.invokeFunc( + new HivemallOps(trainKdd2010aData.part_amplify(lit(1))), + func, + Seq[Column]( + add_bias($"features"), + $"label", + lit(s"-mix localhost:${assignedPort} -mix_session ${groupId} -mix_threshold 2 " + + "-mix_cancel") + ) + ) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = testKdd2010aData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + .select($"rowid", when(sigmoid($"sum(value)") > 0.50, 1.0).otherwise(0.0)) + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .where($"target" === $"predicted") + + // scalastyle:off println + println(s"func:${func} precision:${(eval.count + 0.0) / predict.count}") + + testDf.unpersist() + predict.unpersist() + } + } +} + +object FileIterator { + + def apply[A](f: => A): Iterator[A] = new Iterator[A] { + var opt = Option(f) + def hasNext = opt.nonEmpty + def next() = { + val r = opt.get + opt = Option(f) + r + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala new file mode 100644 index 0000000..0d4b894 --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive + +import java.io.File + +import hivemall.xgboost._ + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallGroupedDataset._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.VectorQueryTest +import org.apache.spark.sql.types._ + +final class XGBoostSuite extends VectorQueryTest { + import hiveContext.implicits._ + + private val defaultOptions = XGBoostOptions() + .set("num_round", "10") + .set("max_depth", "4") + + private val numModles = 3 + + private def countModels(dirPath: String): Int = { + new File(dirPath).listFiles().toSeq.count(_.getName.endsWith(".xgboost")) + } + + test("resolve libxgboost") { + def getProvidingClass(name: String): Class[_] = + DataSource(sparkSession = null, className = name).providingClass + assert(getProvidingClass("libxgboost") === + classOf[org.apache.spark.sql.hive.source.XGBoostFileFormat]) + } + + test("check XGBoost options") { + assert(s"$defaultOptions" == "-max_depth 4 -num_round 10") + val errMsg = intercept[IllegalArgumentException] { + defaultOptions.set("unknown", "3") + } + assert(errMsg.getMessage == "requirement failed: " + + "non-existing key detected in XGBoost options: unknown") + } + + test("train_xgboost_regr") { + withTempModelDir { tempDir => + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + + // Save built models in persistent storage + mllibTrainDf.repartition(numModles) + .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}")) + .write.format("libxgboost").save(tempDir) + + // Check #models generated by XGBoost + assert(countModels(tempDir) == numModles) + + // Load the saved models + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) + val predict = model.join(mllibTestDf) + .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model") + .groupBy("rowid").avg() + .toDF("rowid", "predicted") + + val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER") + .select(predict("rowid"), $"predicted", $"label") + + result.select(avg(abs($"predicted" - $"label"))).collect.map { + case Row(diff: Double) => assert(diff > 0.0) + } + } + } + } + + test("train_xgboost_classifier") { + withTempModelDir { tempDir => + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + + mllibTrainDf.repartition(numModles) + .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}")) + .write.format("libxgboost").save(tempDir) + + // Check #models generated by XGBoost + assert(countModels(tempDir) == numModles) + + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) + val predict = model.join(mllibTestDf) + .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model") + .groupBy("rowid").avg() + .toDF("rowid", "predicted") + + val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER") + .select( + when($"predicted" >= 0.50, 1).otherwise(0), + $"label".cast(IntegerType) + ) + .toDF("predicted", "label") + + assert((result.where($"label" === $"predicted").count + 0.0) / result.count > 0.0) + } + } + } + + test("train_xgboost_multiclass_classifier") { + withTempModelDir { tempDir => + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + + mllibTrainDf.repartition(numModles) + .train_xgboost_multiclass_classifier( + $"features", $"label", lit(s"${defaultOptions.set("num_class", "2")}")) + .write.format("libxgboost").save(tempDir) + + // Check #models generated by XGBoost + assert(countModels(tempDir) == numModles) + + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) + val predict = model.join(mllibTestDf) + .xgboost_multiclass_predict($"rowid", $"features", $"model_id", $"pred_model") + .groupBy("rowid").max_label("probability", "label") + .toDF("rowid", "predicted") + + val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER") + .select( + predict("rowid"), + $"predicted", + $"label".cast(IntegerType) + ) + + assert((result.where($"label" === $"predicted").count + 0.0) / result.count > 0.0) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala new file mode 100644 index 0000000..9b5a1e5 --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive.benchmark + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{EachTopK, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, Project} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.{HiveGenericUDF, HiveGenericUDTF} +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper +import org.apache.spark.sql.types._ +import org.apache.spark.test.TestUtils +import org.apache.spark.util.Benchmark + +class TestFuncWrapper(df: DataFrame) { + + def each_top_k(k: Column, group: Column, value: Column, args: Column*) + : DataFrame = withTypedPlan { + val clusterDf = df.repartition(group).sortWithinPartitions(group) + Generate(HiveGenericUDTF( + "each_top_k", + new HiveFunctionWrapper("hivemall.tools.EachTopKUDTF"), + (Seq(k, group, value) ++ args).map(_.expr)), + join = false, outer = false, None, + (Seq("rank", "key") ++ args.map(_.named.name)).map(UnresolvedAttribute(_)), + clusterDf.logicalPlan) + } + + def each_top_k_improved(k: Int, group: String, score: String, args: String*) + : DataFrame = withTypedPlan { + val clusterDf = df.repartition(df(group)).sortWithinPartitions(group) + val childrenAttributes = clusterDf.logicalPlan.output + val generator = Generate( + EachTopK( + k, + clusterDf.resolve(group), + clusterDf.resolve(score), + childrenAttributes + ), + join = false, outer = false, None, + (Seq("rank") ++ childrenAttributes.map(_.name)).map(UnresolvedAttribute(_)), + clusterDf.logicalPlan) + val attributes = generator.generatedSet + val projectList = (Seq("rank") ++ args).map(s => attributes.find(_.name == s).get) + Project(projectList, generator) + } + + /** + * A convenient function to wrap a logical plan and produce a DataFrame. + */ + @inline private[this] def withTypedPlan(logicalPlan: => LogicalPlan): DataFrame = { + val queryExecution = df.sparkSession.sessionState.executePlan(logicalPlan) + val outputSchema = queryExecution.sparkPlan.schema + new Dataset[Row](df.sparkSession, queryExecution, RowEncoder(outputSchema)) + } +} + +object TestFuncWrapper { + + /** + * Implicitly inject the [[TestFuncWrapper]] into [[DataFrame]]. + */ + implicit def dataFrameToTestFuncWrapper(df: DataFrame): TestFuncWrapper = + new TestFuncWrapper(df) + + def sigmoid(exprs: Column*): Column = withExpr { + HiveGenericUDF("sigmoid", + new HiveFunctionWrapper("hivemall.tools.math.SigmoidGenericUDF"), + exprs.map(_.expr)) + } + + /** + * A convenient function to wrap an expression and produce a Column. + */ + @inline private def withExpr(expr: Expression): Column = Column(expr) +} + +class MiscBenchmark extends SparkFunSuite { + + lazy val sparkSession = SparkSession.builder + .master("local[1]") + .appName("microbenchmark") + .config("spark.sql.shuffle.partitions", 1) + .config("spark.sql.codegen.wholeStage", true) + .getOrCreate() + + val numIters = 3 + + private def addBenchmarkCase(name: String, df: DataFrame)(implicit benchmark: Benchmark): Unit = { + benchmark.addCase(name, numIters) { _ => + df.queryExecution.executedPlan.execute().foreach(_ => {}) + } + } + + TestUtils.benchmark("closure/exprs/spark-udf/hive-udf") { + /** + * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + * + * sigmoid functions: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + * -------------------------------------------------------------------------------- + * exprs 7708 / 8173 3.4 294.0 1.0X + * closure 7722 / 8342 3.4 294.6 1.0X + * spark-udf 7963 / 8350 3.3 303.8 1.0X + * hive-udf 13977 / 14050 1.9 533.2 0.6X + */ + import sparkSession.sqlContext.implicits._ + val N = 100L << 18 + implicit val benchmark = new Benchmark("sigmoid", N) + val schema = StructType( + StructField("value", DoubleType) :: Nil + ) + val testDf = sparkSession.createDataFrame( + sparkSession.range(N).map(_.toDouble).map(Row(_))(RowEncoder(schema)).rdd, + schema + ) + testDf.cache.count // Cached + + def sigmoidExprs(expr: Column): Column = { + val one: () => Literal = () => Literal.create(1.0, DoubleType) + Column(one()) / (Column(one()) + exp(-expr)) + } + addBenchmarkCase( + "exprs", + testDf.select(sigmoidExprs($"value")) + ) + + addBenchmarkCase( + "closure", + testDf.map { d => + Row(1.0 / (1.0 + Math.exp(-d.getDouble(0)))) + }(RowEncoder(schema)) + ) + + val sigmoidUdf = udf { (d: Double) => 1.0 / (1.0 + Math.exp(-d)) } + addBenchmarkCase( + "spark-udf", + testDf.select(sigmoidUdf($"value")) + ) + addBenchmarkCase( + "hive-udf", + testDf.select(TestFuncWrapper.sigmoid($"value")) + ) + + benchmark.run() + } + + TestUtils.benchmark("top-k query") { + /** + * top-k (k=100): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + * ------------------------------------------------------------------------------- + * rank 62748 / 62862 0.4 2393.6 1.0X + * each_top_k (hive-udf) 41421 / 41736 0.6 1580.1 1.5X + * each_top_k (exprs) 15793 / 16394 1.7 602.5 4.0X + */ + import sparkSession.sqlContext.implicits._ + import TestFuncWrapper._ + val N = 100L << 18 + val topK = 100 + val numGroup = 3 + implicit val benchmark = new Benchmark(s"top-k (k=$topK)", N) + val schema = StructType( + StructField("key", IntegerType) :: + StructField("score", DoubleType) :: + StructField("value", StringType) :: + Nil + ) + val testDf = { + val df = sparkSession.createDataFrame( + sparkSession.sparkContext.range(0, N).map(_.toInt).map { d => + Row(d % numGroup, scala.util.Random.nextDouble(), s"group-${d % numGroup}") + }, + schema + ) + // Test data are clustered by group keys + df.repartition($"key").sortWithinPartitions($"key") + } + testDf.cache.count // Cached + + addBenchmarkCase( + "rank", + testDf.withColumn( + "rank", rank().over(Window.partitionBy($"key").orderBy($"score".desc)) + ).where($"rank" <= topK) + ) + + addBenchmarkCase( + "each_top_k (hive-udf)", + // TODO: If $"value" given, it throws `AnalysisException`. Why? + // testDf.each_top_k(10, $"key", $"score", $"value") + // org.apache.spark.sql.catalyst.analysis.UnresolvedException: Invalid call to name + // on unresolved object, tree: unresolvedalias('value, None) + // at org.apache.spark.sql.catalyst.analysis.UnresolvedAlias.name(unresolved.scala:339) + testDf.each_top_k(lit(topK), $"key", $"score", testDf("value")) + ) + + addBenchmarkCase( + "each_top_k (exprs)", + testDf.each_top_k_improved(topK, "key", "score", "value") + ) + + benchmark.run() + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala new file mode 100644 index 0000000..a4733f5 --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive.test + +import scala.collection.mutable.Seq +import scala.reflect.runtime.universe.TypeTag + +import hivemall.tools.RegressionDatagen + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.QueryTest + +/** + * Base class for tests with Hivemall features. + */ +abstract class HivemallFeatureQueryTest extends QueryTest with TestHiveSingleton { + + import hiveContext.implicits._ + + /** + * TODO: spark-2.0 does not support literals for some types (e.g., Seq[_] and Array[_]). + * So, it provides that functionality here. + * This helper function will be removed in future releases. + */ + protected def lit2[T : TypeTag](v: T): Column = { + val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T] + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + Column(Literal(convert(v), dataType)) + } + + protected val DummyInputData = Seq((0, 0)).toDF("c0", "c1") + + protected val IntList2Data = + Seq( + (8 :: 5 :: Nil, 6 :: 4 :: Nil), + (3 :: 1 :: Nil, 3 :: 2 :: Nil), + (2 :: Nil, 3 :: Nil) + ).toDF("target", "predict") + + protected val Float2Data = + Seq( + (0.8f, 0.3f), (0.3f, 0.9f), (0.2f, 0.4f) + ).toDF("target", "predict") + + protected val TinyTrainData = + Seq( + (0.0, "1:0.8" :: "2:0.2" :: Nil), + (1.0, "2:0.7" :: Nil), + (0.0, "1:0.9" :: Nil) + ).toDF("label", "features") + + protected val TinyTestData = + Seq( + (0.0, "1:0.6" :: "2:0.1" :: Nil), + (1.0, "2:0.9" :: Nil), + (0.0, "1:0.2" :: Nil), + (0.0, "2:0.1" :: Nil), + (0.0, "0:0.6" :: "2:0.4" :: Nil) + ).toDF("label", "features") + + protected val LargeRegrTrainData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100000, + seed = 3, + prob_one = 0.8f + ).cache + + protected val LargeRegrTestData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100, + seed = 3, + prob_one = 0.5f + ).cache + + protected val LargeClassifierTrainData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100000, + seed = 5, + prob_one = 0.8f, + cl = true + ).cache + + protected val LargeClassifierTestData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100, + seed = 5, + prob_one = 0.5f, + cl = true + ).cache +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala new file mode 100644 index 0000000..3a88924 --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive.test + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.SparkSession +import org.apache.spark.SparkFunSuite + +trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { + protected val spark: SparkSession = TestHive.sparkSession + protected val hiveContext: TestHiveContext = TestHive + + protected override def afterAll(): Unit = { + try { + hiveContext.reset() + } finally { + super.afterAll() + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala new file mode 100644 index 0000000..02ea34d --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.test + +import java.nio.charset.StandardCharsets + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} + +/** + * A collection of sample data used in SQL tests. + */ +private[sql] trait SQLTestData { self => + protected def spark: SparkSession + + // Helper object to import SQL implicits without a concrete SQLContext + private object internalImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.spark.sqlContext + } + + import internalImplicits._ + import SQLTestData._ + + // Note: all test data should be lazy because the SQLContext is not set up yet. + + protected lazy val emptyTestData: DataFrame = { + val df = spark.sparkContext.parallelize( + Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() + df.createOrReplaceTempView("emptyTestData") + df + } + + protected lazy val testData: DataFrame = { + val df = spark.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.createOrReplaceTempView("testData") + df + } + + protected lazy val testData2: DataFrame = { + val df = spark.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil, 2).toDF() + df.createOrReplaceTempView("testData2") + df + } + + protected lazy val testData3: DataFrame = { + val df = spark.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toDF() + df.createOrReplaceTempView("testData3") + df + } + + protected lazy val negativeData: DataFrame = { + val df = spark.sparkContext.parallelize( + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() + df.createOrReplaceTempView("negativeData") + df + } + + protected lazy val largeAndSmallInts: DataFrame = { + val df = spark.sparkContext.parallelize( + LargeAndSmallInts(2147483644, 1) :: + LargeAndSmallInts(1, 2) :: + LargeAndSmallInts(2147483645, 1) :: + LargeAndSmallInts(2, 2) :: + LargeAndSmallInts(2147483646, 1) :: + LargeAndSmallInts(3, 2) :: Nil).toDF() + df.createOrReplaceTempView("largeAndSmallInts") + df + } + + protected lazy val decimalData: DataFrame = { + val df = spark.sparkContext.parallelize( + DecimalData(1, 1) :: + DecimalData(1, 2) :: + DecimalData(2, 1) :: + DecimalData(2, 2) :: + DecimalData(3, 1) :: + DecimalData(3, 2) :: Nil).toDF() + df.createOrReplaceTempView("decimalData") + df + } + + protected lazy val binaryData: DataFrame = { + val df = spark.sparkContext.parallelize( + BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) :: + BinaryData("22".getBytes(StandardCharsets.UTF_8), 5) :: + BinaryData("122".getBytes(StandardCharsets.UTF_8), 3) :: + BinaryData("121".getBytes(StandardCharsets.UTF_8), 2) :: + BinaryData("123".getBytes(StandardCharsets.UTF_8), 4) :: Nil).toDF() + df.createOrReplaceTempView("binaryData") + df + } + + protected lazy val upperCaseData: DataFrame = { + val df = spark.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil).toDF() + df.createOrReplaceTempView("upperCaseData") + df + } + + protected lazy val lowerCaseData: DataFrame = { + val df = spark.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.createOrReplaceTempView("lowerCaseData") + df + } + + protected lazy val arrayData: RDD[ArrayData] = { + val rdd = spark.sparkContext.parallelize( + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) + rdd.toDF().createOrReplaceTempView("arrayData") + rdd + } + + protected lazy val mapData: RDD[MapData] = { + val rdd = spark.sparkContext.parallelize( + MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: + MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: + MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: + MapData(Map(1 -> "a4", 2 -> "b4")) :: + MapData(Map(1 -> "a5")) :: Nil) + rdd.toDF().createOrReplaceTempView("mapData") + rdd + } + + protected lazy val repeatedData: RDD[StringData] = { + val rdd = spark.sparkContext.parallelize(List.fill(2)(StringData("test"))) + rdd.toDF().createOrReplaceTempView("repeatedData") + rdd + } + + protected lazy val nullableRepeatedData: RDD[StringData] = { + val rdd = spark.sparkContext.parallelize( + List.fill(2)(StringData(null)) ++ + List.fill(2)(StringData("test"))) + rdd.toDF().createOrReplaceTempView("nullableRepeatedData") + rdd + } + + protected lazy val nullInts: DataFrame = { + val df = spark.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil).toDF() + df.createOrReplaceTempView("nullInts") + df + } + + protected lazy val allNulls: DataFrame = { + val df = spark.sparkContext.parallelize( + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: Nil).toDF() + df.createOrReplaceTempView("allNulls") + df + } + + protected lazy val nullStrings: DataFrame = { + val df = spark.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil).toDF() + df.createOrReplaceTempView("nullStrings") + df + } + + protected lazy val tableName: DataFrame = { + val df = spark.sparkContext.parallelize(TableName("test") :: Nil).toDF() + df.createOrReplaceTempView("tableName") + df + } + + protected lazy val unparsedStrings: RDD[String] = { + spark.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) + } + + // An RDD with 4 elements and 8 partitions + protected lazy val withEmptyParts: RDD[IntField] = { + val rdd = spark.sparkContext.parallelize((1 to 4).map(IntField), 8) + rdd.toDF().createOrReplaceTempView("withEmptyParts") + rdd + } + + protected lazy val person: DataFrame = { + val df = spark.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil).toDF() + df.createOrReplaceTempView("person") + df + } + + protected lazy val salary: DataFrame = { + val df = spark.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil).toDF() + df.createOrReplaceTempView("salary") + df + } + + protected lazy val complexData: DataFrame = { + val df = spark.sparkContext.parallelize( + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: + ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: + Nil).toDF() + df.createOrReplaceTempView("complexData") + df + } + + protected lazy val courseSales: DataFrame = { + val df = spark.sparkContext.parallelize( + CourseSales("dotNET", 2012, 10000) :: + CourseSales("Java", 2012, 20000) :: + CourseSales("dotNET", 2012, 5000) :: + CourseSales("dotNET", 2013, 48000) :: + CourseSales("Java", 2013, 30000) :: Nil).toDF() + df.createOrReplaceTempView("courseSales") + df + } + + /** + * Initialize all test data such that all temp tables are properly registered. + */ + def loadTestData(): Unit = { + assert(spark != null, "attempted to initialize test data before SparkSession.") + emptyTestData + testData + testData2 + testData3 + negativeData + largeAndSmallInts + decimalData + binaryData + upperCaseData + lowerCaseData + arrayData + mapData + repeatedData + nullableRepeatedData + nullInts + allNulls + nullStrings + tableName + unparsedStrings + withEmptyParts + person + salary + complexData + courseSales + } +} + +/** + * Case classes used in test data. + */ +private[sql] object SQLTestData { + case class TestData(key: Int, value: String) + case class TestData2(a: Int, b: Int) + case class TestData3(a: Int, b: Option[Int]) + case class LargeAndSmallInts(a: Int, b: Int) + case class DecimalData(a: BigDecimal, b: BigDecimal) + case class BinaryData(a: Array[Byte], b: Int) + case class UpperCaseData(N: Int, L: String) + case class LowerCaseData(n: Int, l: String) + case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + case class MapData(data: scala.collection.Map[Int, String]) + case class StringData(s: String) + case class IntField(i: Int) + case class NullInts(a: Integer) + case class NullStrings(n: Int, s: String) + case class TableName(tableName: String) + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) + case class CourseSales(course: String, year: Int, earnings: Double) +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala new file mode 100644 index 0000000..b926a01 --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.test + +import java.io.File +import java.util.UUID + +import scala.language.implicitConversions +import scala.util.Try +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.FilterExec +import org.apache.spark.util.{UninterruptibleThread, Utils} + +/** + * Helper trait that should be extended by all SQL test suites. + * + * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * prepared in advance as well as all implicit conversions used extensively by dataframes. + * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * + * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtils + extends SparkFunSuite + with BeforeAndAfterAll + with SQLTestData { self => + + protected def sparkContext = spark.sparkContext + + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + // Shorthand for running a query using our SQLContext + protected lazy val sql = spark.sql _ + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `spark.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.spark.sqlContext + } + + /** + * Materialize the test data immediately after the [[SQLContext]] is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + /** + * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). + */ + protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { + try { + f + } catch { + case cause: Throwable => throw cause + } finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + functions.foreach { case (functionName, isTemporary) => + val withTemporary = if (isTemporary) "TEMPORARY" else "" + spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + assert( + !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + s"Function $functionName should have been dropped. But, it still exists.") + } + } + } + + /** + * Drops temporary table `tableName` after calling `f`. + */ + protected def withTempView(tableNames: String*)(f: => Unit): Unit = { + try f finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + try tableNames.foreach(spark.catalog.dropTempView) catch { + case _: NoSuchTableException => + } + } + } + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + /** + * Drops view `viewName` after calling `f`. + */ + protected def withView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + viewNames.foreach { name => + spark.sql(s"DROP VIEW IF EXISTS $name") + } + } + } + + /** + * Creates a temporary database and switches current database to it before executing `f`. This + * database is dropped after `f` returns. + * + * Note that this method doesn't switch current database before executing `f`. + */ + protected def withTempDatabase(f: String => Unit): Unit = { + val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" + + try { + spark.sql(s"CREATE DATABASE $dbName") + } catch { case cause: Throwable => + fail("Failed to create temporary database", cause) + } + + try f(dbName) finally { + if (spark.catalog.currentDatabase == dbName) { + spark.sql(s"USE ${DEFAULT_DATABASE}") + } + spark.sql(s"DROP DATABASE $dbName CASCADE") + } + } + + /** + * Activates database `db` before executing `f`, then switches back to `default` database after + * `f` returns. + */ + protected def activateDatabase(db: String)(f: => Unit): Unit = { + spark.sessionState.catalog.setCurrentDatabase(db) + try f finally spark.sessionState.catalog.setCurrentDatabase("default") + } + + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val withoutFilters = df.queryExecution.sparkPlan.transform { + case FilterExec(_, child) => child + } + + spark.internalCreateDataFrame(withoutFilters.execute(), schema) + } + + /** + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier + * way to construct [[DataFrame]] directly out of local data without relying on implicits. + */ + protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + Dataset.ofRows(spark, plan) + } + + /** + * Disable stdout and stderr when running the test. To not output the logs to the console, + * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if + * we change System.out and System.err. + */ + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } + + /** Run a test on a separate [[UninterruptibleThread]]. */ + protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) + (body: => Unit): Unit = { + val timeoutMillis = 10000 + @transient var ex: Throwable = null + + def runOnThread(): Unit = { + val thread = new UninterruptibleThread(s"Testing thread for test $name") { + override def run(): Unit = { + try { + body + } catch { + case NonFatal(e) => + ex = e + } + } + } + thread.setDaemon(true) + thread.start() + thread.join(timeoutMillis) + if (thread.isAlive) { + thread.interrupt() + // If this interrupt does not work, then this thread is most likely running something that + // is not interruptible. There is not much point to wait for the thread to termniate, and + // we rather let the JVM terminate the thread on exit. + fail( + s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + + s" $timeoutMillis ms") + } else if (ex != null) { + throw ex + } + } + + if (quietly) { + testQuietly(name) { runOnThread() } + } else { + test(name) { runOnThread() } + } + } +} + +private[sql] object SQLTestUtils { + + def compareAnswers( + sparkAnswer: Seq[Row], + expectedAnswer: Seq[Row], + sort: Boolean): Option[String] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } + } + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | == Results == + | ${sideBySide( + s"== Expected Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Actual Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + Some(errorMessage) + } else { + None + } + } +}
