http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala new file mode 100644 index 0000000..28653a5 --- /dev/null +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -0,0 +1,1368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive + +import java.util.UUID + +import org.apache.spark.annotation.Experimental +import org.apache.spark.internal.Logging +import org.apache.spark.ml.feature.HivemallFeature +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{EachTopK, Expression, Literal, NamedExpression, UserDefinedGenerator} +import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, Project} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Hivemall wrapper and some utility functions for DataFrame. + * + * @groupname regression + * @groupname classifier + * @groupname classifier.multiclass + * @groupname xgboost + * @groupname anomaly + * @groupname knn.similarity + * @groupname knn.distance + * @groupname knn.lsh + * @groupname ftvec + * @groupname ftvec.amplify + * @groupname ftvec.hashing + * @groupname ftvec.scaling + * @groupname ftvec.conv + * @groupname ftvec.trans + * @groupname misc + */ +final class HivemallOps(df: DataFrame) extends Logging { + import internal.HivemallOpsImpl._ + + private[this] val _sparkSession = df.sparkSession + private[this] val _analyzer = _sparkSession.sessionState.analyzer + + /** + * @see [[hivemall.regression.AdaDeltaUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_adadelta(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.AdaDeltaUDTF", + "train_adadelta", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.AdaGradUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_adagrad(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.AdaGradUDTF", + "train_adagrad", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.AROWRegressionUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_arow_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.AROWRegressionUDTF", + "train_arow_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.regression.AROWRegressionUDTF.AROWe]] + * @group regression + */ + @scala.annotation.varargs + def train_arowe_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.AROWRegressionUDTF$AROWe", + "train_arowe_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.regression.AROWRegressionUDTF.AROWe2]] + * @group regression + */ + @scala.annotation.varargs + def train_arowe2_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.AROWRegressionUDTF$AROWe2", + "train_arowe2_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.regression.LogressUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_logregr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.LogressUDTF", + "train_logregr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_pa1_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.PassiveAggressiveRegressionUDTF", + "train_pa1_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF.PA1a]] + * @group regression + */ + @scala.annotation.varargs + def train_pa1a_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.PassiveAggressiveRegressionUDTF$PA1a", + "train_pa1a_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF.PA2]] + * @group regression + */ + @scala.annotation.varargs + def train_pa2_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.PassiveAggressiveRegressionUDTF$PA2", + "train_pa2_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF.PA2a]] + * @group regression + */ + @scala.annotation.varargs + def train_pa2a_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a", + "train_pa2a_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.smile.regression.RandomForestRegressionUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_randomforest_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.smile.regression.RandomForestRegressionUDTF", + "train_randomforest_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("model_id", "model_type", "pred_model", "var_importance", "oob_errors", "oob_tests") + ) + } + + /** + * @see [[hivemall.classifier.PerceptronUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_perceptron(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.PerceptronUDTF", + "train_perceptron", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.PassiveAggressiveUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_pa(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.PassiveAggressiveUDTF", + "train_pa", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.PassiveAggressiveUDTF.PA1]] + * @group classifier + */ + @scala.annotation.varargs + def train_pa1(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.PassiveAggressiveUDTF$PA1", + "train_pa1", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.PassiveAggressiveUDTF.PA2]] + * @group classifier + */ + @scala.annotation.varargs + def train_pa2(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.PassiveAggressiveUDTF$PA2", + "train_pa2", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.ConfidenceWeightedUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_cw(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.ConfidenceWeightedUDTF", + "train_cw", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.AROWClassifierUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_arow(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.AROWClassifierUDTF", + "train_arow", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.AROWClassifierUDTF.AROWh]] + * @group classifier + */ + @scala.annotation.varargs + def train_arowh(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.AROWClassifierUDTF$AROWh", + "train_arowh", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.SoftConfideceWeightedUDTF.SCW1]] + * @group classifier + */ + @scala.annotation.varargs + def train_scw(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.SoftConfideceWeightedUDTF$SCW1", + "train_scw", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.SoftConfideceWeightedUDTF.SCW1]] + * @group classifier + */ + @scala.annotation.varargs + def train_scw2(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.SoftConfideceWeightedUDTF$SCW2", + "train_scw2", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.AdaGradRDAUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_adagrad_rda(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.AdaGradRDAUDTF", + "train_adagrad_rda", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.smile.classification.RandomForestClassifierUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_randomforest_classifier(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.smile.classification.RandomForestClassifierUDTF", + "train_randomforest_classifier", + setMixServs(toHivemallFeatures(exprs)), + Seq("model_id", "model_type", "pred_model", "var_importance", "oob_errors", "oob_tests") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassPerceptronUDTF]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_perceptron(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassPerceptronUDTF", + "train_multiclass_perceptron", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_pa(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF", + "train_multiclass_pa", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF.PA1]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_pa1(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF$PA1", + "train_multiclass_pa1", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF.PA2]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_pa2(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF$PA2", + "train_multiclass_pa2", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassConfidenceWeightedUDTF]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_cw(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassConfidenceWeightedUDTF", + "train_multiclass_cw", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassAROWClassifierUDTF]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_arow(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassAROWClassifierUDTF", + "train_multiclass_arow", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF.SCW1]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_scw(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF$SCW1", + "train_multiclass_scw", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF.SCW2]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_scw2(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF$SCW2", + "train_multiclass_scw2", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight", "conv") + ) + } + + /** + * :: Experimental :: + * @see [[hivemall.xgboost.regression.XGBoostRegressionUDTF]] + * @group xgboost + */ + @Experimental + @scala.annotation.varargs + def train_xgboost_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.xgboost.regression.XGBoostRegressionUDTFWrapper", + "train_xgboost_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("model_id", "pred_model") + ) + } + + /** + * :: Experimental :: + * @see [[hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF]] + * @group xgboost + */ + @Experimental + @scala.annotation.varargs + def train_xgboost_classifier(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.xgboost.classification.XGBoostBinaryClassifierUDTFWrapper", + "train_xgboost_classifier", + setMixServs(toHivemallFeatures(exprs)), + Seq("model_id", "pred_model") + ) + } + + /** + * :: Experimental :: + * @see [[hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTF]] + * @group xgboost + */ + @Experimental + @scala.annotation.varargs + def train_xgboost_multiclass_classifier(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper", + "train_xgboost_multiclass_classifier", + setMixServs(toHivemallFeatures(exprs)), + Seq("model_id", "pred_model") + ) + } + + /** + * :: Experimental :: + * @see [[hivemall.xgboost.tools.XGBoostPredictUDTF]] + * @group xgboost + */ + @Experimental + @scala.annotation.varargs + def xgboost_predict(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.xgboost.tools.XGBoostPredictUDTF", + "xgboost_predict", + setMixServs(toHivemallFeatures(exprs)), + Seq("rowid", "predicted") + ) + } + + /** + * :: Experimental :: + * @see [[hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF]] + * @group xgboost + */ + @Experimental + @scala.annotation.varargs + def xgboost_multiclass_predict(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF", + "xgboost_multiclass_predict", + setMixServs(toHivemallFeatures(exprs)), + Seq("rowid", "label", "probability") + ) + } + + /** + * @see [[hivemall.knn.lsh.MinHashUDTF]] + * @group knn.lsh + */ + @scala.annotation.varargs + def minhash(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.knn.lsh.MinHashUDTF", + "minhash", + setMixServs(toHivemallFeatures(exprs)), + Seq("clusterid", "item") + ) + } + + /** + * @see [[hivemall.ftvec.amplify.AmplifierUDTF]] + * @group ftvec.amplify + */ + @scala.annotation.varargs + def amplify(exprs: Column*): DataFrame = withTypedPlan { + val outputAttr = exprs.drop(1).map { + case Column(expr: NamedExpression) => UnresolvedAttribute(expr.name) + case Column(expr: Expression) => UnresolvedAttribute(expr.simpleString) + } + planHiveGenericUDTF( + df, + "hivemall.ftvec.amplify.AmplifierUDTF", + "amplify", + setMixServs(toHivemallFeatures(exprs)), + Seq("clusterid", "item") + ) + } + + /** + * @see [[hivemall.ftvec.amplify.RandomAmplifierUDTF]] + * @group ftvec.amplify + */ + @scala.annotation.varargs + def rand_amplify(exprs: Column*): DataFrame = withTypedPlan { + throw new UnsupportedOperationException("`rand_amplify` not supported yet") + } + + /** + * Amplifies and shuffle data inside partitions. + * @group ftvec.amplify + */ + def part_amplify(xtimes: Column): DataFrame = { + val xtimesInt = xtimes.expr match { + case Literal(v: Any, IntegerType) => v.asInstanceOf[Int] + case e => throw new AnalysisException("`xtimes` must be integer, however " + e) + } + val rdd = df.rdd.mapPartitions({ iter => + val elems = iter.flatMap{ row => + Seq.fill[Row](xtimesInt)(row) + } + // Need to check how this shuffling affects results + scala.util.Random.shuffle(elems) + }, true) + df.sqlContext.createDataFrame(rdd, df.schema) + } + + /** + * Quantifies input columns. + * @see [[hivemall.ftvec.conv.QuantifyColumnsUDTF]] + * @group ftvec.conv + */ + @scala.annotation.varargs + def quantify(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.ftvec.conv.QuantifyColumnsUDTF", + "quantify", + setMixServs(toHivemallFeatures(exprs)), + (0 until exprs.size - 1).map(i => s"c$i") + ) + } + + /** + * @see [[hivemall.ftvec.trans.BinarizeLabelUDTF]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def binarize_label(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.ftvec.trans.BinarizeLabelUDTF", + "binarize_label", + setMixServs(toHivemallFeatures(exprs)), + (0 until exprs.size - 1).map(i => s"c$i") + ) + } + + /** + * @see [[hivemall.ftvec.trans.QuantifiedFeaturesUDTF]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def quantified_features(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.ftvec.trans.QuantifiedFeaturesUDTF", + "quantified_features", + setMixServs(toHivemallFeatures(exprs)), + Seq("features") + ) + } + + /** + * Splits Seq[String] into pieces. + * @group ftvec + */ + def explode_array(expr: Column): DataFrame = { + df.explode(expr) { case Row(v: Seq[_]) => + // Type erasure removes the component type in Seq + v.map(s => HivemallFeature(s.asInstanceOf[String])) + } + } + + /** + * Splits [[Vector]] into pieces. + * @group ftvec + */ + def explode_vector(expr: Column): DataFrame = { + val elementSchema = StructType( + StructField("feature", StringType) :: StructField("weight", DoubleType) :: Nil) + val explodeFunc: Row => TraversableOnce[InternalRow] = (row: Row) => { + row.get(0) match { + case dv: DenseVector => + dv.values.zipWithIndex.map { + case (value, index) => + InternalRow(UTF8String.fromString(s"$index"), value) + } + case sv: SparseVector => + sv.values.zip(sv.indices).map { + case (value, index) => + InternalRow(UTF8String.fromString(s"$index"), value) + } + } + } + withTypedPlan { + Generate( + UserDefinedGenerator(elementSchema, explodeFunc, expr.expr :: Nil), + join = true, outer = false, None, + generatorOutput = Nil, + df.logicalPlan) + } + } + + /** + * Returns `top-k` records for each `group`. + * @group misc + */ + def each_top_k(k: Column, group: Column, score: Column): DataFrame = withTypedPlan { + val kInt = k.expr match { + case Literal(v: Any, IntegerType) => v.asInstanceOf[Int] + case e => throw new AnalysisException("`k` must be integer, however " + e) + } + val clusterDf = df.repartition(group).sortWithinPartitions(group) + val child = clusterDf.logicalPlan + val logicalPlan = Project(group.named +: score.named +: child.output, child) + _analyzer.execute(logicalPlan) match { + case Project(group :: score :: origCols, c) => + Generate( + EachTopK(kInt, group, score, c.output), + join = false, outer = false, None, + (Seq("rank") ++ origCols.map(_.name)).map(UnresolvedAttribute(_)), + clusterDf.logicalPlan + ) + } + } + + /** + * @see [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]] + * @group misc + */ + @scala.annotation.varargs + def lr_datagen(exprs: Column*): Dataset[Row] = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.dataset.LogisticRegressionDataGeneratorUDTFWrapper", + "lr_datagen", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "features") + ) + } + + /** + * Returns all the columns as Seq[Column] in this [[DataFrame]]. + */ + private[sql] def cols: Seq[Column] = { + df.schema.fields.map(col => df.col(col.name)).toSeq + } + + /** + * :: Experimental :: + * If a parameter '-mix' does not exist in a 3rd argument, + * set it from an environmental variable + * 'HIVEMALL_MIX_SERVERS'. + * + * TODO: This could work if '--deploy-mode' has 'client'; + * otherwise, we need to set HIVEMALL_MIX_SERVERS + * in all possible spark workers. + */ + @Experimental + private[this] def setMixServs(exprs: Seq[Column]): Seq[Column] = { + val mixes = System.getenv("HIVEMALL_MIX_SERVERS") + if (mixes != null && !mixes.isEmpty()) { + val groupId = df.sqlContext.sparkContext.applicationId + "-" + UUID.randomUUID + logInfo(s"set '${mixes}' as default mix servers (session: ${groupId})") + exprs.size match { + case 2 => exprs :+ Column( + Literal.create(s"-mix ${mixes} -mix_session ${groupId}", StringType)) + /** TODO: Add codes in the case where exprs.size == 3. */ + case _ => exprs + } + } else { + exprs + } + } + + /** + * If the input is a [[Vector]], transform it into Hivemall features. + */ + @inline private[this] def toHivemallFeatures(exprs: Seq[Column]): Seq[Column] = { + df.select(exprs: _*).queryExecution.analyzed.schema.zip(exprs).map { + case (StructField(_, _: VectorUDT, _, _), c) => HivemallUtils.to_hivemall_features(c) + case (_, c) => c + } + } + + /** + * A convenient function to wrap a logical plan and produce a DataFrame. + */ + @inline private[this] def withTypedPlan(logicalPlan: => LogicalPlan): DataFrame = { + val queryExecution = df.sparkSession.sessionState.executePlan(logicalPlan) + val outputSchema = queryExecution.sparkPlan.schema + new Dataset[Row](df.sparkSession, queryExecution, RowEncoder(outputSchema)) + } +} + +object HivemallOps { + import internal.HivemallOpsImpl._ + + /** + * Implicitly inject the [[HivemallOps]] into [[DataFrame]]. + */ + implicit def dataFrameToHivemallOps(df: DataFrame): HivemallOps = + new HivemallOps(df) + + /** + * @see [[hivemall.HivemallVersionUDF]] + * @group misc + */ + def hivemall_version(): Column = withExpr { + planHiveUDF( + "hivemall.HivemallVersionUDF", + "hivemall_version", + Nil + ) + } + + /** + * @see [[hivemall.anomaly.SingularSpectrumTransformUDF]] + * @group anomaly + */ + @scala.annotation.varargs + def sst(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.anomaly.SingularSpectrumTransformUDF", + "sst", + exprs + ) + } + + /** + * @see [[hivemall.knn.similarity.CosineSimilarityUDF]] + * @group knn.similarity + */ + @scala.annotation.varargs + def cosine_sim(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.similarity.CosineSimilarityUDF", + "cosine_sim", + exprs + ) + } + + /** + * @see [[hivemall.knn.similarity.JaccardIndexUDF]] + * @group knn.similarity + */ + @scala.annotation.varargs + def jaccard(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.knn.similarity.JaccardIndexUDF", + "jaccard", + exprs + ) + } + + /** + * @see [[hivemall.knn.similarity.AngularSimilarityUDF]] + * @group knn.similarity + */ + @scala.annotation.varargs + def angular_similarity(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.similarity.AngularSimilarityUDF", + "angular_similarity", + exprs + ) + } + + /** + * @see [[hivemall.knn.similarity.EuclidSimilarity]] + * @group knn.similarity + */ + @scala.annotation.varargs + def euclid_similarity(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.similarity.EuclidSimilarity", + "euclid_similarity", + exprs + ) + } + + /** + * @see [[hivemall.knn.similarity.Distance2SimilarityUDF]] + * @group knn.similarity + */ + @scala.annotation.varargs + def distance2similarity(exprs: Column*): Column = withExpr { + // TODO: Need a wrapper class because of using unsupported types + planHiveGenericUDF( + "hivemall.knn.similarity.Distance2SimilarityUDF", + "distance2similarity", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.HammingDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def hamming_distance(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.knn.distance.HammingDistanceUDF", + "hamming_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.PopcountUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def popcnt(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.knn.distance.PopcountUDF", + "popcnt", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.KLDivergenceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def kld(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.knn.distance.KLDivergenceUDF", + "kld", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.EuclidDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def euclid_distance(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.distance.EuclidDistanceUDF", + "euclid_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.CosineDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def cosine_distance(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.distance.CosineDistanceUDF", + "cosine_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.AngularDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def angular_distance(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.distance.AngularDistanceUDF", + "angular_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.ManhattanDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def manhattan_distance(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.distance.ManhattanDistanceUDF", + "manhattan_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.MinkowskiDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def minkowski_distance (exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.distance.MinkowskiDistanceUDF", + "minkowski_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.lsh.bBitMinHashUDF]] + * @group knn.lsh + */ + @scala.annotation.varargs + def bbit_minhash(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.knn.lsh.bBitMinHashUDF", + "bbit_minhash", + exprs + ) + } + + /** + * @see [[hivemall.knn.lsh.MinHashesUDFWrapper]] + * @group knn.lsh + */ + @scala.annotation.varargs + def minhashes(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.lsh.MinHashesUDFWrapper", + "minhashes", + exprs + ) + } + + /** + * Returns new features with `1.0` (bias) appended to the input features. + * @see [[hivemall.ftvec.AddBiasUDFWrapper]] + * @group ftvec + */ + def add_bias(expr: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.AddBiasUDFWrapper", + "add_bias", + expr :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.ExtractFeatureUDFWrapper]] + * @group ftvec + * + * TODO: This throws java.lang.ClassCastException because + * HiveInspectors.toInspector has a bug in spark. + * Need to fix it later. + */ + def extract_feature(expr: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.ExtractFeatureUDFWrapper", + "extract_feature", + expr :: Nil + ) + }.as("feature") + + /** + * @see [[hivemall.ftvec.ExtractWeightUDFWrapper]] + * @group ftvec + * + * TODO: This throws java.lang.ClassCastException because + * HiveInspectors.toInspector has a bug in spark. + * Need to fix it later. + */ + def extract_weight(expr: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.ExtractWeightUDFWrapper", + "extract_weight", + expr :: Nil + ) + }.as("value") + + /** + * @see [[hivemall.ftvec.AddFeatureIndexUDFWrapper]] + * @group ftvec + */ + def add_feature_index(expr: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.AddFeatureIndexUDFWrapper", + "add_feature_index", + expr :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.SortByFeatureUDFWrapper]] + * @group ftvec + */ + def sort_by_feature(expr: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.SortByFeatureUDFWrapper", + "sort_by_feature", + expr :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.hashing.MurmurHash3UDF]] + * @group ftvec.hashing + */ + def mhash(expr: Column): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.hashing.MurmurHash3UDF", + "mhash", + expr :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.hashing.Sha1UDF]] + * @group ftvec.hashing + */ + def sha1(expr: Column): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.hashing.Sha1UDF", + "sha1", + expr :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.hashing.ArrayHashValuesUDF]] + * @group ftvec.hashing + */ + @scala.annotation.varargs + def array_hash_values(exprs: Column*): Column = withExpr { + // TODO: Need a wrapper class because of using unsupported types + planHiveUDF( + "hivemall.ftvec.hashing.ArrayHashValuesUDF", + "array_hash_values", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.hashing.ArrayPrefixedHashValuesUDF]] + * @group ftvec.hashing + */ + @scala.annotation.varargs + def prefixed_hash_values(exprs: Column*): Column = withExpr { + // TODO: Need a wrapper class because of using unsupported types + planHiveUDF( + "hivemall.ftvec.hashing.ArrayPrefixedHashValuesUDF", + "prefixed_hash_values", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.scaling.RescaleUDF]] + * @group ftvec.scaling + */ + def rescale(value: Column, max: Column, min: Column): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.scaling.RescaleUDF", + "rescale", + value.cast(FloatType) :: max :: min :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.scaling.ZScoreUDF]] + * @group ftvec.scaling + */ + @scala.annotation.varargs + def zscore(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.scaling.ZScoreUDF", + "zscore", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.scaling.L2NormalizationUDFWrapper]] + * @group ftvec.scaling + */ + def normalize(expr: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.scaling.L2NormalizationUDFWrapper", + "normalize", + expr :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.selection.ChiSquareUDF]] + * @group ftvec.selection + */ + def chi2(observed: Column, expected: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.selection.ChiSquareUDF", + "chi2", + Seq(observed, expected) + ) + } + + /** + * @see [[hivemall.ftvec.conv.ToDenseFeaturesUDF]] + * @group ftvec.conv + */ + @scala.annotation.varargs + def to_dense_features(exprs: Column*): Column = withExpr { + // TODO: Need a wrapper class because of using unsupported types + planHiveGenericUDF( + "hivemall.ftvec.conv.ToDenseFeaturesUDF", + "to_dense_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.conv.ToSparseFeaturesUDF]] + * @group ftvec.conv + */ + @scala.annotation.varargs + def to_sparse_features(exprs: Column*): Column = withExpr { + // TODO: Need a wrapper class because of using unsupported types + planHiveGenericUDF( + "hivemall.ftvec.conv.ToSparseFeaturesUDF", + "to_sparse_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.trans.VectorizeFeaturesUDF]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def vectorize_features(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.trans.VectorizeFeaturesUDF", + "vectorize_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.trans.CategoricalFeaturesUDF]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def categorical_features(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.trans.CategoricalFeaturesUDF", + "categorical_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.trans.IndexedFeatures]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def indexed_features(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.trans.IndexedFeatures", + "indexed_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.trans.QuantitativeFeaturesUDF]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def quantitative_features(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.trans.QuantitativeFeaturesUDF", + "quantitative_features", + exprs + ) + } + + /** + * @see [[hivemall.smile.tools.TreePredictUDF]] + * @group misc + */ + @scala.annotation.varargs + def tree_predict(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.smile.tools.TreePredictUDF", + "tree_predict", + exprs + ) + } + + /** + * @see [[hivemall.tools.array.SelectKBestUDF]] + * @group tools.array + */ + def select_k_best(X: Column, importanceList: Column, k: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.array.SelectKBestUDF", + "select_k_best", + Seq(X, importanceList, k) + ) + } + + /** + * @see [[hivemall.tools.math.SigmoidGenericUDF]] + * @group misc + */ + def sigmoid(expr: Column): Column = { + val one: () => Literal = () => Literal.create(1.0, DoubleType) + Column(one()) / (Column(one()) + exp(-expr)) + } + + /** + * @see [[hivemall.tools.mapred.RowIdUDFWrapper]] + * @group misc + */ + def rowid(): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.mapred.RowIdUDFWrapper", + "rowid", + Nil + ) + }.as("rowid") + + /** + * A convenient function to wrap an expression and produce a Column. + */ + @inline private def withExpr(expr: Expression): Column = Column(expr) +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala new file mode 100644 index 0000000..056d6d6 --- /dev/null +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive + +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +object HivemallUtils { + + // # of maximum dimensions for feature vectors + private[this] val maxDims = 100000000 + + /** + * Check whether the given schema contains a column of the required data type. + * @param colName column name + * @param dataType required column data type + */ + private[this] def checkColumnType(schema: StructType, colName: String, dataType: DataType) + : Unit = { + val actualDataType = schema(colName).dataType + require(actualDataType.equals(dataType), + s"Column $colName must be of type $dataType but was actually $actualDataType.") + } + + def to_vector_func(dense: Boolean, dims: Int): Seq[String] => Vector = { + if (dense) { + // Dense features + i: Seq[String] => { + val features = new Array[Double](dims) + i.map { ft => + val s = ft.split(":").ensuring(_.size == 2) + features(s(0).toInt) = s(1).toDouble + } + Vectors.dense(features) + } + } else { + // Sparse features + i: Seq[String] => { + val features = i.map { ft => + // val s = ft.split(":").ensuring(_.size == 2) + val s = ft.split(":") + (s(0).toInt, s(1).toDouble) + } + Vectors.sparse(dims, features) + } + } + } + + def to_hivemall_features_func(): Vector => Array[String] = { + case dv: DenseVector => + dv.values.zipWithIndex.map { + case (value, index) => s"$index:$value" + } + case sv: SparseVector => + sv.values.zip(sv.indices).map { + case (value, index) => s"$index:$value" + } + case v => + throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}") + } + + def append_bias_func(): Vector => Vector = { + case dv: DenseVector => + val inputValues = dv.values + val inputLength = inputValues.length + val outputValues = Array.ofDim[Double](inputLength + 1) + System.arraycopy(inputValues, 0, outputValues, 0, inputLength) + outputValues(inputLength) = 1.0 + Vectors.dense(outputValues) + case sv: SparseVector => + val inputValues = sv.values + val inputIndices = sv.indices + val inputValuesLength = inputValues.length + val dim = sv.size + val outputValues = Array.ofDim[Double](inputValuesLength + 1) + val outputIndices = Array.ofDim[Int](inputValuesLength + 1) + System.arraycopy(inputValues, 0, outputValues, 0, inputValuesLength) + System.arraycopy(inputIndices, 0, outputIndices, 0, inputValuesLength) + outputValues(inputValuesLength) = 1.0 + outputIndices(inputValuesLength) = dim + Vectors.sparse(dim + 1, outputIndices, outputValues) + case v => + throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}") + } + + /** + * Transforms Hivemall features into a [[Vector]]. + */ + def to_vector(dense: Boolean = false, dims: Int = maxDims): UserDefinedFunction = { + udf(to_vector_func(dense, dims)) + } + + /** + * Transforms a [[Vector]] into Hivemall features. + */ + def to_hivemall_features: UserDefinedFunction = udf(to_hivemall_features_func) + + /** + * Returns a new [[Vector]] with `1.0` (bias) appended to the input [[Vector]]. + * @group ftvec + */ + def append_bias: UserDefinedFunction = udf(append_bias_func) + + /** + * Builds a [[Vector]]-based model from a table of Hivemall models + */ + def vectorized_model(df: DataFrame, dense: Boolean = false, dims: Int = maxDims) + : UserDefinedFunction = { + checkColumnType(df.schema, "feature", StringType) + checkColumnType(df.schema, "weight", DoubleType) + + import df.sqlContext.implicits._ + val intercept = df + .where($"feature" === "0") + .select($"weight") + .map { case Row(weight: Double) => weight} + .reduce(_ + _) + val weights = to_vector_func(dense, dims)( + df.select($"feature", $"weight") + .where($"feature" !== "0") + .map { case Row(label: String, feature: Double) => s"${label}:$feature"} + .collect.toSeq) + + udf((input: Vector) => BLAS.dot(input, weights) + intercept) + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala new file mode 100644 index 0000000..ab5c5fb --- /dev/null +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive.internal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan} +import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper + +/** + * This is an implementation class for [[org.apache.spark.sql.hive.HivemallOps]]. + * This class mainly uses the internal Spark classes (e.g., `Generate` and `HiveGenericUDTF`) that + * have unstable interfaces (so, these interfaces may evolve in upcoming releases). + * Therefore, the objective of this class is to extract these unstable parts + * from [[org.apache.spark.sql.hive.HivemallOps]]. + */ +private[hive] object HivemallOpsImpl extends Logging { + + def planHiveUDF( + className: String, + funcName: String, + argumentExprs: Seq[Column]): Expression = { + HiveSimpleUDF( + name = funcName, + funcWrapper = new HiveFunctionWrapper(className), + children = argumentExprs.map(_.expr) + ) + } + + def planHiveGenericUDF( + className: String, + funcName: String, + argumentExprs: Seq[Column]): Expression = { + HiveGenericUDF( + name = funcName, + funcWrapper = new HiveFunctionWrapper(className), + children = argumentExprs.map(_.expr) + ) + } + + def planHiveGenericUDTF( + df: DataFrame, + className: String, + funcName: String, + argumentExprs: Seq[Column], + outputAttrNames: Seq[String]): LogicalPlan = { + Generate( + generator = HiveGenericUDTF( + name = funcName, + funcWrapper = new HiveFunctionWrapper(className), + children = argumentExprs.map(_.expr) + ), + join = false, + outer = false, + qualifier = None, + generatorOutput = outputAttrNames.map(UnresolvedAttribute(_)), + child = df.logicalPlan) + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala new file mode 100644 index 0000000..9f2cb64 --- /dev/null +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive.source + +import java.io.File +import java.io.IOException +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FSDataInputStream, Path} +import org.apache.hadoop.io.IOUtils +import org.apache.hadoop.io.compress.GzipCodec +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.hadoop.util.ReflectionUtils + +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration + +private[source] final class XGBoostOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + private val hadoopConf = new SerializableConfiguration(new Configuration()) + + override def write(row: Row): Unit = { + val model = row.get(1).asInstanceOf[Array[Byte]] + val filePath = new Path(new URI(s"$path")) + val fs = filePath.getFileSystem(hadoopConf.value) + val outputFile = fs.create(filePath) + outputFile.write(model) + outputFile.close() + } + + override def close(): Unit = {} +} + +object XGBoostOutputWriter { + + /** Returns the compression codec extension to be used in a file name, e.g. ".gzip"). */ + def getCompressionExtension(context: TaskAttemptContext): String = { + if (FileOutputFormat.getCompressOutput(context)) { + val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec]) + ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension + } else { + "" + } + } +} + +final class XGBoostFileFormat extends FileFormat with DataSourceRegister { + + override def shortName(): String = "libxgboost" + + override def toString: String = "XGBoost" + + private def verifySchema(dataSchema: StructType): Unit = { + if ( + dataSchema.size != 2 || + !dataSchema(0).dataType.sameType(StringType) || + !dataSchema(1).dataType.sameType(BinaryType) + ) { + throw new IOException(s"Illegal schema for XGBoost data, schema=$dataSchema") + } + } + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + Some( + StructType( + StructField("model_id", StringType, nullable = false) :: + StructField("pred_model", BinaryType, nullable = false) :: Nil) + ) + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new XGBoostOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + XGBoostOutputWriter.getCompressionExtension(context) + ".xgboost" + } + } + } + + override def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + verifySchema(dataSchema) + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + (file: PartitionedFile) => { + val model = new Array[Byte](file.length.asInstanceOf[Int]) + val filePath = new Path(new URI(file.filePath)) + val fs = filePath.getFileSystem(broadcastedHadoopConf.value.value) + + var in: FSDataInputStream = null + try { + in = fs.open(filePath) + IOUtils.readFully(in, model, 0, model.length) + } finally { + IOUtils.closeStream(in) + } + + val converter = RowEncoder(dataSchema) + val fullOutput = dataSchema.map { f => + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() + } + val requiredOutput = fullOutput.filter { a => + requiredSchema.fieldNames.contains(a.name) + } + val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput) + (requiredColumns( + converter.toRow(Row(new File(file.filePath).getName, model))) + :: Nil + ).toIterator + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/resources/data/files/README.md ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/resources/data/files/README.md b/spark/spark-2.1/src/test/resources/data/files/README.md new file mode 100644 index 0000000..0fd0299 --- /dev/null +++ b/spark/spark-2.1/src/test/resources/data/files/README.md @@ -0,0 +1,3 @@ +The files in this dir exist for preventing exceptions in o.a.s.sql.hive.test.TESTHive. +We need to fix this issue in future. + http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/resources/data/files/complex.seq ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/resources/data/files/complex.seq b/spark/spark-2.1/src/test/resources/data/files/complex.seq new file mode 100644 index 0000000..e69de29 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/resources/data/files/episodes.avro ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/resources/data/files/episodes.avro b/spark/spark-2.1/src/test/resources/data/files/episodes.avro new file mode 100644 index 0000000..e69de29 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/resources/data/files/json.txt ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/resources/data/files/json.txt b/spark/spark-2.1/src/test/resources/data/files/json.txt new file mode 100644 index 0000000..e69de29 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/resources/data/files/kv1.txt ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/resources/data/files/kv1.txt b/spark/spark-2.1/src/test/resources/data/files/kv1.txt new file mode 100644 index 0000000..e69de29 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/resources/data/files/kv3.txt ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/resources/data/files/kv3.txt b/spark/spark-2.1/src/test/resources/data/files/kv3.txt new file mode 100644 index 0000000..e69de29 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/resources/log4j.properties b/spark/spark-2.1/src/test/resources/log4j.properties new file mode 100644 index 0000000..1db11f0 --- /dev/null +++ b/spark/spark-2.1/src/test/resources/log4j.properties @@ -0,0 +1,7 @@ +# Set everything to be logged to the console +log4j.rootCategory=FATAL, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/hivemall/mix/server/MixServerSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/hivemall/mix/server/MixServerSuite.scala b/spark/spark-2.1/src/test/scala/hivemall/mix/server/MixServerSuite.scala new file mode 100644 index 0000000..dbb818b --- /dev/null +++ b/spark/spark-2.1/src/test/scala/hivemall/mix/server/MixServerSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.mix.server + +import java.util.Random +import java.util.concurrent.{Executors, ExecutorService, TimeUnit} +import java.util.logging.Logger + +import hivemall.mix.MixMessage.MixEventName +import hivemall.mix.client.MixClient +import hivemall.mix.server.MixServer.ServerState +import hivemall.model.{DenseModel, PredictionModel, WeightValue} +import hivemall.utils.io.IOUtils +import hivemall.utils.lang.CommandLineUtils +import hivemall.utils.net.NetUtils +import org.scalatest.{BeforeAndAfter, FunSuite} + +class MixServerSuite extends FunSuite with BeforeAndAfter { + + private[this] var server: MixServer = _ + private[this] var executor : ExecutorService = _ + private[this] var port: Int = _ + + private[this] val rand = new Random(43) + private[this] val counter = Stream.from(0).iterator + + private[this] val eachTestTime = 100 + private[this] val logger = + Logger.getLogger(classOf[MixServerSuite].getName) + + before { + this.port = NetUtils.getAvailablePort + this.server = new MixServer( + CommandLineUtils.parseOptions( + Array("-port", s"${port}", "-sync_threshold", "3"), + MixServer.getOptions() + ) + ) + this.executor = Executors.newSingleThreadExecutor + this.executor.submit(server) + var retry = 0 + while (server.getState() != ServerState.RUNNING && retry < 50) { + Thread.sleep(1000L) + retry += 1 + } + assert(server.getState == ServerState.RUNNING) + } + + after { this.executor.shutdown() } + + private[this] def clientDriver( + groupId: String, model: PredictionModel, numMsg: Int = 1000000): Unit = { + var client: MixClient = null + try { + client = new MixClient(MixEventName.average, groupId, s"localhost:${port}", false, 2, model) + model.configureMix(client, false) + model.configureClock() + + for (_ <- 0 until numMsg) { + val feature = Integer.valueOf(rand.nextInt(model.size)) + model.set(feature, new WeightValue(1.0f)) + } + + while (true) { Thread.sleep(eachTestTime * 1000 + 100L) } + assert(model.getNumMixed > 0) + } finally { + IOUtils.closeQuietly(client) + } + } + + private[this] def fixedGroup: (String, () => String) = + ("fixed", () => "fixed") + private[this] def uniqueGroup: (String, () => String) = + ("unique", () => s"${counter.next}") + + Seq(65536).map { ndims => + Seq(4).map { nclient => + Seq(fixedGroup, uniqueGroup).map { id => + val testName = s"dense-dim:${ndims}-clinet:${nclient}-${id._1}" + ignore(testName) { + val clients = Executors.newCachedThreadPool() + val numClients = nclient + val models = (0 until numClients).map(i => new DenseModel(ndims, false)) + (0 until numClients).map { i => + clients.submit(new Runnable() { + override def run(): Unit = { + try { + clientDriver( + s"${testName}-${id._2}", + models(i) + ) + } catch { + case e: InterruptedException => + assert(false, e.getMessage) + } + } + }) + } + clients.awaitTermination(eachTestTime, TimeUnit.SECONDS) + clients.shutdown() + val nMixes = models.map(d => d.getNumMixed).reduce(_ + _) + logger.info(s"${testName} --> ${(nMixes + 0.0) / eachTestTime} mixes/s") + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala b/spark/spark-2.1/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala new file mode 100644 index 0000000..8c06837 --- /dev/null +++ b/spark/spark-2.1/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools + +import org.scalatest.FunSuite + +import org.apache.spark.sql.hive.test.TestHive + +class RegressionDatagenSuite extends FunSuite { + + test("datagen") { + val df = RegressionDatagen.exec( + TestHive, min_examples = 10000, n_features = 100, n_dims = 65536, dense = false, cl = true) + assert(df.count() >= 10000) + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/SparkFunSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/SparkFunSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/SparkFunSuite.scala new file mode 100644 index 0000000..0b101c8 --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark + +// scalastyle:off +import org.scalatest.{FunSuite, Outcome} + +import org.apache.spark.internal.Logging + +/** + * Base abstract class for all unit tests in Spark for handling common functionality. + */ +private[spark] abstract class SparkFunSuite extends FunSuite with Logging { +// scalastyle:on + + /** + * Log the suite name and the test name before and after each test. + * + * Subclasses should never override this method. If they wish to run + * custom code before and after each test, they should mix in the + * {{org.scalatest.BeforeAndAfter}} trait instead. + */ + final protected override def withFixture(test: NoArgTest): Outcome = { + val testName = test.text + val suiteName = this.getClass.getName + val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") + try { + logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") + test() + } finally { + logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala new file mode 100644 index 0000000..f57983f --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite + +class HivemallLabeledPointSuite extends SparkFunSuite { + + test("toString") { + val lp = HivemallLabeledPoint(1.0f, Seq("1:0.5", "3:0.3", "8:0.1")) + assert(lp.toString === "1.0,[1:0.5,3:0.3,8:0.1]") + } + + test("parse") { + val lp = HivemallLabeledPoint.parse("1.0,[1:0.5,3:0.3,8:0.1]") + assert(lp.label === 1.0) + assert(lp.features === Seq("1:0.5", "3:0.3", "8:0.1")) + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/QueryTest.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/QueryTest.scala new file mode 100644 index 0000000..14c8f1b --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql + +import java.util.{ArrayDeque, Locale, TimeZone} + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.streaming.MemoryPlan +import org.apache.spark.sql.types.{Metadata, ObjectType} + + +abstract class QueryTest extends PlanTest { + + protected def spark: SparkSession + + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + + /** + * Runs the plan and makes sure the answer contains all of the keywords. + */ + def checkKeywordsExist(df: DataFrame, keywords: String*): Unit = { + val outputs = df.collect().map(_.mkString).mkString + for (key <- keywords) { + assert(outputs.contains(key), s"Failed for $df ($key doesn't exist in result)") + } + } + + /** + * Runs the plan and makes sure the answer does NOT contain any of the keywords. + */ + def checkKeywordsNotExist(df: DataFrame, keywords: String*): Unit = { + val outputs = df.collect().map(_.mkString).mkString + for (key <- keywords) { + assert(!outputs.contains(key), s"Failed for $df ($key existed in the result)") + } + } + + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer. + */ + protected def checkDataset[T]( + ds: => Dataset[T], + expectedAnswer: T*): Unit = { + val result = getResult(ds) + + if (!compare(result.toSeq, expectedAnswer)) { + fail( + s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + |${ds.exprEnc.deserializer.treeString} + """.stripMargin) + } + } + + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer, after sort. + */ + protected def checkDatasetUnorderly[T : Ordering]( + ds: => Dataset[T], + expectedAnswer: T*): Unit = { + val result = getResult(ds) + + if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) { + fail( + s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + |${ds.exprEnc.deserializer.treeString} + """.stripMargin) + } + } + + private def getResult[T](ds: => Dataset[T]): Array[T] = { + val analyzedDS = try ds catch { + case ae: AnalysisException => + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${stackTraceToString(ae)} + """.stripMargin) + } else { + throw ae + } + } + assertEmptyMissingInput(analyzedDS) + + try ds.collect() catch { + case e: Exception => + fail( + s""" + |Exception collecting dataset as objects + |${ds.exprEnc} + |${ds.exprEnc.deserializer.treeString} + |${ds.queryExecution} + """.stripMargin, e) + } + } + + private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => + a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Iterable[_], b: Iterable[_]) => + a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a, b) => a == b + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + val analyzedDF = try df catch { + case ae: AnalysisException => + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${stackTraceToString(ae)} + |""".stripMargin) + } else { + throw ae + } + } + + assertEmptyMissingInput(analyzedDF) + + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } + + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * + * @param dataFrame the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Seq[Row], + absTol: Double): Unit = { + // TODO: catch exceptions in data frame execution + val actualAnswer = dataFrame.collect() + require(actualAnswer.length == expectedAnswer.length, + s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}") + + actualAnswer.zip(expectedAnswer).foreach { + case (actualRow, expectedRow) => + QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol) + } + } + + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Row, + absTol: Double): Unit = { + checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol) + } + + /** + * Asserts that a given [[Dataset]] will be executed using the given number of cached results. + */ + def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + + /** + * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. + */ + def assertEmptyMissingInput(query: Dataset[_]): Unit = { + assert(query.queryExecution.analyzed.missingInput.isEmpty, + s"The analyzed logical plan has missing inputs:\n${query.queryExecution.analyzed}") + assert(query.queryExecution.optimizedPlan.missingInput.isEmpty, + s"The optimized logical plan has missing inputs:\n${query.queryExecution.optimizedPlan}") + assert(query.queryExecution.executedPlan.missingInput.isEmpty, + s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}") + } +} + +object QueryTest { + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a [[None]] will + * be returned. + * + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. + */ + def checkAnswer( + df: DataFrame, + expectedAnswer: Seq[Row], + checkToRDD: Boolean = true): Option[String] = { + val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + if (checkToRDD) { + df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] + } + + val sparkAnswer = try df.collect().toSeq catch { + case e: Exception => + val errorMessage = + s""" + |Exception thrown while executing query: + |${df.queryExecution} + |== Exception == + |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + sameRows(expectedAnswer, sparkAnswer, isSorted).map { results => + s""" + |Results do not match for query: + |Timezone: ${TimeZone.getDefault} + |Timezone Env: ${sys.env.getOrElse("TZ", "")} + | + |${df.queryExecution} + |== Results == + |$results + """.stripMargin + } + } + + + def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + + def sameRows( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): Option[String] = { + if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { + val errorMessage = + s""" + |== Results == + |${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) + } + None + } + + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * + * @param actualAnswer the actual result in a [[Row]]. + * @param expectedAnswer the expected result in a[[Row]]. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = { + require(actualAnswer.length == expectedAnswer.length, + s"actual answer length ${actualAnswer.length} != " + + s"expected answer length ${expectedAnswer.length}") + + // TODO: support other numeric types besides Double + // TODO: support struct types? + actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach { + case (actual: Double, expected: Double) => + assert(math.abs(actual - expected) < absTol, + s"actual answer $actual not within $absTol of correct answer $expected") + case (actual, expected) => + assert(actual == expected, s"$actual did not equal $expected") + } + } + + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { + checkAnswer(df, expectedAnswer.asScala) match { + case Some(errorMessage) => errorMessage + case None => null + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c837e51a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala new file mode 100644 index 0000000..8672bf2 --- /dev/null +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util._ + +/** + * Provides helper methods for comparing plans. + */ +abstract class PlanTest extends SparkFunSuite with PredicateHelper { + /** + * Since attribute references are given globally unique ids during analysis, + * we must normalize them to check if two different queries are identical. + */ + protected def normalizeExprIds(plan: LogicalPlan) = { + plan transformAllExpressions { + case s: ScalarSubquery => + s.copy(exprId = ExprId(0)) + case e: Exists => + e.copy(exprId = ExprId(0)) + case l: ListQuery => + l.copy(exprId = ExprId(0)) + case p: PredicateSubquery => + p.copy(exprId = ExprId(0)) + case a: AttributeReference => + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) + case a: Alias => + Alias(a.child, a.name)(exprId = ExprId(0)) + case ae: AggregateExpression => + ae.copy(resultId = ExprId(0)) + } + } + + /** + * Normalizes plans: + * - Filter the filter conditions that appear in a plan. For instance, + * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) + * etc., will all now be equivalent. + * - Sample the seed will replaced by 0L. + * - Join conditions will be resorted by hashCode. + */ + private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + plan transform { + case filter @ Filter(condition: Expression, child: LogicalPlan) => + Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And), child) + case sample: Sample => + sample.copy(seed = 0L)(true) + case join @ Join(left, right, joinType, condition) if condition.isDefined => + val newCondition = + splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And) + Join(left, right, joinType, Some(newCondition)) + } + } + + /** + * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be + * equivalent: + * 1. (a = b), (b = a); + * 2. (a <=> b), (b <=> a). + */ + private def rewriteEqual(condition: Expression): Expression = condition match { + case eq @ EqualTo(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) + case eq @ EqualNullSafe(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + case _ => condition // Don't reorder. + } + + /** Fails the test if the two plans do not match */ + protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (normalized1 != normalized2) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Fails the test if the two expressions do not match */ + protected def compareExpressions(e1: Expression, e2: Expression): Unit = { + comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) + } +}