http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/QueryTest.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/QueryTest.scala new file mode 100644 index 0000000..c9d0ba0 --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql + +import java.util.{ArrayDeque, Locale, TimeZone} + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.streaming.MemoryPlan +import org.apache.spark.sql.types.{Metadata, ObjectType} + + +abstract class QueryTest extends PlanTest { + + protected def spark: SparkSession + + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + + /** + * Runs the plan and makes sure the answer contains all of the keywords. + */ + def checkKeywordsExist(df: DataFrame, keywords: String*): Unit = { + val outputs = df.collect().map(_.mkString).mkString + for (key <- keywords) { + assert(outputs.contains(key), s"Failed for $df ($key doesn't exist in result)") + } + } + + /** + * Runs the plan and makes sure the answer does NOT contain any of the keywords. + */ + def checkKeywordsNotExist(df: DataFrame, keywords: String*): Unit = { + val outputs = df.collect().map(_.mkString).mkString + for (key <- keywords) { + assert(!outputs.contains(key), s"Failed for $df ($key existed in the result)") + } + } + + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer. + */ + protected def checkDataset[T]( + ds: => Dataset[T], + expectedAnswer: T*): Unit = { + val result = getResult(ds) + + if (!compare(result.toSeq, expectedAnswer)) { + fail( + s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + |${ds.exprEnc.deserializer.treeString} + """.stripMargin) + } + } + + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer, after sort. + */ + protected def checkDatasetUnorderly[T : Ordering]( + ds: => Dataset[T], + expectedAnswer: T*): Unit = { + val result = getResult(ds) + + if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) { + fail( + s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + |${ds.exprEnc.deserializer.treeString} + """.stripMargin) + } + } + + private def getResult[T](ds: => Dataset[T]): Array[T] = { + val analyzedDS = try ds catch { + case ae: AnalysisException => + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${stackTraceToString(ae)} + """.stripMargin) + } else { + throw ae + } + } + assertEmptyMissingInput(analyzedDS) + + try ds.collect() catch { + case e: Exception => + fail( + s""" + |Exception collecting dataset as objects + |${ds.exprEnc} + |${ds.exprEnc.deserializer.treeString} + |${ds.queryExecution} + """.stripMargin, e) + } + } + + private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => + a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Iterable[_], b: Iterable[_]) => + a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a, b) => a == b + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + val analyzedDF = try df catch { + case ae: AnalysisException => + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${stackTraceToString(ae)} + |""".stripMargin) + } else { + throw ae + } + } + + assertEmptyMissingInput(analyzedDF) + + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } + + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * + * @param dataFrame the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Seq[Row], + absTol: Double): Unit = { + // TODO: catch exceptions in data frame execution + val actualAnswer = dataFrame.collect() + require(actualAnswer.length == expectedAnswer.length, + s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}") + + actualAnswer.zip(expectedAnswer).foreach { + case (actualRow, expectedRow) => + QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol) + } + } + + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Row, + absTol: Double): Unit = { + checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol) + } + + /** + * Asserts that a given [[Dataset]] will be executed using the given number of cached results. + */ + def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + + /** + * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. + */ + def assertEmptyMissingInput(query: Dataset[_]): Unit = { + assert(query.queryExecution.analyzed.missingInput.isEmpty, + s"The analyzed logical plan has missing inputs:\n${query.queryExecution.analyzed}") + assert(query.queryExecution.optimizedPlan.missingInput.isEmpty, + s"The optimized logical plan has missing inputs:\n${query.queryExecution.optimizedPlan}") + assert(query.queryExecution.executedPlan.missingInput.isEmpty, + s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}") + } +} + +object QueryTest { + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a [[None]] will + * be returned. + * + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. + */ + def checkAnswer( + df: DataFrame, + expectedAnswer: Seq[Row], + checkToRDD: Boolean = true): Option[String] = { + val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + if (checkToRDD) { + df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] + } + + val sparkAnswer = try df.collect().toSeq catch { + case e: Exception => + val errorMessage = + s""" + |Exception thrown while executing query: + |${df.queryExecution} + |== Exception == + |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + sameRows(expectedAnswer, sparkAnswer, isSorted).map { results => + s""" + |Results do not match for query: + |Timezone: ${TimeZone.getDefault} + |Timezone Env: ${sys.env.getOrElse("TZ", "")} + | + |${df.queryExecution} + |== Results == + |$results + """.stripMargin + } + } + + + def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + + def sameRows( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): Option[String] = { + if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { + val errorMessage = + s""" + |== Results == + |${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) + } + None + } + + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * + * @param actualAnswer the actual result in a [[Row]]. + * @param expectedAnswer the expected result in a[[Row]]. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = { + require(actualAnswer.length == expectedAnswer.length, + s"actual answer length ${actualAnswer.length} != " + + s"expected answer length ${expectedAnswer.length}") + + // TODO: support other numeric types besides Double + // TODO: support struct types? + actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach { + case (actual: Double, expected: Double) => + assert(math.abs(actual - expected) < absTol, + s"actual answer $actual not within $absTol of correct answer $expected") + case (actual, expected) => + assert(actual == expected, s"$actual did not equal $expected") + } + } + + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { + checkAnswer(df, expectedAnswer.asScala) match { + case Some(errorMessage) => errorMessage + case None => null + } + } +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala new file mode 100644 index 0000000..a4aeaa6 --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Provides helper methods for comparing plans. + */ +abstract class PlanTest extends SparkFunSuite with PredicateHelper { + + protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true) + + /** + * Since attribute references are given globally unique ids during analysis, + * we must normalize them to check if two different queries are identical. + */ + protected def normalizeExprIds(plan: LogicalPlan) = { + plan transformAllExpressions { + case s: ScalarSubquery => + s.copy(exprId = ExprId(0)) + case e: Exists => + e.copy(exprId = ExprId(0)) + case l: ListQuery => + l.copy(exprId = ExprId(0)) + case a: AttributeReference => + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) + case a: Alias => + Alias(a.child, a.name)(exprId = ExprId(0)) + case ae: AggregateExpression => + ae.copy(resultId = ExprId(0)) + } + } + + /** + * Normalizes plans: + * - Filter the filter conditions that appear in a plan. For instance, + * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) + * etc., will all now be equivalent. + * - Sample the seed will replaced by 0L. + * - Join conditions will be resorted by hashCode. + */ + protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { + plan transform { + case filter @ Filter(condition: Expression, child: LogicalPlan) => + Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And), child) + case sample: Sample => + sample.copy(seed = 0L)(true) + case join @ Join(left, right, joinType, condition) if condition.isDefined => + val newCondition = + splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And) + Join(left, right, joinType, Some(newCondition)) + } + } + + /** + * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be + * equivalent: + * 1. (a = b), (b = a); + * 2. (a <=> b), (b <=> a). + */ + private def rewriteEqual(condition: Expression): Expression = condition match { + case eq @ EqualTo(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) + case eq @ EqualNullSafe(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + case _ => condition // Don't reorder. + } + + /** Fails the test if the two plans do not match */ + protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (normalized1 != normalized2) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Fails the test if the two expressions do not match */ + protected def compareExpressions(e1: Expression, e2: Expression): Unit = { + comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) + } + + /** Fails the test if the join order in the two plans do not match */ + protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (!sameJoinPlan(normalized1, normalized2)) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Consider symmetry for joins when comparing plans. */ + private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + (plan1, plan2) match { + case (j1: Join, j2: Join) => + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case (p1: Project, p2: Project) => + p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) + case _ => + plan1 == plan2 + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala new file mode 100644 index 0000000..8283503 --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.util.Benchmark + +/** + * Common base trait for micro benchmarks that are supposed to run standalone (i.e. not together + * with other test suites). + */ +private[sql] trait BenchmarkBase extends SparkFunSuite { + + lazy val sparkSession = SparkSession.builder + .master("local[1]") + .appName("microbenchmark") + .config("spark.sql.shuffle.partitions", 1) + .config("spark.sql.autoBroadcastJoinThreshold", 1) + .getOrCreate() + + /** Runs function `f` with whole stage codegen on and off. */ + def runBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { + val benchmark = new Benchmark(name, cardinality) + + benchmark.addCase(s"$name wholestage off", numIters = 2) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", value = false) + f + } + + benchmark.addCase(s"$name wholestage on", numIters = 5) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) + f + } + + benchmark.run() + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala new file mode 100644 index 0000000..b145b7f --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.Row +import org.apache.spark.sql.hive.HivemallUtils._ +import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest +import org.apache.spark.sql.test.VectorQueryTest + +final class HiveUdfWithFeatureSuite extends HivemallFeatureQueryTest { + import hiveContext.implicits._ + import hiveContext._ + + test("hivemall_version") { + sql(s""" + | CREATE TEMPORARY FUNCTION hivemall_version + | AS '${classOf[hivemall.HivemallVersionUDF].getName}' + """.stripMargin) + + checkAnswer( + sql(s"SELECT DISTINCT hivemall_version()"), + Row("0.4.2-rc.2") + ) + + // sql("DROP TEMPORARY FUNCTION IF EXISTS hivemall_version") + // reset() + } + + test("train_logregr") { + TinyTrainData.createOrReplaceTempView("TinyTrainData") + sql(s""" + | CREATE TEMPORARY FUNCTION train_logregr + | AS '${classOf[hivemall.regression.LogressUDTF].getName}' + """.stripMargin) + sql(s""" + | CREATE TEMPORARY FUNCTION add_bias + | AS '${classOf[hivemall.ftvec.AddBiasUDFWrapper].getName}' + """.stripMargin) + + val model = sql( + s""" + | SELECT feature, AVG(weight) AS weight + | FROM ( + | SELECT train_logregr(add_bias(features), label) AS (feature, weight) + | FROM TinyTrainData + | ) t + | GROUP BY feature + """.stripMargin) + + checkAnswer( + model.select($"feature"), + Seq(Row("0"), Row("1"), Row("2")) + ) + + // TODO: Why 'train_logregr' is not registered in HiveMetaStore? + // ERROR RetryingHMSHandler: MetaException(message:NoSuchObjectException + // (message:Function default.train_logregr does not exist)) + // + // hiveContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_logregr") + // hiveContext.reset() + } + + test("each_top_k") { + val testDf = Seq( + ("a", "1", 0.5, Array(0, 1, 2)), + ("b", "5", 0.1, Array(3)), + ("a", "3", 0.8, Array(2, 5)), + ("c", "6", 0.3, Array(1, 3)), + ("b", "4", 0.3, Array(2)), + ("a", "2", 0.6, Array(1)) + ).toDF("key", "value", "score", "data") + + import testDf.sqlContext.implicits._ + testDf.repartition($"key").sortWithinPartitions($"key").createOrReplaceTempView("TestData") + sql(s""" + | CREATE TEMPORARY FUNCTION each_top_k + | AS '${classOf[hivemall.tools.EachTopKUDTF].getName}' + """.stripMargin) + + // Compute top-1 rows for each group + checkAnswer( + sql("SELECT each_top_k(1, key, score, key, value) FROM TestData"), + Row(1, 0.8, "a", "3") :: + Row(1, 0.3, "b", "4") :: + Row(1, 0.3, "c", "6") :: + Nil + ) + + // Compute reverse top-1 rows for each group + checkAnswer( + sql("SELECT each_top_k(-1, key, score, key, value) FROM TestData"), + Row(1, 0.5, "a", "1") :: + Row(1, 0.1, "b", "5") :: + Row(1, 0.3, "c", "6") :: + Nil + ) + } +} + +final class HiveUdfWithVectorSuite extends VectorQueryTest { + import hiveContext._ + + test("to_hivemall_features") { + mllibTrainDf.createOrReplaceTempView("mllibTrainDf") + hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func) + checkAnswer( + sql( + s""" + | SELECT to_hivemall_features(features) + | FROM mllibTrainDf + """.stripMargin), + Seq( + Row(Seq("0:1.0", "2:2.0", "4:3.0")), + Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2")), + Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0")), + Row(Seq("1:4.0", "3:5.0", "5:6.0")) + ) + ) + } + + test("append_bias") { + mllibTrainDf.createOrReplaceTempView("mllibTrainDf") + hiveContext.udf.register("append_bias", append_bias_func) + hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func) + checkAnswer( + sql( + s""" + | SELECT to_hivemall_features(append_bias(features)) + | FROM mllibTrainDF + """.stripMargin), + Seq( + Row(Seq("0:1.0", "2:2.0", "4:3.0", "7:1.0")), + Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2", "7:1.0")), + Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0", "7:1.0")), + Row(Seq("1:4.0", "3:5.0", "5:6.0", "7:1.0")) + ) + ) + } + + ignore("explode_vector") { + // TODO: Spark-2.0 does not support use-defined generator function in + // `org.apache.spark.sql.UDFRegistration`. + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala new file mode 100644 index 0000000..6b5d4cd --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@ -0,0 +1,961 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallGroupedDataset._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.HivemallUtils._ +import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.VectorQueryTest +import org.apache.spark.sql.types._ +import org.apache.spark.test.TestFPWrapper._ +import org.apache.spark.test.TestUtils + +final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { + + test("anomaly") { + import hiveContext.implicits._ + val df = spark.range(1000).selectExpr("id AS time", "rand() AS x") + // TODO: Test results more strictly + assert(df.sort($"time".asc).select(changefinder($"x")).count === 1000) + assert(df.sort($"time".asc).select(sst($"x", lit("-th 0.005"))).count === 1000) + } + + test("knn.similarity") { + val df1 = DummyInputData.select(cosine_sim(lit2(Seq(1, 2, 3, 4)), lit2(Seq(3, 4, 5, 6)))) + assert(df1.collect.apply(0).getFloat(0) ~== 0.500f) + + val df2 = DummyInputData.select(jaccard(lit(5), lit(6))) + assert(df2.collect.apply(0).getFloat(0) ~== 0.96875f) + + val df3 = DummyInputData.select(angular_similarity(lit2(Seq(1, 2, 3)), lit2(Seq(4, 5, 6)))) + assert(df3.collect.apply(0).getFloat(0) ~== 0.500f) + + val df4 = DummyInputData.select(euclid_similarity(lit2(Seq(5, 3, 1)), lit2(Seq(2, 8, 3)))) + assert(df4.collect.apply(0).getFloat(0) ~== 0.33333334f) + + val df5 = DummyInputData.select(distance2similarity(lit(1.0))) + assert(df5.collect.apply(0).getFloat(0) ~== 0.5f) + } + + test("knn.distance") { + val df1 = DummyInputData.select(hamming_distance(lit(1), lit(3))) + checkAnswer(df1, Row(1) :: Nil) + + val df2 = DummyInputData.select(popcnt(lit(1))) + checkAnswer(df2, Row(1) :: Nil) + + val df3 = DummyInputData.select(kld(lit(0.1), lit(0.5), lit(0.2), lit(0.5))) + assert(df3.collect.apply(0).getDouble(0) ~== 0.01) + + val df4 = DummyInputData.select( + euclid_distance(lit2(Seq("0.1", "0.5")), lit2(Seq("0.2", "0.5")))) + assert(df4.collect.apply(0).getFloat(0) ~== 1.4142135f) + + val df5 = DummyInputData.select( + cosine_distance(lit2(Seq("0.8", "0.3")), lit2(Seq("0.4", "0.6")))) + assert(df5.collect.apply(0).getFloat(0) ~== 1.0f) + + val df6 = DummyInputData.select( + angular_distance(lit2(Seq("0.1", "0.1")), lit2(Seq("0.3", "0.8")))) + assert(df6.collect.apply(0).getFloat(0) ~== 0.50f) + + val df7 = DummyInputData.select( + manhattan_distance(lit2(Seq("0.7", "0.8")), lit2(Seq("0.5", "0.6")))) + assert(df7.collect.apply(0).getFloat(0) ~== 4.0f) + + val df8 = DummyInputData.select( + minkowski_distance(lit2(Seq("0.1", "0.2")), lit2(Seq("0.2", "0.2")), lit2(1.0))) + assert(df8.collect.apply(0).getFloat(0) ~== 2.0f) + } + + test("knn.lsh") { + import hiveContext.implicits._ + assert(IntList2Data.minhash(lit(1), $"target").count() > 0) + + assert(DummyInputData.select(bbit_minhash(lit2(Seq("1:0.1", "2:0.5")), lit(false))).count + == DummyInputData.count) + assert(DummyInputData.select(minhashes(lit2(Seq("1:0.1", "2:0.5")), lit(false))).count + == DummyInputData.count) + } + + test("ftvec - add_bias") { + import hiveContext.implicits._ + checkAnswer(TinyTrainData.select(add_bias($"features")), + Row(Seq("1:0.8", "2:0.2", "0:1.0")) :: + Row(Seq("2:0.7", "0:1.0")) :: + Row(Seq("1:0.9", "0:1.0")) :: + Nil + ) + } + + test("ftvec - extract_feature") { + val df = DummyInputData.select(extract_feature(lit("1:0.8"))) + checkAnswer(df, Row("1") :: Nil) + } + + test("ftvec - extract_weight") { + val df = DummyInputData.select(extract_weight(lit("3:0.1"))) + assert(df.collect.apply(0).getDouble(0) ~== 0.1) + } + + test("ftvec - explode_array") { + import hiveContext.implicits._ + val df = TinyTrainData.explode_array($"features").select($"feature") + checkAnswer(df, Row("1:0.8") :: Row("2:0.2") :: Row("2:0.7") :: Row("1:0.9") :: Nil) + } + + test("ftvec - add_feature_index") { + import hiveContext.implicits._ + val doubleListData = Seq(Array(0.8, 0.5), Array(0.3, 0.1), Array(0.2)).toDF("data") + checkAnswer( + doubleListData.select(add_feature_index($"data")), + Row(Seq("1:0.8", "2:0.5")) :: + Row(Seq("1:0.3", "2:0.1")) :: + Row(Seq("1:0.2")) :: + Nil + ) + } + + test("ftvec - sort_by_feature") { + // import hiveContext.implicits._ + val intFloatMapData = { + // TODO: Use `toDF` + val rowRdd = hiveContext.sparkContext.parallelize( + Row(Map(1 -> 0.3f, 2 -> 0.1f, 3 -> 0.5f)) :: + Row(Map(2 -> 0.4f, 1 -> 0.2f)) :: + Row(Map(2 -> 0.4f, 3 -> 0.2f, 1 -> 0.1f, 4 -> 0.6f)) :: + Nil + ) + hiveContext.createDataFrame( + rowRdd, + StructType( + StructField("data", MapType(IntegerType, FloatType), true) :: + Nil) + ) + } + val sortedKeys = intFloatMapData.select(sort_by_feature(intFloatMapData.col("data"))) + .collect.map { + case Row(m: Map[Int, Float]) => m.keysIterator.toSeq + } + assert(sortedKeys.toSet === Set(Seq(1, 2, 3), Seq(1, 2), Seq(1, 2, 3, 4))) + } + + test("ftvec.hash") { + assert(DummyInputData.select(mhash(lit("test"))).count == DummyInputData.count) + assert(DummyInputData.select(org.apache.spark.sql.hive.HivemallOps.sha1(lit("test"))).count == + DummyInputData.count) + // TODO: The tests below failed because: + // org.apache.spark.sql.AnalysisException: List type in java is unsupported because JVM type + // erasure makes spark fail to catch a component type in List<>; + // + // assert(DummyInputData.select(array_hash_values(lit2(Seq("aaa", "bbb")))).count + // == DummyInputData.count) + // assert(DummyInputData.select( + // prefixed_hash_values(lit2(Seq("ccc", "ddd")), lit("prefix"))).count + // == DummyInputData.count) + } + + test("ftvec.scaling") { + val df1 = TinyTrainData.select(rescale(lit(2.0f), lit(1.0), lit(5.0))) + assert(df1.collect.apply(0).getFloat(0) === 0.25f) + val df2 = TinyTrainData.select(zscore(lit(1.0f), lit(0.5), lit(0.5))) + assert(df2.collect.apply(0).getFloat(0) === 1.0f) + val df3 = TinyTrainData.select(normalize(TinyTrainData.col("features"))) + checkAnswer( + df3, + Row(Seq("1:0.9701425", "2:0.24253562")) :: + Row(Seq("2:1.0")) :: + Row(Seq("1:1.0")) :: + Nil) + } + + test("ftvec.selection - chi2") { + import hiveContext.implicits._ + + // See also hivemall.ftvec.selection.ChiSquareUDFTest + val df = Seq( + Seq( + Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996), + Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3), + Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998) + ) -> Seq( + Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589), + Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589), + Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589))) + .toDF("arg0", "arg1") + + val result = df.select(chi2(df("arg0"), df("arg1"))).collect + assert(result.length == 1) + val chi2Val = result.head.getAs[Row](0).getAs[Seq[Double]](0) + val pVal = result.head.getAs[Row](0).getAs[Seq[Double]](1) + + (chi2Val, Seq(10.81782088, 3.59449902, 116.16984746, 67.24482759)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + + (pVal, Seq(4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + } + + test("ftvec.conv - quantify") { + import hiveContext.implicits._ + val testDf = Seq((1, "aaa", true), (2, "bbb", false), (3, "aaa", false)).toDF + // This test is done in a single partition because `HivemallOps#quantify` assigns identifiers + // for non-numerical values in each partition. + checkAnswer( + testDf.coalesce(1).quantify(lit(true) +: testDf.cols: _*), + Row(1, 0, 0) :: Row(2, 1, 1) :: Row(3, 0, 1) :: Nil) + } + + test("ftvec.amplify") { + import hiveContext.implicits._ + assert(TinyTrainData.amplify(lit(3), $"label", $"features").count() == 9) + assert(TinyTrainData.part_amplify(lit(3)).count() == 9) + // TODO: The test below failed because: + // java.lang.RuntimeException: Unsupported literal type class scala.Tuple3 + // (-buf 128,label,features) + // + // assert(TinyTrainData.rand_amplify(lit(3), lit("-buf 8", $"label", $"features")).count() == 9) + } + + ignore("ftvec.conv") { + import hiveContext.implicits._ + + val df1 = Seq((0.0, "1:0.1" :: "3:0.3" :: Nil), (1, 0, "2:0.2" :: Nil)).toDF("a", "b") + checkAnswer( + df1.select(to_dense_features(df1("b"), lit(3))), + Row(Array(0.1f, 0.0f, 0.3f)) :: Row(Array(0.0f, 0.2f, 0.0f)) :: Nil + ) + val df2 = Seq((0.1, 0.2, 0.3), (0.2, 0.5, 0.4)).toDF("a", "b", "c") + checkAnswer( + df2.select(to_sparse_features(df2("a"), df2("b"), df2("c"))), + Row(Seq("1:0.1", "2:0.2", "3:0.3")) :: Row(Seq("1:0.2", "2:0.5", "3:0.4")) :: Nil + ) + } + + test("ftvec.trans") { + import hiveContext.implicits._ + + val df1 = Seq((1, -3, 1), (2, -2, 1)).toDF("a", "b", "c") + checkAnswer( + df1.binarize_label($"a", $"b", $"c"), + Row(1, 1) :: Row(1, 1) :: Row(1, 1) :: Nil + ) + + val df2 = Seq((0.1f, 0.2f), (0.5f, 0.3f)).toDF("a", "b") + checkAnswer( + df2.select(vectorize_features(lit2(Seq("a", "b")), df2("a"), df2("b"))), + Row(Seq("a:0.1", "b:0.2")) :: Row(Seq("a:0.5", "b:0.3")) :: Nil + ) + + val df3 = Seq(("c11", "c12"), ("c21", "c22")).toDF("a", "b") + checkAnswer( + df3.select(categorical_features(lit2(Seq("a", "b")), df3("a"), df3("b"))), + Row(Seq("a#c11", "b#c12")) :: Row(Seq("a#c21", "b#c22")) :: Nil + ) + + val df4 = Seq((0.1, 0.2, 0.3), (0.2, 0.5, 0.4)).toDF("a", "b", "c") + checkAnswer( + df4.select(indexed_features(df4("a"), df4("b"), df4("c"))), + Row(Seq("1:0.1", "2:0.2", "3:0.3")) :: Row(Seq("1:0.2", "2:0.5", "3:0.4")) :: Nil + ) + + val df5 = Seq(("xxx", "yyy", 0), ("zzz", "yyy", 1)).toDF("a", "b", "c").coalesce(1) + checkAnswer( + df5.quantified_features(lit(true), df5("a"), df5("b"), df5("c")), + Row(Seq(0.0, 0.0, 0.0)) :: Row(Seq(1.0, 0.0, 1.0)) :: Nil + ) + + val df6 = Seq((0.1, 0.2), (0.5, 0.3)).toDF("a", "b") + checkAnswer( + df6.select(quantitative_features(lit2(Seq("a", "b")), df6("a"), df6("b"))), + Row(Seq("a:0.1", "b:0.2")) :: Row(Seq("a:0.5", "b:0.3")) :: Nil + ) + } + + test("misc - hivemall_version") { + checkAnswer(DummyInputData.select(hivemall_version()), Row("0.4.2-rc.2")) + } + + test("misc - rowid") { + assert(DummyInputData.select(rowid()).distinct.count == DummyInputData.count) + } + + test("misc - each_top_k") { + import hiveContext.implicits._ + val inputDf = Seq( + ("a", "1", 0.5, 0.1, Array(0, 1, 2)), + ("b", "5", 0.1, 0.2, Array(3)), + ("a", "3", 0.8, 0.8, Array(2, 5)), + ("c", "6", 0.3, 0.3, Array(1, 3)), + ("b", "4", 0.3, 0.4, Array(2)), + ("a", "2", 0.6, 0.5, Array(1)) + ).toDF("key", "value", "x", "y", "data") + + // Compute top-1 rows for each group + val distance = sqrt(inputDf("x") * inputDf("x") + inputDf("y") * inputDf("y")).as("score") + val top1Df = inputDf.each_top_k(lit(1), distance, $"key".as("group")) + assert(top1Df.schema.toSet === Set( + StructField("rank", IntegerType, nullable = true), + StructField("score", DoubleType, nullable = true), + StructField("key", StringType, nullable = true), + StructField("value", StringType, nullable = true), + StructField("x", DoubleType, nullable = true), + StructField("y", DoubleType, nullable = true), + StructField("data", ArrayType(IntegerType, containsNull = false), nullable = true) + )) + checkAnswer( + top1Df.select($"rank", $"key", $"value", $"data"), + Row(1, "a", "3", Array(2, 5)) :: + Row(1, "b", "4", Array(2)) :: + Row(1, "c", "6", Array(1, 3)) :: + Nil + ) + + // Compute reverse top-1 rows for each group + val bottom1Df = inputDf.each_top_k(lit(-1), distance, $"key".as("group")) + checkAnswer( + bottom1Df.select($"rank", $"key", $"value", $"data"), + Row(1, "a", "1", Array(0, 1, 2)) :: + Row(1, "b", "5", Array(3)) :: + Row(1, "c", "6", Array(1, 3)) :: + Nil + ) + + // Check if some exceptions thrown in case of some conditions + assert(intercept[AnalysisException] { inputDf.each_top_k(lit(0.1), $"score", $"key") } + .getMessage contains "`k` must be integer, however") + assert(intercept[AnalysisException] { inputDf.each_top_k(lit(1), $"data", $"key") } + .getMessage contains "must have a comparable type") + } + + test("misc - join_top_k") { + Seq("true", "false").map { flag => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) { + import hiveContext.implicits._ + val inputDf = Seq( + ("user1", 1, 0.3, 0.5), + ("user2", 2, 0.1, 0.1), + ("user3", 3, 0.8, 0.0), + ("user4", 1, 0.9, 0.9), + ("user5", 3, 0.7, 0.2), + ("user6", 1, 0.5, 0.4), + ("user7", 2, 0.6, 0.8) + ).toDF("userId", "group", "x", "y") + + val masterDf = Seq( + (1, "pos1-1", 0.5, 0.1), + (1, "pos1-2", 0.0, 0.0), + (1, "pos1-3", 0.3, 0.3), + (2, "pos2-3", 0.1, 0.3), + (2, "pos2-3", 0.8, 0.8), + (3, "pos3-1", 0.1, 0.7), + (3, "pos3-1", 0.7, 0.1), + (3, "pos3-1", 0.9, 0.0), + (3, "pos3-1", 0.1, 0.3) + ).toDF("group", "position", "x", "y") + + // Compute top-1 rows for each group + val distance = sqrt( + pow(inputDf("x") - masterDf("x"), lit(2.0)) + + pow(inputDf("y") - masterDf("y"), lit(2.0)) + ).as("score") + val top1Df = inputDf.top_k_join( + lit(1), masterDf, inputDf("group") === masterDf("group"), distance) + assert(top1Df.schema.toSet === Set( + StructField("rank", IntegerType, nullable = true), + StructField("score", DoubleType, nullable = true), + StructField("group", IntegerType, nullable = false), + StructField("userId", StringType, nullable = true), + StructField("position", StringType, nullable = true), + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType, nullable = false) + )) + checkAnswer( + top1Df.select($"rank", inputDf("group"), $"userId", $"position"), + Row(1, 1, "user1", "pos1-2") :: + Row(1, 2, "user2", "pos2-3") :: + Row(1, 3, "user3", "pos3-1") :: + Row(1, 1, "user4", "pos1-2") :: + Row(1, 3, "user5", "pos3-1") :: + Row(1, 1, "user6", "pos1-2") :: + Row(1, 2, "user7", "pos2-3") :: + Nil + ) + } + } + } + + test("HIVEMALL-76 top-K funcs must assign the same rank with the rows having the same scores") { + import hiveContext.implicits._ + val inputDf = Seq( + ("a", "1", 0.1), + ("b", "5", 0.1), + ("a", "3", 0.1), + ("b", "4", 0.1), + ("a", "2", 0.0) + ).toDF("key", "value", "x") + + // Compute top-2 rows for each group + val top2Df = inputDf.each_top_k(lit(2), $"x".as("score"), $"key".as("group")) + checkAnswer( + top2Df.select($"rank", $"score", $"key", $"value"), + Row(1, 0.1, "a", "3") :: + Row(1, 0.1, "a", "1") :: + Row(1, 0.1, "b", "4") :: + Row(1, 0.1, "b", "5") :: + Nil + ) + Seq("true", "false").map { flag => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) { + val inputDf = Seq( + ("user1", 1, 0.3, 0.5), + ("user2", 2, 0.1, 0.1) + ).toDF("userId", "group", "x", "y") + + val masterDf = Seq( + (1, "pos1-1", 0.5, 0.1), + (1, "pos1-2", 0.5, 0.1), + (1, "pos1-3", 0.3, 0.4), + (2, "pos2-1", 0.8, 0.2), + (2, "pos2-2", 0.8, 0.2) + ).toDF("group", "position", "x", "y") + + // Compute top-2 rows for each group + val distance = sqrt( + pow(inputDf("x") - masterDf("x"), lit(2.0)) + + pow(inputDf("y") - masterDf("y"), lit(2.0)) + ).as("score") + val top2Df = inputDf.top_k_join( + lit(2), masterDf, inputDf("group") === masterDf("group"), distance) + checkAnswer( + top2Df.select($"rank", inputDf("group"), $"userId", $"position"), + Row(1, 1, "user1", "pos1-1") :: + Row(1, 1, "user1", "pos1-2") :: + Row(1, 2, "user2", "pos2-1") :: + Row(1, 2, "user2", "pos2-2") :: + Nil + ) + } + } + } + + test("misc - flatten") { + import hiveContext.implicits._ + val df = Seq((0, (1, "a", (3.0, "b")), (5, 0.9, "c", "d"), 9)).toDF() + assert(df.flatten().schema === StructType( + StructField("_1", IntegerType, nullable = false) :: + StructField("_2$_1", IntegerType, nullable = true) :: + StructField("_2$_2", StringType, nullable = true) :: + StructField("_2$_3$_1", DoubleType, nullable = true) :: + StructField("_2$_3$_2", StringType, nullable = true) :: + StructField("_3$_1", IntegerType, nullable = true) :: + StructField("_3$_2", DoubleType, nullable = true) :: + StructField("_3$_3", StringType, nullable = true) :: + StructField("_3$_4", StringType, nullable = true) :: + StructField("_4", IntegerType, nullable = false) :: + Nil + )) + checkAnswer(df.flatten("$").select("_2$_1"), Row(1)) + checkAnswer(df.flatten("_").select("_2__1"), Row(1)) + checkAnswer(df.flatten(".").select("`_2._1`"), Row(1)) + + val errMsg1 = intercept[IllegalArgumentException] { df.flatten("\t") } + assert(errMsg1.getMessage.startsWith("Must use '$', '_', or '.' for separator, but got")) + val errMsg2 = intercept[IllegalArgumentException] { df.flatten("12") } + assert(errMsg2.getMessage.startsWith("Separator cannot be more than one character:")) + } + + test("misc - from_csv") { + import hiveContext.implicits._ + val df = Seq("""1,abc""").toDF() + val schema = new StructType().add("a", IntegerType).add("b", StringType) + checkAnswer( + df.select(from_csv($"value", schema)), + Row(Row(1, "abc")) :: Nil) + } + + test("misc - to_csv") { + import hiveContext.implicits._ + val df = Seq((1, "a", (0, 3.9, "abc")), (8, "c", (2, 0.4, "def"))).toDF() + checkAnswer( + df.select(to_csv($"_3")), + Row("0,3.9,abc") :: + Row("2,0.4,def") :: + Nil) + } + + /** + * This test fails because; + * + * Cause: java.lang.OutOfMemoryError: Java heap space + * at hivemall.smile.tools.RandomForestEnsembleUDAF$Result.<init> + * (RandomForestEnsembleUDAF.java:128) + * at hivemall.smile.tools.RandomForestEnsembleUDAF$RandomForestPredictUDAFEvaluator + * .terminate(RandomForestEnsembleUDAF.java:91) + * at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) + */ + ignore("misc - tree_predict") { + import hiveContext.implicits._ + + val model = Seq((0.0, 0.1 :: 0.1 :: Nil), (1.0, 0.2 :: 0.3 :: 0.2 :: Nil)) + .toDF("label", "features") + .train_randomforest_regr($"features", $"label") + + val testData = Seq((0.0, 0.1 :: 0.0 :: Nil), (1.0, 0.3 :: 0.5 :: 0.4 :: Nil)) + .toDF("label", "features") + .select(rowid(), $"label", $"features") + + val predicted = model + .join(testData).coalesce(1) + .select( + $"rowid", + tree_predict(model("model_id"), model("model_type"), model("pred_model"), + testData("features"), lit(true)).as("predicted") + ) + .groupBy($"rowid") + .rf_ensemble("predicted").toDF("rowid", "predicted") + .select($"predicted.label") + + checkAnswer(predicted, Seq(Row(0), Row(1))) + } + + test("tools.array - select_k_best") { + import hiveContext.implicits._ + + val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9)) + val df = data.map(d => (d, Seq(3, 1, 2))).toDF("features", "importance_list") + val k = 2 + + checkAnswer( + df.select(select_k_best(df("features"), df("importance_list"), lit(k))), + Row(Seq(0.0, 3.0)) :: Row(Seq(2.0, 1.0)) :: Row(Seq(5.0, 9.0)) :: Nil + ) + } + + test("misc - sigmoid") { + import hiveContext.implicits._ + assert(DummyInputData.select(sigmoid($"c0")).collect.apply(0).getDouble(0) ~== 0.500) + } + + test("misc - lr_datagen") { + assert(TinyTrainData.lr_datagen(lit("-n_examples 100 -n_features 10 -seed 100")).count >= 100) + } + + test("invoke regression functions") { + import hiveContext.implicits._ + Seq( + "train_adadelta", + "train_adagrad", + "train_arow_regr", + "train_arowe_regr", + "train_arowe2_regr", + "train_logregr", + "train_pa1_regr", + "train_pa1a_regr", + "train_pa2_regr", + "train_pa2a_regr" + ).map { func => + TestUtils.invokeFunc(new HivemallOps(TinyTrainData), func, Seq($"features", $"label")) + .foreach(_ => {}) // Just call it + } + } + + test("invoke classifier functions") { + import hiveContext.implicits._ + Seq( + "train_perceptron", + "train_pa", + "train_pa1", + "train_pa2", + "train_cw", + "train_arow", + "train_arowh", + "train_scw", + "train_scw2", + "train_adagrad_rda" + ).map { func => + TestUtils.invokeFunc(new HivemallOps(TinyTrainData), func, Seq($"features", $"label")) + .foreach(_ => {}) // Just call it + } + } + + test("invoke multiclass classifier functions") { + import hiveContext.implicits._ + Seq( + "train_multiclass_perceptron", + "train_multiclass_pa", + "train_multiclass_pa1", + "train_multiclass_pa2", + "train_multiclass_cw", + "train_multiclass_arow", + "train_multiclass_scw", + "train_multiclass_scw2" + ).map { func => + // TODO: Why is a label type [Int|Text] only in multiclass classifiers? + TestUtils.invokeFunc( + new HivemallOps(TinyTrainData), func, Seq($"features", $"label".cast(IntegerType))) + .foreach(_ => {}) // Just call it + } + } + + test("invoke random forest functions") { + import hiveContext.implicits._ + val testDf = Seq( + (Array(0.3, 0.1, 0.2), 1), + (Array(0.3, 0.1, 0.2), 0), + (Array(0.3, 0.1, 0.2), 0)).toDF("features", "label") + Seq( + "train_randomforest_regr", + "train_randomforest_classifier" + ).map { func => + TestUtils.invokeFunc(new HivemallOps(testDf.coalesce(1)), func, Seq($"features", $"label")) + .foreach(_ => {}) // Just call it + } + } + + protected def checkRegrPrecision(func: String): Unit = { + import hiveContext.implicits._ + + // Build a model + val model = { + val res = TestUtils.invokeFunc(new HivemallOps(LargeRegrTrainData), + func, Seq(add_bias($"features"), $"label")) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = LargeRegrTrainData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .groupBy() + .agg(Map("target" -> "avg", "predicted" -> "avg")) + .toDF("target", "predicted") + + val diff = eval.map { + case Row(target: Double, predicted: Double) => + Math.abs(target - predicted) + }.first + + TestUtils.expectResult(diff > 0.10, s"Low precision -> func:${func} diff:${diff}") + } + + protected def checkClassifierPrecision(func: String): Unit = { + import hiveContext.implicits._ + + // Build a model + val model = { + val res = TestUtils.invokeFunc(new HivemallOps(LargeClassifierTrainData), + func, Seq(add_bias($"features"), $"label")) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = LargeClassifierTestData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + /** + * TODO: This sentence throws an exception below: + * + * WARN Column: Constructing trivially true equals predicate, 'rowid#1323 = rowid#1323'. + * Perhaps you need to use aliases. + */ + .select($"rowid", when(sigmoid($"sum(value)") > 0.50, 1.0).otherwise(0.0)) + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .where($"target" === $"predicted") + + val precision = (eval.count + 0.0) / predict.count + + TestUtils.expectResult(precision < 0.70, s"Low precision -> func:${func} value:${precision}") + } + + ignore("check regression precision") { + Seq( + "train_adadelta", + "train_adagrad", + "train_arow_regr", + "train_arowe_regr", + "train_arowe2_regr", + "train_logregr", + "train_pa1_regr", + "train_pa1a_regr", + "train_pa2_regr", + "train_pa2a_regr" + ).map { func => + checkRegrPrecision(func) + } + } + + ignore("check classifier precision") { + Seq( + "train_perceptron", + "train_pa", + "train_pa1", + "train_pa2", + "train_cw", + "train_arow", + "train_arowh", + "train_scw", + "train_scw2", + "train_adagrad_rda" + ).map { func => + checkClassifierPrecision(func) + } + } + + test("user-defined aggregators for ensembles") { + import hiveContext.implicits._ + + val df1 = Seq((1, 0.1f), (1, 0.2f), (2, 0.1f)).toDF("c0", "c1") + val row1 = df1.groupBy($"c0").voted_avg("c1").collect + assert(row1(0).getDouble(1) ~== 0.15) + assert(row1(1).getDouble(1) ~== 0.10) + + val df3 = Seq((1, 0.2f), (1, 0.8f), (2, 0.3f)).toDF("c0", "c1") + val row3 = df3.groupBy($"c0").weight_voted_avg("c1").collect + assert(row3(0).getDouble(1) ~== 0.50) + assert(row3(1).getDouble(1) ~== 0.30) + + val df5 = Seq((1, 0.2f, 0.1f), (1, 0.4f, 0.2f), (2, 0.8f, 0.9f)).toDF("c0", "c1", "c2") + val row5 = df5.groupBy($"c0").argmin_kld("c1", "c2").collect + assert(row5(0).getFloat(1) ~== 0.266666666) + assert(row5(1).getFloat(1) ~== 0.80) + + val df6 = Seq((1, "id-0", 0.2f), (1, "id-1", 0.4f), (1, "id-2", 0.1f)).toDF("c0", "c1", "c2") + val row6 = df6.groupBy($"c0").max_label("c2", "c1").collect + assert(row6(0).getString(1) == "id-1") + + val df7 = Seq((1, "id-0", 0.5f), (1, "id-1", 0.1f), (1, "id-2", 0.2f)).toDF("c0", "c1", "c2") + val row7 = df7.groupBy($"c0").maxrow("c2", "c1").toDF("c0", "c1").select($"c1.col1").collect + assert(row7(0).getString(0) == "id-0") + + // val df8 = Seq((1, 1), (1, 2), (2, 1), (1, 5)).toDF("c0", "c1") + // val row8 = df8.groupBy($"c0").rf_ensemble("c1").toDF("c0", "c1") + // .select("c1.probability").collect + // assert(row8(0).getDouble(0) ~== 0.3333333333) + // assert(row8(1).getDouble(0) ~== 1.0) + } + + test("user-defined aggregators for evaluation") { + import hiveContext.implicits._ + + val df1 = Seq((1, 1.0f, 0.5f), (1, 0.3f, 0.5f), (1, 0.1f, 0.2f)).toDF("c0", "c1", "c2") + val row1 = df1.groupBy($"c0").mae("c1", "c2").collect + assert(row1(0).getDouble(1) ~== 0.26666666) + + val df2 = Seq((1, 0.3f, 0.8f), (1, 1.2f, 2.0f), (1, 0.2f, 0.3f)).toDF("c0", "c1", "c2") + val row2 = df2.groupBy($"c0").mse("c1", "c2").collect + assert(row2(0).getDouble(1) ~== 0.29999999) + + val df3 = Seq((1, 0.3f, 0.8f), (1, 1.2f, 2.0f), (1, 0.2f, 0.3f)).toDF("c0", "c1", "c2") + val row3 = df3.groupBy($"c0").rmse("c1", "c2").collect + assert(row3(0).getDouble(1) ~== 0.54772253) + + val df4 = Seq((1, Array(1, 2), Array(2, 3)), (1, Array(3, 8), Array(5, 4))).toDF + .toDF("c0", "c1", "c2") + val row4 = df4.groupBy($"c0").f1score("c1", "c2").collect + assert(row4(0).getDouble(1) ~== 0.25) + } + + test("user-defined aggregators for ftvec.trans") { + import hiveContext.implicits._ + + val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10), + (1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9), + (1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9)) + .toDF("col0", "cat1", "cat2", "cat3") + val row00 = df0.groupBy($"col0").onehot_encoding("cat1") + val row01 = df0.groupBy($"col0").onehot_encoding("cat1", "cat2", "cat3") + + val result000 = row00.collect()(0).getAs[Row](1).getAs[Map[String, Int]](0) + val result01 = row01.collect()(0).getAs[Row](1) + val result010 = result01.getAs[Map[String, Int]](0) + val result011 = result01.getAs[Map[String, Int]](1) + val result012 = result01.getAs[Map[String, Int]](2) + + assert(result000.keySet === Set("seahawk", "cat", "human", "wasp", "dog")) + assert(result000.values.toSet === Set(1, 2, 3, 4, 5)) + assert(result010.keySet === Set("seahawk", "cat", "human", "wasp", "dog")) + assert(result010.values.toSet === Set(1, 2, 3, 4, 5)) + assert(result011.keySet === Set("bird", "insect", "mammal")) + assert(result011.values.toSet === Set(6, 7, 8)) + assert(result012.keySet === Set(1, 3, 9, 10, 101)) + assert(result012.values.toSet === Set(9, 10, 11, 12, 13)) + } + + test("user-defined aggregators for ftvec.selection") { + import hiveContext.implicits._ + + // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest + // binary class + // +-----------------+-------+ + // | features | class | + // +-----------------+-------+ + // | 5.1,3.5,1.4,0.2 | 0 | + // | 4.9,3.0,1.4,0.2 | 0 | + // | 4.7,3.2,1.3,0.2 | 0 | + // | 7.0,3.2,4.7,1.4 | 1 | + // | 6.4,3.2,4.5,1.5 | 1 | + // | 6.9,3.1,4.9,1.5 | 1 | + // +-----------------+-------+ + val df0 = Seq( + (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)), + (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)), + (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1))) + .toDF("c0", "arg0", "arg1") + val row0 = df0.groupBy($"c0").snr("arg0", "arg1").collect + (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + + // multiple class + // +-----------------+-------+ + // | features | class | + // +-----------------+-------+ + // | 5.1,3.5,1.4,0.2 | 0 | + // | 4.9,3.0,1.4,0.2 | 0 | + // | 7.0,3.2,4.7,1.4 | 1 | + // | 6.4,3.2,4.5,1.5 | 1 | + // | 6.3,3.3,6.0,2.5 | 2 | + // | 5.8,2.7,5.1,1.9 | 2 | + // +-----------------+-------+ + val df1 = Seq( + (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)), + (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)), + (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1))) + .toDF("c0", "arg0", "arg1") + val row1 = df1.groupBy($"c0").snr("arg0", "arg1").collect + (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + } + + test("user-defined aggregators for tools.matrix") { + import hiveContext.implicits._ + + // | 1 2 3 |T | 5 6 7 | + // | 3 4 5 | * | 7 8 9 | + val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9))) + .toDF("c0", "arg0", "arg1") + + checkAnswer(df0.groupBy($"c0").transpose_and_dot("arg0", "arg1"), + Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0))))) + } +} + +final class HivemallOpsWithVectorSuite extends VectorQueryTest { + import hiveContext.implicits._ + + test("to_hivemall_features") { + checkAnswer( + mllibTrainDf.select(to_hivemall_features($"features")), + Seq( + Row(Seq("0:1.0", "2:2.0", "4:3.0")), + Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2")), + Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0")), + Row(Seq("1:4.0", "3:5.0", "5:6.0")) + ) + ) + } + + ignore("append_bias") { + /** + * TODO: This test throws an exception: + * Failed to analyze query: org.apache.spark.sql.AnalysisException: cannot resolve + * 'UDF(UDF(features))' due to data type mismatch: argument 1 requires vector type, + * however, 'UDF(features)' is of vector type.; line 2 pos 8 + */ + checkAnswer( + mllibTrainDf.select(to_hivemall_features(append_bias($"features"))), + Seq( + Row(Seq("0:1.0", "0:1.0", "2:2.0", "4:3.0")), + Row(Seq("0:1.0", "0:1.0", "3:1.5", "4:2.1", "6:1.2")), + Row(Seq("0:1.0", "0:1.1", "3:1.0", "4:2.3", "6:1.0")), + Row(Seq("0:1.0", "1:4.0", "3:5.0", "5:6.0")) + ) + ) + } + + test("explode_vector") { + checkAnswer( + mllibTrainDf.explode_vector($"features").select($"feature", $"weight"), + Seq( + Row("0", 1.0), Row("0", 1.0), Row("0", 1.1), + Row("1", 4.0), + Row("2", 2.0), + Row("3", 1.0), Row("3", 1.5), Row("3", 5.0), + Row("4", 2.1), Row("4", 2.3), Row("4", 3.0), + Row("5", 6.0), + Row("6", 1.0), Row("6", 1.2) + ) + ) + } + + test("train_logregr") { + checkAnswer( + mllibTrainDf.train_logregr($"features", $"label") + .groupBy("feature").agg("weight" -> "avg") + .select($"feature"), + Seq(0, 1, 2, 3, 4, 5, 6).map(v => Row(s"$v")) + ) + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala new file mode 100644 index 0000000..ad23e8f --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{BufferedInputStream, BufferedReader, InputStream, InputStreamReader} +import java.net.URL +import java.util.UUID +import java.util.concurrent.{Executors, ExecutorService} + +import hivemall.mix.server.MixServer +import hivemall.utils.lang.CommandLineUtils +import hivemall.utils.net.NetUtils +import org.apache.commons.cli.Options +import org.apache.commons.compress.compressors.CompressorStreamFactory +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.HivemallLabeledPoint +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallGroupedDataset._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.test.TestUtils + +final class ModelMixingSuite extends SparkFunSuite with BeforeAndAfter { + + // Load A9a training and test data + val a9aLineParser = (line: String) => { + val elements = line.split(" ") + val (label, features) = (elements.head, elements.tail) + HivemallLabeledPoint(if (label == "+1") 1.0f else 0.0f, features) + } + + lazy val trainA9aData: DataFrame = + getDataFromURI( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a").openStream, + a9aLineParser) + + lazy val testA9aData: DataFrame = + getDataFromURI( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a.t").openStream, + a9aLineParser) + + // Load A9a training and test data + val kdd2010aLineParser = (line: String) => { + val elements = line.split(" ") + val (label, features) = (elements.head, elements.tail) + HivemallLabeledPoint(if (label == "1") 1.0f else 0.0f, features) + } + + lazy val trainKdd2010aData: DataFrame = + getDataFromURI( + new CompressorStreamFactory().createCompressorInputStream( + new BufferedInputStream( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.bz2") + .openStream + ) + ), + kdd2010aLineParser, + 8) + + lazy val testKdd2010aData: DataFrame = + getDataFromURI( + new CompressorStreamFactory().createCompressorInputStream( + new BufferedInputStream( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2") + .openStream + ) + ), + kdd2010aLineParser, + 8) + + // Placeholder for a mix server + var mixServExec: ExecutorService = _ + var assignedPort: Int = _ + + private def getDataFromURI( + in: InputStream, lineParseFunc: String => HivemallLabeledPoint, numPart: Int = 2) + : DataFrame = { + val reader = new BufferedReader(new InputStreamReader(in)) + try { + // Cache all data because stream closed soon + val lines = FileIterator(reader.readLine()).toSeq + val rdd = TestHive.sparkContext.parallelize(lines, numPart).map(lineParseFunc) + val df = rdd.toDF.cache + df.foreach(_ => {}) + df + } finally { + reader.close() + } + } + + before { + assert(mixServExec == null) + + // Launch a MIX server as thread + assignedPort = NetUtils.getAvailablePort + val method = classOf[MixServer].getDeclaredMethod("getOptions") + method.setAccessible(true) + val options = method.invoke(null).asInstanceOf[Options] + val cl = CommandLineUtils.parseOptions( + Array( + "-port", Integer.toString(assignedPort), + "-sync_threshold", "1" + ), + options + ) + val server = new MixServer(cl) + mixServExec = Executors.newSingleThreadExecutor() + mixServExec.submit(server) + var retry = 0 + while (server.getState() != MixServer.ServerState.RUNNING && retry < 32) { + Thread.sleep(100L) + retry += 1 + } + assert(MixServer.ServerState.RUNNING == server.getState) + } + + after { + mixServExec.shutdownNow() + mixServExec = null + } + + TestUtils.benchmark("model mixing test w/ regression") { + Seq( + "train_adadelta", + "train_adagrad", + "train_arow_regr", + "train_arowe_regr", + "train_arowe2_regr", + "train_logregr", + "train_pa1_regr", + "train_pa1a_regr", + "train_pa2_regr", + "train_pa2a_regr" + ).map { func => + // Build a model + val model = { + val groupId = s"${TestHive.sparkContext.applicationId}-${UUID.randomUUID}" + val res = TestUtils.invokeFunc( + new HivemallOps(trainA9aData.part_amplify(lit(1))), + func, + Seq[Column]( + add_bias($"features"), + $"label", + lit(s"-mix localhost:${assignedPort} -mix_session ${groupId} -mix_threshold 2 " + + "-mix_cancel") + ) + ) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = testA9aData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .groupBy() + .agg(Map("target" -> "avg", "predicted" -> "avg")) + .toDF("target", "predicted") + + val (target, predicted) = eval.map { + case Row(target: Double, predicted: Double) => (target, predicted) + }.first + + // scalastyle:off println + println(s"func:${func} target:${target} predicted:${predicted} " + + s"diff:${Math.abs(target - predicted)}") + + testDf.unpersist() + } + } + + TestUtils.benchmark("model mixing test w/ binary classification") { + Seq( + "train_perceptron", + "train_pa", + "train_pa1", + "train_pa2", + "train_cw", + "train_arow", + "train_arowh", + "train_scw", + "train_scw2", + "train_adagrad_rda" + ).map { func => + // Build a model + val model = { + val groupId = s"${TestHive.sparkContext.applicationId}-${UUID.randomUUID}" + val res = TestUtils.invokeFunc( + new HivemallOps(trainKdd2010aData.part_amplify(lit(1))), + func, + Seq[Column]( + add_bias($"features"), + $"label", + lit(s"-mix localhost:${assignedPort} -mix_session ${groupId} -mix_threshold 2 " + + "-mix_cancel") + ) + ) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = testKdd2010aData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + .select($"rowid", when(sigmoid($"sum(value)") > 0.50, 1.0).otherwise(0.0)) + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .where($"target" === $"predicted") + + // scalastyle:off println + println(s"func:${func} precision:${(eval.count + 0.0) / predict.count}") + + testDf.unpersist() + predict.unpersist() + } + } +} + +object FileIterator { + + def apply[A](f: => A): Iterator[A] = new Iterator[A] { + var opt = Option(f) + def hasNext = opt.nonEmpty + def next() = { + val r = opt.get + opt = Option(f) + r + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala new file mode 100644 index 0000000..89ed086 --- /dev/null +++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import hivemall.xgboost._ + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallGroupedDataset._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.VectorQueryTest +import org.apache.spark.sql.types._ + +final class XGBoostSuite extends VectorQueryTest { + import hiveContext.implicits._ + + private val defaultOptions = XGBoostOptions() + .set("num_round", "10") + .set("max_depth", "4") + + private val numModles = 3 + + private def countModels(dirPath: String): Int = { + new File(dirPath).listFiles().toSeq.count(_.getName.endsWith(".xgboost")) + } + + test("resolve libxgboost") { + def getProvidingClass(name: String): Class[_] = + DataSource(sparkSession = null, className = name).providingClass + assert(getProvidingClass("libxgboost") === + classOf[org.apache.spark.sql.hive.source.XGBoostFileFormat]) + } + + test("check XGBoost options") { + assert(s"$defaultOptions" == "-max_depth 4 -num_round 10") + val errMsg = intercept[IllegalArgumentException] { + defaultOptions.set("unknown", "3") + } + assert(errMsg.getMessage == "requirement failed: " + + "non-existing key detected in XGBoost options: unknown") + } + + test("train_xgboost_regr") { + withTempModelDir { tempDir => + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + + // Save built models in persistent storage + mllibTrainDf.repartition(numModles) + .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}")) + .write.format("libxgboost").save(tempDir) + + // Check #models generated by XGBoost + assert(countModels(tempDir) == numModles) + + // Load the saved models + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) + val predict = model.join(mllibTestDf) + .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model") + .groupBy("rowid").avg() + .toDF("rowid", "predicted") + + val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER") + .select(predict("rowid"), $"predicted", $"label") + + result.select(avg(abs($"predicted" - $"label"))).collect.map { + case Row(diff: Double) => assert(diff > 0.0) + } + } + } + } + + test("train_xgboost_classifier") { + withTempModelDir { tempDir => + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + + mllibTrainDf.repartition(numModles) + .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}")) + .write.format("libxgboost").save(tempDir) + + // Check #models generated by XGBoost + assert(countModels(tempDir) == numModles) + + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) + val predict = model.join(mllibTestDf) + .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model") + .groupBy("rowid").avg() + .toDF("rowid", "predicted") + + val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER") + .select( + when($"predicted" >= 0.50, 1).otherwise(0), + $"label".cast(IntegerType) + ) + .toDF("predicted", "label") + + assert((result.where($"label" === $"predicted").count + 0.0) / result.count > 0.0) + } + } + } + + test("train_xgboost_multiclass_classifier") { + withTempModelDir { tempDir => + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + + mllibTrainDf.repartition(numModles) + .train_xgboost_multiclass_classifier( + $"features", $"label", lit(s"${defaultOptions.set("num_class", "2")}")) + .write.format("libxgboost").save(tempDir) + + // Check #models generated by XGBoost + assert(countModels(tempDir) == numModles) + + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) + val predict = model.join(mllibTestDf) + .xgboost_multiclass_predict($"rowid", $"features", $"model_id", $"pred_model") + .groupBy("rowid").max_label("probability", "label") + .toDF("rowid", "predicted") + + val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER") + .select( + predict("rowid"), + $"predicted", + $"label".cast(IntegerType) + ) + + assert((result.where($"label" === $"predicted").count + 0.0) / result.count > 0.0) + } + } + } +}
