Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 688daa5f8 -> 8bf6dd9e7


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala
 
b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala
new file mode 100644
index 0000000..6bb644a
--- /dev/null
+++ 
b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.hive.benchmark
+
+import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.benchmark.BenchmarkBase
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.HivemallOps._
+import org.apache.spark.sql.hive.internal.HivemallOpsImpl._
+import org.apache.spark.sql.types._
+import org.apache.spark.test.TestUtils
+import org.apache.spark.util.Benchmark
+
+class TestFuncWrapper(df: DataFrame) {
+
+  def hive_each_top_k(k: Column, group: Column, value: Column, args: Column*)
+    : DataFrame = withTypedPlan {
+    planHiveGenericUDTF(
+      df.repartition(group).sortWithinPartitions(group),
+      "hivemall.tools.EachTopKUDTF",
+      "each_top_k",
+      Seq(k, group, value) ++ args,
+      Seq("rank", "key") ++ args.map { _.expr match {
+        case ua: UnresolvedAttribute => ua.name
+      }}
+    )
+  }
+
+  /**
+   * A convenient function to wrap a logical plan and produce a DataFrame.
+   */
+  @inline private[this] def withTypedPlan(logicalPlan: => LogicalPlan): 
DataFrame = {
+    val queryExecution = df.sparkSession.sessionState.executePlan(logicalPlan)
+    val outputSchema = queryExecution.sparkPlan.schema
+    new Dataset[Row](df.sparkSession, queryExecution, RowEncoder(outputSchema))
+  }
+}
+
+object TestFuncWrapper {
+
+  /**
+   * Implicitly inject the [[TestFuncWrapper]] into [[DataFrame]].
+   */
+  implicit def dataFrameToTestFuncWrapper(df: DataFrame): TestFuncWrapper =
+    new TestFuncWrapper(df)
+
+  def sigmoid(exprs: Column*): Column = withExpr {
+    planHiveGenericUDF(
+      "hivemall.tools.math.SigmoidGenericUDF",
+      "sigmoid",
+      exprs
+    )
+  }
+
+  /**
+   * A convenient function to wrap an expression and produce a Column.
+   */
+  @inline private def withExpr(expr: Expression): Column = Column(expr)
+}
+
+class MiscBenchmark extends BenchmarkBase {
+
+  val numIters = 10
+
+  private def addBenchmarkCase(name: String, df: DataFrame)(implicit 
benchmark: Benchmark): Unit = {
+    benchmark.addCase(name, numIters) {
+      _ => df.queryExecution.executedPlan.execute().foreach(x => {})
+    }
+  }
+
+  TestUtils.benchmark("closure/exprs/spark-udf/hive-udf") {
+    /**
+     * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2
+     * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz
+     *
+     * sigmoid functions:       Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   
Relative
+     * 
--------------------------------------------------------------------------------
+     * exprs                         7708 / 8173          3.4         294.0    
   1.0X
+     * closure                       7722 / 8342          3.4         294.6    
   1.0X
+     * spark-udf                     7963 / 8350          3.3         303.8    
   1.0X
+     * hive-udf                    13977 / 14050          1.9         533.2    
   0.6X
+     */
+    import sparkSession.sqlContext.implicits._
+    val N = 1L << 18
+    val testDf = sparkSession.range(N).selectExpr("rand() AS value").cache
+
+    // First, cache data
+    testDf.count
+
+    implicit val benchmark = new Benchmark("sigmoid", N)
+    def sigmoidExprs(expr: Column): Column = {
+      val one: () => Literal = () => Literal.create(1.0, DoubleType)
+      Column(one()) / (Column(one()) + exp(-expr))
+    }
+    addBenchmarkCase(
+      "exprs",
+      testDf.select(sigmoidExprs($"value"))
+    )
+    implicit val encoder = RowEncoder(StructType(StructField("value", 
DoubleType) :: Nil))
+    addBenchmarkCase(
+      "closure",
+      testDf.map { d =>
+        Row(1.0 / (1.0 + Math.exp(-d.getDouble(0))))
+      }
+    )
+    val sigmoidUdf = udf { (d: Double) => 1.0 / (1.0 + Math.exp(-d)) }
+    addBenchmarkCase(
+      "spark-udf",
+      testDf.select(sigmoidUdf($"value"))
+    )
+    addBenchmarkCase(
+      "hive-udf",
+      testDf.select(TestFuncWrapper.sigmoid($"value"))
+    )
+    benchmark.run()
+  }
+
+  TestUtils.benchmark("top-k query") {
+    /**
+     * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2
+     * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz
+     *
+     * top-k (k=100):          Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   
Relative
+     * 
-------------------------------------------------------------------------------
+     * rank                       62748 / 62862          0.4        2393.6     
  1.0X
+     * each_top_k (hive-udf)      41421 / 41736          0.6        1580.1     
  1.5X
+     * each_top_k (exprs)         15793 / 16394          1.7         602.5     
  4.0X
+     */
+    import sparkSession.sqlContext.implicits._
+    import TestFuncWrapper._
+    val topK = 100
+    val N = 1L << 20
+    val numGroup = 3
+    val testDf = sparkSession.range(N).selectExpr(
+      s"id % $numGroup AS key", "rand() AS x", "CAST(id AS STRING) AS value"
+    ).cache
+
+    // First, cache data
+    testDf.count
+
+    implicit val benchmark = new Benchmark(s"top-k (k=$topK)", N)
+    addBenchmarkCase(
+      "rank",
+      testDf.withColumn("rank", 
rank().over(Window.partitionBy($"key").orderBy($"x".desc)))
+        .where($"rank" <= topK)
+    )
+    addBenchmarkCase(
+      "each_top_k (hive-udf)",
+      testDf.hive_each_top_k(lit(topK), $"key", $"x", $"key", $"value")
+    )
+    addBenchmarkCase(
+      "each_top_k (exprs)",
+      testDf.each_top_k(lit(topK), $"x".as("score"), $"key".as("group"))
+    )
+    benchmark.run()
+  }
+
+  TestUtils.benchmark("top-k join query") {
+    /**
+     * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2
+     * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz
+     *
+     * top-k join (k=3):       Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   
Relative
+     * 
-------------------------------------------------------------------------------
+     * join + rank                65959 / 71324          0.0      503223.9     
  1.0X
+     * join + each_top_k          66093 / 78864          0.0      504247.3     
  1.0X
+     * top_k_join                   5013 / 5431          0.0       38249.3     
 13.2X
+     */
+    import sparkSession.sqlContext.implicits._
+    val topK = 3
+    val N = 1L << 10
+    val M = 1L << 10
+    val numGroup = 3
+    val inputDf = sparkSession.range(N).selectExpr(
+      s"CAST(rand() * $numGroup AS INT) AS group", "id AS userId", "rand() AS 
x", "rand() AS y"
+    ).cache
+    val masterDf = sparkSession.range(M).selectExpr(
+      s"id % $numGroup AS group", "id AS posId", "rand() AS x", "rand() AS y"
+    ).cache
+
+    // First, cache data
+    inputDf.count
+    masterDf.count
+
+    implicit val benchmark = new Benchmark(s"top-k join (k=$topK)", N)
+    // Define a score column
+    val distance = sqrt(
+      pow(inputDf("x") - masterDf("x"), lit(2.0)) +
+      pow(inputDf("y") - masterDf("y"), lit(2.0))
+    ).as("score")
+    addBenchmarkCase(
+      "join + rank",
+      inputDf.join(masterDf, inputDf("group") === masterDf("group"))
+        .select(inputDf("group"), $"userId", $"posId", distance)
+        .withColumn(
+          "rank", rank().over(Window.partitionBy($"group", 
$"userId").orderBy($"score".desc)))
+        .where($"rank" <= topK)
+    )
+    addBenchmarkCase(
+      "join + each_top_k",
+      inputDf.join(masterDf, inputDf("group") === masterDf("group"))
+        .each_top_k(lit(topK), distance, inputDf("group").as("group"))
+    )
+    addBenchmarkCase(
+      "top_k_join",
+      inputDf.top_k_join(lit(topK), masterDf, inputDf("group") === 
masterDf("group"), distance)
+    )
+    benchmark.run()
+  }
+
+  TestUtils.benchmark("codegen top-k join") {
+    /**
+     * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2
+     * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz
+     *
+     * top_k_join:                 Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)   Relative
+     * 
-----------------------------------------------------------------------------------
+     * top_k_join wholestage off           3 /    5       2751.9           0.4 
      1.0X
+     * top_k_join wholestage on            1 /    1       6494.4           0.2 
      2.4X
+     */
+    val topK = 3
+    val N = 1L << 23
+    val M = 1L << 22
+    val numGroup = 3
+    val inputDf = sparkSession.range(N).selectExpr(
+      s"CAST(rand() * $numGroup AS INT) AS group", "id AS userId", "rand() AS 
x", "rand() AS y"
+    ).cache
+    val masterDf = sparkSession.range(M).selectExpr(
+      s"id % $numGroup AS group", "id AS posId", "rand() AS x", "rand() AS y"
+    ).cache
+
+    // First, cache data
+    inputDf.count
+    masterDf.count
+
+    // Define a score column
+    val distance = sqrt(
+      pow(inputDf("x") - masterDf("x"), lit(2.0)) +
+      pow(inputDf("y") - masterDf("y"), lit(2.0))
+    )
+    runBenchmark("top_k_join", N) {
+      inputDf.top_k_join(lit(topK), masterDf, inputDf("group") === 
masterDf("group"),
+        distance.as("score"))
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala
 
b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala
new file mode 100644
index 0000000..3ca9bbf
--- /dev/null
+++ 
b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.hive.test
+
+import scala.collection.mutable.Seq
+import scala.reflect.runtime.universe.TypeTag
+
+import hivemall.tools.RegressionDatagen
+
+import org.apache.spark.sql.{Column, QueryTest}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.test.SQLTestUtils
+
+/**
+ * Base class for tests with Hivemall features.
+ */
+abstract class HivemallFeatureQueryTest extends QueryTest with SQLTestUtils 
with TestHiveSingleton {
+
+  import hiveContext.implicits._
+
+  /**
+   * TODO: spark-2.0 does not support literals for some types (e.g., Seq[_] 
and Array[_]).
+   * So, it provides that functionality here.
+   * This helper function will be removed in future releases.
+   */
+  protected def lit2[T : TypeTag](v: T): Column = {
+    val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
+    val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
+    Column(Literal(convert(v), dataType))
+  }
+
+  protected val DummyInputData = Seq((0, 0)).toDF("c0", "c1")
+
+  protected val IntList2Data =
+    Seq(
+      (8 :: 5 :: Nil, 6 :: 4 :: Nil),
+      (3 :: 1 :: Nil, 3 :: 2 :: Nil),
+      (2 :: Nil, 3 :: Nil)
+    ).toDF("target", "predict")
+
+  protected val Float2Data =
+    Seq(
+      (0.8f, 0.3f), (0.3f, 0.9f), (0.2f, 0.4f)
+    ).toDF("target", "predict")
+
+  protected val TinyTrainData =
+    Seq(
+      (0.0, "1:0.8" :: "2:0.2" :: Nil),
+      (1.0, "2:0.7" :: Nil),
+      (0.0, "1:0.9" :: Nil)
+    ).toDF("label", "features")
+
+  protected val TinyTestData =
+    Seq(
+      (0.0, "1:0.6" :: "2:0.1" :: Nil),
+      (1.0, "2:0.9" :: Nil),
+      (0.0, "1:0.2" :: Nil),
+      (0.0, "2:0.1" :: Nil),
+      (0.0, "0:0.6" :: "2:0.4" :: Nil)
+    ).toDF("label", "features")
+
+  protected val LargeRegrTrainData = RegressionDatagen.exec(
+      hiveContext,
+      n_partitions = 2,
+      min_examples = 100000,
+      seed = 3,
+      prob_one = 0.8f
+    ).cache
+
+  protected val LargeRegrTestData = RegressionDatagen.exec(
+      hiveContext,
+      n_partitions = 2,
+      min_examples = 100,
+      seed = 3,
+      prob_one = 0.5f
+    ).cache
+
+  protected val LargeClassifierTrainData = RegressionDatagen.exec(
+      hiveContext,
+      n_partitions = 2,
+      min_examples = 100000,
+      seed = 5,
+      prob_one = 0.8f,
+      cl = true
+    ).cache
+
+  protected val LargeClassifierTestData = RegressionDatagen.exec(
+      hiveContext,
+      n_partitions = 2,
+      min_examples = 100,
+      seed = 5,
+      prob_one = 0.5f,
+      cl = true
+    ).cache
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
 
b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
new file mode 100644
index 0000000..ccb21cf
--- /dev/null
+++ 
b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.hive.test
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.SparkFunSuite
+
+trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll {
+  protected val spark: SparkSession = TestHive.sparkSession
+  protected val hiveContext: TestHiveContext = TestHive
+
+  protected override def afterAll(): Unit = {
+    try {
+      hiveContext.reset()
+    } finally {
+      super.afterAll()
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala 
b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
new file mode 100644
index 0000000..50b80fa
--- /dev/null
+++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -0,0 +1,315 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import java.nio.charset.StandardCharsets
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits}
+
+/**
+ * A collection of sample data used in SQL tests.
+ */
+private[sql] trait SQLTestData { self =>
+  protected def spark: SparkSession
+
+  // Helper object to import SQL implicits without a concrete SQLContext
+  private object internalImplicits extends SQLImplicits {
+    protected override def _sqlContext: SQLContext = self.spark.sqlContext
+  }
+
+  import internalImplicits._
+  import SQLTestData._
+
+  // Note: all test data should be lazy because the SQLContext is not set up 
yet.
+
+  protected lazy val emptyTestData: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      Seq.empty[Int].map(i => TestData(i, i.toString))).toDF()
+    df.createOrReplaceTempView("emptyTestData")
+    df
+  }
+
+  protected lazy val testData: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      (1 to 100).map(i => TestData(i, i.toString))).toDF()
+    df.createOrReplaceTempView("testData")
+    df
+  }
+
+  protected lazy val testData2: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      TestData2(1, 1) ::
+      TestData2(1, 2) ::
+      TestData2(2, 1) ::
+      TestData2(2, 2) ::
+      TestData2(3, 1) ::
+      TestData2(3, 2) :: Nil, 2).toDF()
+    df.createOrReplaceTempView("testData2")
+    df
+  }
+
+  protected lazy val testData3: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      TestData3(1, None) ::
+      TestData3(2, Some(2)) :: Nil).toDF()
+    df.createOrReplaceTempView("testData3")
+    df
+  }
+
+  protected lazy val negativeData: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      (1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
+    df.createOrReplaceTempView("negativeData")
+    df
+  }
+
+  protected lazy val largeAndSmallInts: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      LargeAndSmallInts(2147483644, 1) ::
+      LargeAndSmallInts(1, 2) ::
+      LargeAndSmallInts(2147483645, 1) ::
+      LargeAndSmallInts(2, 2) ::
+      LargeAndSmallInts(2147483646, 1) ::
+      LargeAndSmallInts(3, 2) :: Nil).toDF()
+    df.createOrReplaceTempView("largeAndSmallInts")
+    df
+  }
+
+  protected lazy val decimalData: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      DecimalData(1, 1) ::
+      DecimalData(1, 2) ::
+      DecimalData(2, 1) ::
+      DecimalData(2, 2) ::
+      DecimalData(3, 1) ::
+      DecimalData(3, 2) :: Nil).toDF()
+    df.createOrReplaceTempView("decimalData")
+    df
+  }
+
+  protected lazy val binaryData: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) ::
+      BinaryData("22".getBytes(StandardCharsets.UTF_8), 5) ::
+      BinaryData("122".getBytes(StandardCharsets.UTF_8), 3) ::
+      BinaryData("121".getBytes(StandardCharsets.UTF_8), 2) ::
+      BinaryData("123".getBytes(StandardCharsets.UTF_8), 4) :: Nil).toDF()
+    df.createOrReplaceTempView("binaryData")
+    df
+  }
+
+  protected lazy val upperCaseData: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      UpperCaseData(1, "A") ::
+      UpperCaseData(2, "B") ::
+      UpperCaseData(3, "C") ::
+      UpperCaseData(4, "D") ::
+      UpperCaseData(5, "E") ::
+      UpperCaseData(6, "F") :: Nil).toDF()
+    df.createOrReplaceTempView("upperCaseData")
+    df
+  }
+
+  protected lazy val lowerCaseData: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      LowerCaseData(1, "a") ::
+      LowerCaseData(2, "b") ::
+      LowerCaseData(3, "c") ::
+      LowerCaseData(4, "d") :: Nil).toDF()
+    df.createOrReplaceTempView("lowerCaseData")
+    df
+  }
+
+  protected lazy val arrayData: RDD[ArrayData] = {
+    val rdd = spark.sparkContext.parallelize(
+      ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
+      ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
+    rdd.toDF().createOrReplaceTempView("arrayData")
+    rdd
+  }
+
+  protected lazy val mapData: RDD[MapData] = {
+    val rdd = spark.sparkContext.parallelize(
+      MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
+      MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
+      MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
+      MapData(Map(1 -> "a4", 2 -> "b4")) ::
+      MapData(Map(1 -> "a5")) :: Nil)
+    rdd.toDF().createOrReplaceTempView("mapData")
+    rdd
+  }
+
+  protected lazy val repeatedData: RDD[StringData] = {
+    val rdd = spark.sparkContext.parallelize(List.fill(2)(StringData("test")))
+    rdd.toDF().createOrReplaceTempView("repeatedData")
+    rdd
+  }
+
+  protected lazy val nullableRepeatedData: RDD[StringData] = {
+    val rdd = spark.sparkContext.parallelize(
+      List.fill(2)(StringData(null)) ++
+      List.fill(2)(StringData("test")))
+    rdd.toDF().createOrReplaceTempView("nullableRepeatedData")
+    rdd
+  }
+
+  protected lazy val nullInts: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      NullInts(1) ::
+      NullInts(2) ::
+      NullInts(3) ::
+      NullInts(null) :: Nil).toDF()
+    df.createOrReplaceTempView("nullInts")
+    df
+  }
+
+  protected lazy val allNulls: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      NullInts(null) ::
+      NullInts(null) ::
+      NullInts(null) ::
+      NullInts(null) :: Nil).toDF()
+    df.createOrReplaceTempView("allNulls")
+    df
+  }
+
+  protected lazy val nullStrings: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      NullStrings(1, "abc") ::
+      NullStrings(2, "ABC") ::
+      NullStrings(3, null) :: Nil).toDF()
+    df.createOrReplaceTempView("nullStrings")
+    df
+  }
+
+  protected lazy val tableName: DataFrame = {
+    val df = spark.sparkContext.parallelize(TableName("test") :: Nil).toDF()
+    df.createOrReplaceTempView("tableName")
+    df
+  }
+
+  protected lazy val unparsedStrings: RDD[String] = {
+    spark.sparkContext.parallelize(
+      "1, A1, true, null" ::
+      "2, B2, false, null" ::
+      "3, C3, true, null" ::
+      "4, D4, true, 2147483644" :: Nil)
+  }
+
+  // An RDD with 4 elements and 8 partitions
+  protected lazy val withEmptyParts: RDD[IntField] = {
+    val rdd = spark.sparkContext.parallelize((1 to 4).map(IntField), 8)
+    rdd.toDF().createOrReplaceTempView("withEmptyParts")
+    rdd
+  }
+
+  protected lazy val person: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      Person(0, "mike", 30) ::
+      Person(1, "jim", 20) :: Nil).toDF()
+    df.createOrReplaceTempView("person")
+    df
+  }
+
+  protected lazy val salary: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      Salary(0, 2000.0) ::
+      Salary(1, 1000.0) :: Nil).toDF()
+    df.createOrReplaceTempView("salary")
+    df
+  }
+
+  protected lazy val complexData: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) ::
+      ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) ::
+      Nil).toDF()
+    df.createOrReplaceTempView("complexData")
+    df
+  }
+
+  protected lazy val courseSales: DataFrame = {
+    val df = spark.sparkContext.parallelize(
+      CourseSales("dotNET", 2012, 10000) ::
+        CourseSales("Java", 2012, 20000) ::
+        CourseSales("dotNET", 2012, 5000) ::
+        CourseSales("dotNET", 2013, 48000) ::
+        CourseSales("Java", 2013, 30000) :: Nil).toDF()
+    df.createOrReplaceTempView("courseSales")
+    df
+  }
+
+  /**
+   * Initialize all test data such that all temp tables are properly 
registered.
+   */
+  def loadTestData(): Unit = {
+    assert(spark != null, "attempted to initialize test data before 
SparkSession.")
+    emptyTestData
+    testData
+    testData2
+    testData3
+    negativeData
+    largeAndSmallInts
+    decimalData
+    binaryData
+    upperCaseData
+    lowerCaseData
+    arrayData
+    mapData
+    repeatedData
+    nullableRepeatedData
+    nullInts
+    allNulls
+    nullStrings
+    tableName
+    unparsedStrings
+    withEmptyParts
+    person
+    salary
+    complexData
+    courseSales
+  }
+}
+
+/**
+ * Case classes used in test data.
+ */
+private[sql] object SQLTestData {
+  case class TestData(key: Int, value: String)
+  case class TestData2(a: Int, b: Int)
+  case class TestData3(a: Int, b: Option[Int])
+  case class LargeAndSmallInts(a: Int, b: Int)
+  case class DecimalData(a: BigDecimal, b: BigDecimal)
+  case class BinaryData(a: Array[Byte], b: Int)
+  case class UpperCaseData(N: Int, L: String)
+  case class LowerCaseData(n: Int, l: String)
+  case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
+  case class MapData(data: scala.collection.Map[Int, String])
+  case class StringData(s: String)
+  case class IntField(i: Int)
+  case class NullInts(a: Integer)
+  case class NullStrings(n: Int, s: String)
+  case class TableName(tableName: String)
+  case class Person(id: Int, name: String, age: Int)
+  case class Salary(personId: Int, salary: Double)
+  case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: 
Boolean)
+  case class CourseSales(course: String, year: Int, earnings: Double)
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala 
b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
new file mode 100644
index 0000000..1e48e71
--- /dev/null
+++ 
b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import java.io.File
+import java.util.UUID
+
+import scala.language.implicitConversions
+import scala.util.Try
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.FilterExec
+import org.apache.spark.util.{UninterruptibleThread, Utils}
+
+/**
+ * Helper trait that should be extended by all SQL test suites.
+ *
+ * This allows subclasses to plugin a custom [[SQLContext]]. It comes with 
test data
+ * prepared in advance as well as all implicit conversions used extensively by 
dataframes.
+ * To use implicit methods, import `testImplicits._` instead of through the 
[[SQLContext]].
+ *
+ * Subclasses should *not* create [[SQLContext]]s in the test suite 
constructor, which is
+ * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in 
the same JVM.
+ */
+private[sql] trait SQLTestUtils
+  extends SparkFunSuite
+  with BeforeAndAfterAll
+  with SQLTestData { self =>
+
+  protected def sparkContext = spark.sparkContext
+
+  // Whether to materialize all test data before the first test is run
+  private var loadTestDataBeforeTests = false
+
+  // Shorthand for running a query using our SQLContext
+  protected lazy val sql = spark.sql _
+
+  /**
+   * A helper object for importing SQL implicits.
+   *
+   * Note that the alternative of importing `spark.implicits._` is not 
possible here.
+   * This is because we create the [[SQLContext]] immediately before the first 
test is run,
+   * but the implicits import is needed in the constructor.
+   */
+  protected object testImplicits extends SQLImplicits {
+    protected override def _sqlContext: SQLContext = self.spark.sqlContext
+  }
+
+  /**
+   * Materialize the test data immediately after the [[SQLContext]] is set up.
+   * This is necessary if the data is accessed by name but not through direct 
reference.
+   */
+  protected def setupTestData(): Unit = {
+    loadTestDataBeforeTests = true
+  }
+
+  protected override def beforeAll(): Unit = {
+    super.beforeAll()
+    if (loadTestDataBeforeTests) {
+      loadTestData()
+    }
+  }
+
+  /**
+   * Sets all SQL configurations specified in `pairs`, calls `f`, and then 
restore all SQL
+   * configurations.
+   *
+   * @todo Probably this method should be moved to a more general place
+   */
+  protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
+    val (keys, values) = pairs.unzip
+    val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption)
+    (keys, values).zipped.foreach(spark.conf.set)
+    try f finally {
+      keys.zip(currentValues).foreach {
+        case (key, Some(value)) => spark.conf.set(key, value)
+        case (key, None) => spark.conf.unset(key)
+      }
+    }
+  }
+
+  /**
+   * Generates a temporary path without creating the actual file/directory, 
then pass it to `f`. If
+   * a file/directory is created there by `f`, it will be delete after `f` 
returns.
+   *
+   * @todo Probably this method should be moved to a more general place
+   */
+  protected def withTempPath(f: File => Unit): Unit = {
+    val path = Utils.createTempDir()
+    path.delete()
+    try f(path) finally Utils.deleteRecursively(path)
+  }
+
+  /**
+   * Creates a temporary directory, which is then passed to `f` and will be 
deleted after `f`
+   * returns.
+   *
+   * @todo Probably this method should be moved to a more general place
+   */
+  protected def withTempDir(f: File => Unit): Unit = {
+    val dir = Utils.createTempDir().getCanonicalFile
+    try f(dir) finally Utils.deleteRecursively(dir)
+  }
+
+  /**
+   * Drops functions after calling `f`. A function is represented by 
(functionName, isTemporary).
+   */
+  protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => 
Unit): Unit = {
+    try {
+      f
+    } catch {
+      case cause: Throwable => throw cause
+    } finally {
+      // If the test failed part way, we don't want to mask the failure by 
failing to remove
+      // temp tables that never got created.
+      functions.foreach { case (functionName, isTemporary) =>
+        val withTemporary = if (isTemporary) "TEMPORARY" else ""
+        spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName")
+        assert(
+          
!spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)),
+          s"Function $functionName should have been dropped. But, it still 
exists.")
+      }
+    }
+  }
+
+  /**
+   * Drops temporary table `tableName` after calling `f`.
+   */
+  protected def withTempView(tableNames: String*)(f: => Unit): Unit = {
+    try f finally {
+      // If the test failed part way, we don't want to mask the failure by 
failing to remove
+      // temp tables that never got created.
+      try tableNames.foreach(spark.catalog.dropTempView) catch {
+        case _: NoSuchTableException =>
+      }
+    }
+  }
+
+  /**
+   * Drops table `tableName` after calling `f`.
+   */
+  protected def withTable(tableNames: String*)(f: => Unit): Unit = {
+    try f finally {
+      tableNames.foreach { name =>
+        spark.sql(s"DROP TABLE IF EXISTS $name")
+      }
+    }
+  }
+
+  /**
+   * Drops view `viewName` after calling `f`.
+   */
+  protected def withView(viewNames: String*)(f: => Unit): Unit = {
+    try f finally {
+      viewNames.foreach { name =>
+        spark.sql(s"DROP VIEW IF EXISTS $name")
+      }
+    }
+  }
+
+  /**
+   * Creates a temporary database and switches current database to it before 
executing `f`.  This
+   * database is dropped after `f` returns.
+   *
+   * Note that this method doesn't switch current database before executing 
`f`.
+   */
+  protected def withTempDatabase(f: String => Unit): Unit = {
+    val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}"
+
+    try {
+      spark.sql(s"CREATE DATABASE $dbName")
+    } catch { case cause: Throwable =>
+      fail("Failed to create temporary database", cause)
+    }
+
+    try f(dbName) finally {
+      if (spark.catalog.currentDatabase == dbName) {
+        spark.sql(s"USE ${DEFAULT_DATABASE}")
+      }
+      spark.sql(s"DROP DATABASE $dbName CASCADE")
+    }
+  }
+
+  /**
+   * Activates database `db` before executing `f`, then switches back to 
`default` database after
+   * `f` returns.
+   */
+  protected def activateDatabase(db: String)(f: => Unit): Unit = {
+    spark.sessionState.catalog.setCurrentDatabase(db)
+    try f finally spark.sessionState.catalog.setCurrentDatabase("default")
+  }
+
+  /**
+   * Strip Spark-side filtering in order to check if a datasource filters rows 
correctly.
+   */
+  protected def stripSparkFilter(df: DataFrame): DataFrame = {
+    val schema = df.schema
+    val withoutFilters = df.queryExecution.sparkPlan.transform {
+      case FilterExec(_, child) => child
+    }
+
+    spark.internalCreateDataFrame(withoutFilters.execute(), schema)
+  }
+
+  /**
+   * Turn a logical plan into a [[DataFrame]]. This should be removed once we 
have an easier
+   * way to construct [[DataFrame]] directly out of local data without relying 
on implicits.
+   */
+  protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame 
= {
+    Dataset.ofRows(spark, plan)
+  }
+
+  /**
+   * Disable stdout and stderr when running the test. To not output the logs 
to the console,
+   * ConsoleAppender's `follow` should be set to `true` so that it will honors 
reassignments of
+   * System.out or System.err. Otherwise, ConsoleAppender will still output to 
the console even if
+   * we change System.out and System.err.
+   */
+  protected def testQuietly(name: String)(f: => Unit): Unit = {
+    test(name) {
+      quietly {
+        f
+      }
+    }
+  }
+
+  /** Run a test on a separate [[UninterruptibleThread]]. */
+  protected def testWithUninterruptibleThread(name: String, quietly: Boolean = 
false)
+    (body: => Unit): Unit = {
+    val timeoutMillis = 10000
+    @transient var ex: Throwable = null
+
+    def runOnThread(): Unit = {
+      val thread = new UninterruptibleThread(s"Testing thread for test $name") 
{
+        override def run(): Unit = {
+          try {
+            body
+          } catch {
+            case NonFatal(e) =>
+              ex = e
+          }
+        }
+      }
+      thread.setDaemon(true)
+      thread.start()
+      thread.join(timeoutMillis)
+      if (thread.isAlive) {
+        thread.interrupt()
+        // If this interrupt does not work, then this thread is most likely 
running something that
+        // is not interruptible. There is not much point to wait for the 
thread to termniate, and
+        // we rather let the JVM terminate the thread on exit.
+        fail(
+          s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out 
after" +
+            s" $timeoutMillis ms")
+      } else if (ex != null) {
+        throw ex
+      }
+    }
+
+    if (quietly) {
+      testQuietly(name) { runOnThread() }
+    } else {
+      test(name) { runOnThread() }
+    }
+  }
+}
+
+private[sql] object SQLTestUtils {
+
+  def compareAnswers(
+      sparkAnswer: Seq[Row],
+      expectedAnswer: Seq[Row],
+      sort: Boolean): Option[String] = {
+    def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
+      // Converts data to types that we can do equality comparison using Scala 
collections.
+      // For BigDecimal type, the Scala type has a better definition of 
equality test (similar to
+      // Java's java.math.BigDecimal.compareTo).
+      // For binary arrays, we convert it to Seq to avoid of calling 
java.util.Arrays.equals for
+      // equality test.
+      // This function is copied from Catalyst's QueryTest
+      val converted: Seq[Row] = answer.map { s =>
+        Row.fromSeq(s.toSeq.map {
+          case d: java.math.BigDecimal => BigDecimal(d)
+          case b: Array[Byte] => b.toSeq
+          case o => o
+        })
+      }
+      if (sort) {
+        converted.sortBy(_.toString())
+      } else {
+        converted
+      }
+    }
+    if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
+      val errorMessage =
+        s"""
+           | == Results ==
+           | ${sideBySide(
+          s"== Expected Answer - ${expectedAnswer.size} ==" +:
+            prepareAnswer(expectedAnswer).map(_.toString()),
+          s"== Actual Answer - ${sparkAnswer.size} ==" +:
+            prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
+      """.stripMargin
+      Some(errorMessage)
+    } else {
+      None
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala
 
b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala
new file mode 100644
index 0000000..4e2a0c1
--- /dev/null
+++ 
b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+
+import com.google.common.io.Files
+
+import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.hive.HivemallOps._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.util.Utils
+
+/**
+ * Base class for tests with SparkSQL VectorUDT data.
+ */
+abstract class VectorQueryTest extends QueryTest with SQLTestUtils with 
TestHiveSingleton {
+
+  private var trainDir: File = _
+  private var testDir: File = _
+
+  // A `libsvm` schema is (Double, ml.linalg.Vector)
+  protected var mllibTrainDf: DataFrame = _
+  protected var mllibTestDf: DataFrame = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val trainLines =
+      """
+        |1 1:1.0 3:2.0 5:3.0
+        |0 2:4.0 4:5.0 6:6.0
+        |1 1:1.1 4:1.0 5:2.3 7:1.0
+        |1 1:1.0 4:1.5 5:2.1 7:1.2
+      """.stripMargin
+    trainDir = Utils.createTempDir()
+    Files.write(trainLines, new File(trainDir, "train-00000"), 
StandardCharsets.UTF_8)
+    val testLines =
+      """
+        |1 1:1.3 3:2.1 5:2.8
+        |0 2:3.9 4:5.3 6:8.0
+      """.stripMargin
+    testDir = Utils.createTempDir()
+    Files.write(testLines, new File(testDir, "test-00000"), 
StandardCharsets.UTF_8)
+
+    mllibTrainDf = spark.read.format("libsvm").load(trainDir.getAbsolutePath)
+    // Must be cached because rowid() is deterministic
+    mllibTestDf = spark.read.format("libsvm").load(testDir.getAbsolutePath)
+      .withColumn("rowid", rowid()).cache
+  }
+
+  override def afterAll(): Unit = {
+    try {
+      Utils.deleteRecursively(trainDir)
+      Utils.deleteRecursively(testDir)
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  protected def withTempModelDir(f: String => Unit): Unit = {
+    var tempDir: File = null
+    try {
+      tempDir = Utils.createTempDir()
+      f(tempDir.getAbsolutePath + "/xgboost_models")
+    } catch {
+      case e: Throwable => fail(s"Unexpected exception detected: ${e}")
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.2/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala
 
b/spark/spark-2.2/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala
new file mode 100644
index 0000000..0e1372d
--- /dev/null
+++ 
b/spark/spark-2.2/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.streaming
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.ml.feature.HivemallLabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.hive.HivemallOps._
+import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest
+import org.apache.spark.streaming.HivemallStreamingOps._
+import org.apache.spark.streaming.dstream.InputDStream
+import org.apache.spark.streaming.scheduler.StreamInputInfo
+
+/**
+ * This is an input stream just for tests.
+ */
+private[this] class TestInputStream[T: ClassTag](
+    ssc: StreamingContext,
+    input: Seq[Seq[T]],
+    numPartitions: Int) extends InputDStream[T](ssc) {
+
+  override def start() {}
+
+  override def stop() {}
+
+  override def compute(validTime: Time): Option[RDD[T]] = {
+    logInfo("Computing RDD for time " + validTime)
+    val index = ((validTime - zeroTime) / slideDuration - 1).toInt
+    val selectedInput = if (index < input.size) input(index) else Seq[T]()
+
+    // lets us test cases where RDDs are not created
+    if (selectedInput == null) {
+      return None
+    }
+
+    // Report the input data's information to InputInfoTracker for testing
+    val inputInfo = StreamInputInfo(id, selectedInput.length.toLong)
+    ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
+
+    val rdd = ssc.sc.makeRDD(selectedInput, numPartitions)
+    logInfo("Created RDD " + rdd.id + " with " + selectedInput)
+    Some(rdd)
+  }
+}
+
+final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
+
+  // This implicit value used in `HivemallStreamingOps`
+  implicit val sqlCtx = hiveContext
+
+  /**
+   * Run a block of code with the given StreamingContext.
+   * This method do not stop a given SparkContext because other tests share 
the context.
+   */
+  private def withStreamingContext[R](ssc: StreamingContext)(block: 
StreamingContext => R): Unit = {
+    try {
+      block(ssc)
+      ssc.start()
+      ssc.awaitTerminationOrTimeout(10 * 1000) // 10s wait
+    } finally {
+      try {
+        ssc.stop(stopSparkContext = false)
+      } catch {
+        case e: Exception => logError("Error stopping StreamingContext", e)
+      }
+    }
+  }
+
+  // scalastyle:off line.size.limit
+
+  /**
+   * This test below fails sometimes (too flaky), so we temporarily ignore it.
+   * The stacktrace of this failure is:
+   *
+   * HivemallOpsWithFeatureSuite:
+   *  Exception in thread "broadcast-exchange-60" java.lang.OutOfMemoryError: 
Java heap space
+   *   at java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:57)
+   *   at java.nio.ByteBuffer.allocate(ByteBuffer.java:331)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast$$anonfun$4.apply(TorrentBroadcast.scala:231)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast$$anonfun$4.apply(TorrentBroadcast.scala:231)
+   *   at 
org.apache.spark.util.io.ChunkedByteBufferOutputStream.allocateNewChunkIfNeeded(ChunkedByteBufferOutputStream.scala:78)
+   *   at 
org.apache.spark.util.io.ChunkedByteBufferOutputStream.write(ChunkedByteBufferOutputStream.scala:65)
+   *   at 
net.jpountz.lz4.LZ4BlockOutputStream.flushBufferedData(LZ4BlockOutputStream.java:205)
+   *   at 
net.jpountz.lz4.LZ4BlockOutputStream.finish(LZ4BlockOutputStream.java:235)
+   *   at 
net.jpountz.lz4.LZ4BlockOutputStream.close(LZ4BlockOutputStream.java:175)
+   *   at 
java.io.ObjectOutputStream$BlockDataOutputStream.close(ObjectOutputStream.java:1827)
+   *   at java.io.ObjectOutputStream.close(ObjectOutputStream.java:741)
+   *   at 
org.apache.spark.serializer.JavaSerializationStream.close(JavaSerializer.scala:57)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast$$anonfun$blockifyObject$1.apply$mcV$sp(TorrentBroadcast.scala:238)
+   *   at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1296)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast$.blockifyObject(TorrentBroadcast.scala:237)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast.writeBlocks(TorrentBroadcast.scala:107)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast.<init>(TorrentBroadcast.scala:86)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast(TorrentBroadcastFactory.scala:34)
+   *   ...
+   */
+
+  // scalastyle:on line.size.limit
+
+  ignore("streaming") {
+    import sqlCtx.implicits._
+
+    // We assume we build a model in advance
+    val testModel = Seq(
+      ("0", 0.3f), ("1", 0.1f), ("2", 0.6f), ("3", 0.2f)
+    ).toDF("feature", "weight")
+
+    withStreamingContext(new StreamingContext(sqlCtx.sparkContext, 
Milliseconds(100))) { ssc =>
+      val inputData = Seq(
+        Seq(HivemallLabeledPoint(features = "1:0.6" :: "2:0.1" :: Nil)),
+        Seq(HivemallLabeledPoint(features = "2:0.9" :: Nil)),
+        Seq(HivemallLabeledPoint(features = "1:0.2" :: Nil)),
+        Seq(HivemallLabeledPoint(features = "2:0.1" :: Nil)),
+        Seq(HivemallLabeledPoint(features = "0:0.6" :: "2:0.4" :: Nil))
+      )
+
+      val inputStream = new TestInputStream[HivemallLabeledPoint](ssc, 
inputData, 1)
+
+      // Apply predictions on input streams
+      val prediction = inputStream.predict { streamDf =>
+          val df = streamDf.select(rowid(), 
$"features").explode_array($"features")
+          val testDf = df.select(
+            // TODO: `$"feature"` throws AnalysisException, why?
+            $"rowid", extract_feature(df("feature")), 
extract_weight(df("feature"))
+          )
+          testDf.join(testModel, testDf("feature") === testModel("feature"), 
"LEFT_OUTER")
+            .select($"rowid", ($"weight" * $"value").as("value"))
+            .groupBy("rowid").sum("value")
+            .toDF("rowid", "value")
+            .select($"rowid", sigmoid($"value"))
+        }
+
+      // Dummy output stream
+      prediction.foreachRDD(_ => {})
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/test/TestUtils.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.2/src/test/scala/org/apache/spark/test/TestUtils.scala 
b/spark/spark-2.2/src/test/scala/org/apache/spark/test/TestUtils.scala
new file mode 100644
index 0000000..fa7b6e5
--- /dev/null
+++ b/spark/spark-2.2/src/test/scala/org/apache/spark/test/TestUtils.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.test
+
+import scala.reflect.runtime.{universe => ru}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.DataFrame
+
+object TestUtils extends Logging {
+
+  // Do benchmark if INFO-log enabled
+  def benchmark(benchName: String)(testFunc: => Unit): Unit = {
+    if (log.isDebugEnabled) {
+      testFunc
+    }
+  }
+
+  def expectResult(res: Boolean, errMsg: String): Unit = if (res) {
+    logWarning(errMsg)
+  }
+
+  def invokeFunc(cls: Any, func: String, args: Any*): DataFrame = try {
+    // Invoke a function with the given name via reflection
+    val im = scala.reflect.runtime.currentMirror.reflect(cls)
+    val mSym = im.symbol.typeSignature.member(ru.newTermName(func)).asMethod
+    im.reflectMethod(mSym).apply(args: _*)
+      .asInstanceOf[DataFrame]
+  } catch {
+    case e: Exception =>
+      assert(false, s"Invoking ${func} failed because: ${e.getMessage}")
+      null // Not executed
+  }
+}
+
+// TODO: Any same function in o.a.spark.*?
+class TestFPWrapper(d: Double) {
+
+  // Check an equality between Double/Float values
+  def ~==(d: Double): Boolean = Math.abs(this.d - d) < 0.001
+}
+
+object TestFPWrapper {
+
+  @inline implicit def toTestFPWrapper(d: Double): TestFPWrapper = {
+    new TestFPWrapper(d)
+  }
+}

Reply via email to