Close #137: [HIVEMALL-179][SPARK] Support spark-v2.3
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/bd143146 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/bd143146 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/bd143146 Branch: refs/heads/master Commit: bd1431467db278b0a1c844a7186d14dc70db5a00 Parents: fc881c3 Author: Takeshi Yamamuro <[email protected]> Authored: Thu Mar 29 08:41:54 2018 +0900 Committer: Takeshi Yamamuro <[email protected]> Committed: Thu Mar 29 08:41:54 2018 +0900 ---------------------------------------------------------------------- bin/run_travis_tests.sh | 6 +- spark/pom.xml | 6 +- spark/spark-2.3/bin/mvn-zinc | 99 + spark/spark-2.3/extra-src/README.md | 20 + .../org/apache/spark/sql/hive/HiveShim.scala | 279 +++ spark/spark-2.3/pom.xml | 187 ++ .../java/hivemall/xgboost/XGBoostOptions.scala | 59 + ....apache.spark.sql.sources.DataSourceRegister | 1 + .../src/main/resources/log4j.properties | 29 + .../hivemall/tools/RegressionDatagen.scala | 67 + .../sql/catalyst/expressions/EachTopK.scala | 133 ++ .../sql/catalyst/plans/logical/JoinTopK.scala | 68 + .../utils/InternalRowPriorityQueue.scala | 76 + .../sql/execution/UserProvidedPlanner.scala | 83 + .../datasources/csv/csvExpressions.scala | 169 ++ .../joins/ShuffledHashJoinTopKExec.scala | 405 ++++ .../spark/sql/hive/HivemallGroupedDataset.scala | 636 +++++ .../org/apache/spark/sql/hive/HivemallOps.scala | 2247 ++++++++++++++++++ .../apache/spark/sql/hive/HivemallUtils.scala | 146 ++ .../sql/hive/internal/HivemallOpsImpl.scala | 79 + .../sql/hive/source/XGBoostFileFormat.scala | 163 ++ .../spark/streaming/HivemallStreamingOps.scala | 47 + .../src/test/resources/data/files/README.md | 22 + .../src/test/resources/data/files/complex.seq | 0 .../src/test/resources/data/files/episodes.avro | 0 .../src/test/resources/data/files/json.txt | 0 .../src/test/resources/data/files/kv1.txt | 0 .../src/test/resources/data/files/kv3.txt | 0 .../src/test/resources/log4j.properties | 24 + .../hivemall/mix/server/MixServerSuite.scala | 124 + .../hivemall/tools/RegressionDatagenSuite.scala | 33 + .../ml/feature/HivemallLabeledPointSuite.scala | 36 + .../benchmark/BenchmarkBaseAccessor.scala | 23 + .../apache/spark/sql/hive/HiveUdfSuite.scala | 161 ++ .../spark/sql/hive/HivemallOpsSuite.scala | 1393 +++++++++++ .../spark/sql/hive/ModelMixingSuite.scala | 286 +++ .../apache/spark/sql/hive/XGBoostSuite.scala | 151 ++ .../sql/hive/benchmark/MiscBenchmark.scala | 268 +++ .../hive/test/HivemallFeatureQueryTest.scala | 102 + .../apache/spark/sql/test/VectorQueryTest.scala | 89 + .../streaming/HivemallOpsWithFeatureSuite.scala | 155 ++ .../scala/org/apache/spark/test/TestUtils.scala | 65 + 42 files changed, 7932 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/bin/run_travis_tests.sh ---------------------------------------------------------------------- diff --git a/bin/run_travis_tests.sh b/bin/run_travis_tests.sh index d0ad8c2..9693520 100755 --- a/bin/run_travis_tests.sh +++ b/bin/run_travis_tests.sh @@ -35,14 +35,12 @@ cd $HIVEMALL_HOME/spark export MAVEN_OPTS="-XX:MaxPermSize=256m" -mvn -q scalastyle:check -Pspark-2.0 -pl spark-2.0 -am test -Dtest=none - -mvn -q scalastyle:check clean -Pspark-2.1 -pl spark-2.1 -am test -Dtest=none +mvn -q scalastyle:check -pl spark-2.0,spark-2.1 -am test # spark-2.2 runs on Java 8+ if [[ ! -z "$(java -version 2>&1 | grep 1.8)" ]]; then mvn -q scalastyle:check clean -Djava.source.version=1.8 -Djava.target.version=1.8 \ - -Pspark-2.2 -pl spark-2.2 -am test -Dtest=none + -pl spark-2.2,spark-2.3 -am test fi exit 0 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/pom.xml ---------------------------------------------------------------------- diff --git a/spark/pom.xml b/spark/pom.xml index 8279df1..27bb6db 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -35,6 +35,7 @@ <module>spark-2.0</module> <module>spark-2.1</module> <module>spark-2.2</module> + <module>spark-2.3</module> </modules> <properties> @@ -157,7 +158,10 @@ <include>org.apache.hivemall:hivemall-spark-common</include> <!-- hivemall-core --> <include>org.apache.hivemall:hivemall-core</include> - <include>io.netty:netty-all</include> + <!-- + Since `netty-all` is bundled in Spark, we don't need to include it here + <include>io.netty:netty-all</include> + --> <include>com.github.haifengl:smile-core</include> <include>com.github.haifengl:smile-math</include> <include>com.github.haifengl:smile-data</include> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/bin/mvn-zinc ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/bin/mvn-zinc b/spark/spark-2.3/bin/mvn-zinc new file mode 100755 index 0000000..759b0a5 --- /dev/null +++ b/spark/spark-2.3/bin/mvn-zinc @@ -0,0 +1,99 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyed from commit 48682f6bf663e54cb63b7e95a4520d34b6fa890b in Apache Spark + +# Determine the current working directory +_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# Preserve the calling directory +_CALLING_DIR="$(pwd)" +# Options used during compilation +_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" + +# Installs any application tarball given a URL, the expected tarball name, +# and, optionally, a checkable binary path to determine if the binary has +# already been installed +## Arg1 - URL +## Arg2 - Tarball Name +## Arg3 - Checkable Binary +install_app() { + local remote_tarball="$1/$2" + local local_tarball="${_DIR}/$2" + local binary="${_DIR}/$3" + local curl_opts="--progress-bar -L" + local wget_opts="--progress=bar:force ${wget_opts}" + + if [ -z "$3" -o ! -f "$binary" ]; then + # check if we already have the tarball + # check if we have curl installed + # download application + [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ + echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 && \ + curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" + # if the file still doesn't exist, lets try `wget` and cross our fingers + [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ + echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 && \ + wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" + # if both were unsuccessful, exit + [ ! -f "${local_tarball}" ] && \ + echo -n "ERROR: Cannot download $2 with cURL or wget; " && \ + echo "please install manually and try again." && \ + exit 2 + cd "${_DIR}" && tar -xzf "$2" + rm -rf "$local_tarball" + fi +} + +# Install zinc under the bin/ folder +install_zinc() { + local zinc_path="zinc-0.3.9/bin/zinc" + [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + install_app \ + "http://downloads.typesafe.com/zinc/0.3.9" \ + "zinc-0.3.9.tgz" \ + "${zinc_path}" + ZINC_BIN="${_DIR}/${zinc_path}" +} + +# Setup healthy defaults for the Zinc port if none were provided from +# the environment +ZINC_PORT=${ZINC_PORT:-"3030"} + +# Install Zinc for the bin/ +install_zinc + +# Reset the current working directory +cd "${_CALLING_DIR}" + +# Now that zinc is ensured to be installed, check its status and, if its +# not running or just installed, start it +if [ ! -f "${ZINC_BIN}" ]; then + exit -1 +fi +if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then + export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} + "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} + "${ZINC_BIN}" -start -port ${ZINC_PORT} &>/dev/null +fi + +# Set any `mvn` options if not already present +export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} + +# Last, call the `mvn` command as usual +mvn -DzincPort=${ZINC_PORT} "$@" http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/extra-src/README.md ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/extra-src/README.md b/spark/spark-2.3/extra-src/README.md new file mode 100644 index 0000000..0c622a2 --- /dev/null +++ b/spark/spark-2.3/extra-src/README.md @@ -0,0 +1,20 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +Copyed from the spark v2.3.0 release. http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/spark/spark-2.3/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala new file mode 100644 index 0000000..11afe1a --- /dev/null +++ b/spark/spark-2.3/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{InputStream, OutputStream} +import java.rmi.server.UID + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import com.google.common.base.Objects +import org.apache.avro.Schema +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils +import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector +import org.apache.hadoop.io.Writable +import org.apache.hive.com.esotericsoftware.kryo.Kryo +import org.apache.hive.com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.Utils + +private[hive] object HiveShim { + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + val UNLIMITED_DECIMAL_PRECISION = 38 + val UNLIMITED_DECIMAL_SCALE = 18 + val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro" + + /* + * This function in hive-0.13 become private, but we have to do this to work around hive bug + */ + private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { + val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") + val result: StringBuilder = new StringBuilder(old) + var first: Boolean = old.isEmpty + + for (col <- cols) { + if (first) { + first = false + } else { + result.append(',') + } + result.append(col) + } + conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) + } + + /* + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null + */ + def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { + if (ids != null) { + ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) + } + if (names != null) { + appendReadColumnNames(conf, names) + } + } + + /* + * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that + * is needed to initialize before serialization. + */ + def prepareWritable(w: Writable, serDeProps: Seq[(String, String)]): Writable = { + w match { + case w: AvroGenericRecordWritable => + w.setRecordReaderID(new UID()) + // In Hive 1.1, the record's schema may need to be initialized manually or a NPE will + // be thrown. + if (w.getFileSchema() == null) { + serDeProps + .find(_._1 == AvroSerdeUtils.AvroTableProperties.SCHEMA_LITERAL.getPropName()) + .foreach { kv => + w.setFileSchema(new Schema.Parser().parse(kv._2)) + } + } + case _ => + } + w + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } + } + + /** + * This class provides the UDF creation and also the UDF instance serialization and + * de-serialization cross process boundary. + * + * Detail discussion can be found at https://github.com/apache/spark/pull/3640 + * + * @param functionClassName UDF class name + * @param instance optional UDF instance which contains additional information (for macro) + */ + private[hive] case class HiveFunctionWrapper(var functionClassName: String, + private var instance: AnyRef = null) extends java.io.Externalizable { + + // for Serialization + def this() = this(null) + + override def hashCode(): Int = { + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + Objects.hashCode(functionClassName, instance.asInstanceOf[GenericUDFMacro].getBody()) + } else { + functionClassName.hashCode() + } + } + + override def equals(other: Any): Boolean = other match { + case a: HiveFunctionWrapper if functionClassName == a.functionClassName => + // In case of udf macro, check to make sure they point to the same underlying UDF + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + a.instance.asInstanceOf[GenericUDFMacro].getBody() == + instance.asInstanceOf[GenericUDFMacro].getBody() + } else { + true + } + case _ => false + } + + @transient + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp, clazz).asInstanceOf[T] + inp.close() + t + } + + @transient + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() + } + + def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) + .asInstanceOf[UDFType] + } + + def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) + } + + def writeExternal(out: java.io.ObjectOutput) { + // output the function name + out.writeUTF(functionClassName) + + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } + } + + def readExternal(in: java.io.ObjectInput) { + // read the function name + functionClassName = in.readUTF() + + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.readFully(functionInBytes) + + // deserialize the function object via Hive Utilities + instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), + Utils.getContextOrSparkClassLoader.loadClass(functionClassName)) + } + } + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = Utils.getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + // We cache the function if it's no the Simple UDF, + // as we always have to create new instance for Simple UDF + instance = func + } + func + } + } + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { + val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + f.setCompressCodec(w.compressCodec) + f.setCompressType(w.compressType) + f.setTableInfo(w.tableInfo) + f.setDestTableId(w.destTableId) + f + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) + extends Serializable with Logging { + var compressCodec: String = _ + var compressType: String = _ + var destTableId: Int = _ + + def setCompressed(compressed: Boolean) { + this.compressed = compressed + } + + def getDirName(): String = dir + + def setDestTableId(destTableId: Int) { + this.destTableId = destTableId + } + + def setTableInfo(tableInfo: TableDesc) { + this.tableInfo = tableInfo + } + + def setCompressCodec(intermediateCompressorCodec: String) { + compressCodec = intermediateCompressorCodec + } + + def setCompressType(intermediateCompressType: String) { + compressType = intermediateCompressType + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/pom.xml ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/pom.xml b/spark/spark-2.3/pom.xml new file mode 100644 index 0000000..cfa6457 --- /dev/null +++ b/spark/spark-2.3/pom.xml @@ -0,0 +1,187 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + + <parent> + <groupId>org.apache.hivemall</groupId> + <artifactId>hivemall-spark</artifactId> + <version>0.5.1-incubating-SNAPSHOT</version> + <relativePath>../pom.xml</relativePath> + </parent> + + <artifactId>hivemall-spark2.3</artifactId> + <name>Hivemall on Spark 2.3</name> + <packaging>jar</packaging> + + <properties> + <main.basedir>${project.parent.parent.basedir}</main.basedir> + <spark.version>2.3.0</spark.version> + <spark.binary.version>2.3</spark.binary.version> + <hadoop.version>2.6.5</hadoop.version> + <scalatest.jvm.opts>-ea -Xms768m -Xmx2g -XX:MetaspaceSize=128m -XX:MaxMetaspaceSize=512m -XX:ReservedCodeCacheSize=512m</scalatest.jvm.opts> + <maven.compiler.source>1.8</maven.compiler.source> + <maven.compiler.target>1.8</maven.compiler.target> + </properties> + + <dependencies> + <!-- compile scope --> + <dependency> + <groupId>org.apache.hivemall</groupId> + <artifactId>hivemall-core</artifactId> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>org.apache.hivemall</groupId> + <artifactId>hivemall-xgboost</artifactId> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>org.apache.hivemall</groupId> + <artifactId>hivemall-spark-common</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + </dependency> + + <!-- provided scope --> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-library</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-core_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-hive_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-streaming_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-mllib_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + + <!-- test dependencies --> + <dependency> + <groupId>org.apache.hivemall</groupId> + <artifactId>hivemall-mixserv</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-core_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-streaming_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-catalyst_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-hive_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + </dependencies> + + <build> + <plugins> + <!-- hivemall-spark_xx-xx-with-dependencies.jar including minimum dependencies --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-shade-plugin</artifactId> + </plugin> + <!-- disable surefire because there is no java test --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-surefire-plugin</artifactId> + <configuration> + <skipTests>true</skipTests> + </configuration> + </plugin> + <!-- then, enable scalatest --> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <executions> + <execution> + <id>test</id> + <goals> + <goal>test</goal> + </goals> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <configuration> + <environmentVariables> + <JAVA_HOME>${env.JAVA8_HOME}</JAVA_HOME> + <PATH>${env.JAVA8_HOME}/bin:${env.PATH}</PATH> + </environmentVariables> + </configuration> + </plugin> + </plugins> + </build> +</project> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/src/main/java/hivemall/xgboost/XGBoostOptions.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/src/main/java/hivemall/xgboost/XGBoostOptions.scala b/spark/spark-2.3/src/main/java/hivemall/xgboost/XGBoostOptions.scala new file mode 100644 index 0000000..3e0f274 --- /dev/null +++ b/spark/spark-2.3/src/main/java/hivemall/xgboost/XGBoostOptions.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package hivemall.xgboost + +import scala.collection.mutable + +import org.apache.commons.cli.Options +import org.apache.spark.annotation.AlphaComponent + +/** + * :: AlphaComponent :: + * An utility class to generate a sequence of options used in XGBoost. + */ +@AlphaComponent +case class XGBoostOptions() { + private val params: mutable.Map[String, String] = mutable.Map.empty + private val options: Options = { + new XGBoostUDTF() { + def options(): Options = super.getOptions() + }.options() + } + + private def isValidKey(key: String): Boolean = { + // TODO: Is there another way to handle all the XGBoost options? + options.hasOption(key) || key == "num_class" + } + + def set(key: String, value: String): XGBoostOptions = { + require(isValidKey(key), s"non-existing key detected in XGBoost options: ${key}") + params.put(key, value) + this + } + + def help(): Unit = { + import scala.collection.JavaConversions._ + options.getOptions.map { case option => println(option) } + } + + override def toString(): String = { + params.map { case (key, value) => s"-$key $value" }.mkString(" ") + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-2.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000..b49e20a --- /dev/null +++ b/spark/spark-2.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.source.XGBoostFileFormat http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/src/main/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/src/main/resources/log4j.properties b/spark/spark-2.3/src/main/resources/log4j.properties new file mode 100644 index 0000000..ef4f606 --- /dev/null +++ b/spark/spark-2.3/src/main/resources/log4j.properties @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the console +log4j.rootCategory=INFO, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.eclipse.jetty=INFO +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/src/main/scala/hivemall/tools/RegressionDatagen.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/src/main/scala/hivemall/tools/RegressionDatagen.scala b/spark/spark-2.3/src/main/scala/hivemall/tools/RegressionDatagen.scala new file mode 100644 index 0000000..a2b7f60 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/hivemall/tools/RegressionDatagen.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package hivemall.tools + +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.types._ + +object RegressionDatagen { + + /** + * Generate data for regression/classification. + * See [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]] + * for the details of arguments below. + */ + def exec(sc: SQLContext, + n_partitions: Int = 2, + min_examples: Int = 1000, + n_features: Int = 10, + n_dims: Int = 200, + seed: Int = 43, + dense: Boolean = false, + prob_one: Float = 0.6f, + sort: Boolean = false, + cl: Boolean = false): DataFrame = { + + require(n_partitions > 0, "Non-negative #n_partitions required.") + require(min_examples > 0, "Non-negative #min_examples required.") + require(n_features > 0, "Non-negative #n_features required.") + require(n_dims > 0, "Non-negative #n_dims required.") + + // Calculate #examples to generate in each partition + val n_examples = (min_examples + n_partitions - 1) / n_partitions + + val df = sc.createDataFrame( + sc.sparkContext.parallelize((0 until n_partitions).map(Row(_)), n_partitions), + StructType( + StructField("data", IntegerType, true) :: + Nil) + ) + import sc.implicits._ + df.lr_datagen( + lit(s"-n_examples $n_examples -n_features $n_features -n_dims $n_dims -prob_one $prob_one" + + (if (dense) " -dense" else "") + + (if (sort) " -sort" else "") + + (if (cl) " -cl" else "")) + ).select($"label".cast(DoubleType).as("label"), $"features") + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala new file mode 100644 index 0000000..cac2a5d --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue +import org.apache.spark.sql.types._ + +trait TopKHelper { + + def k: Int + def scoreType: DataType + + @transient val ScoreTypes = TypeCollection( + ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType + ) + + protected case class ScoreWriter(writer: UnsafeRowWriter, ordinal: Int) { + + def write(v: Any): Unit = scoreType match { + case ByteType => writer.write(ordinal, v.asInstanceOf[Byte]) + case ShortType => writer.write(ordinal, v.asInstanceOf[Short]) + case IntegerType => writer.write(ordinal, v.asInstanceOf[Int]) + case LongType => writer.write(ordinal, v.asInstanceOf[Long]) + case FloatType => writer.write(ordinal, v.asInstanceOf[Float]) + case DoubleType => writer.write(ordinal, v.asInstanceOf[Double]) + case d: DecimalType => writer.write(ordinal, v.asInstanceOf[Decimal], d.precision, d.scale) + } + } + + protected lazy val scoreOrdering = { + val ordering = TypeUtils.getInterpretedOrdering(scoreType) + if (k > 0) ordering else ordering.reverse + } + + protected lazy val reverseScoreOrdering = scoreOrdering.reverse + + protected lazy val queue: InternalRowPriorityQueue = { + new InternalRowPriorityQueue(Math.abs(k), (x: Any, y: Any) => scoreOrdering.compare(x, y)) + } +} + +case class EachTopK( + k: Int, + scoreExpr: Expression, + groupExprs: Seq[Expression], + elementSchema: StructType, + children: Seq[Attribute]) + extends Generator with TopKHelper with CodegenFallback { + + override val scoreType: DataType = scoreExpr.dataType + + private lazy val groupingProjection: UnsafeProjection = UnsafeProjection.create(groupExprs) + private lazy val scoreProjection: UnsafeProjection = UnsafeProjection.create(scoreExpr :: Nil) + + // The grouping key of the current partition + private var currentGroupingKeys: UnsafeRow = _ + + override def checkInputDataTypes(): TypeCheckResult = { + if (!ScoreTypes.acceptsType(scoreExpr.dataType)) { + TypeCheckResult.TypeCheckFailure(s"$scoreExpr must have a comparable type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + private def topKRowsForGroup(): Seq[InternalRow] = if (queue.size > 0) { + val outputRows = queue.iterator.toSeq.reverse + val (headScore, _) = outputRows.head + val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) => + if (prevScore == score) (rank, score) else (rank + 1, score) + } + val topKRow = new UnsafeRow(1) + val bufferHolder = new BufferHolder(topKRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + outputRows.zip(rankNum.map(_._1)).map { case ((_, row), index) => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, index) + topKRow.setTotalSize(bufferHolder.totalSize()) + new JoinedRow(topKRow, row) + } + } else { + Seq.empty + } + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val groupingKeys = groupingProjection(input) + val ret = if (currentGroupingKeys != groupingKeys) { + val topKRows = topKRowsForGroup() + currentGroupingKeys = groupingKeys.copy() + queue.clear() + topKRows + } else { + Iterator.empty + } + queue += Tuple2(scoreProjection(input).get(0, scoreType), input) + ret + } + + override def terminate(): TraversableOnce[InternalRow] = { + if (queue.size > 0) { + val topKRows = topKRowsForGroup() + queue.clear() + topKRows + } else { + Iterator.empty + } + } + + // TODO: Need to support codegen + // protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala new file mode 100644 index 0000000..556cdc3 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +case class JoinTopK( + k: Int, + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression])( + val scoreExpr: NamedExpression, + private[sql] val rankAttr: Seq[Attribute] = AttributeReference("rank", IntegerType)() :: Nil) + extends BinaryNode with PredicateHelper { + + override def output: Seq[Attribute] = joinType match { + case Inner => rankAttr ++ Seq(scoreExpr.toAttribute) ++ left.output ++ right.output + } + + override def references: AttributeSet = { + AttributeSet((expressions ++ Seq(scoreExpr)).flatMap(_.references)) + } + + override protected def validConstraints: Set[Expression] = joinType match { + case Inner if condition.isDefined => + left.constraints.union(right.constraints) + .union(splitConjunctivePredicates(condition.get).toSet) + } + + override protected final def otherCopyArgs: Seq[AnyRef] = { + scoreExpr :: rankAttr :: Nil + } + + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + + lazy val resolvedExceptNatural: Boolean = { + childrenResolved && + expressions.forall(_.resolved) && + duplicateResolved && + condition.forall(_.dataType == BooleanType) + } + + override lazy val resolved: Boolean = joinType match { + case Inner => resolvedExceptNatural + case tpe => throw new AnalysisException(s"Unsupported using join type $tpe") + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala new file mode 100644 index 0000000..12c20fb --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.utils + +import java.io.Serializable +import java.util.{PriorityQueue => JPriorityQueue} + +import scala.collection.JavaConverters._ +import scala.collection.generic.Growable + +import org.apache.spark.sql.catalyst.InternalRow + +private[sql] class InternalRowPriorityQueue( + maxSize: Int, + compareFunc: (Any, Any) => Int + ) extends Iterable[(Any, InternalRow)] with Growable[(Any, InternalRow)] with Serializable { + + private[this] val ordering = new Ordering[(Any, InternalRow)] { + override def compare(x: (Any, InternalRow), y: (Any, InternalRow)): Int = + compareFunc(x._1, y._1) + } + + private val underlying = new JPriorityQueue[(Any, InternalRow)](maxSize, ordering) + + override def iterator: Iterator[(Any, InternalRow)] = underlying.iterator.asScala + + override def size: Int = underlying.size + + override def ++=(xs: TraversableOnce[(Any, InternalRow)]): this.type = { + xs.foreach { this += _ } + this + } + + override def +=(elem: (Any, InternalRow)): this.type = { + if (size < maxSize) { + underlying.offer((elem._1, elem._2.copy())) + } else { + maybeReplaceLowest(elem) + } + this + } + + override def +=(elem1: (Any, InternalRow), elem2: (Any, InternalRow), elems: (Any, InternalRow)*) + : this.type = { + this += elem1 += elem2 ++= elems + } + + override def clear() { underlying.clear() } + + private def maybeReplaceLowest(a: (Any, InternalRow)): Boolean = { + val head = underlying.peek() + if (head != null && ordering.gt(a, head)) { + underlying.poll() + underlying.offer((a._1, a._2.copy())) + } else { + false + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala new file mode 100644 index 0000000..09d60a6 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical.{JoinTopK, LogicalPlan} +import org.apache.spark.sql.internal.SQLConf + +private object ExtractJoinTopKKeys extends Logging with PredicateHelper { + /** (k, scoreExpr, joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ + type ReturnType = + (Int, NamedExpression, Seq[Attribute], JoinType, Seq[Expression], Seq[Expression], + Option[Expression], LogicalPlan, LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case join @ JoinTopK(k, left, right, joinType, condition) => + logDebug(s"Considering join on: $condition") + val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) + val joinKeys = predicates.flatMap { + case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r)) + case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l)) + // Replace null with default value for joining key, then those rows with null in it could + // be joined together + case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => + Some((Coalesce(Seq(l, Literal.default(l.dataType))), + Coalesce(Seq(r, Literal.default(r.dataType))))) + case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => + Some((Coalesce(Seq(r, Literal.default(r.dataType))), + Coalesce(Seq(l, Literal.default(l.dataType))))) + case other => None + } + val otherPredicates = predicates.filterNot { + case EqualTo(l, r) => + canEvaluate(l, left) && canEvaluate(r, right) || + canEvaluate(l, right) && canEvaluate(r, left) + case other => false + } + + if (joinKeys.nonEmpty) { + val (leftKeys, rightKeys) = joinKeys.unzip + logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") + Some((k, join.scoreExpr, join.rankAttr, joinType, leftKeys, rightKeys, + otherPredicates.reduceOption(And), left, right)) + } else { + None + } + + case p => + None + } +} + +private[sql] class UserProvidedPlanner(val conf: SQLConf) extends Strategy { + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractJoinTopKKeys( + k, scoreExpr, rankAttr, _, leftKeys, rightKeys, condition, left, right) => + Seq(joins.ShuffledHashJoinTopKExec( + k, leftKeys, rightKeys, condition, planLater(left), planLater(right))(scoreExpr, rankAttr)) + case _ => + Nil + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala new file mode 100644 index 0000000..1f56c90 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import com.univocity.parsers.csv.CsvWriter + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Converts a csv input string to a [[StructType]] with the specified schema. + * + * TODO: Move this class into org.apache.spark.sql.catalyst.expressions in Spark-v2.2+ + */ +case class CsvToStruct( + schema: StructType, + options: Map[String, String], + child: Expression, + timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + + def this(schema: StructType, options: Map[String, String], child: Expression) = + this(schema, options, child, None) + + override def nullable: Boolean = true + + @transient private lazy val csvOptions = new CSVOptions(options, timeZoneId.get) + @transient private lazy val csvParser = new UnivocityParser(schema, schema, csvOptions) + + private def parse(input: String): InternalRow = csvParser.parse(input) + + override def dataType: DataType = schema + + override def nullSafeEval(csv: Any): Any = { + try parse(csv.toString) catch { case _: RuntimeException => null } + } + + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) +} + +private class CsvGenerator(schema: StructType, options: CSVOptions) { + + // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. + // When the value is null, this converter should not be called. + private type ValueConverter = (InternalRow, Int) => String + + // `ValueConverter`s for all values in the fields of the schema + private val valueConverters: Array[ValueConverter] = + schema.map(_.dataType).map(makeConverter).toArray + + private def makeConverter(dataType: DataType): ValueConverter = dataType match { + case DateType => + (row: InternalRow, ordinal: Int) => + options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + + case udt: UserDefinedType[_] => makeConverter(udt.sqlType) + + case dt: DataType => + (row: InternalRow, ordinal: Int) => + row.get(ordinal, dt).toString + } + + def convertRow(row: InternalRow): Seq[String] = { + var i = 0 + val values = new Array[String](row.numFields) + while (i < row.numFields) { + if (!row.isNullAt(i)) { + values(i) = valueConverters(i).apply(row, i) + } else { + values(i) = options.nullValue + } + i += 1 + } + values + } +} + +/** + * Converts a [[StructType]] to a csv output string. + */ +case class StructToCsv( + options: Map[String, String], + child: Expression, + timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient + private lazy val params = new CSVOptions(options, timeZoneId.get) + + @transient + private lazy val dataSchema = child.dataType.asInstanceOf[StructType] + + @transient + private lazy val writer = new CsvGenerator(dataSchema, params) + + override def dataType: DataType = StringType + + private def verifySchema(schema: StructType): Unit = { + def verifyType(dataType: DataType): Unit = dataType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | + DoubleType | BooleanType | _: DecimalType | TimestampType | + DateType | StringType => + + case udt: UserDefinedType[_] => verifyType(udt.sqlType) + + case _ => + throw new UnsupportedOperationException( + s"CSV data source does not support ${dataType.simpleString} data type.") + } + + schema.foreach(field => verifyType(field.dataType)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (StructType.acceptsType(child.dataType)) { + try { + verifySchema(child.dataType.asInstanceOf[StructType]) + TypeCheckResult.TypeCheckSuccess + } catch { + case e: UnsupportedOperationException => + TypeCheckResult.TypeCheckFailure(e.getMessage) + } + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName requires that the expression is a struct expression.") + } + } + + override def nullSafeEval(row: Any): Any = { + val rowStr = writer.convertRow(row.asInstanceOf[InternalRow]) + .mkString(params.delimiter.toString) + UTF8String.fromString(rowStr) + } + + override def inputTypes: Seq[AbstractDataType] = StructType :: Nil + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bd143146/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala new file mode 100644 index 0000000..f628b78 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.types._ + +abstract class PriorityQueueShim { + + def insert(score: Any, row: InternalRow): Unit + def get(): Iterator[InternalRow] + def clear(): Unit +} + +case class ShuffledHashJoinTopKExec( + k: Int, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan)( + scoreExpr: NamedExpression, + rankAttr: Seq[Attribute]) + extends BinaryExecNode with TopKHelper with HashJoin with CodegenSupport { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override val scoreType: DataType = scoreExpr.dataType + override val joinType: JoinType = Inner + override val buildSide: BuildSide = BuildRight // Only support `BuildRight` + + private lazy val scoreProjection: UnsafeProjection = + UnsafeProjection.create(scoreExpr :: Nil, left.output ++ right.output) + + private lazy val boundCondition = if (condition.isDefined) { + (r: InternalRow) => newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval(r) + } else { + (r: InternalRow) => true + } + + private lazy val topKAttr = rankAttr :+ scoreExpr.toAttribute + + private lazy val _priorityQueue = new PriorityQueueShim { + + private val q: InternalRowPriorityQueue = queue + private val joinedRow = new JoinedRow + + override def insert(score: Any, row: InternalRow): Unit = { + q += Tuple2(score, row) + } + + override def get(): Iterator[InternalRow] = { + val outputRows = queue.iterator.toSeq.reverse + val (headScore, _) = outputRows.head + val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) => + if (prevScore == score) (rank, score) else (rank + 1, score) + } + val topKRow = new UnsafeRow(2) + val bufferHolder = new BufferHolder(topKRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2) + val scoreWriter = ScoreWriter(unsafeRowWriter, 1) + outputRows.zip(rankNum.map(_._1)).map { case ((score, row), index) => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, index) + scoreWriter.write(score) + topKRow.setTotalSize(bufferHolder.totalSize()) + joinedRow.apply(topKRow, row) + }.iterator + } + + override def clear(): Unit = q.clear() + } + + override def output: Seq[Attribute] = joinType match { + case Inner => topKAttr ++ left.output ++ right.output + } + + override protected final def otherCopyArgs: Seq[AnyRef] = { + scoreExpr :: rankAttr :: Nil + } + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { + val context = TaskContext.get() + val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) + context.addTaskCompletionListener(_ => relation.close()) + relation + } + + override protected def createResultProjection(): (InternalRow) => InternalRow = joinType match { + case Inner => + // Always put the stream side on left to simplify implementation + // both of left and right side could be null + UnsafeProjection.create( + output, (topKAttr ++ streamedPlan.output ++ buildPlan.output).map(_.withNullability(true))) + } + + protected def InnerJoin( + streamedIter: Iterator[InternalRow], + hashedRelation: HashedRelation, + numOutputRows: SQLMetric): Iterator[InternalRow] = { + val joinRow = new JoinedRow + val joinKeysProj = streamSideKeyGenerator() + val joinedIter = streamedIter.flatMap { srow => + joinRow.withLeft(srow) + val joinKeys = joinKeysProj(srow) // `joinKeys` is also a grouping key + val matches = hashedRelation.get(joinKeys) + if (matches != null) { + matches.map(joinRow.withRight).filter(boundCondition).foreach { resultRow => + _priorityQueue.insert(scoreProjection(resultRow).get(0, scoreType), resultRow) + } + val iter = _priorityQueue.get() + _priorityQueue.clear() + iter + } else { + Seq.empty + } + } + val resultProj = createResultProjection() + (joinedIter ++ queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering) + .map(_._2)).map { r => + resultProj(r) + } + } + + override protected def doExecute(): RDD[InternalRow] = { + streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => + val hashed = buildHashedRelation(buildIter) + InnerJoin(streamIter, hashed, null) + } + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + left.execute() :: right.execute() :: Nil + } + + // Accessor for generated code + def priorityQueue(): PriorityQueueShim = _priorityQueue + + /** + * Add a state of HashedRelation and return the variable name for it. + */ + private def prepareHashedRelation(ctx: CodegenContext): String = { + // create a name for HashedRelation + val joinExec = ctx.addReferenceObj("joinExec", this) + val relationTerm = ctx.freshName("relation") + val clsName = HashedRelation.getClass.getName.replace("$", "") + ctx.addMutableState(clsName, relationTerm, + v => s""" + | $v = ($clsName) $joinExec.buildHashedRelation(inputs[1]); + | incPeakExecutionMemory($v.estimatedSize()); + """.stripMargin) + relationTerm + } + + /** + * Creates variables for left part of result row. + * + * In order to defer the access after condition and also only access once in the loop, + * the variables should be declared separately from accessing the columns, we can't use the + * codegen of BoundReference here. + */ + private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = leftRow + left.output.zipWithIndex.map { case (a, i) => + val value = ctx.freshName("value") + val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) + // declare it as class member, so we can access the column before or in the loop. + ctx.addMutableState(ctx.javaType(a.dataType), value, _ => "") + if (a.nullable) { + val isNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", isNull, _ => "") + val code = + s""" + |$isNull = $leftRow.isNullAt($i); + |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$value = $valueCode;", "false", value) + } + } + } + + /** + * Creates the variables for right part of result row, using BoundReference, since the right + * part are accessed inside the loop. + */ + private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = rightRow + right.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + } + + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + private def genStreamSideJoinKey(ctx: CodegenContext, leftRow: String): (ExprCode, String) = { + ctx.INPUT_ROW = leftRow + if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { + // generate the join key as Long + val ev = streamedKeys.head.genCode(ctx) + (ev, ev.isNull) + } else { + // generate the join key as UnsafeRow + val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) + (ev, s"${ev.value}.anyNull()") + } + } + + private def createScoreVar(ctx: CodegenContext, row: String): ExprCode = { + ctx.INPUT_ROW = row + BindReferences.bindReference(scoreExpr, left.output ++ right.output).genCode(ctx) + } + + private def createResultVars(ctx: CodegenContext, resultRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = resultRow + output.zipWithIndex.map { case (a, i) => + val value = ctx.freshName("value") + val valueCode = ctx.getValue(resultRow, a.dataType, i.toString) + // declare it as class member, so we can access the column before or in the loop. + ctx.addMutableState(ctx.javaType(a.dataType), value, _ => "") + if (a.nullable) { + val isNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", isNull, _ => "") + val code = + s""" + |$isNull = $resultRow.isNullAt($i); + |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$value = $valueCode;", "false", value) + } + } + } + + /** + * Splits variables based on whether it's used by condition or not, returns the code to create + * these variables before the condition and after the condition. + * + * Only a few columns are used by condition, then we can skip the accessing of those columns + * that are not used by condition also filtered out by condition. + */ + private def splitVarsByCondition( + attributes: Seq[Attribute], + variables: Seq[ExprCode]): (String, String) = { + if (condition.isDefined) { + val condRefs = condition.get.references + val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => + condRefs.contains(a) + } + val beforeCond = evaluateVariables(used.map(_._2)) + val afterCond = evaluateVariables(notUsed.map(_._2)) + (beforeCond, afterCond) + } else { + (evaluateVariables(variables), "") + } + } + + override def needCopyResult: Boolean = true + + override def doProduce(ctx: CodegenContext): String = { + val topKJoin = ctx.addReferenceObj("topKJoin", this) + + // Prepare a priority queue for top-K computing + val pQueue = ctx.freshName("queue") + ctx.addMutableState(classOf[PriorityQueueShim].getName, pQueue, + v => s"$v= $topKJoin.priorityQueue();") + + // Prepare variables for a left side + val leftIter = ctx.freshName("leftIter") + ctx.addMutableState("scala.collection.Iterator", leftIter, v => s"$v = inputs[0];") + val leftRow = ctx.freshName("leftRow") + ctx.addMutableState("InternalRow", leftRow, v => "") + val leftVars = createLeftVars(ctx, leftRow) + + // Prepare variables for a right side + val rightRow = ctx.freshName("rightRow") + val rightVars = createRightVar(ctx, rightRow) + + // Build a hashed relation from a right side + val buildRelation = prepareHashedRelation(ctx) + + // Project join keys from a left side + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, leftRow) + + // Prepare variables for joined rows + val joinedRow = ctx.freshName("joinedRow") + val joinedRowCls = classOf[JoinedRow].getName + ctx.addMutableState(joinedRowCls, joinedRow, v => s"$v = new $joinedRowCls();") + + // Project score values from joined rows + val scoreVar = createScoreVar(ctx, joinedRow) + + // Prepare variables for output rows + val resultRow = ctx.freshName("resultRow") + val resultVars = createResultVars(ctx, resultRow) + + val (beforeLoop, condCheck) = if (condition.isDefined) { + // Split the code of creating variables based on whether it's used by condition or not. + val loaded = ctx.freshName("loaded") + val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) + val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + // Generate code for condition + ctx.currentVars = leftVars ++ rightVars + val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) + // evaluate the columns those used by condition before loop + val before = s""" + |boolean $loaded = false; + |$leftBefore + """.stripMargin + + val checking = s""" + |$rightBefore + |${cond.code} + |if (${cond.isNull} || !${cond.value}) continue; + |if (!$loaded) { + | $loaded = true; + | $leftAfter + |} + |$rightAfter + """.stripMargin + (before, checking) + } else { + (evaluateVariables(leftVars), "") + } + + val numOutput = metricTerm(ctx, "numOutputRows") + + val matches = ctx.freshName("matches") + val topKRows = ctx.freshName("topKRows") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + + s""" + |$leftRow = null; + |while ($leftIter.hasNext()) { + | $leftRow = (InternalRow) $leftIter.next(); + | + | // Generate join key for stream side + | ${keyEv.code} + | + | // Find matches from HashedRelation + | $iteratorCls $matches = $anyNull? null : ($iteratorCls)$buildRelation.get(${keyEv.value}); + | if ($matches == null) continue; + | + | // Join top-K right rows + | while ($matches.hasNext()) { + | ${beforeLoop.trim} + | InternalRow $rightRow = (InternalRow) $matches.next(); + | ${condCheck.trim} + | InternalRow row = $joinedRow.apply($leftRow, $rightRow); + | // Compute a score for the `row` + | ${scoreVar.code} + | $pQueue.insert(${scoreVar.value}, row); + | } + | + | // Get top-K rows + | $iteratorCls $topKRows = $pQueue.get(); + | $pQueue.clear(); + | + | // Output top-K rows + | while ($topKRows.hasNext()) { + | InternalRow $resultRow = (InternalRow) $topKRows.next(); + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | + | if (shouldStop()) return; + |} + """.stripMargin + } +}
