Close #34, #23: [HIVEMALL-31][SPARK] Support Spark-v2.1.0
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/d3afb11b Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/d3afb11b Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/d3afb11b Branch: refs/heads/master Commit: d3afb11ba3a71735e035fd4dd481372653af6357 Parents: bfd129c Author: Takeshi YAMAMURO <linguin....@gmail.com> Authored: Tue Jan 31 23:04:40 2017 +0900 Committer: Takeshi YAMAMURO <linguin....@gmail.com> Committed: Tue Jan 31 23:04:40 2017 +0900 ---------------------------------------------------------------------- .travis.yml | 5 +- NOTICE | 12 +- pom.xml | 11 + spark/spark-2.1/bin/mvn-zinc | 99 ++ spark/spark-2.1/extra-src/README | 1 + .../org/apache/spark/sql/hive/HiveShim.scala | 279 ++++ spark/spark-2.1/pom.xml | 269 ++++ .../java/hivemall/xgboost/XGBoostOptions.scala | 58 + .../XGBoostBinaryClassifierUDTFWrapper.java | 47 + .../XGBoostMulticlassClassifierUDTFWrapper.java | 47 + .../XGBoostRegressionUDTFWrapper.java | 47 + ....apache.spark.sql.sources.DataSourceRegister | 1 + .../src/main/resources/log4j.properties | 12 + .../hivemall/tools/RegressionDatagen.scala | 66 + .../sql/catalyst/expressions/EachTopK.scala | 109 ++ .../spark/sql/hive/HivemallGroupedDataset.scala | 303 ++++ .../org/apache/spark/sql/hive/HivemallOps.scala | 1368 ++++++++++++++++++ .../apache/spark/sql/hive/HivemallUtils.scala | 145 ++ .../sql/hive/internal/HivemallOpsImpl.scala | 78 + .../sql/hive/source/XGBoostFileFormat.scala | 161 +++ .../src/test/resources/data/files/README.md | 3 + .../src/test/resources/data/files/complex.seq | 0 .../src/test/resources/data/files/episodes.avro | 0 .../src/test/resources/data/files/json.txt | 0 .../src/test/resources/data/files/kv1.txt | 0 .../src/test/resources/data/files/kv3.txt | 0 .../src/test/resources/log4j.properties | 7 + .../hivemall/mix/server/MixServerSuite.scala | 123 ++ .../hivemall/tools/RegressionDatagenSuite.scala | 32 + .../scala/org/apache/spark/SparkFunSuite.scala | 50 + .../ml/feature/HivemallLabeledPointSuite.scala | 35 + .../scala/org/apache/spark/sql/QueryTest.scala | 359 +++++ .../spark/sql/catalyst/plans/PlanTest.scala | 108 ++ .../apache/spark/sql/hive/HiveUdfSuite.scala | 160 ++ .../spark/sql/hive/HivemallOpsSuite.scala | 784 ++++++++++ .../spark/sql/hive/ModelMixingSuite.scala | 285 ++++ .../apache/spark/sql/hive/XGBoostSuite.scala | 150 ++ .../sql/hive/benchmark/MiscBenchmark.scala | 224 +++ .../hive/test/HivemallFeatureQueryTest.scala | 112 ++ .../spark/sql/hive/test/TestHiveSingleton.scala | 38 + .../org/apache/spark/sql/test/SQLTestData.scala | 314 ++++ .../apache/spark/sql/test/SQLTestUtils.scala | 335 +++++ .../apache/spark/sql/test/VectorQueryTest.scala | 88 ++ .../streaming/HivemallOpsWithFeatureSuite.scala | 154 ++ .../scala/org/apache/spark/test/TestUtils.scala | 64 + 45 files changed, 6538 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/.travis.yml ---------------------------------------------------------------------- diff --git a/.travis.yml b/.travis.yml index ffa529c..a4cb7ea 100644 --- a/.travis.yml +++ b/.travis.yml @@ -34,8 +34,9 @@ notifications: email: false script: - - mvn -q scalastyle:check test -Pspark-2.0 - # test the spark-1.6 module only in this second run + - mvn -q scalastyle:check test -Pspark-2.1 + # test the spark-1.6/2.0 modules only in the following runs + - mvn -q scalastyle:check clean -Pspark-2.0 -pl spark/spark-2.0 -am test -Dtest=none - mvn -q scalastyle:check clean -Pspark-1.6 -pl spark/spark-1.6 -am test -Dtest=none after_success: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/NOTICE ---------------------------------------------------------------------- diff --git a/NOTICE b/NOTICE index 013774b..699b055 100644 --- a/NOTICE +++ b/NOTICE @@ -38,7 +38,7 @@ Copyright notifications which have been relocated from ASF projects o hivemall/core/src/main/java/hivemall/utils/math/MathUtils.java#erfInv() - Copyright (C) 2003-2016 The Apache Software Foundation. + Copyright (C) 2003-2016 The Apache Software Foundation. http://commons.apache.org/proper/commons-math/ Licensed under the Apache License, Version 2.0 @@ -55,8 +55,14 @@ o hivemall/spark/spark-1.6/extra-src/hive/src/main/scala/org/apache/spark/sql/hi hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/QueryTest.scala hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala - - Copyright (C) 2014-2016 The Apache Software Foundation. + hivemall/spark/spark-2.1/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala + hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/QueryTest.scala + hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala + hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala + hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala + hivemall/spark/spark-2.1/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala + + Copyright (C) 2014-2017 The Apache Software Foundation. http://spark.apache.org/ Licensed under the Apache License, Version 2.0 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index 715e8d4..17ee4ee 100644 --- a/pom.xml +++ b/pom.xml @@ -235,6 +235,17 @@ <profiles> <profile> + <id>spark-2.1</id> + <modules> + <module>spark/spark-2.1</module> + <module>spark/spark-common</module> + </modules> + <properties> + <spark.version>2.1.0</spark.version> + <spark.binary.version>2.1</spark.binary.version> + </properties> + </profile> + <profile> <id>spark-2.0</id> <modules> <module>spark/spark-2.0</module> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/bin/mvn-zinc ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/bin/mvn-zinc b/spark/spark-2.1/bin/mvn-zinc new file mode 100755 index 0000000..759b0a5 --- /dev/null +++ b/spark/spark-2.1/bin/mvn-zinc @@ -0,0 +1,99 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyed from commit 48682f6bf663e54cb63b7e95a4520d34b6fa890b in Apache Spark + +# Determine the current working directory +_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# Preserve the calling directory +_CALLING_DIR="$(pwd)" +# Options used during compilation +_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" + +# Installs any application tarball given a URL, the expected tarball name, +# and, optionally, a checkable binary path to determine if the binary has +# already been installed +## Arg1 - URL +## Arg2 - Tarball Name +## Arg3 - Checkable Binary +install_app() { + local remote_tarball="$1/$2" + local local_tarball="${_DIR}/$2" + local binary="${_DIR}/$3" + local curl_opts="--progress-bar -L" + local wget_opts="--progress=bar:force ${wget_opts}" + + if [ -z "$3" -o ! -f "$binary" ]; then + # check if we already have the tarball + # check if we have curl installed + # download application + [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ + echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 && \ + curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" + # if the file still doesn't exist, lets try `wget` and cross our fingers + [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ + echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 && \ + wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" + # if both were unsuccessful, exit + [ ! -f "${local_tarball}" ] && \ + echo -n "ERROR: Cannot download $2 with cURL or wget; " && \ + echo "please install manually and try again." && \ + exit 2 + cd "${_DIR}" && tar -xzf "$2" + rm -rf "$local_tarball" + fi +} + +# Install zinc under the bin/ folder +install_zinc() { + local zinc_path="zinc-0.3.9/bin/zinc" + [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + install_app \ + "http://downloads.typesafe.com/zinc/0.3.9" \ + "zinc-0.3.9.tgz" \ + "${zinc_path}" + ZINC_BIN="${_DIR}/${zinc_path}" +} + +# Setup healthy defaults for the Zinc port if none were provided from +# the environment +ZINC_PORT=${ZINC_PORT:-"3030"} + +# Install Zinc for the bin/ +install_zinc + +# Reset the current working directory +cd "${_CALLING_DIR}" + +# Now that zinc is ensured to be installed, check its status and, if its +# not running or just installed, start it +if [ ! -f "${ZINC_BIN}" ]; then + exit -1 +fi +if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then + export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} + "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} + "${ZINC_BIN}" -start -port ${ZINC_PORT} &>/dev/null +fi + +# Set any `mvn` options if not already present +export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} + +# Last, call the `mvn` command as usual +mvn -DzincPort=${ZINC_PORT} "$@" http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/extra-src/README ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/extra-src/README b/spark/spark-2.1/extra-src/README new file mode 100644 index 0000000..8b5d0cd --- /dev/null +++ b/spark/spark-2.1/extra-src/README @@ -0,0 +1 @@ +Copyed from spark master [commit 908e37bcc10132bb2aa7f80ae694a9df6e40f31a] http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/spark/spark-2.1/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala new file mode 100644 index 0000000..9e98948 --- /dev/null +++ b/spark/spark-2.1/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{InputStream, OutputStream} +import java.rmi.server.UID + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import com.google.common.base.Objects +import org.apache.avro.Schema +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils +import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector +import org.apache.hadoop.io.Writable +import org.apache.hive.com.esotericsoftware.kryo.Kryo +import org.apache.hive.com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.Utils + +private[hive] object HiveShim { + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + val UNLIMITED_DECIMAL_PRECISION = 38 + val UNLIMITED_DECIMAL_SCALE = 18 + val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro" + + /* + * This function in hive-0.13 become private, but we have to do this to walkaround hive bug + */ + private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { + val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") + val result: StringBuilder = new StringBuilder(old) + var first: Boolean = old.isEmpty + + for (col <- cols) { + if (first) { + first = false + } else { + result.append(',') + } + result.append(col) + } + conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) + } + + /* + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null + */ + def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { + if (ids != null) { + ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) + } + if (names != null) { + appendReadColumnNames(conf, names) + } + } + + /* + * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that + * is needed to initialize before serialization. + */ + def prepareWritable(w: Writable, serDeProps: Seq[(String, String)]): Writable = { + w match { + case w: AvroGenericRecordWritable => + w.setRecordReaderID(new UID()) + // In Hive 1.1, the record's schema may need to be initialized manually or a NPE will + // be thrown. + if (w.getFileSchema() == null) { + serDeProps + .find(_._1 == AvroSerdeUtils.AvroTableProperties.SCHEMA_LITERAL.getPropName()) + .foreach { kv => + w.setFileSchema(new Schema.Parser().parse(kv._2)) + } + } + case _ => + } + w + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } + } + + /** + * This class provides the UDF creation and also the UDF instance serialization and + * de-serialization cross process boundary. + * + * Detail discussion can be found at https://github.com/apache/spark/pull/3640 + * + * @param functionClassName UDF class name + * @param instance optional UDF instance which contains additional information (for macro) + */ + private[hive] case class HiveFunctionWrapper(var functionClassName: String, + private var instance: AnyRef = null) extends java.io.Externalizable { + + // for Serialization + def this() = this(null) + + override def hashCode(): Int = { + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + Objects.hashCode(functionClassName, instance.asInstanceOf[GenericUDFMacro].getBody()) + } else { + functionClassName.hashCode() + } + } + + override def equals(other: Any): Boolean = other match { + case a: HiveFunctionWrapper if functionClassName == a.functionClassName => + // In case of udf macro, check to make sure they point to the same underlying UDF + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + a.instance.asInstanceOf[GenericUDFMacro].getBody() == + instance.asInstanceOf[GenericUDFMacro].getBody() + } else { + true + } + case _ => false + } + + @transient + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp, clazz).asInstanceOf[T] + inp.close() + t + } + + @transient + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() + } + + def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) + .asInstanceOf[UDFType] + } + + def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) + } + + def writeExternal(out: java.io.ObjectOutput) { + // output the function name + out.writeUTF(functionClassName) + + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } + } + + def readExternal(in: java.io.ObjectInput) { + // read the function name + functionClassName = in.readUTF() + + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.readFully(functionInBytes) + + // deserialize the function object via Hive Utilities + instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), + Utils.getContextOrSparkClassLoader.loadClass(functionClassName)) + } + } + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = Utils.getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + // We cache the function if it's no the Simple UDF, + // as we always have to create new instance for Simple UDF + instance = func + } + func + } + } + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { + val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + f.setCompressCodec(w.compressCodec) + f.setCompressType(w.compressType) + f.setTableInfo(w.tableInfo) + f.setDestTableId(w.destTableId) + f + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) + extends Serializable with Logging { + var compressCodec: String = _ + var compressType: String = _ + var destTableId: Int = _ + + def setCompressed(compressed: Boolean) { + this.compressed = compressed + } + + def getDirName(): String = dir + + def setDestTableId(destTableId: Int) { + this.destTableId = destTableId + } + + def setTableInfo(tableInfo: TableDesc) { + this.tableInfo = tableInfo + } + + def setCompressCodec(intermediateCompressorCodec: String) { + compressCodec = intermediateCompressorCodec + } + + def setCompressType(intermediateCompressType: String) { + compressType = intermediateCompressType + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/pom.xml ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/pom.xml b/spark/spark-2.1/pom.xml new file mode 100644 index 0000000..a0f380f --- /dev/null +++ b/spark/spark-2.1/pom.xml @@ -0,0 +1,269 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + + <parent> + <groupId>io.github.myui</groupId> + <artifactId>hivemall</artifactId> + <version>0.4.2-rc.2</version> + <relativePath>../../pom.xml</relativePath> + </parent> + + <artifactId>hivemall-spark</artifactId> + <name>Hivemall on Spark 2.1</name> + <packaging>jar</packaging> + + <properties> + <PermGen>64m</PermGen> + <MaxPermGen>512m</MaxPermGen> + <CodeCacheSize>512m</CodeCacheSize> + <main.basedir>${project.parent.basedir}</main.basedir> + </properties> + + <dependencies> + <!-- hivemall dependencies --> + <dependency> + <groupId>io.github.myui</groupId> + <artifactId>hivemall-core</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>io.github.myui</groupId> + <artifactId>hivemall-xgboost</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>io.github.myui</groupId> + <artifactId>hivemall-spark-common</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + </dependency> + + <!-- third-party dependencies --> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-library</artifactId> + <version>${scala.version}</version> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-compress</artifactId> + <version>1.8</version> + <scope>compile</scope> + </dependency> + + <!-- other provided dependencies --> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-core_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-hive_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-streaming_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-mllib_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + + <!-- test dependencies --> + <dependency> + <groupId>io.github.myui</groupId> + <artifactId>hivemall-mixserv</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.xerial</groupId> + <artifactId>xerial-core</artifactId> + <version>3.2.3</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.binary.version}</artifactId> + <version>2.2.4</version> + <scope>test</scope> + </dependency> + </dependencies> + + <build> + <directory>target</directory> + <outputDirectory>target/classes</outputDirectory> + <finalName>${project.artifactId}-${spark.binary.version}_${scala.binary.version}-${project.version}</finalName> + <testOutputDirectory>target/test-classes</testOutputDirectory> + <plugins> + <!-- For incremental compilation --> + <plugin> + <groupId>net.alchim31.maven</groupId> + <artifactId>scala-maven-plugin</artifactId> + <version>3.2.2</version> + <executions> + <execution> + <id>scala-compile-first</id> + <phase>process-resources</phase> + <goals> + <goal>compile</goal> + </goals> + </execution> + <execution> + <id>scala-test-compile-first</id> + <phase>process-test-resources</phase> + <goals> + <goal>testCompile</goal> + </goals> + </execution> + </executions> + <configuration> + <scalaVersion>${scala.version}</scalaVersion> + <recompileMode>incremental</recompileMode> + <useZincServer>true</useZincServer> + <args> + <arg>-unchecked</arg> + <arg>-deprecation</arg> + <!-- TODO: To enable this option, we need to fix many wornings --> + <!-- <arg>-feature</arg> --> + </args> + <jvmArgs> + <jvmArg>-Xms1024m</jvmArg> + <jvmArg>-Xmx1024m</jvmArg> + <jvmArg>-XX:PermSize=${PermGen}</jvmArg> + <jvmArg>-XX:MaxPermSize=${MaxPermGen}</jvmArg> + <jvmArg>-XX:ReservedCodeCacheSize=${CodeCacheSize}</jvmArg> + </jvmArgs> + </configuration> + </plugin> + <!-- hivemall-spark_xx-xx.jar --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <version>2.5</version> + <configuration> + <finalName>${project.artifactId}-${spark.binary.version}_${scala.binary.version}-${project.version}</finalName> + <outputDirectory>${project.parent.build.directory}</outputDirectory> + </configuration> + </plugin> + <!-- hivemall-spark_xx-xx-with-dependencies.jar including minimum dependencies --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-shade-plugin</artifactId> + <version>2.3</version> + <executions> + <execution> + <id>jar-with-dependencies</id> + <phase>package</phase> + <goals> + <goal>shade</goal> + </goals> + <configuration> + <finalName>${project.artifactId}-${spark.binary.version}_${scala.binary.version}-${project.version}-with-dependencies</finalName> + <outputDirectory>${project.parent.build.directory}</outputDirectory> + <minimizeJar>false</minimizeJar> + <createDependencyReducedPom>false</createDependencyReducedPom> + <artifactSet> + <includes> + <include>io.github.myui:hivemall-core</include> + <include>io.github.myui:hivemall-xgboost</include> + <include>io.github.myui:hivemall-spark-common</include> + <include>com.github.haifengl:smile-core</include> + <include>com.github.haifengl:smile-math</include> + <include>com.github.haifengl:smile-data</include> + <include>ml.dmlc:xgboost4j</include> + <include>com.esotericsoftware.kryo:kryo</include> + </includes> + </artifactSet> + </configuration> + </execution> + </executions> + </plugin> + <!-- disable surefire because there is no java test --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-surefire-plugin</artifactId> + <version>2.7</version> + <configuration> + <skipTests>true</skipTests> + </configuration> + </plugin> + <!-- then, enable scalatest --> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <version>1.0</version> + <configuration> + <reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory> + <junitxml>.</junitxml> + <filereports>SparkTestSuite.txt</filereports> + <argLine>-ea -Xmx2g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=${CodeCacheSize}</argLine> + <stderr/> + <environmentVariables> + <SPARK_PREPEND_CLASSES>1</SPARK_PREPEND_CLASSES> + <SPARK_SCALA_VERSION>${scala.binary.version}</SPARK_SCALA_VERSION> + <SPARK_TESTING>1</SPARK_TESTING> + <JAVA_HOME>${env.JAVA_HOME}</JAVA_HOME> + </environmentVariables> + <systemProperties> + <log4j.configuration>file:src/test/resources/log4j.properties</log4j.configuration> + <derby.system.durability>test</derby.system.durability> + <java.awt.headless>true</java.awt.headless> + <java.io.tmpdir>${project.build.directory}/tmp</java.io.tmpdir> + <spark.testing>1</spark.testing> + <spark.ui.enabled>false</spark.ui.enabled> + <spark.ui.showConsoleProgress>false</spark.ui.showConsoleProgress> + <spark.unsafe.exceptionOnMemoryLeak>true</spark.unsafe.exceptionOnMemoryLeak> + <!-- Needed by sql/hive tests. --> + <test.src.tables>__not_used__</test.src.tables> + </systemProperties> + <tagsToExclude>${test.exclude.tags}</tagsToExclude> + </configuration> + <executions> + <execution> + <id>test</id> + <goals> + <goal>test</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> +</project> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/src/main/java/hivemall/xgboost/XGBoostOptions.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/java/hivemall/xgboost/XGBoostOptions.scala b/spark/spark-2.1/src/main/java/hivemall/xgboost/XGBoostOptions.scala new file mode 100644 index 0000000..48a773b --- /dev/null +++ b/spark/spark-2.1/src/main/java/hivemall/xgboost/XGBoostOptions.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.xgboost + +import scala.collection.mutable + +import org.apache.commons.cli.Options +import org.apache.spark.annotation.AlphaComponent + +/** + * :: AlphaComponent :: + * An utility class to generate a sequence of options used in XGBoost. + */ +@AlphaComponent +case class XGBoostOptions() { + private val params: mutable.Map[String, String] = mutable.Map.empty + private val options: Options = { + new XGBoostUDTF() { + def options(): Options = super.getOptions() + }.options() + } + + private def isValidKey(key: String): Boolean = { + // TODO: Is there another way to handle all the XGBoost options? + options.hasOption(key) || key == "num_class" + } + + def set(key: String, value: String): XGBoostOptions = { + require(isValidKey(key), s"non-existing key detected in XGBoost options: ${key}") + params.put(key, value) + this + } + + def help(): Unit = { + import scala.collection.JavaConversions._ + options.getOptions.map { case option => println(option) } + } + + override def toString(): String = { + params.map { case (key, value) => s"-$key $value" }.mkString(" ") + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java b/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java new file mode 100644 index 0000000..310d15e --- /dev/null +++ b/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.xgboost.classification; + +import java.util.UUID; + +import org.apache.hadoop.hive.ql.exec.Description; + +/** An alternative implementation of [[hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF]]. */ +@Description( + name = "train_xgboost_classifier", + value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" +) +public class XGBoostBinaryClassifierUDTFWrapper extends XGBoostBinaryClassifierUDTF { + private long sequence; + private long taskId; + + public XGBoostBinaryClassifierUDTFWrapper() { + this.sequence = 0L; + this.taskId = Thread.currentThread().getId(); + } + + @Override + protected String generateUniqueModelId() { + sequence++; + /** + * TODO: Check if it is unique over all tasks in executors of Spark. + */ + return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java b/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java new file mode 100644 index 0000000..81e6fe8 --- /dev/null +++ b/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.xgboost.classification; + +import java.util.UUID; + +import org.apache.hadoop.hive.ql.exec.Description; + +/** An alternative implementation of [[hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper]]. */ +@Description( + name = "train_multiclass_xgboost_classifier", + value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" +) +public class XGBoostMulticlassClassifierUDTFWrapper extends XGBoostMulticlassClassifierUDTF { + private long sequence; + private long taskId; + + public XGBoostMulticlassClassifierUDTFWrapper() { + this.sequence = 0L; + this.taskId = Thread.currentThread().getId(); + } + + @Override + protected String generateUniqueModelId() { + sequence++; + /** + * TODO: Check if it is unique over all tasks in executors of Spark. + */ + return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java b/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java new file mode 100644 index 0000000..b72e045 --- /dev/null +++ b/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.xgboost.regression; + +import java.util.UUID; + +import org.apache.hadoop.hive.ql.exec.Description; + +/** An alternative implementation of [[hivemall.xgboost.regression.XGBoostRegressionUDTF]]. */ +@Description( + name = "train_xgboost_regr", + value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" +) +public class XGBoostRegressionUDTFWrapper extends XGBoostRegressionUDTF { + private long sequence; + private long taskId; + + public XGBoostRegressionUDTFWrapper() { + this.sequence = 0L; + this.taskId = Thread.currentThread().getId(); + } + + @Override + protected String generateUniqueModelId() { + sequence++; + /** + * TODO: Check if it is unique over all tasks in executors of Spark. + */ + return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-2.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000..b49e20a --- /dev/null +++ b/spark/spark-2.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.source.XGBoostFileFormat http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/src/main/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/resources/log4j.properties b/spark/spark-2.1/src/main/resources/log4j.properties new file mode 100644 index 0000000..72bf5b6 --- /dev/null +++ b/spark/spark-2.1/src/main/resources/log4j.properties @@ -0,0 +1,12 @@ +# Set everything to be logged to the console +log4j.rootCategory=INFO, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.eclipse.jetty=INFO +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/src/main/scala/hivemall/tools/RegressionDatagen.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/hivemall/tools/RegressionDatagen.scala b/spark/spark-2.1/src/main/scala/hivemall/tools/RegressionDatagen.scala new file mode 100644 index 0000000..72a5c83 --- /dev/null +++ b/spark/spark-2.1/src/main/scala/hivemall/tools/RegressionDatagen.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools + +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.types._ + +object RegressionDatagen { + + /** + * Generate data for regression/classification. + * See [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]] + * for the details of arguments below. + */ + def exec(sc: SQLContext, + n_partitions: Int = 2, + min_examples: Int = 1000, + n_features: Int = 10, + n_dims: Int = 200, + seed: Int = 43, + dense: Boolean = false, + prob_one: Float = 0.6f, + sort: Boolean = false, + cl: Boolean = false): DataFrame = { + + require(n_partitions > 0, "Non-negative #n_partitions required.") + require(min_examples > 0, "Non-negative #min_examples required.") + require(n_features > 0, "Non-negative #n_features required.") + require(n_dims > 0, "Non-negative #n_dims required.") + + // Calculate #examples to generate in each partition + val n_examples = (min_examples + n_partitions - 1) / n_partitions + + val df = sc.createDataFrame( + sc.sparkContext.parallelize((0 until n_partitions).map(Row(_)), n_partitions), + StructType( + StructField("data", IntegerType, true) :: + Nil) + ) + import sc.implicits._ + df.lr_datagen( + lit(s"-n_examples $n_examples -n_features $n_features -n_dims $n_dims -prob_one $prob_one" + + (if (dense) " -dense" else "") + + (if (sort) " -sort" else "") + + (if (cl) " -cl" else "")) + ).select($"label".cast(DoubleType).as("label"), $"features") + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala new file mode 100644 index 0000000..491363d --- /dev/null +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.util.BoundedPriorityQueue + +case class EachTopK( + k: Int, + groupingExpression: Expression, + scoreExpression: Expression, + children: Seq[Attribute]) extends Generator with CodegenFallback { + type QueueType = (AnyRef, InternalRow) + + require(k != 0, "`k` must not have 0") + + private[this] lazy val scoreType = scoreExpression.dataType + private[this] lazy val scoreOrdering = { + val ordering = TypeUtils.getInterpretedOrdering(scoreType) + .asInstanceOf[Ordering[AnyRef]] + if (k > 0) { + ordering + } else { + ordering.reverse + } + } + private[this] lazy val reverseScoreOrdering = scoreOrdering.reverse + + private[this] val queue: BoundedPriorityQueue[QueueType] = { + new BoundedPriorityQueue(Math.abs(k))(new Ordering[QueueType] { + override def compare(x: QueueType, y: QueueType): Int = + scoreOrdering.compare(x._1, y._1) + }) + } + + lazy private[this] val groupingProjection: UnsafeProjection = + UnsafeProjection.create(groupingExpression :: Nil, children) + + lazy private[this] val scoreProjection: UnsafeProjection = + UnsafeProjection.create(scoreExpression :: Nil, children) + + // The grouping key of the current partition + private[this] var currentGroupingKey: UnsafeRow = _ + + override def checkInputDataTypes(): TypeCheckResult = { + if (!TypeCollection.Ordered.acceptsType(scoreExpression.dataType)) { + TypeCheckResult.TypeCheckFailure( + s"$scoreExpression must have a comparable type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def elementSchema: StructType = + StructType( + Seq(StructField("rank", IntegerType)) ++ + children.map(d => StructField(d.prettyName, d.dataType)) + ) + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val groupingKey = groupingProjection(input) + val ret = if (currentGroupingKey != groupingKey) { + val part = queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering) + .zipWithIndex.map { case ((_, row), index) => + new JoinedRow(InternalRow(1 + index), row) + } + currentGroupingKey = groupingKey.copy() + queue.clear() + part + } else { + Iterator.empty + } + queue += Tuple2(scoreProjection(input).get(0, scoreType), input.copy()) + ret + } + + override def terminate(): TraversableOnce[InternalRow] = { + if (queue.size > 0) { + val part = queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering) + .zipWithIndex.map { case ((_, row), index) => + new JoinedRow(InternalRow(1 + index), row) + } + queue.clear() + part + } else { + Iterator.empty + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3afb11b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala new file mode 100644 index 0000000..bdeff98 --- /dev/null +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.RelationalGroupedDataset +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.logical.Pivot +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper +import org.apache.spark.sql.types._ + +/** + * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * + * @groupname ensemble + * @groupname ftvec.trans + * @groupname evaluation + */ +final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) { + + /** + * @see hivemall.ensemble.bagging.VotedAvgUDAF + * @group ensemble + */ + def voted_avg(weight: String): DataFrame = { + // checkType(weight, NumericType) + val udaf = HiveUDAFFunction( + "voted_avg", + new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"), + Seq(weight).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.ensemble.bagging.WeightVotedAvgUDAF + * @group ensemble + */ + def weight_voted_avg(weight: String): DataFrame = { + // checkType(weight, NumericType) + val udaf = HiveUDAFFunction( + "weight_voted_avg", + new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"), + Seq(weight).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.ensemble.ArgminKLDistanceUDAF + * @group ensemble + */ + def argmin_kld(weight: String, conv: String): DataFrame = { + // checkType(weight, NumericType) + // checkType(conv, NumericType) + val udaf = HiveUDAFFunction( + "argmin_kld", + new HiveFunctionWrapper("hivemall.ensemble.ArgminKLDistanceUDAF"), + Seq(weight, conv).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.ensemble.MaxValueLabelUDAF" + * @group ensemble + */ + def max_label(score: String, label: String): DataFrame = { + // checkType(score, NumericType) + checkType(label, StringType) + val udaf = HiveUDAFFunction( + "max_label", + new HiveFunctionWrapper("hivemall.ensemble.MaxValueLabelUDAF"), + Seq(score, label).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.ensemble.MaxRowUDAF + * @group ensemble + */ + def maxrow(score: String, label: String): DataFrame = { + // checkType(score, NumericType) + checkType(label, StringType) + val udaf = HiveUDAFFunction( + "maxrow", + new HiveFunctionWrapper("hivemall.ensemble.MaxRowUDAF"), + Seq(score, label).map(df.col(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.smile.tools.RandomForestEnsembleUDAF + * @group ensemble + */ + def rf_ensemble(predict: String): DataFrame = { + // checkType(predict, NumericType) + val udaf = HiveUDAFFunction( + "rf_ensemble", + new HiveFunctionWrapper("hivemall.smile.tools.RandomForestEnsembleUDAF"), + Seq(predict).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.tools.matrix.TransposeAndDotUDAF + */ + def transpose_and_dot(X: String, Y: String): DataFrame = { + val udaf = HiveUDAFFunction( + "transpose_and_dot", + new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"), + Seq(X, Y).map(df.col(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Seq(Alias(udaf, udaf.prettyName)())) + } + + /** + * @see hivemall.ftvec.trans.OnehotEncodingUDAF + * @group ftvec.trans + */ + def onehot_encoding(cols: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "onehot_encoding", + new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"), + cols.map(df.col(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Seq(Alias(udaf, udaf.prettyName)())) + } + + /** + * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF + */ + def snr(X: String, Y: String): DataFrame = { + val udaf = HiveUDAFFunction( + "snr", + new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"), + Seq(X, Y).map(df.col(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Seq(Alias(udaf, udaf.prettyName)())) + } + + /** + * @see hivemall.evaluation.MeanAbsoluteErrorUDAF + * @group evaluation + */ + def mae(predict: String, target: String): DataFrame = { + checkType(predict, FloatType) + checkType(target, FloatType) + val udaf = HiveUDAFFunction( + "mae", + new HiveFunctionWrapper("hivemall.evaluation.MeanAbsoluteErrorUDAF"), + Seq(predict, target).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.evaluation.MeanSquareErrorUDAF + * @group evaluation + */ + def mse(predict: String, target: String): DataFrame = { + checkType(predict, FloatType) + checkType(target, FloatType) + val udaf = HiveUDAFFunction( + "mse", + new HiveFunctionWrapper("hivemall.evaluation.MeanSquaredErrorUDAF"), + Seq(predict, target).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.evaluation.RootMeanSquareErrorUDAF + * @group evaluation + */ + def rmse(predict: String, target: String): DataFrame = { + checkType(predict, FloatType) + checkType(target, FloatType) + val udaf = HiveUDAFFunction( + "rmse", + new HiveFunctionWrapper("hivemall.evaluation.RootMeanSquaredErrorUDAF"), + Seq(predict, target).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.evaluation.FMeasureUDAF + * @group evaluation + */ + def f1score(predict: String, target: String): DataFrame = { + // checkType(target, ArrayType(IntegerType)) + // checkType(predict, ArrayType(IntegerType)) + val udaf = HiveUDAFFunction( + "f1score", + new HiveFunctionWrapper("hivemall.evaluation.FMeasureUDAF"), + Seq(predict, target).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * [[RelationalGroupedDataset]] has the three values as private fields, so, to inject Hivemall + * aggregate functions, we fetch them via Java Reflections. + */ + private val df = getPrivateField[DataFrame]("org$apache$spark$sql$RelationalGroupedDataset$$df") + private val groupingExprs = getPrivateField[Seq[Expression]]("groupingExprs") + private val groupType = getPrivateField[RelationalGroupedDataset.GroupType]("groupType") + + private def getPrivateField[T](name: String): T = { + val field = groupBy.getClass.getDeclaredField(name) + field.setAccessible(true) + field.get(groupBy).asInstanceOf[T] + } + + private def toDF(aggExprs: Seq[Expression]): DataFrame = { + val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { + groupingExprs ++ aggExprs + } else { + aggExprs + } + + val aliasedAgg = aggregates.map(alias) + + groupType match { + case RelationalGroupedDataset.GroupByType => + Dataset.ofRows( + df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.RollupType => + Dataset.ofRows( + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.CubeType => + Dataset.ofRows( + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.PivotType(pivotCol, values) => + val aliasedGrps = groupingExprs.map(alias) + Dataset.ofRows( + df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + } + } + + private def alias(expr: Expression): NamedExpression = expr match { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyName)() + } + + private def checkType(colName: String, expected: DataType) = { + val dataType = df.resolve(colName).dataType + if (dataType != expected) { + throw new AnalysisException( + s""""$colName" must be $expected, however it is $dataType""") + } + } +} + +object HivemallGroupedDataset { + + /** + * Implicitly inject the [[HivemallGroupedDataset]] into [[RelationalGroupedDataset]]. + */ + implicit def relationalGroupedDatasetToHivemallOne( + groupBy: RelationalGroupedDataset): HivemallGroupedDataset = { + new HivemallGroupedDataset(groupBy) + } +}