Repository: incubator-hivemall Updated Branches: refs/heads/master 688daa5f8 -> 8bf6dd9e7
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala new file mode 100644 index 0000000..6bb644a --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive.benchmark + +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.benchmark.BenchmarkBase +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.internal.HivemallOpsImpl._ +import org.apache.spark.sql.types._ +import org.apache.spark.test.TestUtils +import org.apache.spark.util.Benchmark + +class TestFuncWrapper(df: DataFrame) { + + def hive_each_top_k(k: Column, group: Column, value: Column, args: Column*) + : DataFrame = withTypedPlan { + planHiveGenericUDTF( + df.repartition(group).sortWithinPartitions(group), + "hivemall.tools.EachTopKUDTF", + "each_top_k", + Seq(k, group, value) ++ args, + Seq("rank", "key") ++ args.map { _.expr match { + case ua: UnresolvedAttribute => ua.name + }} + ) + } + + /** + * A convenient function to wrap a logical plan and produce a DataFrame. + */ + @inline private[this] def withTypedPlan(logicalPlan: => LogicalPlan): DataFrame = { + val queryExecution = df.sparkSession.sessionState.executePlan(logicalPlan) + val outputSchema = queryExecution.sparkPlan.schema + new Dataset[Row](df.sparkSession, queryExecution, RowEncoder(outputSchema)) + } +} + +object TestFuncWrapper { + + /** + * Implicitly inject the [[TestFuncWrapper]] into [[DataFrame]]. + */ + implicit def dataFrameToTestFuncWrapper(df: DataFrame): TestFuncWrapper = + new TestFuncWrapper(df) + + def sigmoid(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.math.SigmoidGenericUDF", + "sigmoid", + exprs + ) + } + + /** + * A convenient function to wrap an expression and produce a Column. + */ + @inline private def withExpr(expr: Expression): Column = Column(expr) +} + +class MiscBenchmark extends BenchmarkBase { + + val numIters = 10 + + private def addBenchmarkCase(name: String, df: DataFrame)(implicit benchmark: Benchmark): Unit = { + benchmark.addCase(name, numIters) { + _ => df.queryExecution.executedPlan.execute().foreach(x => {}) + } + } + + TestUtils.benchmark("closure/exprs/spark-udf/hive-udf") { + /** + * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + * + * sigmoid functions: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + * -------------------------------------------------------------------------------- + * exprs 7708 / 8173 3.4 294.0 1.0X + * closure 7722 / 8342 3.4 294.6 1.0X + * spark-udf 7963 / 8350 3.3 303.8 1.0X + * hive-udf 13977 / 14050 1.9 533.2 0.6X + */ + import sparkSession.sqlContext.implicits._ + val N = 1L << 18 + val testDf = sparkSession.range(N).selectExpr("rand() AS value").cache + + // First, cache data + testDf.count + + implicit val benchmark = new Benchmark("sigmoid", N) + def sigmoidExprs(expr: Column): Column = { + val one: () => Literal = () => Literal.create(1.0, DoubleType) + Column(one()) / (Column(one()) + exp(-expr)) + } + addBenchmarkCase( + "exprs", + testDf.select(sigmoidExprs($"value")) + ) + implicit val encoder = RowEncoder(StructType(StructField("value", DoubleType) :: Nil)) + addBenchmarkCase( + "closure", + testDf.map { d => + Row(1.0 / (1.0 + Math.exp(-d.getDouble(0)))) + } + ) + val sigmoidUdf = udf { (d: Double) => 1.0 / (1.0 + Math.exp(-d)) } + addBenchmarkCase( + "spark-udf", + testDf.select(sigmoidUdf($"value")) + ) + addBenchmarkCase( + "hive-udf", + testDf.select(TestFuncWrapper.sigmoid($"value")) + ) + benchmark.run() + } + + TestUtils.benchmark("top-k query") { + /** + * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + * + * top-k (k=100): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + * ------------------------------------------------------------------------------- + * rank 62748 / 62862 0.4 2393.6 1.0X + * each_top_k (hive-udf) 41421 / 41736 0.6 1580.1 1.5X + * each_top_k (exprs) 15793 / 16394 1.7 602.5 4.0X + */ + import sparkSession.sqlContext.implicits._ + import TestFuncWrapper._ + val topK = 100 + val N = 1L << 20 + val numGroup = 3 + val testDf = sparkSession.range(N).selectExpr( + s"id % $numGroup AS key", "rand() AS x", "CAST(id AS STRING) AS value" + ).cache + + // First, cache data + testDf.count + + implicit val benchmark = new Benchmark(s"top-k (k=$topK)", N) + addBenchmarkCase( + "rank", + testDf.withColumn("rank", rank().over(Window.partitionBy($"key").orderBy($"x".desc))) + .where($"rank" <= topK) + ) + addBenchmarkCase( + "each_top_k (hive-udf)", + testDf.hive_each_top_k(lit(topK), $"key", $"x", $"key", $"value") + ) + addBenchmarkCase( + "each_top_k (exprs)", + testDf.each_top_k(lit(topK), $"x".as("score"), $"key".as("group")) + ) + benchmark.run() + } + + TestUtils.benchmark("top-k join query") { + /** + * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + * + * top-k join (k=3): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + * ------------------------------------------------------------------------------- + * join + rank 65959 / 71324 0.0 503223.9 1.0X + * join + each_top_k 66093 / 78864 0.0 504247.3 1.0X + * top_k_join 5013 / 5431 0.0 38249.3 13.2X + */ + import sparkSession.sqlContext.implicits._ + val topK = 3 + val N = 1L << 10 + val M = 1L << 10 + val numGroup = 3 + val inputDf = sparkSession.range(N).selectExpr( + s"CAST(rand() * $numGroup AS INT) AS group", "id AS userId", "rand() AS x", "rand() AS y" + ).cache + val masterDf = sparkSession.range(M).selectExpr( + s"id % $numGroup AS group", "id AS posId", "rand() AS x", "rand() AS y" + ).cache + + // First, cache data + inputDf.count + masterDf.count + + implicit val benchmark = new Benchmark(s"top-k join (k=$topK)", N) + // Define a score column + val distance = sqrt( + pow(inputDf("x") - masterDf("x"), lit(2.0)) + + pow(inputDf("y") - masterDf("y"), lit(2.0)) + ).as("score") + addBenchmarkCase( + "join + rank", + inputDf.join(masterDf, inputDf("group") === masterDf("group")) + .select(inputDf("group"), $"userId", $"posId", distance) + .withColumn( + "rank", rank().over(Window.partitionBy($"group", $"userId").orderBy($"score".desc))) + .where($"rank" <= topK) + ) + addBenchmarkCase( + "join + each_top_k", + inputDf.join(masterDf, inputDf("group") === masterDf("group")) + .each_top_k(lit(topK), distance, inputDf("group").as("group")) + ) + addBenchmarkCase( + "top_k_join", + inputDf.top_k_join(lit(topK), masterDf, inputDf("group") === masterDf("group"), distance) + ) + benchmark.run() + } + + TestUtils.benchmark("codegen top-k join") { + /** + * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + * + * top_k_join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + * ----------------------------------------------------------------------------------- + * top_k_join wholestage off 3 / 5 2751.9 0.4 1.0X + * top_k_join wholestage on 1 / 1 6494.4 0.2 2.4X + */ + val topK = 3 + val N = 1L << 23 + val M = 1L << 22 + val numGroup = 3 + val inputDf = sparkSession.range(N).selectExpr( + s"CAST(rand() * $numGroup AS INT) AS group", "id AS userId", "rand() AS x", "rand() AS y" + ).cache + val masterDf = sparkSession.range(M).selectExpr( + s"id % $numGroup AS group", "id AS posId", "rand() AS x", "rand() AS y" + ).cache + + // First, cache data + inputDf.count + masterDf.count + + // Define a score column + val distance = sqrt( + pow(inputDf("x") - masterDf("x"), lit(2.0)) + + pow(inputDf("y") - masterDf("y"), lit(2.0)) + ) + runBenchmark("top_k_join", N) { + inputDf.top_k_join(lit(topK), masterDf, inputDf("group") === masterDf("group"), + distance.as("score")) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala new file mode 100644 index 0000000..3ca9bbf --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive.test + +import scala.collection.mutable.Seq +import scala.reflect.runtime.universe.TypeTag + +import hivemall.tools.RegressionDatagen + +import org.apache.spark.sql.{Column, QueryTest} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.test.SQLTestUtils + +/** + * Base class for tests with Hivemall features. + */ +abstract class HivemallFeatureQueryTest extends QueryTest with SQLTestUtils with TestHiveSingleton { + + import hiveContext.implicits._ + + /** + * TODO: spark-2.0 does not support literals for some types (e.g., Seq[_] and Array[_]). + * So, it provides that functionality here. + * This helper function will be removed in future releases. + */ + protected def lit2[T : TypeTag](v: T): Column = { + val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T] + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + Column(Literal(convert(v), dataType)) + } + + protected val DummyInputData = Seq((0, 0)).toDF("c0", "c1") + + protected val IntList2Data = + Seq( + (8 :: 5 :: Nil, 6 :: 4 :: Nil), + (3 :: 1 :: Nil, 3 :: 2 :: Nil), + (2 :: Nil, 3 :: Nil) + ).toDF("target", "predict") + + protected val Float2Data = + Seq( + (0.8f, 0.3f), (0.3f, 0.9f), (0.2f, 0.4f) + ).toDF("target", "predict") + + protected val TinyTrainData = + Seq( + (0.0, "1:0.8" :: "2:0.2" :: Nil), + (1.0, "2:0.7" :: Nil), + (0.0, "1:0.9" :: Nil) + ).toDF("label", "features") + + protected val TinyTestData = + Seq( + (0.0, "1:0.6" :: "2:0.1" :: Nil), + (1.0, "2:0.9" :: Nil), + (0.0, "1:0.2" :: Nil), + (0.0, "2:0.1" :: Nil), + (0.0, "0:0.6" :: "2:0.4" :: Nil) + ).toDF("label", "features") + + protected val LargeRegrTrainData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100000, + seed = 3, + prob_one = 0.8f + ).cache + + protected val LargeRegrTestData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100, + seed = 3, + prob_one = 0.5f + ).cache + + protected val LargeClassifierTrainData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100000, + seed = 5, + prob_one = 0.8f, + cl = true + ).cache + + protected val LargeClassifierTestData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100, + seed = 5, + prob_one = 0.5f, + cl = true + ).cache +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala new file mode 100644 index 0000000..ccb21cf --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive.test + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.SparkSession +import org.apache.spark.SparkFunSuite + +trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { + protected val spark: SparkSession = TestHive.sparkSession + protected val hiveContext: TestHiveContext = TestHive + + protected override def afterAll(): Unit = { + try { + hiveContext.reset() + } finally { + super.afterAll() + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala new file mode 100644 index 0000000..50b80fa --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -0,0 +1,315 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.test + +import java.nio.charset.StandardCharsets + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} + +/** + * A collection of sample data used in SQL tests. + */ +private[sql] trait SQLTestData { self => + protected def spark: SparkSession + + // Helper object to import SQL implicits without a concrete SQLContext + private object internalImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.spark.sqlContext + } + + import internalImplicits._ + import SQLTestData._ + + // Note: all test data should be lazy because the SQLContext is not set up yet. + + protected lazy val emptyTestData: DataFrame = { + val df = spark.sparkContext.parallelize( + Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() + df.createOrReplaceTempView("emptyTestData") + df + } + + protected lazy val testData: DataFrame = { + val df = spark.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.createOrReplaceTempView("testData") + df + } + + protected lazy val testData2: DataFrame = { + val df = spark.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil, 2).toDF() + df.createOrReplaceTempView("testData2") + df + } + + protected lazy val testData3: DataFrame = { + val df = spark.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toDF() + df.createOrReplaceTempView("testData3") + df + } + + protected lazy val negativeData: DataFrame = { + val df = spark.sparkContext.parallelize( + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() + df.createOrReplaceTempView("negativeData") + df + } + + protected lazy val largeAndSmallInts: DataFrame = { + val df = spark.sparkContext.parallelize( + LargeAndSmallInts(2147483644, 1) :: + LargeAndSmallInts(1, 2) :: + LargeAndSmallInts(2147483645, 1) :: + LargeAndSmallInts(2, 2) :: + LargeAndSmallInts(2147483646, 1) :: + LargeAndSmallInts(3, 2) :: Nil).toDF() + df.createOrReplaceTempView("largeAndSmallInts") + df + } + + protected lazy val decimalData: DataFrame = { + val df = spark.sparkContext.parallelize( + DecimalData(1, 1) :: + DecimalData(1, 2) :: + DecimalData(2, 1) :: + DecimalData(2, 2) :: + DecimalData(3, 1) :: + DecimalData(3, 2) :: Nil).toDF() + df.createOrReplaceTempView("decimalData") + df + } + + protected lazy val binaryData: DataFrame = { + val df = spark.sparkContext.parallelize( + BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) :: + BinaryData("22".getBytes(StandardCharsets.UTF_8), 5) :: + BinaryData("122".getBytes(StandardCharsets.UTF_8), 3) :: + BinaryData("121".getBytes(StandardCharsets.UTF_8), 2) :: + BinaryData("123".getBytes(StandardCharsets.UTF_8), 4) :: Nil).toDF() + df.createOrReplaceTempView("binaryData") + df + } + + protected lazy val upperCaseData: DataFrame = { + val df = spark.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil).toDF() + df.createOrReplaceTempView("upperCaseData") + df + } + + protected lazy val lowerCaseData: DataFrame = { + val df = spark.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.createOrReplaceTempView("lowerCaseData") + df + } + + protected lazy val arrayData: RDD[ArrayData] = { + val rdd = spark.sparkContext.parallelize( + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) + rdd.toDF().createOrReplaceTempView("arrayData") + rdd + } + + protected lazy val mapData: RDD[MapData] = { + val rdd = spark.sparkContext.parallelize( + MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: + MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: + MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: + MapData(Map(1 -> "a4", 2 -> "b4")) :: + MapData(Map(1 -> "a5")) :: Nil) + rdd.toDF().createOrReplaceTempView("mapData") + rdd + } + + protected lazy val repeatedData: RDD[StringData] = { + val rdd = spark.sparkContext.parallelize(List.fill(2)(StringData("test"))) + rdd.toDF().createOrReplaceTempView("repeatedData") + rdd + } + + protected lazy val nullableRepeatedData: RDD[StringData] = { + val rdd = spark.sparkContext.parallelize( + List.fill(2)(StringData(null)) ++ + List.fill(2)(StringData("test"))) + rdd.toDF().createOrReplaceTempView("nullableRepeatedData") + rdd + } + + protected lazy val nullInts: DataFrame = { + val df = spark.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil).toDF() + df.createOrReplaceTempView("nullInts") + df + } + + protected lazy val allNulls: DataFrame = { + val df = spark.sparkContext.parallelize( + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: Nil).toDF() + df.createOrReplaceTempView("allNulls") + df + } + + protected lazy val nullStrings: DataFrame = { + val df = spark.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil).toDF() + df.createOrReplaceTempView("nullStrings") + df + } + + protected lazy val tableName: DataFrame = { + val df = spark.sparkContext.parallelize(TableName("test") :: Nil).toDF() + df.createOrReplaceTempView("tableName") + df + } + + protected lazy val unparsedStrings: RDD[String] = { + spark.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) + } + + // An RDD with 4 elements and 8 partitions + protected lazy val withEmptyParts: RDD[IntField] = { + val rdd = spark.sparkContext.parallelize((1 to 4).map(IntField), 8) + rdd.toDF().createOrReplaceTempView("withEmptyParts") + rdd + } + + protected lazy val person: DataFrame = { + val df = spark.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil).toDF() + df.createOrReplaceTempView("person") + df + } + + protected lazy val salary: DataFrame = { + val df = spark.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil).toDF() + df.createOrReplaceTempView("salary") + df + } + + protected lazy val complexData: DataFrame = { + val df = spark.sparkContext.parallelize( + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: + ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: + Nil).toDF() + df.createOrReplaceTempView("complexData") + df + } + + protected lazy val courseSales: DataFrame = { + val df = spark.sparkContext.parallelize( + CourseSales("dotNET", 2012, 10000) :: + CourseSales("Java", 2012, 20000) :: + CourseSales("dotNET", 2012, 5000) :: + CourseSales("dotNET", 2013, 48000) :: + CourseSales("Java", 2013, 30000) :: Nil).toDF() + df.createOrReplaceTempView("courseSales") + df + } + + /** + * Initialize all test data such that all temp tables are properly registered. + */ + def loadTestData(): Unit = { + assert(spark != null, "attempted to initialize test data before SparkSession.") + emptyTestData + testData + testData2 + testData3 + negativeData + largeAndSmallInts + decimalData + binaryData + upperCaseData + lowerCaseData + arrayData + mapData + repeatedData + nullableRepeatedData + nullInts + allNulls + nullStrings + tableName + unparsedStrings + withEmptyParts + person + salary + complexData + courseSales + } +} + +/** + * Case classes used in test data. + */ +private[sql] object SQLTestData { + case class TestData(key: Int, value: String) + case class TestData2(a: Int, b: Int) + case class TestData3(a: Int, b: Option[Int]) + case class LargeAndSmallInts(a: Int, b: Int) + case class DecimalData(a: BigDecimal, b: BigDecimal) + case class BinaryData(a: Array[Byte], b: Int) + case class UpperCaseData(N: Int, L: String) + case class LowerCaseData(n: Int, l: String) + case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + case class MapData(data: scala.collection.Map[Int, String]) + case class StringData(s: String) + case class IntField(i: Int) + case class NullInts(a: Integer) + case class NullStrings(n: Int, s: String) + case class TableName(tableName: String) + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) + case class CourseSales(course: String, year: Int, earnings: Double) +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala new file mode 100644 index 0000000..1e48e71 --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.test + +import java.io.File +import java.util.UUID + +import scala.language.implicitConversions +import scala.util.Try +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.FilterExec +import org.apache.spark.util.{UninterruptibleThread, Utils} + +/** + * Helper trait that should be extended by all SQL test suites. + * + * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * prepared in advance as well as all implicit conversions used extensively by dataframes. + * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * + * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtils + extends SparkFunSuite + with BeforeAndAfterAll + with SQLTestData { self => + + protected def sparkContext = spark.sparkContext + + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + // Shorthand for running a query using our SQLContext + protected lazy val sql = spark.sql _ + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `spark.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.spark.sqlContext + } + + /** + * Materialize the test data immediately after the [[SQLContext]] is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + /** + * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). + */ + protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { + try { + f + } catch { + case cause: Throwable => throw cause + } finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + functions.foreach { case (functionName, isTemporary) => + val withTemporary = if (isTemporary) "TEMPORARY" else "" + spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + assert( + !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + s"Function $functionName should have been dropped. But, it still exists.") + } + } + } + + /** + * Drops temporary table `tableName` after calling `f`. + */ + protected def withTempView(tableNames: String*)(f: => Unit): Unit = { + try f finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + try tableNames.foreach(spark.catalog.dropTempView) catch { + case _: NoSuchTableException => + } + } + } + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + /** + * Drops view `viewName` after calling `f`. + */ + protected def withView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + viewNames.foreach { name => + spark.sql(s"DROP VIEW IF EXISTS $name") + } + } + } + + /** + * Creates a temporary database and switches current database to it before executing `f`. This + * database is dropped after `f` returns. + * + * Note that this method doesn't switch current database before executing `f`. + */ + protected def withTempDatabase(f: String => Unit): Unit = { + val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" + + try { + spark.sql(s"CREATE DATABASE $dbName") + } catch { case cause: Throwable => + fail("Failed to create temporary database", cause) + } + + try f(dbName) finally { + if (spark.catalog.currentDatabase == dbName) { + spark.sql(s"USE ${DEFAULT_DATABASE}") + } + spark.sql(s"DROP DATABASE $dbName CASCADE") + } + } + + /** + * Activates database `db` before executing `f`, then switches back to `default` database after + * `f` returns. + */ + protected def activateDatabase(db: String)(f: => Unit): Unit = { + spark.sessionState.catalog.setCurrentDatabase(db) + try f finally spark.sessionState.catalog.setCurrentDatabase("default") + } + + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val withoutFilters = df.queryExecution.sparkPlan.transform { + case FilterExec(_, child) => child + } + + spark.internalCreateDataFrame(withoutFilters.execute(), schema) + } + + /** + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier + * way to construct [[DataFrame]] directly out of local data without relying on implicits. + */ + protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + Dataset.ofRows(spark, plan) + } + + /** + * Disable stdout and stderr when running the test. To not output the logs to the console, + * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if + * we change System.out and System.err. + */ + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } + + /** Run a test on a separate [[UninterruptibleThread]]. */ + protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) + (body: => Unit): Unit = { + val timeoutMillis = 10000 + @transient var ex: Throwable = null + + def runOnThread(): Unit = { + val thread = new UninterruptibleThread(s"Testing thread for test $name") { + override def run(): Unit = { + try { + body + } catch { + case NonFatal(e) => + ex = e + } + } + } + thread.setDaemon(true) + thread.start() + thread.join(timeoutMillis) + if (thread.isAlive) { + thread.interrupt() + // If this interrupt does not work, then this thread is most likely running something that + // is not interruptible. There is not much point to wait for the thread to termniate, and + // we rather let the JVM terminate the thread on exit. + fail( + s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + + s" $timeoutMillis ms") + } else if (ex != null) { + throw ex + } + } + + if (quietly) { + testQuietly(name) { runOnThread() } + } else { + test(name) { runOnThread() } + } + } +} + +private[sql] object SQLTestUtils { + + def compareAnswers( + sparkAnswer: Seq[Row], + expectedAnswer: Seq[Row], + sort: Boolean): Option[String] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } + } + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | == Results == + | ${sideBySide( + s"== Expected Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Actual Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + Some(errorMessage) + } else { + None + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala new file mode 100644 index 0000000..4e2a0c1 --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.test + +import java.io.File +import java.nio.charset.StandardCharsets + +import com.google.common.io.Files + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.util.Utils + +/** + * Base class for tests with SparkSQL VectorUDT data. + */ +abstract class VectorQueryTest extends QueryTest with SQLTestUtils with TestHiveSingleton { + + private var trainDir: File = _ + private var testDir: File = _ + + // A `libsvm` schema is (Double, ml.linalg.Vector) + protected var mllibTrainDf: DataFrame = _ + protected var mllibTestDf: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val trainLines = + """ + |1 1:1.0 3:2.0 5:3.0 + |0 2:4.0 4:5.0 6:6.0 + |1 1:1.1 4:1.0 5:2.3 7:1.0 + |1 1:1.0 4:1.5 5:2.1 7:1.2 + """.stripMargin + trainDir = Utils.createTempDir() + Files.write(trainLines, new File(trainDir, "train-00000"), StandardCharsets.UTF_8) + val testLines = + """ + |1 1:1.3 3:2.1 5:2.8 + |0 2:3.9 4:5.3 6:8.0 + """.stripMargin + testDir = Utils.createTempDir() + Files.write(testLines, new File(testDir, "test-00000"), StandardCharsets.UTF_8) + + mllibTrainDf = spark.read.format("libsvm").load(trainDir.getAbsolutePath) + // Must be cached because rowid() is deterministic + mllibTestDf = spark.read.format("libsvm").load(testDir.getAbsolutePath) + .withColumn("rowid", rowid()).cache + } + + override def afterAll(): Unit = { + try { + Utils.deleteRecursively(trainDir) + Utils.deleteRecursively(testDir) + } finally { + super.afterAll() + } + } + + protected def withTempModelDir(f: String => Unit): Unit = { + var tempDir: File = null + try { + tempDir = Utils.createTempDir() + f(tempDir.getAbsolutePath + "/xgboost_models") + } catch { + case e: Throwable => fail(s"Unexpected exception detected: ${e}") + } finally { + Utils.deleteRecursively(tempDir) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala new file mode 100644 index 0000000..0e1372d --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.ml.feature.HivemallLabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest +import org.apache.spark.streaming.HivemallStreamingOps._ +import org.apache.spark.streaming.dstream.InputDStream +import org.apache.spark.streaming.scheduler.StreamInputInfo + +/** + * This is an input stream just for tests. + */ +private[this] class TestInputStream[T: ClassTag]( + ssc: StreamingContext, + input: Seq[Seq[T]], + numPartitions: Int) extends InputDStream[T](ssc) { + + override def start() {} + + override def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + logInfo("Computing RDD for time " + validTime) + val index = ((validTime - zeroTime) / slideDuration - 1).toInt + val selectedInput = if (index < input.size) input(index) else Seq[T]() + + // lets us test cases where RDDs are not created + if (selectedInput == null) { + return None + } + + // Report the input data's information to InputInfoTracker for testing + val inputInfo = StreamInputInfo(id, selectedInput.length.toLong) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + + val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) + logInfo("Created RDD " + rdd.id + " with " + selectedInput) + Some(rdd) + } +} + +final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { + + // This implicit value used in `HivemallStreamingOps` + implicit val sqlCtx = hiveContext + + /** + * Run a block of code with the given StreamingContext. + * This method do not stop a given SparkContext because other tests share the context. + */ + private def withStreamingContext[R](ssc: StreamingContext)(block: StreamingContext => R): Unit = { + try { + block(ssc) + ssc.start() + ssc.awaitTerminationOrTimeout(10 * 1000) // 10s wait + } finally { + try { + ssc.stop(stopSparkContext = false) + } catch { + case e: Exception => logError("Error stopping StreamingContext", e) + } + } + } + + // scalastyle:off line.size.limit + + /** + * This test below fails sometimes (too flaky), so we temporarily ignore it. + * The stacktrace of this failure is: + * + * HivemallOpsWithFeatureSuite: + * Exception in thread "broadcast-exchange-60" java.lang.OutOfMemoryError: Java heap space + * at java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:57) + * at java.nio.ByteBuffer.allocate(ByteBuffer.java:331) + * at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$4.apply(TorrentBroadcast.scala:231) + * at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$4.apply(TorrentBroadcast.scala:231) + * at org.apache.spark.util.io.ChunkedByteBufferOutputStream.allocateNewChunkIfNeeded(ChunkedByteBufferOutputStream.scala:78) + * at org.apache.spark.util.io.ChunkedByteBufferOutputStream.write(ChunkedByteBufferOutputStream.scala:65) + * at net.jpountz.lz4.LZ4BlockOutputStream.flushBufferedData(LZ4BlockOutputStream.java:205) + * at net.jpountz.lz4.LZ4BlockOutputStream.finish(LZ4BlockOutputStream.java:235) + * at net.jpountz.lz4.LZ4BlockOutputStream.close(LZ4BlockOutputStream.java:175) + * at java.io.ObjectOutputStream$BlockDataOutputStream.close(ObjectOutputStream.java:1827) + * at java.io.ObjectOutputStream.close(ObjectOutputStream.java:741) + * at org.apache.spark.serializer.JavaSerializationStream.close(JavaSerializer.scala:57) + * at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$blockifyObject$1.apply$mcV$sp(TorrentBroadcast.scala:238) + * at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1296) + * at org.apache.spark.broadcast.TorrentBroadcast$.blockifyObject(TorrentBroadcast.scala:237) + * at org.apache.spark.broadcast.TorrentBroadcast.writeBlocks(TorrentBroadcast.scala:107) + * at org.apache.spark.broadcast.TorrentBroadcast.<init>(TorrentBroadcast.scala:86) + * at org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast(TorrentBroadcastFactory.scala:34) + * ... + */ + + // scalastyle:on line.size.limit + + ignore("streaming") { + import sqlCtx.implicits._ + + // We assume we build a model in advance + val testModel = Seq( + ("0", 0.3f), ("1", 0.1f), ("2", 0.6f), ("3", 0.2f) + ).toDF("feature", "weight") + + withStreamingContext(new StreamingContext(sqlCtx.sparkContext, Milliseconds(100))) { ssc => + val inputData = Seq( + Seq(HivemallLabeledPoint(features = "1:0.6" :: "2:0.1" :: Nil)), + Seq(HivemallLabeledPoint(features = "2:0.9" :: Nil)), + Seq(HivemallLabeledPoint(features = "1:0.2" :: Nil)), + Seq(HivemallLabeledPoint(features = "2:0.1" :: Nil)), + Seq(HivemallLabeledPoint(features = "0:0.6" :: "2:0.4" :: Nil)) + ) + + val inputStream = new TestInputStream[HivemallLabeledPoint](ssc, inputData, 1) + + // Apply predictions on input streams + val prediction = inputStream.predict { streamDf => + val df = streamDf.select(rowid(), $"features").explode_array($"features") + val testDf = df.select( + // TODO: `$"feature"` throws AnalysisException, why? + $"rowid", extract_feature(df("feature")), extract_weight(df("feature")) + ) + testDf.join(testModel, testDf("feature") === testModel("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + .toDF("rowid", "value") + .select($"rowid", sigmoid($"value")) + } + + // Dummy output stream + prediction.foreachRDD(_ => {}) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/test/TestUtils.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/test/TestUtils.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/test/TestUtils.scala new file mode 100644 index 0000000..fa7b6e5 --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/test/TestUtils.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.test + +import scala.reflect.runtime.{universe => ru} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.DataFrame + +object TestUtils extends Logging { + + // Do benchmark if INFO-log enabled + def benchmark(benchName: String)(testFunc: => Unit): Unit = { + if (log.isDebugEnabled) { + testFunc + } + } + + def expectResult(res: Boolean, errMsg: String): Unit = if (res) { + logWarning(errMsg) + } + + def invokeFunc(cls: Any, func: String, args: Any*): DataFrame = try { + // Invoke a function with the given name via reflection + val im = scala.reflect.runtime.currentMirror.reflect(cls) + val mSym = im.symbol.typeSignature.member(ru.newTermName(func)).asMethod + im.reflectMethod(mSym).apply(args: _*) + .asInstanceOf[DataFrame] + } catch { + case e: Exception => + assert(false, s"Invoking ${func} failed because: ${e.getMessage}") + null // Not executed + } +} + +// TODO: Any same function in o.a.spark.*? +class TestFPWrapper(d: Double) { + + // Check an equality between Double/Float values + def ~==(d: Double): Boolean = Math.abs(this.d - d) < 0.001 +} + +object TestFPWrapper { + + @inline implicit def toTestFPWrapper(d: Double): TestFPWrapper = { + new TestFPWrapper(d) + } +}
