Repository: incubator-hivemall
Updated Branches:
  refs/heads/master d3afb11ba -> c837e51ad (forced update)


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala
 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala
new file mode 100644
index 0000000..9c23687
--- /dev/null
+++ 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.test
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+
+import com.google.common.io.Files
+
+import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.hive.HivemallOps._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.util.Utils
+
+/**
+ * Base class for tests with SparkSQL VectorUDT data.
+ */
+abstract class VectorQueryTest extends QueryTest with SQLTestUtils with 
TestHiveSingleton {
+
+  private var trainDir: File = _
+  private var testDir: File = _
+
+  // A `libsvm` schema is (Double, ml.linalg.Vector)
+  protected var mllibTrainDf: DataFrame = _
+  protected var mllibTestDf: DataFrame = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val trainLines =
+      """
+        |1 1:1.0 3:2.0 5:3.0
+        |0 2:4.0 4:5.0 6:6.0
+        |1 1:1.1 4:1.0 5:2.3 7:1.0
+        |1 1:1.0 4:1.5 5:2.1 7:1.2
+      """.stripMargin
+    trainDir = Utils.createTempDir()
+    Files.write(trainLines, new File(trainDir, "train-00000"), 
StandardCharsets.UTF_8)
+    val testLines =
+      """
+        |1 1:1.3 3:2.1 5:2.8
+        |0 2:3.9 4:5.3 6:8.0
+      """.stripMargin
+    testDir = Utils.createTempDir()
+    Files.write(testLines, new File(testDir, "test-00000"), 
StandardCharsets.UTF_8)
+
+    mllibTrainDf = spark.read.format("libsvm").load(trainDir.getAbsolutePath)
+    // Must be cached because rowid() is deterministic
+    mllibTestDf = spark.read.format("libsvm").load(testDir.getAbsolutePath)
+      .withColumn("rowid", rowid()).cache
+  }
+
+  override def afterAll(): Unit = {
+    try {
+      Utils.deleteRecursively(trainDir)
+      Utils.deleteRecursively(testDir)
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  protected def withTempModelDir(f: String => Unit): Unit = {
+    var tempDir: File = null
+    try {
+      tempDir = Utils.createTempDir()
+      f(tempDir.getAbsolutePath + "/xgboost_models")
+    } catch {
+      case e: Throwable => fail(s"Unexpected exception detected: ${e}")
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala
 
b/spark/spark-2.1/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala
new file mode 100644
index 0000000..b15c77c
--- /dev/null
+++ 
b/spark/spark-2.1/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.streaming
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.ml.feature.HivemallLabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.hive.HivemallOps._
+import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest
+import org.apache.spark.streaming.HivemallStreamingOps._
+import org.apache.spark.streaming.dstream.InputDStream
+import org.apache.spark.streaming.scheduler.StreamInputInfo
+
+/**
+ * This is an input stream just for tests.
+ */
+private[this] class TestInputStream[T: ClassTag](
+    ssc: StreamingContext,
+    input: Seq[Seq[T]],
+    numPartitions: Int) extends InputDStream[T](ssc) {
+
+  override def start() {}
+
+  override def stop() {}
+
+  override def compute(validTime: Time): Option[RDD[T]] = {
+    logInfo("Computing RDD for time " + validTime)
+    val index = ((validTime - zeroTime) / slideDuration - 1).toInt
+    val selectedInput = if (index < input.size) input(index) else Seq[T]()
+
+    // lets us test cases where RDDs are not created
+    if (selectedInput == null) {
+      return None
+    }
+
+    // Report the input data's information to InputInfoTracker for testing
+    val inputInfo = StreamInputInfo(id, selectedInput.length.toLong)
+    ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
+
+    val rdd = ssc.sc.makeRDD(selectedInput, numPartitions)
+    logInfo("Created RDD " + rdd.id + " with " + selectedInput)
+    Some(rdd)
+  }
+}
+
+final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
+
+  // This implicit value used in `HivemallStreamingOps`
+  implicit val sqlCtx = hiveContext
+
+  /**
+   * Run a block of code with the given StreamingContext.
+   * This method do not stop a given SparkContext because other tests share 
the context.
+   */
+  private def withStreamingContext[R](ssc: StreamingContext)(block: 
StreamingContext => R): Unit = {
+    try {
+      block(ssc)
+      ssc.start()
+      ssc.awaitTerminationOrTimeout(10 * 1000) // 10s wait
+    } finally {
+      try {
+        ssc.stop(stopSparkContext = false)
+      } catch {
+        case e: Exception => logError("Error stopping StreamingContext", e)
+      }
+    }
+  }
+
+  // scalastyle:off line.size.limit
+
+  /**
+   * This test below fails sometimes (too flaky), so we temporarily ignore it.
+   * The stacktrace of this failure is:
+   *
+   * HivemallOpsWithFeatureSuite:
+   *  Exception in thread "broadcast-exchange-60" java.lang.OutOfMemoryError: 
Java heap space
+   *   at java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:57)
+   *   at java.nio.ByteBuffer.allocate(ByteBuffer.java:331)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast$$anonfun$4.apply(TorrentBroadcast.scala:231)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast$$anonfun$4.apply(TorrentBroadcast.scala:231)
+   *   at 
org.apache.spark.util.io.ChunkedByteBufferOutputStream.allocateNewChunkIfNeeded(ChunkedByteBufferOutputStream.scala:78)
+   *   at 
org.apache.spark.util.io.ChunkedByteBufferOutputStream.write(ChunkedByteBufferOutputStream.scala:65)
+   *   at 
net.jpountz.lz4.LZ4BlockOutputStream.flushBufferedData(LZ4BlockOutputStream.java:205)
+   *   at 
net.jpountz.lz4.LZ4BlockOutputStream.finish(LZ4BlockOutputStream.java:235)
+   *   at 
net.jpountz.lz4.LZ4BlockOutputStream.close(LZ4BlockOutputStream.java:175)
+   *   at 
java.io.ObjectOutputStream$BlockDataOutputStream.close(ObjectOutputStream.java:1827)
+   *   at java.io.ObjectOutputStream.close(ObjectOutputStream.java:741)
+   *   at 
org.apache.spark.serializer.JavaSerializationStream.close(JavaSerializer.scala:57)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast$$anonfun$blockifyObject$1.apply$mcV$sp(TorrentBroadcast.scala:238)
+   *   at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1296)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast$.blockifyObject(TorrentBroadcast.scala:237)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast.writeBlocks(TorrentBroadcast.scala:107)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcast.<init>(TorrentBroadcast.scala:86)
+   *   at 
org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast(TorrentBroadcastFactory.scala:34)
+   *   ...
+   */
+
+  // scalastyle:on line.size.limit
+
+  ignore("streaming") {
+    import sqlCtx.implicits._
+
+    // We assume we build a model in advance
+    val testModel = Seq(
+      ("0", 0.3f), ("1", 0.1f), ("2", 0.6f), ("3", 0.2f)
+    ).toDF("feature", "weight")
+
+    withStreamingContext(new StreamingContext(sqlCtx.sparkContext, 
Milliseconds(100))) { ssc =>
+      val inputData = Seq(
+        Seq(HivemallLabeledPoint(features = "1:0.6" :: "2:0.1" :: Nil)),
+        Seq(HivemallLabeledPoint(features = "2:0.9" :: Nil)),
+        Seq(HivemallLabeledPoint(features = "1:0.2" :: Nil)),
+        Seq(HivemallLabeledPoint(features = "2:0.1" :: Nil)),
+        Seq(HivemallLabeledPoint(features = "0:0.6" :: "2:0.4" :: Nil))
+      )
+
+      val inputStream = new TestInputStream[HivemallLabeledPoint](ssc, 
inputData, 1)
+
+      // Apply predictions on input streams
+      val prediction = inputStream.predict { streamDf =>
+          val df = streamDf.select(rowid(), 
$"features").explode_array($"features")
+          val testDf = df.select(
+            // TODO: `$"feature"` throws AnalysisException, why?
+            $"rowid", extract_feature(df("feature")), 
extract_weight(df("feature"))
+          )
+          testDf.join(testModel, testDf("feature") === testModel("feature"), 
"LEFT_OUTER")
+            .select($"rowid", ($"weight" * $"value").as("value"))
+            .groupBy("rowid").sum("value")
+            .toDF("rowid", "value")
+            .select($"rowid", sigmoid($"value"))
+        }
+
+      // Dummy output stream
+      prediction.foreachRDD(_ => {})
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/test/TestUtils.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/test/scala/org/apache/spark/test/TestUtils.scala 
b/spark/spark-2.1/src/test/scala/org/apache/spark/test/TestUtils.scala
new file mode 100644
index 0000000..8a2a385
--- /dev/null
+++ b/spark/spark-2.1/src/test/scala/org/apache/spark/test/TestUtils.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.test
+
+import scala.reflect.runtime.{universe => ru}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.DataFrame
+
+object TestUtils extends Logging {
+
+  // Do benchmark if INFO-log enabled
+  def benchmark(benchName: String)(testFunc: => Unit): Unit = {
+    if (log.isDebugEnabled) {
+      testFunc
+    }
+  }
+
+  def expectResult(res: Boolean, errMsg: String): Unit = if (res) {
+    logWarning(errMsg)
+  }
+
+  def invokeFunc(cls: Any, func: String, args: Any*): DataFrame = try {
+    // Invoke a function with the given name via reflection
+    val im = scala.reflect.runtime.currentMirror.reflect(cls)
+    val mSym = im.symbol.typeSignature.member(ru.newTermName(func)).asMethod
+    im.reflectMethod(mSym).apply(args: _*)
+      .asInstanceOf[DataFrame]
+  } catch {
+    case e: Exception =>
+      assert(false, s"Invoking ${func} failed because: ${e.getMessage}")
+      null // Not executed
+  }
+}
+
+// TODO: Any same function in o.a.spark.*?
+class TestFPWrapper(d: Double) {
+
+  // Check an equality between Double/Float values
+  def ~==(d: Double): Boolean = Math.abs(this.d - d) < 0.001
+}
+
+object TestFPWrapper {
+
+  @inline implicit def toTestFPWrapper(d: Double): TestFPWrapper = {
+    new TestFPWrapper(d)
+  }
+}

Reply via email to