http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c53b9ff9/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala ---------------------------------------------------------------------- diff --git a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala deleted file mode 100644 index 8583e1c..0000000 --- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala +++ /dev/null @@ -1,1125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.hive - -import java.util.UUID - -import org.apache.spark.Logging -import org.apache.spark.ml.feature.HivemallFeature -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, NamedExpression} -import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper -import org.apache.spark.sql.types._ - -/** - * Hivemall wrapper and some utility functions for DataFrame. - * - * @groupname regression - * @groupname classifier - * @groupname classifier.multiclass - * @groupname ensemble - * @groupname knn.similarity - * @groupname knn.distance - * @groupname knn.lsh - * @groupname ftvec - * @groupname ftvec.amplify - * @groupname ftvec.hashing - * @groupname ftvec.scaling - * @groupname ftvec.conv - * @groupname ftvec.trans - * @groupname misc - */ -final class HivemallOps(df: DataFrame) extends Logging { - - /** - * An implicit conversion to avoid doing annoying transformation. - */ - @inline - private[this] implicit def toDataFrame(logicalPlan: LogicalPlan) = - DataFrame(df.sqlContext, logicalPlan) - - /** - * @see hivemall.regression.AdaDeltaUDTF - * @group regression - */ - @scala.annotation.varargs - def train_adadelta(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.regression.AdaDeltaUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.regression.AdaGradUDTF - * @group regression - */ - @scala.annotation.varargs - def train_adagrad(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.regression.AdaGradUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.regression.AROWRegressionUDTF - * @group regression - */ - @scala.annotation.varargs - def train_arow_regr(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.regression.AROWRegressionUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.regression.AROWRegressionUDTF$AROWe - * @group regression - */ - @scala.annotation.varargs - def train_arowe_regr(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.regression.AROWRegressionUDTF$AROWe"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.regression.AROWRegressionUDTF$AROWe2 - * @group regression - */ - @scala.annotation.varargs - def train_arowe2_regr(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.regression.AROWRegressionUDTF$AROWe2"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.regression.LogressUDTF - * @group regression - */ - @scala.annotation.varargs - def train_logregr(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.regression.LogressUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.regression.PassiveAggressiveRegressionUDTF - * @group regression - */ - @scala.annotation.varargs - def train_pa1_regr(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.regression.PassiveAggressiveRegressionUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.regression.PassiveAggressiveRegressionUDTF.PA1a - * @group regression - */ - @scala.annotation.varargs - def train_pa1a_regr(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.regression.PassiveAggressiveRegressionUDTF$PA1a"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.regression.PassiveAggressiveRegressionUDTF.PA2 - * @group regression - */ - @scala.annotation.varargs - def train_pa2_regr(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.regression.PassiveAggressiveRegressionUDTF$PA2"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.regression.PassiveAggressiveRegressionUDTF.PA2a - * @group regression - */ - @scala.annotation.varargs - def train_pa2a_regr(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.smile.regression.RandomForestRegressionUDTF - * @group regression - */ - @scala.annotation.varargs - def train_randomforest_regr(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.smile.regression.RandomForestRegressionUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("model_id", "model_type", "pred_model", "var_importance", "oob_errors", "oob_tests") - .map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.PerceptronUDTF - * @group classifier - */ - @scala.annotation.varargs - def train_perceptron(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.PerceptronUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.PassiveAggressiveUDTF - * @group classifier - */ - @scala.annotation.varargs - def train_pa(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.PassiveAggressiveUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.PassiveAggressiveUDTF$PA1 - * @group classifier - */ - @scala.annotation.varargs - def train_pa1(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.PassiveAggressiveUDTF$PA1"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.PassiveAggressiveUDTF$PA2 - * @group classifier - */ - @scala.annotation.varargs - def train_pa2(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.PassiveAggressiveUDTF$PA2"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.ConfidenceWeightedUDTF - * @group classifier - */ - @scala.annotation.varargs - def train_cw(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.ConfidenceWeightedUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.AROWClassifierUDTF - * @group classifier - */ - @scala.annotation.varargs - def train_arow(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.AROWClassifierUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.AROWClassifierUDTF$AROWh - * @group classifier - */ - @scala.annotation.varargs - def train_arowh(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.AROWClassifierUDTF$AROWh"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.SoftConfideceWeightedUDTF$SCW1 - * @group classifier - */ - @scala.annotation.varargs - def train_scw(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.SoftConfideceWeightedUDTF$SCW1"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.SoftConfideceWeightedUDTF$SCW1 - * @group classifier - */ - @scala.annotation.varargs - def train_scw2(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.SoftConfideceWeightedUDTF$SCW2"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.classifier.AdaGradRDAUDTF - * @group classifier - */ - @scala.annotation.varargs - def train_adagrad_rda(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.AdaGradRDAUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.smile.classification.RandomForestClassifierUDTF - * @group classifier - */ - @scala.annotation.varargs - def train_randomforest_classifier(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.smile.classification.RandomForestClassifierUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("model_id", "model_type", "pred_model", "var_importance", "oob_errors", "oob_tests") - .map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.classifier.MulticlassPerceptronUDTF - * @group classifier.multiclass - */ - @scala.annotation.varargs - def train_multiclass_perceptron(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.multiclass.MulticlassPerceptronUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("label", "feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.classifier.PassiveAggressiveUDTF - * @group classifier.multiclass - */ - @scala.annotation.varargs - def train_multiclass_pa(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("label", "feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.classifier.PassiveAggressiveUDTF$PA1 - * @group classifier.multiclass - */ - @scala.annotation.varargs - def train_multiclass_pa1(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper( - "hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF$PA1"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("label", "feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.classifier.PassiveAggressiveUDTF$PA2 - * @group classifier.multiclass - */ - @scala.annotation.varargs - def train_multiclass_pa2(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper( - "hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF$PA2"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("label", "feature", "weight").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.classifier.MulticlassConfidenceWeightedUDTF - * @group classifier.multiclass - */ - @scala.annotation.varargs - def train_multiclass_cw(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.multiclass.MulticlassConfidenceWeightedUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("label", "feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.classifier.MulticlassAROWClassifierUDTF - * @group classifier.multiclass - */ - @scala.annotation.varargs - def train_multiclass_arow(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.classifier.multiclass.MulticlassAROWClassifierUDTF"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("label", "feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.classifier.MulticlassSoftConfidenceWeightedUDTF$SCW1 - * @group classifier.multiclass - */ - @scala.annotation.varargs - def train_multiclass_scw(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper( - "hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF$SCW1"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("label", "feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.classifier.classifier.MulticlassSoftConfidenceWeightedUDTF$SCW2 - * @group classifier.multiclass - */ - @scala.annotation.varargs - def train_multiclass_scw2(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper( - "hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF$SCW2"), - setMixServs(exprs: _*).map(_.expr)), - join = false, outer = false, None, - Seq("label", "feature", "weight", "conv").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. - * See [[GroupedDataEx]] for all the available aggregate functions. - * - * TODO: This class bypasses the original GroupData - * so as to support user-defined aggregations. - * Need a more smart injection into existing DataFrame APIs. - * - * A list of added Hivemall UDAF: - * - voted_avg - * - weight_voted_avg - * - argmin_kld - * - max_label - * - maxrow - * - f1score - * - mae - * - mse - * - rmse - * - * @groupname ensemble - */ - @scala.annotation.varargs - def groupby(cols: Column*): GroupedDataEx = { - new GroupedDataEx(df, cols.map(_.expr), GroupedData.GroupByType) - } - - @scala.annotation.varargs - def groupby(col1: String, cols: String*): GroupedDataEx = { - val colNames: Seq[String] = col1 +: cols - new GroupedDataEx(df, colNames.map(colName => df(colName).expr), GroupedData.GroupByType) - } - - /** - * @see hivemall.knn.lsh.MinHashUDTF - * @group knn.lsh - */ - @scala.annotation.varargs - def minhash(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.knn.lsh.MinHashUDTF"), - exprs.map(_.expr)), - join = false, outer = false, None, - Seq("clusterid", "item").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.ftvec.amplify.AmplifierUDTF - * @group ftvec.amplify - */ - @scala.annotation.varargs - def amplify(exprs: Column*): DataFrame = { - val outputAttr = exprs.drop(1).map { - case Column(expr: NamedExpression) => UnresolvedAttribute(expr.name) - case Column(expr: Expression) => UnresolvedAttribute(expr.prettyString) - } - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.ftvec.amplify.AmplifierUDTF"), - exprs.map(_.expr)), - join = false, outer = false, None, - outputAttr, - df.logicalPlan) - } - - /** - * @see hivemall.ftvec.amplify.RandomAmplifierUDTF - * @group ftvec.amplify - */ - @scala.annotation.varargs - def rand_amplify(exprs: Column*): DataFrame = { - val outputAttr = exprs.drop(2).map { - case Column(expr: NamedExpression) => UnresolvedAttribute(expr.name) - case Column(expr: Expression) => UnresolvedAttribute(expr.prettyString) - } - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.ftvec.amplify.RandomAmplifierUDTF"), - exprs.map(_.expr)), - join = false, outer = false, None, - outputAttr, - df.logicalPlan) - } - - /** - * Amplifies and shuffle data inside partitions. - * @group ftvec.amplify - */ - def part_amplify(xtimes: Int): DataFrame = { - val rdd = df.rdd.mapPartitions({ iter => - val elems = iter.flatMap{ row => - Seq.fill[Row](xtimes)(row) - } - // Need to check how this shuffling affects results - scala.util.Random.shuffle(elems) - }, true) - df.sqlContext.createDataFrame(rdd, df.schema) - } - - /** - * Quantifies input columns. - * @see hivemall.ftvec.conv.QuantifyColumnsUDTF - * @group ftvec.conv - */ - @scala.annotation.varargs - def quantify(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.ftvec.conv.QuantifyColumnsUDTF"), - exprs.map(_.expr)), - join = false, outer = false, None, - (0 until exprs.size - 1).map(i => s"c$i").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.ftvec.trans.BinarizeLabelUDTF - * @group ftvec.trans - */ - @scala.annotation.varargs - def binarize_label(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.ftvec.trans.BinarizeLabelUDTF"), - exprs.map(_.expr)), - join = false, outer = false, None, - (0 until exprs.size - 1).map(i => s"c$i").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * @see hivemall.ftvec.trans.QuantifiedFeaturesUDTF - * @group ftvec.trans - */ - @scala.annotation.varargs - def quantified_features(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.ftvec.trans.QuantifiedFeaturesUDTF"), - exprs.map(_.expr)), - join = false, outer = false, None, - Seq("features").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * Splits Seq[String] into pieces. - * @group ftvec - */ - def explode_array(expr: Column): DataFrame = { - df.explode(expr) { case Row(v: Seq[_]) => - // Type erasure removes the component type in Seq - v.map(s => HivemallFeature(s.asInstanceOf[String])) - } - } - - def explode_array(expr: String): DataFrame = - this.explode_array(df(expr)) - - /** - * Returns a top-`k` records for each `group`. - * @group misc - */ - def each_top_k(k: Column, group: Column, value: Column, args: Column*): DataFrame = { - val clusterDf = df.repartition(group).sortWithinPartitions(group) - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.tools.EachTopKUDTF"), - (Seq(k, group, value) ++ args).map(_.expr)), - join = false, outer = false, None, - (Seq("rank", "key") ++ args.map(_.named.name)).map(UnresolvedAttribute(_)), - clusterDf.logicalPlan) - } - - /** - * Returns a new [[DataFrame]] with columns renamed. - * This is a wrapper for DataFrame#toDF. - * @group misc - */ - @scala.annotation.varargs - def as(colNames: String*): DataFrame = df.toDF(colNames: _*) - - /** - * Returns all the columns as Seq[Column] in this [[DataFrame]]. - * @group misc - */ - def cols: Seq[Column] = { - df.schema.fields.map(col => df.col(col.name)).toSeq - } - - /** - * @see hivemall.dataset.LogisticRegressionDataGeneratorUDTF - * @group misc - */ - @scala.annotation.varargs - def lr_datagen(exprs: Column*): DataFrame = { - Generate(HiveGenericUDTF( - new HiveFunctionWrapper("hivemall.dataset.LogisticRegressionDataGeneratorUDTFWrapper"), - exprs.map(_.expr)), - join = false, outer = false, None, - Seq("label", "features").map(UnresolvedAttribute(_)), - df.logicalPlan) - } - - /** - * :: Experimental :: - * If a parameter '-mix' does not exist in a 3rd argument, - * set it from an environmental variable - * 'HIVEMALL_MIX_SERVERS'. - * - * TODO: This could work if '--deploy-mode' has 'client'; - * otherwise, we need to set HIVEMALL_MIX_SERVERS - * in all possible spark workers. - */ - private[this] def setMixServs(exprs: Column*): Seq[Column] = { - val mixes = System.getenv("HIVEMALL_MIX_SERVERS") - if (mixes != null && !mixes.isEmpty()) { - val groupId = df.sqlContext.sparkContext.applicationId + "-" + UUID.randomUUID - logInfo(s"set '${mixes}' as default mix servers (session: ${groupId})") - exprs.size match { - case 2 => - exprs :+ Column(Literal.create(s"-mix ${mixes} -mix_session ${groupId}", StringType)) - /** TODO: Add codes in the case where exprs.size == 3. */ - case _ => - exprs - } - } else { - exprs - } - } -} - -object HivemallOps { - - /** - * Implicitly inject the [[HivemallOps]] into [[DataFrame]]. - */ - implicit def dataFrameToHivemallOps(df: DataFrame): HivemallOps = - new HivemallOps(df) - - /** - * An implicit conversion to avoid doing annoying transformation. - */ - @inline private implicit def toColumn(expr: Expression) = Column(expr) - - /** - * @see hivemall.HivemallVersionUDF - * @group misc - */ - def hivemall_version(): Column = { - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.HivemallVersionUDF"), Nil) - } - - /** - * @see hivemall.knn.similarity.CosineSimilarityUDF - * @group knn.similarity - */ - @scala.annotation.varargs - def cosine_sim(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.knn.similarity.CosineSimilarityUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.similarity.JaccardIndexUDF - * @group knn.similarity - */ - @scala.annotation.varargs - def jaccard(exprs: Column*): Column = { - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.knn.similarity.JaccardIndexUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.similarity.AngularSimilarityUDF - * @group knn.similarity - */ - @scala.annotation.varargs - def angular_similarity(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.knn.similarity.AngularSimilarityUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.similarity.EuclidSimilarity - * @group knn.similarity - */ - @scala.annotation.varargs - def euclid_similarity(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.knn.similarity.EuclidSimilarity"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.similarity.Distance2SimilarityUDF - * @group knn.similarity - */ - @scala.annotation.varargs - def distance2similarity(exprs: Column*): Column = { - // TODO: Need a wrapper class because of using unsupported types - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.knn.similarity.Distance2SimilarityUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.distance.HammingDistanceUDF - * @group knn.distance - */ - @scala.annotation.varargs - def hamming_distance(exprs: Column*): Column = { - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.knn.distance.HammingDistanceUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.distance.PopcountUDF - * @group knn.distance - */ - @scala.annotation.varargs - def popcnt(exprs: Column*): Column = { - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.knn.distance.PopcountUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.distance.KLDivergenceUDF - * @group knn.distance - */ - @scala.annotation.varargs - def kld(exprs: Column*): Column = { - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.knn.distance.KLDivergenceUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.distance.EuclidDistanceUDF - * @group knn.distance - */ - @scala.annotation.varargs - def euclid_distance(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.knn.distance.EuclidDistanceUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.distance.CosineDistanceUDF - * @group knn.distance - */ - @scala.annotation.varargs - def cosine_distance(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.knn.distance.CosineDistanceUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.distance.AngularDistanceUDF - * @group knn.distance - */ - @scala.annotation.varargs - def angular_distance(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.knn.distance.AngularDistanceUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.distance.ManhattanDistanceUDF - * @group knn.distance - */ - @scala.annotation.varargs - def manhattan_distance(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.knn.distance.ManhattanDistanceUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.distance.MinkowskiDistanceUDF - * @group knn.distance - */ - @scala.annotation.varargs - def minkowski_distance (exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.knn.distance.MinkowskiDistanceUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.lsh.bBitMinHashUDF - * @group knn.lsh - */ - @scala.annotation.varargs - def bbit_minhash(exprs: Column*): Column = { - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.knn.lsh.bBitMinHashUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.knn.lsh.MinHashesUDF - * @group knn.lsh - */ - @scala.annotation.varargs - def minhashes(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.knn.lsh.MinHashesUDFWrapper"), exprs.map(_.expr)) - } - - /** - * @see hivemall.ftvec.AddBiasUDF - * @group ftvec - */ - @scala.annotation.varargs - def add_bias(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.ftvec.AddBiasUDFWrapper"), exprs.map(_.expr)) - } - - /** - * @see hivemall.ftvec.ExtractFeatureUdf - * @group ftvec - * - * TODO: This throws java.lang.ClassCastException because - * HiveInspectors.toInspector has a bug in spark. - * Need to fix it later. - */ - def extract_feature(expr: Column): Column = { - val hiveUdf = HiveGenericUDF( - new HiveFunctionWrapper("hivemall.ftvec.ExtractFeatureUDFWrapper"), - expr.expr :: Nil) - Column(hiveUdf).as("feature") - } - - /** - * @see hivemall.ftvec.ExtractWeightUdf - * @group ftvec - * - * TODO: This throws java.lang.ClassCastException because - * HiveInspectors.toInspector has a bug in spark. - * Need to fix it later. - */ - def extract_weight(expr: Column): Column = { - val hiveUdf = HiveGenericUDF( - new HiveFunctionWrapper("hivemall.ftvec.ExtractWeightUDFWrapper"), - expr.expr :: Nil) - Column(hiveUdf).as("value") - } - - /** - * @see hivemall.ftvec.AddFeatureIndexUDFWrapper - * @group ftvec - */ - def add_feature_index(expr: Column): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.ftvec.AddFeatureIndexUDFWrapper"), expr.expr :: Nil) - } - - /** - * @see hivemall.ftvec.SortByFeatureUDF - * @group ftvec - */ - def sort_by_feature(expr: Column): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.ftvec.SortByFeatureUDFWrapper"), expr.expr :: Nil) - } - - /** - * @see hivemall.ftvec.hashing.MurmurHash3UDF - * @group ftvec.hashing - */ - def mhash(expr: Column): Column = { - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.ftvec.hashing.MurmurHash3UDF"), expr.expr :: Nil) - } - - /** - * @see hivemall.ftvec.hashing.Sha1UDF - * @group ftvec.hashing - */ - def sha1(expr: Column): Column = { - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.ftvec.hashing.Sha1UDF"), expr.expr :: Nil) - } - - /** - * @see hivemall.ftvec.hashing.ArrayHashValuesUDF - * @group ftvec.hashing - */ - @scala.annotation.varargs - def array_hash_values(exprs: Column*): Column = { - // TODO: Need a wrapper class because of using unsupported types - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.ftvec.hashing.ArrayHashValuesUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.ftvec.hashing.ArrayPrefixedHashValuesUDF - * @group ftvec.hashing - */ - @scala.annotation.varargs - def prefixed_hash_values(exprs: Column*): Column = { - // TODO: Need a wrapper class because of using unsupported types - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.ftvec.hashing.ArrayPrefixedHashValuesUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.ftvec.scaling.RescaleUDF - * @group ftvec.scaling - */ - @scala.annotation.varargs - def rescale(exprs: Column*): Column = { - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.ftvec.scaling.RescaleUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.ftvec.scaling.ZScoreUDF - * @group ftvec.scaling - */ - @scala.annotation.varargs - def zscore(exprs: Column*): Column = { - HiveSimpleUDF(new HiveFunctionWrapper( - "hivemall.ftvec.scaling.ZScoreUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.ftvec.scaling.L2NormalizationUDF - * @group ftvec.scaling - */ - def normalize(expr: Column): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.ftvec.scaling.L2NormalizationUDFWrapper"), expr.expr :: Nil) - } - - /** - * @see hivemall.ftvec.selection.ChiSquareUDF - * @group ftvec.selection - */ - def chi2(observed: Column, expected: Column): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.ftvec.selection.ChiSquareUDF"), Seq(observed.expr, expected.expr)) - } - - /** - * @see hivemall.ftvec.conv.ToDenseFeaturesUDF - * @group ftvec.conv - */ - @scala.annotation.varargs - def to_dense_features(exprs: Column*): Column = { - // TODO: Need a wrapper class because of using unsupported types - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.ftvec.conv.ToDenseFeaturesUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.ftvec.conv.ToSparseFeaturesUDF - * @group ftvec.conv - */ - @scala.annotation.varargs - def to_sparse_features(exprs: Column*): Column = { - // TODO: Need a wrapper class because of using unsupported types - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.ftvec.conv.ToSparseFeaturesUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.ftvec.trans.VectorizeFeaturesUDF - * @group ftvec.trans - */ - @scala.annotation.varargs - def vectorize_features(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.ftvec.trans.VectorizeFeaturesUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.ftvec.trans.CategoricalFeaturesUDF - * @group ftvec.trans - */ - @scala.annotation.varargs - def categorical_features(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.ftvec.trans.CategoricalFeaturesUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.ftvec.trans.IndexedFeatures - * @group ftvec.trans - */ - @scala.annotation.varargs - def indexed_features(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.ftvec.trans.IndexedFeatures"), exprs.map(_.expr)) - } - - /** - * @see hivemall.ftvec.trans.QuantitativeFeaturesUDF - * @group ftvec.trans - */ - @scala.annotation.varargs - def quantitative_features(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.ftvec.trans.QuantitativeFeaturesUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.smile.tools.TreePredictUDF - * @group misc - */ - @scala.annotation.varargs - def tree_predict(exprs: Column*): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.smile.tools.TreePredictUDF"), exprs.map(_.expr)) - } - - /** - * @see hivemall.tools.array.SelectKBestUDF - * @group tools.array - */ - def select_k_best(X: Column, importanceList: Column, k: Column): Column = { - HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.tools.array.SelectKBestUDF"), Seq(X.expr, importanceList.expr, k.expr)) - } - - /** - * @see hivemall.tools.math.SigmoidUDF - * @group misc - */ - @scala.annotation.varargs - def sigmoid(exprs: Column*): Column = { - /** - * TODO: SigmodUDF only accepts floating-point types in spark-v1.5.0? - */ - val value = exprs.head - val one: () => Literal = () => Literal.create(1.0, DoubleType) - Column(one()) / (Column(one()) + exp(-value)) - } - - /** - * @see hivemall.tools.mapred.RowIdUDF - * @group misc - */ - def rowid(): Column = { - val hiveUdf = HiveGenericUDF(new HiveFunctionWrapper( - "hivemall.tools.mapred.RowIdUDFWrapper"), Nil) - hiveUdf.as("rowid") - } -}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c53b9ff9/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala ---------------------------------------------------------------------- diff --git a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala deleted file mode 100644 index dff62b3..0000000 --- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.hive - -import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} -import org.apache.spark.sql.{Column, DataFrame, Row, UserDefinedFunction} -import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ - -object HivemallUtils { - - // # of maximum dimensions for feature vectors - val maxDims = 100000000 - - /** - * An implicit conversion to avoid doing annoying transformation. - * This class must be in o.a.spark.sql._ because - * a Column class is private. - */ - @inline implicit def toBooleanLiteral(i: Boolean): Column = Column(Literal.create(i, BooleanType)) - @inline implicit def toIntLiteral(i: Int): Column = Column(Literal.create(i, IntegerType)) - @inline implicit def toFloatLiteral(i: Float): Column = Column(Literal.create(i, FloatType)) - @inline implicit def toDoubleLiteral(i: Double): Column = Column(Literal.create(i, DoubleType)) - @inline implicit def toStringLiteral(i: String): Column = Column(Literal.create(i, StringType)) - @inline implicit def toIntArrayLiteral(i: Seq[Int]): Column = - Column(Literal.create(i, ArrayType(IntegerType))) - @inline implicit def toStringArrayLiteral(i: Seq[String]): Column = - Column(Literal.create(i, ArrayType(StringType))) - - /** - * Check whether the given schema contains a column of the required data type. - * @param colName column name - * @param dataType required column data type - */ - def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { - val actualDataType = schema(colName).dataType - require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.") - } - - /** - * Make up a function object from a Hivemall model. - */ - def funcModel(df: DataFrame, dense: Boolean = false, dims: Int = maxDims): UserDefinedFunction = { - checkColumnType(df.schema, "feature", StringType) - checkColumnType(df.schema, "weight", DoubleType) - - import df.sqlContext.implicits._ - val intercept = df - .where($"feature" === "0") - .select($"weight") - .map { case Row(weight: Double) => weight} - .reduce(_ + _) - val weights = funcVectorizerImpl(dense, dims)( - df.select($"feature", $"weight") - .where($"feature" !== "0") - .map { case Row(label: String, feature: Double) => s"${label}:$feature"} - .collect.toSeq) - - udf((input: Vector) => BLAS.dot(input, weights) + intercept) - } - - /** - * Make up a function object to transform Hivemall features into Vector. - */ - def funcVectorizer(dense: Boolean = false, dims: Int = maxDims) - : UserDefinedFunction = { - udf(funcVectorizerImpl(dense, dims)) - } - - private def funcVectorizerImpl(dense: Boolean, dims: Int) - : Seq[String] => Vector = { - if (dense) { - // Dense features - i: Seq[String] => { - val features = new Array[Double](dims) - i.map { ft => - val s = ft.split(":").ensuring(_.size == 2) - features(s(0).toInt) = s(1).toDouble - } - Vectors.dense(features) - } - } else { - // Sparse features - i: Seq[String] => { - val features = i.map { ft => - // val s = ft.split(":").ensuring(_.size == 2) - val s = ft.split(":") - (s(0).toInt, s(1).toDouble) - } - Vectors.sparse(dims, features) - } - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c53b9ff9/spark/spark-1.6/src/test/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/spark/spark-1.6/src/test/resources/log4j.properties b/spark/spark-1.6/src/test/resources/log4j.properties deleted file mode 100644 index 1db11f0..0000000 --- a/spark/spark-1.6/src/test/resources/log4j.properties +++ /dev/null @@ -1,7 +0,0 @@ -# Set everything to be logged to the console -log4j.rootCategory=FATAL, console -log4j.appender.console=org.apache.log4j.ConsoleAppender -log4j.appender.console.target=System.err -log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c53b9ff9/spark/spark-1.6/src/test/scala/hivemall/mix/server/MixServerSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-1.6/src/test/scala/hivemall/mix/server/MixServerSuite.scala b/spark/spark-1.6/src/test/scala/hivemall/mix/server/MixServerSuite.scala deleted file mode 100644 index dbb818b..0000000 --- a/spark/spark-1.6/src/test/scala/hivemall/mix/server/MixServerSuite.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.mix.server - -import java.util.Random -import java.util.concurrent.{Executors, ExecutorService, TimeUnit} -import java.util.logging.Logger - -import hivemall.mix.MixMessage.MixEventName -import hivemall.mix.client.MixClient -import hivemall.mix.server.MixServer.ServerState -import hivemall.model.{DenseModel, PredictionModel, WeightValue} -import hivemall.utils.io.IOUtils -import hivemall.utils.lang.CommandLineUtils -import hivemall.utils.net.NetUtils -import org.scalatest.{BeforeAndAfter, FunSuite} - -class MixServerSuite extends FunSuite with BeforeAndAfter { - - private[this] var server: MixServer = _ - private[this] var executor : ExecutorService = _ - private[this] var port: Int = _ - - private[this] val rand = new Random(43) - private[this] val counter = Stream.from(0).iterator - - private[this] val eachTestTime = 100 - private[this] val logger = - Logger.getLogger(classOf[MixServerSuite].getName) - - before { - this.port = NetUtils.getAvailablePort - this.server = new MixServer( - CommandLineUtils.parseOptions( - Array("-port", s"${port}", "-sync_threshold", "3"), - MixServer.getOptions() - ) - ) - this.executor = Executors.newSingleThreadExecutor - this.executor.submit(server) - var retry = 0 - while (server.getState() != ServerState.RUNNING && retry < 50) { - Thread.sleep(1000L) - retry += 1 - } - assert(server.getState == ServerState.RUNNING) - } - - after { this.executor.shutdown() } - - private[this] def clientDriver( - groupId: String, model: PredictionModel, numMsg: Int = 1000000): Unit = { - var client: MixClient = null - try { - client = new MixClient(MixEventName.average, groupId, s"localhost:${port}", false, 2, model) - model.configureMix(client, false) - model.configureClock() - - for (_ <- 0 until numMsg) { - val feature = Integer.valueOf(rand.nextInt(model.size)) - model.set(feature, new WeightValue(1.0f)) - } - - while (true) { Thread.sleep(eachTestTime * 1000 + 100L) } - assert(model.getNumMixed > 0) - } finally { - IOUtils.closeQuietly(client) - } - } - - private[this] def fixedGroup: (String, () => String) = - ("fixed", () => "fixed") - private[this] def uniqueGroup: (String, () => String) = - ("unique", () => s"${counter.next}") - - Seq(65536).map { ndims => - Seq(4).map { nclient => - Seq(fixedGroup, uniqueGroup).map { id => - val testName = s"dense-dim:${ndims}-clinet:${nclient}-${id._1}" - ignore(testName) { - val clients = Executors.newCachedThreadPool() - val numClients = nclient - val models = (0 until numClients).map(i => new DenseModel(ndims, false)) - (0 until numClients).map { i => - clients.submit(new Runnable() { - override def run(): Unit = { - try { - clientDriver( - s"${testName}-${id._2}", - models(i) - ) - } catch { - case e: InterruptedException => - assert(false, e.getMessage) - } - } - }) - } - clients.awaitTermination(eachTestTime, TimeUnit.SECONDS) - clients.shutdown() - val nMixes = models.map(d => d.getNumMixed).reduce(_ + _) - logger.info(s"${testName} --> ${(nMixes + 0.0) / eachTestTime} mixes/s") - } - } - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c53b9ff9/spark/spark-1.6/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-1.6/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala b/spark/spark-1.6/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala deleted file mode 100644 index f203fc2..0000000 --- a/spark/spark-1.6/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.tools - -import org.scalatest.FunSuite - -import org.apache.spark.sql.hive.test.TestHive - -class RegressionDatagenSuite extends FunSuite { - - test("datagen") { - val df = RegressionDatagen.exec( - TestHive, min_examples = 10000, n_features = 100, n_dims = 65536, dense = false, cl = true) - assert(df.count() >= 10000) - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c53b9ff9/spark/spark-1.6/src/test/scala/org/apache/spark/SparkFunSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-1.6/src/test/scala/org/apache/spark/SparkFunSuite.scala b/spark/spark-1.6/src/test/scala/org/apache/spark/SparkFunSuite.scala deleted file mode 100644 index 991e46f..0000000 --- a/spark/spark-1.6/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark - -// scalastyle:off -import org.scalatest.{FunSuite, Outcome} - -/** - * Base abstract class for all unit tests in Spark for handling common functionality. - */ -private[spark] abstract class SparkFunSuite extends FunSuite with Logging { -// scalastyle:on - - /** - * Log the suite name and the test name before and after each test. - * - * Subclasses should never override this method. If they wish to run - * custom code before and after each test, they should mix in the - * {{org.scalatest.BeforeAndAfter}} trait instead. - */ - final protected override def withFixture(test: NoArgTest): Outcome = { - val testName = test.text - val suiteName = this.getClass.getName - val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") - try { - logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") - test() - } finally { - logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c53b9ff9/spark/spark-1.6/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-1.6/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala b/spark/spark-1.6/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala deleted file mode 100644 index f57983f..0000000 --- a/spark/spark-1.6/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.ml.feature - -import org.apache.spark.SparkFunSuite - -class HivemallLabeledPointSuite extends SparkFunSuite { - - test("toString") { - val lp = HivemallLabeledPoint(1.0f, Seq("1:0.5", "3:0.3", "8:0.1")) - assert(lp.toString === "1.0,[1:0.5,3:0.3,8:0.1]") - } - - test("parse") { - val lp = HivemallLabeledPoint.parse("1.0,[1:0.5,3:0.3,8:0.1]") - assert(lp.label === 1.0) - assert(lp.features === Seq("1:0.5", "3:0.3", "8:0.1")) - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c53b9ff9/spark/spark-1.6/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/QueryTest.scala b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/QueryTest.scala deleted file mode 100644 index ef520ae..0000000 --- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ /dev/null @@ -1,295 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql - -import java.util.{Locale, TimeZone} - -import scala.collection.JavaConverters._ -import scala.util.control.NonFatal - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.{LogicalRDD, Queryable} -import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation - -abstract class QueryTest extends PlanTest { - - protected def sqlContext: SQLContext - - // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - // Add Locale setting - Locale.setDefault(Locale.US) - - /** - * Runs the plan and makes sure the answer contains all of the keywords, or the - * none of keywords are listed in the answer - * @param df the [[DataFrame]] to be executed - * @param exists true for make sure the keywords are listed in the output, otherwise - * to make sure none of the keyword are not listed in the output - * @param keywords keyword in string array - */ - def checkExistence(df: DataFrame, exists: Boolean, keywords: String*) { - val outputs = df.collect().map(_.mkString).mkString - for (key <- keywords) { - if (exists) { - assert(outputs.contains(key), s"Failed for $df ($key doesn't exist in result)") - } else { - assert(!outputs.contains(key), s"Failed for $df ($key existed in the result)") - } - } - } - - /** - * Evaluates a dataset to make sure that the result of calling collect matches the given - * expected answer. - * - Special handling is done based on whether the query plan should be expected to return - * the results in sorted order. - * - This function also checks to make sure that the schema for serializing the expected answer - * matches that produced by the dataset (i.e. does manual construction of object match - * the constructed encoder for cases like joins, etc). Note that this means that it will fail - * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead - * which performs a subset of the checks done by this function. - */ - protected def checkAnswer[T]( - ds: Dataset[T], - expectedAnswer: T*): Unit = { - checkAnswer( - ds.toDF(), - sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) - - checkDecoding(ds, expectedAnswer: _*) - } - - protected def checkDecoding[T]( - ds: => Dataset[T], - expectedAnswer: T*): Unit = { - val decoded = try ds.collect().toSet catch { - case e: Exception => - fail( - s""" - |Exception collecting dataset as objects - |${ds.resolvedTEncoder} - |${ds.resolvedTEncoder.fromRowExpression.treeString} - |${ds.queryExecution} - """.stripMargin, e) - } - - if (decoded != expectedAnswer.toSet) { - val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted - val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted - - val comparision = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") - fail( - s"""Decoded objects do not match expected objects: - |$comparision - |${ds.resolvedTEncoder.fromRowExpression.treeString} - """.stripMargin) - } - } - - /** - * Runs the plan and makes sure the answer matches the expected result. - * @param df the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - */ - protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - val analyzedDF = try df catch { - case ae: AnalysisException => - val currentValue = sqlContext.conf.dataFrameEagerAnalysis - sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) - val partiallyAnalzyedPlan = df.queryExecution.analyzed - sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, currentValue) - fail( - s""" - |Failed to analyze query: $ae - |$partiallyAnalzyedPlan - | - |${stackTraceToString(ae)} - |""".stripMargin) - } - - assertEmptyMissingInput(df) - - QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } - } - - protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(df, Seq(expectedAnswer)) - } - - protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { - checkAnswer(df, expectedAnswer.collect()) - } - - /** - * Runs the plan and makes sure the answer is within absTol of the expected result. - * @param dataFrame the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param absTol the absolute tolerance between actual and expected answers. - */ - protected def checkAggregatesWithTol(dataFrame: DataFrame, - expectedAnswer: Seq[Row], - absTol: Double): Unit = { - // TODO: catch exceptions in data frame execution - val actualAnswer = dataFrame.collect() - require(actualAnswer.length == expectedAnswer.length, - s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}") - - actualAnswer.zip(expectedAnswer).foreach { - case (actualRow, expectedRow) => - QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol) - } - } - - protected def checkAggregatesWithTol(dataFrame: DataFrame, - expectedAnswer: Row, - absTol: Double): Unit = { - checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol) - } - - /** - * Asserts that a given [[Queryable]] will be executed using the given number of cached results. - */ - def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } - - /** - * Asserts that a given [[Queryable]] does not have missing inputs in all the analyzed plans. - */ - def assertEmptyMissingInput(query: Queryable): Unit = { - assert(query.queryExecution.analyzed.missingInput.isEmpty, - s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}") - assert(query.queryExecution.optimizedPlan.missingInput.isEmpty, - s"The optimized logical plan has missing inputs: ${query.queryExecution.optimizedPlan}") - assert(query.queryExecution.executedPlan.missingInput.isEmpty, - s"The physical plan has missing inputs: ${query.queryExecution.executedPlan}") - } -} - -object QueryTest { - /** - * Runs the plan and makes sure the answer matches the expected result. - * If there was exception during the execution or the contents of the DataFrame does not - * match the expected result, an error message will be returned. Otherwise, a [[None]] will - * be returned. - * @param df the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - */ - def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { - val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - - // We need to call prepareRow recursively to handle schemas with struct types. - def prepareRow(row: Row): Row = { - Row.fromSeq(row.toSeq.map { - case null => null - case d: java.math.BigDecimal => BigDecimal(d) - // Convert array to Seq for easy equality check. - case b: Array[_] => b.toSeq - case r: Row => prepareRow(r) - case o => o - }) - } - - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - val converted: Seq[Row] = answer.map(prepareRow) - if (!isSorted) converted.sortBy(_.toString()) else converted - } - val sparkAnswer = try df.collect().toSeq catch { - case e: Exception => - val errorMessage = - s""" - |Exception thrown while executing query: - |${df.queryExecution} - |== Exception == - |$e - |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = - s""" - |Results do not match for query: - |${df.queryExecution} - |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - return Some(errorMessage) - } - - return None - } - - /** - * Runs the plan and makes sure the answer is within absTol of the expected result. - * @param actualAnswer the actual result in a [[Row]]. - * @param expectedAnswer the expected result in a[[Row]]. - * @param absTol the absolute tolerance between actual and expected answers. - */ - protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = { - require(actualAnswer.length == expectedAnswer.length, - s"actual answer length ${actualAnswer.length} != " + - s"expected answer length ${expectedAnswer.length}") - - // TODO: support other numeric types besides Double - // TODO: support struct types? - actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach { - case (actual: Double, expected: Double) => - assert(math.abs(actual - expected) < absTol, - s"actual answer $actual not within $absTol of correct answer $expected") - case (actual, expected) => - assert(actual == expected, s"$actual did not equal $expected") - } - } - - def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { - checkAnswer(df, expectedAnswer.asScala) match { - case Some(errorMessage) => errorMessage - case None => null - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c53b9ff9/spark/spark-1.6/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala deleted file mode 100644 index 816576e..0000000 --- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.catalyst.plans - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation} -import org.apache.spark.sql.catalyst.util._ - -/** - * Provides helper methods for comparing plans. - */ -class PlanTest extends SparkFunSuite { - - /** - * Since attribute references are given globally unique ids during analysis, - * we must normalize them to check if two different queries are identical. - */ - protected def normalizeExprIds(plan: LogicalPlan) = { - plan transformAllExpressions { - case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) - case a: Alias => - Alias(a.child, a.name)(exprId = ExprId(0)) - } - } - - /** Fails the test if the two plans do not match */ - protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeExprIds(plan1) - val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - - /** Fails the test if the two expressions do not match */ - protected def compareExpressions(e1: Expression, e2: Expression): Unit = { - comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c53b9ff9/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala deleted file mode 100644 index ded94ba..0000000 --- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.hive - -import org.apache.spark.sql.Row -import org.apache.spark.test.HivemallQueryTest - -final class HiveUdfSuite extends HivemallQueryTest { - - import hiveContext.implicits._ - import hiveContext._ - - test("hivemall_version") { - sql(s""" - | CREATE TEMPORARY FUNCTION hivemall_version - | AS '${classOf[hivemall.HivemallVersionUDF].getName}' - """.stripMargin) - - checkAnswer( - sql(s"SELECT DISTINCT hivemall_version()"), - Row("0.4.2-rc.2") - ) - - // sql("DROP TEMPORARY FUNCTION IF EXISTS hivemall_version") - // reset() - } - - test("train_logregr") { - TinyTrainData.registerTempTable("TinyTrainData") - sql(s""" - | CREATE TEMPORARY FUNCTION train_logregr - | AS '${classOf[hivemall.regression.LogressUDTF].getName}' - """.stripMargin) - sql(s""" - | CREATE TEMPORARY FUNCTION add_bias - | AS '${classOf[hivemall.ftvec.AddBiasUDFWrapper].getName}' - """.stripMargin) - - val model = sql( - s""" - | SELECT feature, AVG(weight) AS weight - | FROM ( - | SELECT train_logregr(add_bias(features), label) AS (feature, weight) - | FROM TinyTrainData - | ) t - | GROUP BY feature - """.stripMargin) - - checkAnswer( - model.select($"feature"), - Seq(Row("0"), Row("1"), Row("2")) - ) - - // TODO: Why 'train_logregr' is not registered in HiveMetaStore? - // ERROR RetryingHMSHandler: MetaException(message:NoSuchObjectException - // (message:Function default.train_logregr does not exist)) - // - // hiveContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_logregr") - // hiveContext.reset() - } - - test("each_top_k") { - val testDf = Seq( - ("a", "1", 0.5, Array(0, 1, 2)), - ("b", "5", 0.1, Array(3)), - ("a", "3", 0.8, Array(2, 5)), - ("c", "6", 0.3, Array(1, 3)), - ("b", "4", 0.3, Array(2)), - ("a", "2", 0.6, Array(1)) - ).toDF("key", "value", "score", "data") - - import testDf.sqlContext.implicits._ - testDf.repartition($"key").sortWithinPartitions($"key").registerTempTable("TestData") - sql(s""" - | CREATE TEMPORARY FUNCTION each_top_k - | AS '${classOf[hivemall.tools.EachTopKUDTF].getName}' - """.stripMargin) - - // Compute top-1 rows for each group - assert( - sql("SELECT each_top_k(1, key, score, key, value) FROM TestData").collect.toSet === - Set( - Row(1, 0.8, "a", "3"), - Row(1, 0.3, "b", "4"), - Row(1, 0.3, "c", "6") - )) - - // Compute reverse top-1 rows for each group - assert( - sql("SELECT each_top_k(-1, key, score, key, value) FROM TestData").collect.toSet === - Set( - Row(1, 0.5, "a", "1"), - Row(1, 0.1, "b", "5"), - Row(1, 0.3, "c", "6") - )) - } -}