http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala deleted file mode 100644 index 056d6d6..0000000 --- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.hive - -import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ - -object HivemallUtils { - - // # of maximum dimensions for feature vectors - private[this] val maxDims = 100000000 - - /** - * Check whether the given schema contains a column of the required data type. - * @param colName column name - * @param dataType required column data type - */ - private[this] def checkColumnType(schema: StructType, colName: String, dataType: DataType) - : Unit = { - val actualDataType = schema(colName).dataType - require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.") - } - - def to_vector_func(dense: Boolean, dims: Int): Seq[String] => Vector = { - if (dense) { - // Dense features - i: Seq[String] => { - val features = new Array[Double](dims) - i.map { ft => - val s = ft.split(":").ensuring(_.size == 2) - features(s(0).toInt) = s(1).toDouble - } - Vectors.dense(features) - } - } else { - // Sparse features - i: Seq[String] => { - val features = i.map { ft => - // val s = ft.split(":").ensuring(_.size == 2) - val s = ft.split(":") - (s(0).toInt, s(1).toDouble) - } - Vectors.sparse(dims, features) - } - } - } - - def to_hivemall_features_func(): Vector => Array[String] = { - case dv: DenseVector => - dv.values.zipWithIndex.map { - case (value, index) => s"$index:$value" - } - case sv: SparseVector => - sv.values.zip(sv.indices).map { - case (value, index) => s"$index:$value" - } - case v => - throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}") - } - - def append_bias_func(): Vector => Vector = { - case dv: DenseVector => - val inputValues = dv.values - val inputLength = inputValues.length - val outputValues = Array.ofDim[Double](inputLength + 1) - System.arraycopy(inputValues, 0, outputValues, 0, inputLength) - outputValues(inputLength) = 1.0 - Vectors.dense(outputValues) - case sv: SparseVector => - val inputValues = sv.values - val inputIndices = sv.indices - val inputValuesLength = inputValues.length - val dim = sv.size - val outputValues = Array.ofDim[Double](inputValuesLength + 1) - val outputIndices = Array.ofDim[Int](inputValuesLength + 1) - System.arraycopy(inputValues, 0, outputValues, 0, inputValuesLength) - System.arraycopy(inputIndices, 0, outputIndices, 0, inputValuesLength) - outputValues(inputValuesLength) = 1.0 - outputIndices(inputValuesLength) = dim - Vectors.sparse(dim + 1, outputIndices, outputValues) - case v => - throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}") - } - - /** - * Transforms Hivemall features into a [[Vector]]. - */ - def to_vector(dense: Boolean = false, dims: Int = maxDims): UserDefinedFunction = { - udf(to_vector_func(dense, dims)) - } - - /** - * Transforms a [[Vector]] into Hivemall features. - */ - def to_hivemall_features: UserDefinedFunction = udf(to_hivemall_features_func) - - /** - * Returns a new [[Vector]] with `1.0` (bias) appended to the input [[Vector]]. - * @group ftvec - */ - def append_bias: UserDefinedFunction = udf(append_bias_func) - - /** - * Builds a [[Vector]]-based model from a table of Hivemall models - */ - def vectorized_model(df: DataFrame, dense: Boolean = false, dims: Int = maxDims) - : UserDefinedFunction = { - checkColumnType(df.schema, "feature", StringType) - checkColumnType(df.schema, "weight", DoubleType) - - import df.sqlContext.implicits._ - val intercept = df - .where($"feature" === "0") - .select($"weight") - .map { case Row(weight: Double) => weight} - .reduce(_ + _) - val weights = to_vector_func(dense, dims)( - df.select($"feature", $"weight") - .where($"feature" !== "0") - .map { case Row(label: String, feature: Double) => s"${label}:$feature"} - .collect.toSeq) - - udf((input: Vector) => BLAS.dot(input, weights) + intercept) - } -}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala deleted file mode 100644 index ab5c5fb..0000000 --- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.hive.internal - -import org.apache.spark.internal.Logging -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan} -import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper - -/** - * This is an implementation class for [[org.apache.spark.sql.hive.HivemallOps]]. - * This class mainly uses the internal Spark classes (e.g., `Generate` and `HiveGenericUDTF`) that - * have unstable interfaces (so, these interfaces may evolve in upcoming releases). - * Therefore, the objective of this class is to extract these unstable parts - * from [[org.apache.spark.sql.hive.HivemallOps]]. - */ -private[hive] object HivemallOpsImpl extends Logging { - - def planHiveUDF( - className: String, - funcName: String, - argumentExprs: Seq[Column]): Expression = { - HiveSimpleUDF( - name = funcName, - funcWrapper = new HiveFunctionWrapper(className), - children = argumentExprs.map(_.expr) - ) - } - - def planHiveGenericUDF( - className: String, - funcName: String, - argumentExprs: Seq[Column]): Expression = { - HiveGenericUDF( - name = funcName, - funcWrapper = new HiveFunctionWrapper(className), - children = argumentExprs.map(_.expr) - ) - } - - def planHiveGenericUDTF( - df: DataFrame, - className: String, - funcName: String, - argumentExprs: Seq[Column], - outputAttrNames: Seq[String]): LogicalPlan = { - Generate( - generator = HiveGenericUDTF( - name = funcName, - funcWrapper = new HiveFunctionWrapper(className), - children = argumentExprs.map(_.expr) - ), - join = false, - outer = false, - qualifier = None, - generatorOutput = outputAttrNames.map(UnresolvedAttribute(_)), - child = df.logicalPlan) - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala deleted file mode 100644 index 9cdc09f..0000000 --- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.hive.source - -import java.io.File -import java.io.IOException -import java.net.URI - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FSDataInputStream, Path} -import org.apache.hadoop.io.IOUtils -import org.apache.hadoop.mapreduce._ - -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.SerializableConfiguration - -private[source] final class XGBoostOutputWriter( - path: String, - dataSchema: StructType, - context: TaskAttemptContext) - extends OutputWriter { - - private val hadoopConf = new SerializableConfiguration(new Configuration()) - - override def write(row: Row): Unit = { - val modelId = row.getString(0) - val model = row.get(1).asInstanceOf[Array[Byte]] - val filePath = new Path(new URI(s"$path/$modelId")) - val fs = filePath.getFileSystem(hadoopConf.value) - val outputFile = fs.create(filePath) - outputFile.write(model) - outputFile.close() - } - - override def close(): Unit = {} -} - -final class XGBoostFileFormat extends FileFormat with DataSourceRegister { - - override def shortName(): String = "libxgboost" - - override def toString: String = "XGBoost" - - private def verifySchema(dataSchema: StructType): Unit = { - if ( - dataSchema.size != 2 || - !dataSchema(0).dataType.sameType(StringType) || - !dataSchema(1).dataType.sameType(BinaryType) - ) { - throw new IOException(s"Illegal schema for XGBoost data, schema=$dataSchema") - } - } - - override def inferSchema( - sparkSession: SparkSession, - options: Map[String, String], - files: Seq[FileStatus]): Option[StructType] = { - Some( - StructType( - StructField("model_id", StringType, nullable = false) :: - StructField("pred_model", BinaryType, nullable = false) :: Nil) - ) - } - - override def prepareWrite( - sparkSession: SparkSession, - job: Job, - options: Map[String, String], - dataSchema: StructType): OutputWriterFactory = { - new OutputWriterFactory { - override def newInstance( - path: String, - bucketId: Option[Int], - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - if (bucketId.isDefined) { - sys.error("XGBoostFileFormat doesn't support bucketing") - } - new XGBoostOutputWriter(path, dataSchema, context) - } - } - } - - override def buildReader( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - verifySchema(dataSchema) - val broadcastedHadoopConf = - sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - - (file: PartitionedFile) => { - val model = new Array[Byte](file.length.asInstanceOf[Int]) - val filePath = new Path(new URI(file.filePath)) - val fs = filePath.getFileSystem(broadcastedHadoopConf.value.value) - - var in: FSDataInputStream = null - try { - in = fs.open(filePath) - IOUtils.readFully(in, model, 0, model.length) - } finally { - IOUtils.closeStream(in) - } - - val converter = RowEncoder(dataSchema) - val fullOutput = dataSchema.map { f => - AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() - } - val requiredOutput = fullOutput.filter { a => - requiredSchema.fieldNames.contains(a.name) - } - val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput) - (requiredColumns( - converter.toRow(Row(new File(file.filePath).getName, model))) - :: Nil - ).toIterator - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/main/scala/org/apache/spark/streaming/HivemallStreamingOps.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/streaming/HivemallStreamingOps.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/streaming/HivemallStreamingOps.scala deleted file mode 100644 index a6bbb4b..0000000 --- a/spark/spark-2.0/src/main/scala/org/apache/spark/streaming/HivemallStreamingOps.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.streaming - -import scala.reflect.ClassTag - -import org.apache.spark.ml.feature.HivemallLabeledPoint -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} -import org.apache.spark.streaming.dstream.DStream - -final class HivemallStreamingOps(ds: DStream[HivemallLabeledPoint]) { - - def predict[U: ClassTag](f: DataFrame => DataFrame)(implicit sqlContext: SQLContext) - : DStream[Row] = { - ds.transform[Row] { rdd: RDD[HivemallLabeledPoint] => - f(sqlContext.createDataFrame(rdd)).rdd - } - } -} - -object HivemallStreamingOps { - - /** - * Implicitly inject the [[HivemallStreamingOps]] into [[DStream]]. - */ - implicit def dataFrameToHivemallStreamingOps(ds: DStream[HivemallLabeledPoint]) - : HivemallStreamingOps = { - new HivemallStreamingOps(ds) - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/resources/data/files/README.md ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/resources/data/files/README.md b/spark/spark-2.0/src/test/resources/data/files/README.md deleted file mode 100644 index 238d472..0000000 --- a/spark/spark-2.0/src/test/resources/data/files/README.md +++ /dev/null @@ -1,22 +0,0 @@ -<!-- - Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. ---> - -The files in this dir exist for preventing exceptions in o.a.s.sql.hive.test.TESTHive. -We need to fix this issue in future. - http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/resources/data/files/complex.seq ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/resources/data/files/complex.seq b/spark/spark-2.0/src/test/resources/data/files/complex.seq deleted file mode 100644 index e69de29..0000000 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/resources/data/files/episodes.avro ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/resources/data/files/episodes.avro b/spark/spark-2.0/src/test/resources/data/files/episodes.avro deleted file mode 100644 index e69de29..0000000 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/resources/data/files/json.txt ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/resources/data/files/json.txt b/spark/spark-2.0/src/test/resources/data/files/json.txt deleted file mode 100644 index e69de29..0000000 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/resources/data/files/kv1.txt ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/resources/data/files/kv1.txt b/spark/spark-2.0/src/test/resources/data/files/kv1.txt deleted file mode 100644 index e69de29..0000000 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/resources/data/files/kv3.txt ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/resources/data/files/kv3.txt b/spark/spark-2.0/src/test/resources/data/files/kv3.txt deleted file mode 100644 index e69de29..0000000 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/resources/log4j.properties b/spark/spark-2.0/src/test/resources/log4j.properties deleted file mode 100644 index c6e4297..0000000 --- a/spark/spark-2.0/src/test/resources/log4j.properties +++ /dev/null @@ -1,24 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the console -log4j.rootCategory=FATAL, console -log4j.appender.console=org.apache.log4j.ConsoleAppender -log4j.appender.console.target=System.err -log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala b/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala deleted file mode 100644 index cadc852..0000000 --- a/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.mix.server - -import java.util.Random -import java.util.concurrent.{Executors, ExecutorService, TimeUnit} -import java.util.logging.Logger - -import hivemall.mix.client.MixClient -import hivemall.mix.MixMessage.MixEventName -import hivemall.mix.server.MixServer.ServerState -import hivemall.model.{DenseModel, PredictionModel} -import hivemall.model.{NewDenseModel, PredictionModel} -import hivemall.model.WeightValue -import hivemall.utils.io.IOUtils -import hivemall.utils.lang.CommandLineUtils -import hivemall.utils.net.NetUtils -import org.scalatest.{BeforeAndAfter, FunSuite} - -class MixServerSuite extends FunSuite with BeforeAndAfter { - - private[this] var server: MixServer = _ - private[this] var executor : ExecutorService = _ - private[this] var port: Int = _ - - private[this] val rand = new Random(43) - private[this] val counter = Stream.from(0).iterator - - private[this] val eachTestTime = 100 - private[this] val logger = - Logger.getLogger(classOf[MixServerSuite].getName) - - before { - this.port = NetUtils.getAvailablePort - this.server = new MixServer( - CommandLineUtils.parseOptions( - Array("-port", s"${port}", "-sync_threshold", "3"), - MixServer.getOptions() - ) - ) - this.executor = Executors.newSingleThreadExecutor - this.executor.submit(server) - var retry = 0 - while (server.getState() != ServerState.RUNNING && retry < 50) { - Thread.sleep(1000L) - retry += 1 - } - assert(server.getState == ServerState.RUNNING) - } - - after { this.executor.shutdown() } - - private[this] def clientDriver( - groupId: String, model: PredictionModel, numMsg: Int = 1000000): Unit = { - var client: MixClient = null - try { - client = new MixClient(MixEventName.average, groupId, s"localhost:${port}", false, 2, model) - model.configureMix(client, false) - model.configureClock() - - for (_ <- 0 until numMsg) { - val feature = Integer.valueOf(rand.nextInt(model.size)) - model.set(feature, new WeightValue(1.0f)) - } - - while (true) { Thread.sleep(eachTestTime * 1000 + 100L) } - assert(model.getNumMixed > 0) - } finally { - IOUtils.closeQuietly(client) - } - } - - private[this] def fixedGroup: (String, () => String) = - ("fixed", () => "fixed") - private[this] def uniqueGroup: (String, () => String) = - ("unique", () => s"${counter.next}") - - Seq(65536).map { ndims => - Seq(4).map { nclient => - Seq(fixedGroup, uniqueGroup).map { id => - val testName = s"dense-dim:${ndims}-clinet:${nclient}-${id._1}" - ignore(testName) { - val clients = Executors.newCachedThreadPool() - val numClients = nclient - val models = (0 until numClients).map(i => new NewDenseModel(ndims, false)) - (0 until numClients).map { i => - clients.submit(new Runnable() { - override def run(): Unit = { - try { - clientDriver( - s"${testName}-${id._2}", - models(i) - ) - } catch { - case e: InterruptedException => - assert(false, e.getMessage) - } - } - }) - } - clients.awaitTermination(eachTestTime, TimeUnit.SECONDS) - clients.shutdown() - val nMixes = models.map(d => d.getNumMixed).reduce(_ + _) - logger.info(s"${testName} --> ${(nMixes + 0.0) / eachTestTime} mixes/s") - } - } - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala b/spark/spark-2.0/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala deleted file mode 100644 index 8c06837..0000000 --- a/spark/spark-2.0/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.tools - -import org.scalatest.FunSuite - -import org.apache.spark.sql.hive.test.TestHive - -class RegressionDatagenSuite extends FunSuite { - - test("datagen") { - val df = RegressionDatagen.exec( - TestHive, min_examples = 10000, n_features = 100, n_dims = 65536, dense = false, cl = true) - assert(df.count() >= 10000) - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/scala/org/apache/spark/SparkFunSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/SparkFunSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/SparkFunSuite.scala deleted file mode 100644 index 0b101c8..0000000 --- a/spark/spark-2.0/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark - -// scalastyle:off -import org.scalatest.{FunSuite, Outcome} - -import org.apache.spark.internal.Logging - -/** - * Base abstract class for all unit tests in Spark for handling common functionality. - */ -private[spark] abstract class SparkFunSuite extends FunSuite with Logging { -// scalastyle:on - - /** - * Log the suite name and the test name before and after each test. - * - * Subclasses should never override this method. If they wish to run - * custom code before and after each test, they should mix in the - * {{org.scalatest.BeforeAndAfter}} trait instead. - */ - final protected override def withFixture(test: NoArgTest): Outcome = { - val testName = test.text - val suiteName = this.getClass.getName - val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") - try { - logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") - test() - } finally { - logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala deleted file mode 100644 index f57983f..0000000 --- a/spark/spark-2.0/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.ml.feature - -import org.apache.spark.SparkFunSuite - -class HivemallLabeledPointSuite extends SparkFunSuite { - - test("toString") { - val lp = HivemallLabeledPoint(1.0f, Seq("1:0.5", "3:0.3", "8:0.1")) - assert(lp.toString === "1.0,[1:0.5,3:0.3,8:0.1]") - } - - test("parse") { - val lp = HivemallLabeledPoint.parse("1.0,[1:0.5,3:0.3,8:0.1]") - assert(lp.label === 1.0) - assert(lp.features === Seq("1:0.5", "3:0.3", "8:0.1")) - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/QueryTest.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/QueryTest.scala deleted file mode 100644 index 8b03911..0000000 --- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ /dev/null @@ -1,475 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql - -import java.util.{ArrayDeque, Locale, TimeZone} - -import scala.collection.JavaConverters._ -import scala.util.control.NonFatal - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.LogicalRDD -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.streaming.MemoryPlan -import org.apache.spark.sql.types.ObjectType - -abstract class QueryTest extends PlanTest { - - protected def spark: SparkSession - - // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - // Add Locale setting - Locale.setDefault(Locale.US) - - /** - * Runs the plan and makes sure the answer contains all of the keywords. - */ - def checkKeywordsExist(df: DataFrame, keywords: String*): Unit = { - val outputs = df.collect().map(_.mkString).mkString - for (key <- keywords) { - assert(outputs.contains(key), s"Failed for $df ($key doesn't exist in result)") - } - } - - /** - * Runs the plan and makes sure the answer does NOT contain any of the keywords. - */ - def checkKeywordsNotExist(df: DataFrame, keywords: String*): Unit = { - val outputs = df.collect().map(_.mkString).mkString - for (key <- keywords) { - assert(!outputs.contains(key), s"Failed for $df ($key existed in the result)") - } - } - - /** - * Evaluates a dataset to make sure that the result of calling collect matches the given - * expected answer. - */ - protected def checkDataset[T]( - ds: => Dataset[T], - expectedAnswer: T*): Unit = { - val result = getResult(ds) - - if (!compare(result.toSeq, expectedAnswer)) { - fail( - s""" - |Decoded objects do not match expected objects: - |expected: $expectedAnswer - |actual: ${result.toSeq} - |${ds.exprEnc.deserializer.treeString} - """.stripMargin) - } - } - - /** - * Evaluates a dataset to make sure that the result of calling collect matches the given - * expected answer, after sort. - */ - protected def checkDatasetUnorderly[T : Ordering]( - ds: => Dataset[T], - expectedAnswer: T*): Unit = { - val result = getResult(ds) - - if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) { - fail( - s""" - |Decoded objects do not match expected objects: - |expected: $expectedAnswer - |actual: ${result.toSeq} - |${ds.exprEnc.deserializer.treeString} - """.stripMargin) - } - } - - private def getResult[T](ds: => Dataset[T]): Array[T] = { - val analyzedDS = try ds catch { - case ae: AnalysisException => - if (ae.plan.isDefined) { - fail( - s""" - |Failed to analyze query: $ae - |${ae.plan.get} - | - |${stackTraceToString(ae)} - """.stripMargin) - } else { - throw ae - } - } - checkJsonFormat(analyzedDS) - assertEmptyMissingInput(analyzedDS) - - try ds.collect() catch { - case e: Exception => - fail( - s""" - |Exception collecting dataset as objects - |${ds.exprEnc} - |${ds.exprEnc.deserializer.treeString} - |${ds.queryExecution} - """.stripMargin, e) - } - } - - private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { - case (null, null) => true - case (null, _) => false - case (_, null) => false - case (a: Array[_], b: Array[_]) => - a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} - case (a: Iterable[_], b: Iterable[_]) => - a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} - case (a, b) => a == b - } - - /** - * Runs the plan and makes sure the answer matches the expected result. - * - * @param df the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - */ - protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - val analyzedDF = try df catch { - case ae: AnalysisException => - if (ae.plan.isDefined) { - fail( - s""" - |Failed to analyze query: $ae - |${ae.plan.get} - | - |${stackTraceToString(ae)} - |""".stripMargin) - } else { - throw ae - } - } - - checkJsonFormat(analyzedDF) - - assertEmptyMissingInput(analyzedDF) - - QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } - } - - protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(df, Seq(expectedAnswer)) - } - - protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { - checkAnswer(df, expectedAnswer.collect()) - } - - /** - * Runs the plan and makes sure the answer is within absTol of the expected result. - * - * @param dataFrame the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param absTol the absolute tolerance between actual and expected answers. - */ - protected def checkAggregatesWithTol(dataFrame: DataFrame, - expectedAnswer: Seq[Row], - absTol: Double): Unit = { - // TODO: catch exceptions in data frame execution - val actualAnswer = dataFrame.collect() - require(actualAnswer.length == expectedAnswer.length, - s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}") - - actualAnswer.zip(expectedAnswer).foreach { - case (actualRow, expectedRow) => - QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol) - } - } - - protected def checkAggregatesWithTol(dataFrame: DataFrame, - expectedAnswer: Row, - absTol: Double): Unit = { - checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol) - } - - /** - * Asserts that a given [[Dataset]] will be executed using the given number of cached results. - */ - def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } - - private def checkJsonFormat(ds: Dataset[_]): Unit = { - // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that - // RDD and Data resolution does not break. - val logicalPlan = ds.queryExecution.analyzed - - // bypass some cases that we can't handle currently. - logicalPlan.transform { - case _: ObjectConsumer => return - case _: ObjectProducer => return - case _: AppendColumns => return - case _: LogicalRelation => return - case p if p.getClass.getSimpleName == "MetastoreRelation" => return - case _: MemoryPlan => return - }.transformAllExpressions { - case a: ImperativeAggregate => return - case _: TypedAggregateExpression => return - case Literal(_, _: ObjectType) => return - } - - // bypass hive tests before we fix all corner cases in hive module. - if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return - - val jsonString = try { - logicalPlan.toJSON - } catch { - case NonFatal(e) => - fail( - s""" - |Failed to parse logical plan to JSON: - |${logicalPlan.treeString} - """.stripMargin, e) - } - - // scala function is not serializable to JSON, use null to replace them so that we can compare - // the plans later. - val normalized1 = logicalPlan.transformAllExpressions { - case udf: ScalaUDF => udf.copy(function = null) - case gen: UserDefinedGenerator => gen.copy(function = null) - } - - // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains - // these non-serializable stuff, and use these original ones to replace the null-placeholders - // in the logical plans parsed from JSON. - val logicalRDDs = new ArrayDeque[LogicalRDD]() - val localRelations = new ArrayDeque[LocalRelation]() - val inMemoryRelations = new ArrayDeque[InMemoryRelation]() - def collectData: (LogicalPlan => Unit) = { - case l: LogicalRDD => - logicalRDDs.offer(l) - case l: LocalRelation => - localRelations.offer(l) - case i: InMemoryRelation => - inMemoryRelations.offer(i) - case p => - p.expressions.foreach { - _.foreach { - case s: SubqueryExpression => - s.query.foreach(collectData) - case _ => - } - } - } - logicalPlan.foreach(collectData) - - - val jsonBackPlan = try { - TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext) - } catch { - case NonFatal(e) => - fail( - s""" - |Failed to rebuild the logical plan from JSON: - |${logicalPlan.treeString} - | - |${logicalPlan.prettyJson} - """.stripMargin, e) - } - - def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = { - case l: LogicalRDD => - val origin = logicalRDDs.pop() - LogicalRDD(l.output, origin.rdd)(spark) - case l: LocalRelation => - val origin = localRelations.pop() - l.copy(data = origin.data) - case l: InMemoryRelation => - val origin = inMemoryRelations.pop() - InMemoryRelation( - l.output, - l.useCompression, - l.batchSize, - l.storageLevel, - origin.child, - l.tableName)( - origin.cachedColumnBuffers, - origin.batchStats) - case p => - p.transformExpressions { - case s: SubqueryExpression => - s.withNewPlan(s.query.transformDown(renormalize)) - } - } - val normalized2 = jsonBackPlan.transformDown(renormalize) - - assert(logicalRDDs.isEmpty) - assert(localRelations.isEmpty) - assert(inMemoryRelations.isEmpty) - - if (normalized1 != normalized2) { - fail( - s""" - |== FAIL: the logical plan parsed from json does not match the original one === - |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - - /** - * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. - */ - def assertEmptyMissingInput(query: Dataset[_]): Unit = { - assert(query.queryExecution.analyzed.missingInput.isEmpty, - s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}") - assert(query.queryExecution.optimizedPlan.missingInput.isEmpty, - s"The optimized logical plan has missing inputs: ${query.queryExecution.optimizedPlan}") - assert(query.queryExecution.executedPlan.missingInput.isEmpty, - s"The physical plan has missing inputs: ${query.queryExecution.executedPlan}") - } -} - -object QueryTest { - /** - * Runs the plan and makes sure the answer matches the expected result. - * If there was exception during the execution or the contents of the DataFrame does not - * match the expected result, an error message will be returned. Otherwise, a [[None]] will - * be returned. - * - * @param df the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. - */ - def checkAnswer( - df: DataFrame, - expectedAnswer: Seq[Row], - checkToRDD: Boolean = true): Option[String] = { - val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - if (checkToRDD) { - df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] - } - - val sparkAnswer = try df.collect().toSeq catch { - case e: Exception => - val errorMessage = - s""" - |Exception thrown while executing query: - |${df.queryExecution} - |== Exception == - |$e - |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - - sameRows(expectedAnswer, sparkAnswer, isSorted).map { results => - s""" - |Results do not match for query: - |${df.queryExecution} - |== Results == - |$results - """.stripMargin - } - } - - - def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - val converted: Seq[Row] = answer.map(prepareRow) - if (!isSorted) converted.sortBy(_.toString()) else converted - } - - // We need to call prepareRow recursively to handle schemas with struct types. - def prepareRow(row: Row): Row = { - Row.fromSeq(row.toSeq.map { - case null => null - case d: java.math.BigDecimal => BigDecimal(d) - // Convert array to Seq for easy equality check. - case b: Array[_] => b.toSeq - case r: Row => prepareRow(r) - case o => o - }) - } - - def sameRows( - expectedAnswer: Seq[Row], - sparkAnswer: Seq[Row], - isSorted: Boolean = false): Option[String] = { - if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { - val errorMessage = - s""" - |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer, isSorted).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} - """.stripMargin - return Some(errorMessage) - } - None - } - - /** - * Runs the plan and makes sure the answer is within absTol of the expected result. - * - * @param actualAnswer the actual result in a [[Row]]. - * @param expectedAnswer the expected result in a[[Row]]. - * @param absTol the absolute tolerance between actual and expected answers. - */ - protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = { - require(actualAnswer.length == expectedAnswer.length, - s"actual answer length ${actualAnswer.length} != " + - s"expected answer length ${expectedAnswer.length}") - - // TODO: support other numeric types besides Double - // TODO: support struct types? - actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach { - case (actual: Double, expected: Double) => - assert(math.abs(actual - expected) < absTol, - s"actual answer $actual not within $absTol of correct answer $expected") - case (actual, expected) => - assert(actual == expected, s"$actual did not equal $expected") - } - } - - def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { - checkAnswer(df, expectedAnswer.asScala) match { - case Some(errorMessage) => errorMessage - case None => null - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala deleted file mode 100644 index 816576e..0000000 --- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.catalyst.plans - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation} -import org.apache.spark.sql.catalyst.util._ - -/** - * Provides helper methods for comparing plans. - */ -class PlanTest extends SparkFunSuite { - - /** - * Since attribute references are given globally unique ids during analysis, - * we must normalize them to check if two different queries are identical. - */ - protected def normalizeExprIds(plan: LogicalPlan) = { - plan transformAllExpressions { - case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) - case a: Alias => - Alias(a.child, a.name)(exprId = ExprId(0)) - } - } - - /** Fails the test if the two plans do not match */ - protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeExprIds(plan1) - val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - - /** Fails the test if the two expressions do not match */ - protected def compareExpressions(e1: Expression, e2: Expression): Unit = { - comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cd24be89/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala deleted file mode 100644 index 4a43afc..0000000 --- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.hive - -import org.apache.spark.sql.Row -import org.apache.spark.sql.hive.HivemallUtils._ -import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest -import org.apache.spark.test.VectorQueryTest - -final class HiveUdfWithFeatureSuite extends HivemallFeatureQueryTest { - import hiveContext.implicits._ - import hiveContext._ - - test("hivemall_version") { - sql(s""" - | CREATE TEMPORARY FUNCTION hivemall_version - | AS '${classOf[hivemall.HivemallVersionUDF].getName}' - """.stripMargin) - - checkAnswer( - sql(s"SELECT DISTINCT hivemall_version()"), - Row("0.5.1-incubating-SNAPSHOT") - ) - - // sql("DROP TEMPORARY FUNCTION IF EXISTS hivemall_version") - // reset() - } - - test("train_logregr") { - TinyTrainData.createOrReplaceTempView("TinyTrainData") - sql(s""" - | CREATE TEMPORARY FUNCTION train_logregr - | AS '${classOf[hivemall.regression.LogressUDTF].getName}' - """.stripMargin) - sql(s""" - | CREATE TEMPORARY FUNCTION add_bias - | AS '${classOf[hivemall.ftvec.AddBiasUDFWrapper].getName}' - """.stripMargin) - - val model = sql( - s""" - | SELECT feature, AVG(weight) AS weight - | FROM ( - | SELECT train_logregr(add_bias(features), label) AS (feature, weight) - | FROM TinyTrainData - | ) t - | GROUP BY feature - """.stripMargin) - - checkAnswer( - model.select($"feature"), - Seq(Row("0"), Row("1"), Row("2")) - ) - - // TODO: Why 'train_logregr' is not registered in HiveMetaStore? - // ERROR RetryingHMSHandler: MetaException(message:NoSuchObjectException - // (message:Function default.train_logregr does not exist)) - // - // hiveContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_logregr") - // hiveContext.reset() - } - - test("each_top_k") { - val testDf = Seq( - ("a", "1", 0.5, Array(0, 1, 2)), - ("b", "5", 0.1, Array(3)), - ("a", "3", 0.8, Array(2, 5)), - ("c", "6", 0.3, Array(1, 3)), - ("b", "4", 0.3, Array(2)), - ("a", "2", 0.6, Array(1)) - ).toDF("key", "value", "score", "data") - - import testDf.sqlContext.implicits._ - testDf.repartition($"key").sortWithinPartitions($"key").createOrReplaceTempView("TestData") - sql(s""" - | CREATE TEMPORARY FUNCTION each_top_k - | AS '${classOf[hivemall.tools.EachTopKUDTF].getName}' - """.stripMargin) - - // Compute top-1 rows for each group - checkAnswer( - sql("SELECT each_top_k(1, key, score, key, value) FROM TestData"), - Row(1, 0.8, "a", "3") :: - Row(1, 0.3, "b", "4") :: - Row(1, 0.3, "c", "6") :: - Nil - ) - - // Compute reverse top-1 rows for each group - checkAnswer( - sql("SELECT each_top_k(-1, key, score, key, value) FROM TestData"), - Row(1, 0.5, "a", "1") :: - Row(1, 0.1, "b", "5") :: - Row(1, 0.3, "c", "6") :: - Nil - ) - } -} - -final class HiveUdfWithVectorSuite extends VectorQueryTest { - import hiveContext._ - - test("to_hivemall_features") { - mllibTrainDf.createOrReplaceTempView("mllibTrainDf") - hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func) - checkAnswer( - sql( - s""" - | SELECT to_hivemall_features(features) - | FROM mllibTrainDf - """.stripMargin), - Seq( - Row(Seq("0:1.0", "2:2.0", "4:3.0")), - Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2")), - Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0")), - Row(Seq("1:4.0", "3:5.0", "5:6.0")) - ) - ) - } - - test("append_bias") { - mllibTrainDf.createOrReplaceTempView("mllibTrainDf") - hiveContext.udf.register("append_bias", append_bias_func) - hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func) - checkAnswer( - sql( - s""" - | SELECT to_hivemall_features(append_bias(features)) - | FROM mllibTrainDF - """.stripMargin), - Seq( - Row(Seq("0:1.0", "2:2.0", "4:3.0", "7:1.0")), - Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2", "7:1.0")), - Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0", "7:1.0")), - Row(Seq("1:4.0", "3:5.0", "5:6.0", "7:1.0")) - ) - ) - } - - ignore("explode_vector") { - // TODO: Spark-2.0 does not support use-defined generator function in - // `org.apache.spark.sql.UDFRegistration`. - } -}
