Close #122: [HIVEMALL-133][SPARK] Support spark-v2.2 in the hivemalls-spark module
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/8bf6dd9e Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/8bf6dd9e Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/8bf6dd9e Branch: refs/heads/master Commit: 8bf6dd9e760b1d4bfdf9046fdf09e62f46f97d37 Parents: 688daa5 Author: Takeshi Yamamuro <[email protected]> Authored: Wed Sep 13 21:18:06 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Wed Sep 13 21:18:06 2017 +0900 ---------------------------------------------------------------------- .travis.yml | 4 +- bin/run_travis_tests.sh | 47 + pom.xml | 44 + spark/spark-2.2/bin/mvn-zinc | 99 ++ spark/spark-2.2/extra-src/README.md | 1 + .../org/apache/spark/sql/hive/HiveShim.scala | 279 ++++ spark/spark-2.2/pom.xml | 269 +++ .../java/hivemall/xgboost/XGBoostOptions.scala | 59 + ....apache.spark.sql.sources.DataSourceRegister | 1 + .../src/main/resources/log4j.properties | 12 + .../hivemall/tools/RegressionDatagen.scala | 67 + .../sql/catalyst/expressions/EachTopK.scala | 133 ++ .../sql/catalyst/plans/logical/JoinTopK.scala | 68 + .../utils/InternalRowPriorityQueue.scala | 76 + .../sql/execution/UserProvidedPlanner.scala | 83 + .../datasources/csv/csvExpressions.scala | 169 ++ .../joins/ShuffledHashJoinTopKExec.scala | 405 +++++ .../spark/sql/hive/HivemallGroupedDataset.scala | 304 ++++ .../org/apache/spark/sql/hive/HivemallOps.scala | 1538 ++++++++++++++++++ .../apache/spark/sql/hive/HivemallUtils.scala | 146 ++ .../sql/hive/internal/HivemallOpsImpl.scala | 79 + .../sql/hive/source/XGBoostFileFormat.scala | 163 ++ .../src/test/resources/data/files/README.md | 3 + .../src/test/resources/data/files/complex.seq | 0 .../src/test/resources/data/files/episodes.avro | 0 .../src/test/resources/data/files/json.txt | 0 .../src/test/resources/data/files/kv1.txt | 0 .../src/test/resources/data/files/kv3.txt | 0 .../src/test/resources/log4j.properties | 7 + .../hivemall/mix/server/MixServerSuite.scala | 124 ++ .../hivemall/tools/RegressionDatagenSuite.scala | 33 + .../scala/org/apache/spark/SparkFunSuite.scala | 51 + .../ml/feature/HivemallLabeledPointSuite.scala | 36 + .../scala/org/apache/spark/sql/QueryTest.scala | 360 ++++ .../spark/sql/catalyst/plans/PlanTest.scala | 137 ++ .../sql/execution/benchmark/BenchmarkBase.scala | 56 + .../apache/spark/sql/hive/HiveUdfSuite.scala | 161 ++ .../spark/sql/hive/HivemallOpsSuite.scala | 961 +++++++++++ .../spark/sql/hive/ModelMixingSuite.scala | 286 ++++ .../apache/spark/sql/hive/XGBoostSuite.scala | 151 ++ .../sql/hive/benchmark/MiscBenchmark.scala | 268 +++ .../hive/test/HivemallFeatureQueryTest.scala | 113 ++ .../spark/sql/hive/test/TestHiveSingleton.scala | 39 + .../org/apache/spark/sql/test/SQLTestData.scala | 315 ++++ .../apache/spark/sql/test/SQLTestUtils.scala | 336 ++++ .../apache/spark/sql/test/VectorQueryTest.scala | 89 + .../streaming/HivemallOpsWithFeatureSuite.scala | 155 ++ .../scala/org/apache/spark/test/TestUtils.scala | 65 + 48 files changed, 7789 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/.travis.yml ---------------------------------------------------------------------- diff --git a/.travis.yml b/.travis.yml index 96f8f4e..c64c5ff 100644 --- a/.travis.yml +++ b/.travis.yml @@ -34,9 +34,7 @@ notifications: email: false script: - - mvn -q scalastyle:check test -Pspark-2.1 - # test the spark-2.0 modules only in the following runs - - mvn -q scalastyle:check clean -Pspark-2.0 -pl spark/spark-2.0 -am test -Dtest=none + - ./bin/run_travis_tests.sh after_success: - mvn clean cobertura:cobertura coveralls:report http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/bin/run_travis_tests.sh ---------------------------------------------------------------------- diff --git a/bin/run_travis_tests.sh b/bin/run_travis_tests.sh new file mode 100755 index 0000000..f1bffec --- /dev/null +++ b/bin/run_travis_tests.sh @@ -0,0 +1,47 @@ +#!/bin/sh +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +if [ "$HIVEMALL_HOME" = "" ]; then + if [ -e ../bin/${0##*/} ]; then + HIVEMALL_HOME=".." + elif [ -e ./bin/${0##*/} ]; then + HIVEMALL_HOME="." + else + echo "env HIVEMALL_HOME not defined" + exit 1 + fi +fi + +set -ev + +cd $HIVEMALL_HOME + +mvn -q scalastyle:check test -Pspark-2.1 + +# Tests the spark-2.2/spark-2.0 modules only in the following runs +if [[ ! -z "$(java -version 2>&1 | grep 1.8)" ]]; then + mvn -q scalastyle:check clean -Djava.source.version=1.8 -Djava.target.version=1.8 \ + -Pspark-2.2 -pl spark/spark-2.2 -am test -Dtest=none +fi + +mvn -q scalastyle:check clean -Pspark-2.0 -pl spark/spark-2.0 -am test -Dtest=none + +exit 0 + http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index 3d7040c..8a543e6 100644 --- a/pom.xml +++ b/pom.xml @@ -267,6 +267,50 @@ <profiles> <profile> + <id>spark-2.2</id> + <modules> + <module>spark/spark-2.2</module> + <module>spark/spark-common</module> + </modules> + <properties> + <spark.version>2.2.0</spark.version> + <spark.binary.version>2.2</spark.binary.version> + </properties> + <build> + <plugins> + <!-- Spark-2.2 only supports Java 8 --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-enforcer-plugin</artifactId> + <version>1.4.1</version> + <executions> + <execution> + <id>enforce-versions</id> + <phase>validate</phase> + <goals> + <goal>enforce</goal> + </goals> + <configuration> + <rules> + <requireProperty> + <property>java.source.version</property> + <regex>1.8</regex> + <regexMessage>When -Pspark-2.2 set, java.source.version must be 1.8</regexMessage> + </requireProperty> + <requireProperty> + <property>java.target.version</property> + <regex>1.8</regex> + <regexMessage>When -Pspark-2.2 set, java.target.version must be 1.8</regexMessage> + </requireProperty> + </rules> + </configuration> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> + <profile> <id>spark-2.1</id> <modules> <module>spark/spark-2.1</module> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/bin/mvn-zinc ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/bin/mvn-zinc b/spark/spark-2.2/bin/mvn-zinc new file mode 100755 index 0000000..759b0a5 --- /dev/null +++ b/spark/spark-2.2/bin/mvn-zinc @@ -0,0 +1,99 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyed from commit 48682f6bf663e54cb63b7e95a4520d34b6fa890b in Apache Spark + +# Determine the current working directory +_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# Preserve the calling directory +_CALLING_DIR="$(pwd)" +# Options used during compilation +_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" + +# Installs any application tarball given a URL, the expected tarball name, +# and, optionally, a checkable binary path to determine if the binary has +# already been installed +## Arg1 - URL +## Arg2 - Tarball Name +## Arg3 - Checkable Binary +install_app() { + local remote_tarball="$1/$2" + local local_tarball="${_DIR}/$2" + local binary="${_DIR}/$3" + local curl_opts="--progress-bar -L" + local wget_opts="--progress=bar:force ${wget_opts}" + + if [ -z "$3" -o ! -f "$binary" ]; then + # check if we already have the tarball + # check if we have curl installed + # download application + [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ + echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 && \ + curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" + # if the file still doesn't exist, lets try `wget` and cross our fingers + [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ + echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 && \ + wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" + # if both were unsuccessful, exit + [ ! -f "${local_tarball}" ] && \ + echo -n "ERROR: Cannot download $2 with cURL or wget; " && \ + echo "please install manually and try again." && \ + exit 2 + cd "${_DIR}" && tar -xzf "$2" + rm -rf "$local_tarball" + fi +} + +# Install zinc under the bin/ folder +install_zinc() { + local zinc_path="zinc-0.3.9/bin/zinc" + [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + install_app \ + "http://downloads.typesafe.com/zinc/0.3.9" \ + "zinc-0.3.9.tgz" \ + "${zinc_path}" + ZINC_BIN="${_DIR}/${zinc_path}" +} + +# Setup healthy defaults for the Zinc port if none were provided from +# the environment +ZINC_PORT=${ZINC_PORT:-"3030"} + +# Install Zinc for the bin/ +install_zinc + +# Reset the current working directory +cd "${_CALLING_DIR}" + +# Now that zinc is ensured to be installed, check its status and, if its +# not running or just installed, start it +if [ ! -f "${ZINC_BIN}" ]; then + exit -1 +fi +if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then + export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} + "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} + "${ZINC_BIN}" -start -port ${ZINC_PORT} &>/dev/null +fi + +# Set any `mvn` options if not already present +export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} + +# Last, call the `mvn` command as usual +mvn -DzincPort=${ZINC_PORT} "$@" http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/extra-src/README.md ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/extra-src/README.md b/spark/spark-2.2/extra-src/README.md new file mode 100644 index 0000000..1d89d0a --- /dev/null +++ b/spark/spark-2.2/extra-src/README.md @@ -0,0 +1 @@ +Copyed from the spark v2.2.0 release. http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/spark/spark-2.2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala new file mode 100644 index 0000000..9e98948 --- /dev/null +++ b/spark/spark-2.2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{InputStream, OutputStream} +import java.rmi.server.UID + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import com.google.common.base.Objects +import org.apache.avro.Schema +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils +import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector +import org.apache.hadoop.io.Writable +import org.apache.hive.com.esotericsoftware.kryo.Kryo +import org.apache.hive.com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.Utils + +private[hive] object HiveShim { + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + val UNLIMITED_DECIMAL_PRECISION = 38 + val UNLIMITED_DECIMAL_SCALE = 18 + val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro" + + /* + * This function in hive-0.13 become private, but we have to do this to walkaround hive bug + */ + private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { + val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") + val result: StringBuilder = new StringBuilder(old) + var first: Boolean = old.isEmpty + + for (col <- cols) { + if (first) { + first = false + } else { + result.append(',') + } + result.append(col) + } + conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) + } + + /* + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null + */ + def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { + if (ids != null) { + ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) + } + if (names != null) { + appendReadColumnNames(conf, names) + } + } + + /* + * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that + * is needed to initialize before serialization. + */ + def prepareWritable(w: Writable, serDeProps: Seq[(String, String)]): Writable = { + w match { + case w: AvroGenericRecordWritable => + w.setRecordReaderID(new UID()) + // In Hive 1.1, the record's schema may need to be initialized manually or a NPE will + // be thrown. + if (w.getFileSchema() == null) { + serDeProps + .find(_._1 == AvroSerdeUtils.AvroTableProperties.SCHEMA_LITERAL.getPropName()) + .foreach { kv => + w.setFileSchema(new Schema.Parser().parse(kv._2)) + } + } + case _ => + } + w + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } + } + + /** + * This class provides the UDF creation and also the UDF instance serialization and + * de-serialization cross process boundary. + * + * Detail discussion can be found at https://github.com/apache/spark/pull/3640 + * + * @param functionClassName UDF class name + * @param instance optional UDF instance which contains additional information (for macro) + */ + private[hive] case class HiveFunctionWrapper(var functionClassName: String, + private var instance: AnyRef = null) extends java.io.Externalizable { + + // for Serialization + def this() = this(null) + + override def hashCode(): Int = { + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + Objects.hashCode(functionClassName, instance.asInstanceOf[GenericUDFMacro].getBody()) + } else { + functionClassName.hashCode() + } + } + + override def equals(other: Any): Boolean = other match { + case a: HiveFunctionWrapper if functionClassName == a.functionClassName => + // In case of udf macro, check to make sure they point to the same underlying UDF + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + a.instance.asInstanceOf[GenericUDFMacro].getBody() == + instance.asInstanceOf[GenericUDFMacro].getBody() + } else { + true + } + case _ => false + } + + @transient + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp, clazz).asInstanceOf[T] + inp.close() + t + } + + @transient + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() + } + + def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) + .asInstanceOf[UDFType] + } + + def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) + } + + def writeExternal(out: java.io.ObjectOutput) { + // output the function name + out.writeUTF(functionClassName) + + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } + } + + def readExternal(in: java.io.ObjectInput) { + // read the function name + functionClassName = in.readUTF() + + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.readFully(functionInBytes) + + // deserialize the function object via Hive Utilities + instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), + Utils.getContextOrSparkClassLoader.loadClass(functionClassName)) + } + } + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = Utils.getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + // We cache the function if it's no the Simple UDF, + // as we always have to create new instance for Simple UDF + instance = func + } + func + } + } + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { + val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + f.setCompressCodec(w.compressCodec) + f.setCompressType(w.compressType) + f.setTableInfo(w.tableInfo) + f.setDestTableId(w.destTableId) + f + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) + extends Serializable with Logging { + var compressCodec: String = _ + var compressType: String = _ + var destTableId: Int = _ + + def setCompressed(compressed: Boolean) { + this.compressed = compressed + } + + def getDirName(): String = dir + + def setDestTableId(destTableId: Int) { + this.destTableId = destTableId + } + + def setTableInfo(tableInfo: TableDesc) { + this.tableInfo = tableInfo + } + + def setCompressCodec(intermediateCompressorCodec: String) { + compressCodec = intermediateCompressorCodec + } + + def setCompressType(intermediateCompressType: String) { + compressType = intermediateCompressType + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/pom.xml ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/pom.xml b/spark/spark-2.2/pom.xml new file mode 100644 index 0000000..85a296f --- /dev/null +++ b/spark/spark-2.2/pom.xml @@ -0,0 +1,269 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + + <parent> + <groupId>io.github.myui</groupId> + <artifactId>hivemall</artifactId> + <version>0.4.2-rc.2</version> + <relativePath>../../pom.xml</relativePath> + </parent> + + <artifactId>hivemall-spark</artifactId> + <name>Hivemall on Spark 2.2</name> + <packaging>jar</packaging> + + <properties> + <PermGen>64m</PermGen> + <MaxPermGen>512m</MaxPermGen> + <CodeCacheSize>512m</CodeCacheSize> + <main.basedir>${project.parent.basedir}</main.basedir> + </properties> + + <dependencies> + <!-- hivemall dependencies --> + <dependency> + <groupId>io.github.myui</groupId> + <artifactId>hivemall-core</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>io.github.myui</groupId> + <artifactId>hivemall-xgboost</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>io.github.myui</groupId> + <artifactId>hivemall-spark-common</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + </dependency> + + <!-- third-party dependencies --> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-library</artifactId> + <version>${scala.version}</version> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-compress</artifactId> + <version>1.8</version> + <scope>compile</scope> + </dependency> + + <!-- other provided dependencies --> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-core_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-hive_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-streaming_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-mllib_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + + <!-- test dependencies --> + <dependency> + <groupId>io.github.myui</groupId> + <artifactId>hivemall-mixserv</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.xerial</groupId> + <artifactId>xerial-core</artifactId> + <version>3.2.3</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.binary.version}</artifactId> + <version>2.2.4</version> + <scope>test</scope> + </dependency> + </dependencies> + + <build> + <directory>target</directory> + <outputDirectory>target/classes</outputDirectory> + <finalName>${project.artifactId}-${spark.binary.version}_${scala.binary.version}-${project.version}</finalName> + <testOutputDirectory>target/test-classes</testOutputDirectory> + <plugins> + <!-- For incremental compilation --> + <plugin> + <groupId>net.alchim31.maven</groupId> + <artifactId>scala-maven-plugin</artifactId> + <version>3.2.2</version> + <executions> + <execution> + <id>scala-compile-first</id> + <phase>process-resources</phase> + <goals> + <goal>compile</goal> + </goals> + </execution> + <execution> + <id>scala-test-compile-first</id> + <phase>process-test-resources</phase> + <goals> + <goal>testCompile</goal> + </goals> + </execution> + </executions> + <configuration> + <scalaVersion>${scala.version}</scalaVersion> + <recompileMode>incremental</recompileMode> + <useZincServer>true</useZincServer> + <args> + <arg>-unchecked</arg> + <arg>-deprecation</arg> + <!-- TODO: To enable this option, we need to fix many wornings --> + <!-- <arg>-feature</arg> --> + </args> + <jvmArgs> + <jvmArg>-Xms1024m</jvmArg> + <jvmArg>-Xmx1024m</jvmArg> + <jvmArg>-XX:PermSize=${PermGen}</jvmArg> + <jvmArg>-XX:MaxPermSize=${MaxPermGen}</jvmArg> + <jvmArg>-XX:ReservedCodeCacheSize=${CodeCacheSize}</jvmArg> + </jvmArgs> + </configuration> + </plugin> + <!-- hivemall-spark_xx-xx.jar --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <version>2.5</version> + <configuration> + <finalName>${project.artifactId}-${spark.binary.version}_${scala.binary.version}-${project.version}</finalName> + <outputDirectory>${project.parent.build.directory}</outputDirectory> + </configuration> + </plugin> + <!-- hivemall-spark_xx-xx-with-dependencies.jar including minimum dependencies --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-shade-plugin</artifactId> + <version>2.3</version> + <executions> + <execution> + <id>jar-with-dependencies</id> + <phase>package</phase> + <goals> + <goal>shade</goal> + </goals> + <configuration> + <finalName>${project.artifactId}-${spark.binary.version}_${scala.binary.version}-${project.version}-with-dependencies</finalName> + <outputDirectory>${project.parent.build.directory}</outputDirectory> + <minimizeJar>false</minimizeJar> + <createDependencyReducedPom>false</createDependencyReducedPom> + <artifactSet> + <includes> + <include>io.github.myui:hivemall-core</include> + <include>io.github.myui:hivemall-xgboost</include> + <include>io.github.myui:hivemall-spark-common</include> + <include>com.github.haifengl:smile-core</include> + <include>com.github.haifengl:smile-math</include> + <include>com.github.haifengl:smile-data</include> + <include>ml.dmlc:xgboost4j</include> + <include>com.esotericsoftware.kryo:kryo</include> + </includes> + </artifactSet> + </configuration> + </execution> + </executions> + </plugin> + <!-- disable surefire because there is no java test --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-surefire-plugin</artifactId> + <version>2.7</version> + <configuration> + <skipTests>true</skipTests> + </configuration> + </plugin> + <!-- then, enable scalatest --> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <version>1.0</version> + <configuration> + <reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory> + <junitxml>.</junitxml> + <filereports>SparkTestSuite.txt</filereports> + <argLine>-ea -Xmx2g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=${CodeCacheSize}</argLine> + <stderr /> + <environmentVariables> + <SPARK_PREPEND_CLASSES>1</SPARK_PREPEND_CLASSES> + <SPARK_SCALA_VERSION>${scala.binary.version}</SPARK_SCALA_VERSION> + <SPARK_TESTING>1</SPARK_TESTING> + <JAVA_HOME>${env.JAVA_HOME}</JAVA_HOME> + </environmentVariables> + <systemProperties> + <log4j.configuration>file:src/test/resources/log4j.properties</log4j.configuration> + <derby.system.durability>test</derby.system.durability> + <java.awt.headless>true</java.awt.headless> + <java.io.tmpdir>${project.build.directory}/tmp</java.io.tmpdir> + <spark.testing>1</spark.testing> + <spark.ui.enabled>false</spark.ui.enabled> + <spark.ui.showConsoleProgress>false</spark.ui.showConsoleProgress> + <spark.unsafe.exceptionOnMemoryLeak>true</spark.unsafe.exceptionOnMemoryLeak> + <!-- Needed by sql/hive tests. --> + <test.src.tables>__not_used__</test.src.tables> + </systemProperties> + <tagsToExclude>${test.exclude.tags}</tagsToExclude> + </configuration> + <executions> + <execution> + <id>test</id> + <goals> + <goal>test</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> +</project> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/java/hivemall/xgboost/XGBoostOptions.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/main/java/hivemall/xgboost/XGBoostOptions.scala b/spark/spark-2.2/src/main/java/hivemall/xgboost/XGBoostOptions.scala new file mode 100644 index 0000000..3e0f274 --- /dev/null +++ b/spark/spark-2.2/src/main/java/hivemall/xgboost/XGBoostOptions.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package hivemall.xgboost + +import scala.collection.mutable + +import org.apache.commons.cli.Options +import org.apache.spark.annotation.AlphaComponent + +/** + * :: AlphaComponent :: + * An utility class to generate a sequence of options used in XGBoost. + */ +@AlphaComponent +case class XGBoostOptions() { + private val params: mutable.Map[String, String] = mutable.Map.empty + private val options: Options = { + new XGBoostUDTF() { + def options(): Options = super.getOptions() + }.options() + } + + private def isValidKey(key: String): Boolean = { + // TODO: Is there another way to handle all the XGBoost options? + options.hasOption(key) || key == "num_class" + } + + def set(key: String, value: String): XGBoostOptions = { + require(isValidKey(key), s"non-existing key detected in XGBoost options: ${key}") + params.put(key, value) + this + } + + def help(): Unit = { + import scala.collection.JavaConversions._ + options.getOptions.map { case option => println(option) } + } + + override def toString(): String = { + params.map { case (key, value) => s"-$key $value" }.mkString(" ") + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-2.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000..b49e20a --- /dev/null +++ b/spark/spark-2.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.source.XGBoostFileFormat http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/main/resources/log4j.properties b/spark/spark-2.2/src/main/resources/log4j.properties new file mode 100644 index 0000000..72bf5b6 --- /dev/null +++ b/spark/spark-2.2/src/main/resources/log4j.properties @@ -0,0 +1,12 @@ +# Set everything to be logged to the console +log4j.rootCategory=INFO, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.eclipse.jetty=INFO +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/hivemall/tools/RegressionDatagen.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/main/scala/hivemall/tools/RegressionDatagen.scala b/spark/spark-2.2/src/main/scala/hivemall/tools/RegressionDatagen.scala new file mode 100644 index 0000000..a2b7f60 --- /dev/null +++ b/spark/spark-2.2/src/main/scala/hivemall/tools/RegressionDatagen.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package hivemall.tools + +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.types._ + +object RegressionDatagen { + + /** + * Generate data for regression/classification. + * See [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]] + * for the details of arguments below. + */ + def exec(sc: SQLContext, + n_partitions: Int = 2, + min_examples: Int = 1000, + n_features: Int = 10, + n_dims: Int = 200, + seed: Int = 43, + dense: Boolean = false, + prob_one: Float = 0.6f, + sort: Boolean = false, + cl: Boolean = false): DataFrame = { + + require(n_partitions > 0, "Non-negative #n_partitions required.") + require(min_examples > 0, "Non-negative #min_examples required.") + require(n_features > 0, "Non-negative #n_features required.") + require(n_dims > 0, "Non-negative #n_dims required.") + + // Calculate #examples to generate in each partition + val n_examples = (min_examples + n_partitions - 1) / n_partitions + + val df = sc.createDataFrame( + sc.sparkContext.parallelize((0 until n_partitions).map(Row(_)), n_partitions), + StructType( + StructField("data", IntegerType, true) :: + Nil) + ) + import sc.implicits._ + df.lr_datagen( + lit(s"-n_examples $n_examples -n_features $n_features -n_dims $n_dims -prob_one $prob_one" + + (if (dense) " -dense" else "") + + (if (sort) " -sort" else "") + + (if (cl) " -cl" else "")) + ).select($"label".cast(DoubleType).as("label"), $"features") + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala new file mode 100644 index 0000000..cac2a5d --- /dev/null +++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue +import org.apache.spark.sql.types._ + +trait TopKHelper { + + def k: Int + def scoreType: DataType + + @transient val ScoreTypes = TypeCollection( + ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType + ) + + protected case class ScoreWriter(writer: UnsafeRowWriter, ordinal: Int) { + + def write(v: Any): Unit = scoreType match { + case ByteType => writer.write(ordinal, v.asInstanceOf[Byte]) + case ShortType => writer.write(ordinal, v.asInstanceOf[Short]) + case IntegerType => writer.write(ordinal, v.asInstanceOf[Int]) + case LongType => writer.write(ordinal, v.asInstanceOf[Long]) + case FloatType => writer.write(ordinal, v.asInstanceOf[Float]) + case DoubleType => writer.write(ordinal, v.asInstanceOf[Double]) + case d: DecimalType => writer.write(ordinal, v.asInstanceOf[Decimal], d.precision, d.scale) + } + } + + protected lazy val scoreOrdering = { + val ordering = TypeUtils.getInterpretedOrdering(scoreType) + if (k > 0) ordering else ordering.reverse + } + + protected lazy val reverseScoreOrdering = scoreOrdering.reverse + + protected lazy val queue: InternalRowPriorityQueue = { + new InternalRowPriorityQueue(Math.abs(k), (x: Any, y: Any) => scoreOrdering.compare(x, y)) + } +} + +case class EachTopK( + k: Int, + scoreExpr: Expression, + groupExprs: Seq[Expression], + elementSchema: StructType, + children: Seq[Attribute]) + extends Generator with TopKHelper with CodegenFallback { + + override val scoreType: DataType = scoreExpr.dataType + + private lazy val groupingProjection: UnsafeProjection = UnsafeProjection.create(groupExprs) + private lazy val scoreProjection: UnsafeProjection = UnsafeProjection.create(scoreExpr :: Nil) + + // The grouping key of the current partition + private var currentGroupingKeys: UnsafeRow = _ + + override def checkInputDataTypes(): TypeCheckResult = { + if (!ScoreTypes.acceptsType(scoreExpr.dataType)) { + TypeCheckResult.TypeCheckFailure(s"$scoreExpr must have a comparable type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + private def topKRowsForGroup(): Seq[InternalRow] = if (queue.size > 0) { + val outputRows = queue.iterator.toSeq.reverse + val (headScore, _) = outputRows.head + val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) => + if (prevScore == score) (rank, score) else (rank + 1, score) + } + val topKRow = new UnsafeRow(1) + val bufferHolder = new BufferHolder(topKRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + outputRows.zip(rankNum.map(_._1)).map { case ((_, row), index) => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, index) + topKRow.setTotalSize(bufferHolder.totalSize()) + new JoinedRow(topKRow, row) + } + } else { + Seq.empty + } + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val groupingKeys = groupingProjection(input) + val ret = if (currentGroupingKeys != groupingKeys) { + val topKRows = topKRowsForGroup() + currentGroupingKeys = groupingKeys.copy() + queue.clear() + topKRows + } else { + Iterator.empty + } + queue += Tuple2(scoreProjection(input).get(0, scoreType), input) + ret + } + + override def terminate(): TraversableOnce[InternalRow] = { + if (queue.size > 0) { + val topKRows = topKRowsForGroup() + queue.clear() + topKRows + } else { + Iterator.empty + } + } + + // TODO: Need to support codegen + // protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala new file mode 100644 index 0000000..556cdc3 --- /dev/null +++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +case class JoinTopK( + k: Int, + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression])( + val scoreExpr: NamedExpression, + private[sql] val rankAttr: Seq[Attribute] = AttributeReference("rank", IntegerType)() :: Nil) + extends BinaryNode with PredicateHelper { + + override def output: Seq[Attribute] = joinType match { + case Inner => rankAttr ++ Seq(scoreExpr.toAttribute) ++ left.output ++ right.output + } + + override def references: AttributeSet = { + AttributeSet((expressions ++ Seq(scoreExpr)).flatMap(_.references)) + } + + override protected def validConstraints: Set[Expression] = joinType match { + case Inner if condition.isDefined => + left.constraints.union(right.constraints) + .union(splitConjunctivePredicates(condition.get).toSet) + } + + override protected final def otherCopyArgs: Seq[AnyRef] = { + scoreExpr :: rankAttr :: Nil + } + + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + + lazy val resolvedExceptNatural: Boolean = { + childrenResolved && + expressions.forall(_.resolved) && + duplicateResolved && + condition.forall(_.dataType == BooleanType) + } + + override lazy val resolved: Boolean = joinType match { + case Inner => resolvedExceptNatural + case tpe => throw new AnalysisException(s"Unsupported using join type $tpe") + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala new file mode 100644 index 0000000..12c20fb --- /dev/null +++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.utils + +import java.io.Serializable +import java.util.{PriorityQueue => JPriorityQueue} + +import scala.collection.JavaConverters._ +import scala.collection.generic.Growable + +import org.apache.spark.sql.catalyst.InternalRow + +private[sql] class InternalRowPriorityQueue( + maxSize: Int, + compareFunc: (Any, Any) => Int + ) extends Iterable[(Any, InternalRow)] with Growable[(Any, InternalRow)] with Serializable { + + private[this] val ordering = new Ordering[(Any, InternalRow)] { + override def compare(x: (Any, InternalRow), y: (Any, InternalRow)): Int = + compareFunc(x._1, y._1) + } + + private val underlying = new JPriorityQueue[(Any, InternalRow)](maxSize, ordering) + + override def iterator: Iterator[(Any, InternalRow)] = underlying.iterator.asScala + + override def size: Int = underlying.size + + override def ++=(xs: TraversableOnce[(Any, InternalRow)]): this.type = { + xs.foreach { this += _ } + this + } + + override def +=(elem: (Any, InternalRow)): this.type = { + if (size < maxSize) { + underlying.offer((elem._1, elem._2.copy())) + } else { + maybeReplaceLowest(elem) + } + this + } + + override def +=(elem1: (Any, InternalRow), elem2: (Any, InternalRow), elems: (Any, InternalRow)*) + : this.type = { + this += elem1 += elem2 ++= elems + } + + override def clear() { underlying.clear() } + + private def maybeReplaceLowest(a: (Any, InternalRow)): Boolean = { + val head = underlying.peek() + if (head != null && ordering.gt(a, head)) { + underlying.poll() + underlying.offer((a._1, a._2.copy())) + } else { + false + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala new file mode 100644 index 0000000..09d60a6 --- /dev/null +++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical.{JoinTopK, LogicalPlan} +import org.apache.spark.sql.internal.SQLConf + +private object ExtractJoinTopKKeys extends Logging with PredicateHelper { + /** (k, scoreExpr, joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ + type ReturnType = + (Int, NamedExpression, Seq[Attribute], JoinType, Seq[Expression], Seq[Expression], + Option[Expression], LogicalPlan, LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case join @ JoinTopK(k, left, right, joinType, condition) => + logDebug(s"Considering join on: $condition") + val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) + val joinKeys = predicates.flatMap { + case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r)) + case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l)) + // Replace null with default value for joining key, then those rows with null in it could + // be joined together + case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => + Some((Coalesce(Seq(l, Literal.default(l.dataType))), + Coalesce(Seq(r, Literal.default(r.dataType))))) + case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => + Some((Coalesce(Seq(r, Literal.default(r.dataType))), + Coalesce(Seq(l, Literal.default(l.dataType))))) + case other => None + } + val otherPredicates = predicates.filterNot { + case EqualTo(l, r) => + canEvaluate(l, left) && canEvaluate(r, right) || + canEvaluate(l, right) && canEvaluate(r, left) + case other => false + } + + if (joinKeys.nonEmpty) { + val (leftKeys, rightKeys) = joinKeys.unzip + logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") + Some((k, join.scoreExpr, join.rankAttr, joinType, leftKeys, rightKeys, + otherPredicates.reduceOption(And), left, right)) + } else { + None + } + + case p => + None + } +} + +private[sql] class UserProvidedPlanner(val conf: SQLConf) extends Strategy { + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractJoinTopKKeys( + k, scoreExpr, rankAttr, _, leftKeys, rightKeys, condition, left, right) => + Seq(joins.ShuffledHashJoinTopKExec( + k, leftKeys, rightKeys, condition, planLater(left), planLater(right))(scoreExpr, rankAttr)) + case _ => + Nil + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala new file mode 100644 index 0000000..1f56c90 --- /dev/null +++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import com.univocity.parsers.csv.CsvWriter + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Converts a csv input string to a [[StructType]] with the specified schema. + * + * TODO: Move this class into org.apache.spark.sql.catalyst.expressions in Spark-v2.2+ + */ +case class CsvToStruct( + schema: StructType, + options: Map[String, String], + child: Expression, + timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + + def this(schema: StructType, options: Map[String, String], child: Expression) = + this(schema, options, child, None) + + override def nullable: Boolean = true + + @transient private lazy val csvOptions = new CSVOptions(options, timeZoneId.get) + @transient private lazy val csvParser = new UnivocityParser(schema, schema, csvOptions) + + private def parse(input: String): InternalRow = csvParser.parse(input) + + override def dataType: DataType = schema + + override def nullSafeEval(csv: Any): Any = { + try parse(csv.toString) catch { case _: RuntimeException => null } + } + + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) +} + +private class CsvGenerator(schema: StructType, options: CSVOptions) { + + // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. + // When the value is null, this converter should not be called. + private type ValueConverter = (InternalRow, Int) => String + + // `ValueConverter`s for all values in the fields of the schema + private val valueConverters: Array[ValueConverter] = + schema.map(_.dataType).map(makeConverter).toArray + + private def makeConverter(dataType: DataType): ValueConverter = dataType match { + case DateType => + (row: InternalRow, ordinal: Int) => + options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + + case udt: UserDefinedType[_] => makeConverter(udt.sqlType) + + case dt: DataType => + (row: InternalRow, ordinal: Int) => + row.get(ordinal, dt).toString + } + + def convertRow(row: InternalRow): Seq[String] = { + var i = 0 + val values = new Array[String](row.numFields) + while (i < row.numFields) { + if (!row.isNullAt(i)) { + values(i) = valueConverters(i).apply(row, i) + } else { + values(i) = options.nullValue + } + i += 1 + } + values + } +} + +/** + * Converts a [[StructType]] to a csv output string. + */ +case class StructToCsv( + options: Map[String, String], + child: Expression, + timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient + private lazy val params = new CSVOptions(options, timeZoneId.get) + + @transient + private lazy val dataSchema = child.dataType.asInstanceOf[StructType] + + @transient + private lazy val writer = new CsvGenerator(dataSchema, params) + + override def dataType: DataType = StringType + + private def verifySchema(schema: StructType): Unit = { + def verifyType(dataType: DataType): Unit = dataType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | + DoubleType | BooleanType | _: DecimalType | TimestampType | + DateType | StringType => + + case udt: UserDefinedType[_] => verifyType(udt.sqlType) + + case _ => + throw new UnsupportedOperationException( + s"CSV data source does not support ${dataType.simpleString} data type.") + } + + schema.foreach(field => verifyType(field.dataType)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (StructType.acceptsType(child.dataType)) { + try { + verifySchema(child.dataType.asInstanceOf[StructType]) + TypeCheckResult.TypeCheckSuccess + } catch { + case e: UnsupportedOperationException => + TypeCheckResult.TypeCheckFailure(e.getMessage) + } + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName requires that the expression is a struct expression.") + } + } + + override def nullSafeEval(row: Any): Any = { + val rowStr = writer.convertRow(row.asInstanceOf[InternalRow]) + .mkString(params.delimiter.toString) + UTF8String.fromString(rowStr) + } + + override def inputTypes: Seq[AbstractDataType] = StructType :: Nil + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala new file mode 100644 index 0000000..0067bbb --- /dev/null +++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.types._ + +abstract class PriorityQueueShim { + + def insert(score: Any, row: InternalRow): Unit + def get(): Iterator[InternalRow] + def clear(): Unit +} + +case class ShuffledHashJoinTopKExec( + k: Int, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan)( + scoreExpr: NamedExpression, + rankAttr: Seq[Attribute]) + extends BinaryExecNode with TopKHelper with HashJoin with CodegenSupport { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override val scoreType: DataType = scoreExpr.dataType + override val joinType: JoinType = Inner + override val buildSide: BuildSide = BuildRight // Only support `BuildRight` + + private lazy val scoreProjection: UnsafeProjection = + UnsafeProjection.create(scoreExpr :: Nil, left.output ++ right.output) + + private lazy val boundCondition = if (condition.isDefined) { + (r: InternalRow) => newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval(r) + } else { + (r: InternalRow) => true + } + + private lazy val topKAttr = rankAttr :+ scoreExpr.toAttribute + + private lazy val _priorityQueue = new PriorityQueueShim { + + private val q: InternalRowPriorityQueue = queue + private val joinedRow = new JoinedRow + + override def insert(score: Any, row: InternalRow): Unit = { + q += Tuple2(score, row) + } + + override def get(): Iterator[InternalRow] = { + val outputRows = queue.iterator.toSeq.reverse + val (headScore, _) = outputRows.head + val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) => + if (prevScore == score) (rank, score) else (rank + 1, score) + } + val topKRow = new UnsafeRow(2) + val bufferHolder = new BufferHolder(topKRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2) + val scoreWriter = ScoreWriter(unsafeRowWriter, 1) + outputRows.zip(rankNum.map(_._1)).map { case ((score, row), index) => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, index) + scoreWriter.write(score) + topKRow.setTotalSize(bufferHolder.totalSize()) + joinedRow.apply(topKRow, row) + }.iterator + } + + override def clear(): Unit = q.clear() + } + + override def output: Seq[Attribute] = joinType match { + case Inner => topKAttr ++ left.output ++ right.output + } + + override protected final def otherCopyArgs: Seq[AnyRef] = { + scoreExpr :: rankAttr :: Nil + } + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { + val context = TaskContext.get() + val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) + context.addTaskCompletionListener(_ => relation.close()) + relation + } + + override protected def createResultProjection(): (InternalRow) => InternalRow = joinType match { + case Inner => + // Always put the stream side on left to simplify implementation + // both of left and right side could be null + UnsafeProjection.create( + output, (topKAttr ++ streamedPlan.output ++ buildPlan.output).map(_.withNullability(true))) + } + + protected def InnerJoin( + streamedIter: Iterator[InternalRow], + hashedRelation: HashedRelation, + numOutputRows: SQLMetric): Iterator[InternalRow] = { + val joinRow = new JoinedRow + val joinKeysProj = streamSideKeyGenerator() + val joinedIter = streamedIter.flatMap { srow => + joinRow.withLeft(srow) + val joinKeys = joinKeysProj(srow) // `joinKeys` is also a grouping key + val matches = hashedRelation.get(joinKeys) + if (matches != null) { + matches.map(joinRow.withRight).filter(boundCondition).foreach { resultRow => + _priorityQueue.insert(scoreProjection(resultRow).get(0, scoreType), resultRow) + } + val iter = _priorityQueue.get() + _priorityQueue.clear() + iter + } else { + Seq.empty + } + } + val resultProj = createResultProjection() + (joinedIter ++ queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering) + .map(_._2)).map { r => + resultProj(r) + } + } + + override protected def doExecute(): RDD[InternalRow] = { + streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => + val hashed = buildHashedRelation(buildIter) + InnerJoin(streamIter, hashed, null) + } + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + left.execute() :: right.execute() :: Nil + } + + // Accessor for generated code + def priorityQueue(): PriorityQueueShim = _priorityQueue + + /** + * Add a state of HashedRelation and return the variable name for it. + */ + private def prepareHashedRelation(ctx: CodegenContext): String = { + // create a name for HashedRelation + val joinExec = ctx.addReferenceObj("joinExec", this) + val relationTerm = ctx.freshName("relation") + val clsName = HashedRelation.getClass.getName.replace("$", "") + ctx.addMutableState(clsName, relationTerm, + s""" + | $relationTerm = ($clsName) $joinExec.buildHashedRelation(inputs[1]); + | incPeakExecutionMemory($relationTerm.estimatedSize()); + """.stripMargin) + relationTerm + } + + /** + * Creates variables for left part of result row. + * + * In order to defer the access after condition and also only access once in the loop, + * the variables should be declared separately from accessing the columns, we can't use the + * codegen of BoundReference here. + */ + private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = leftRow + left.output.zipWithIndex.map { case (a, i) => + val value = ctx.freshName("value") + val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) + // declare it as class member, so we can access the column before or in the loop. + ctx.addMutableState(ctx.javaType(a.dataType), value, "") + if (a.nullable) { + val isNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", isNull, "") + val code = + s""" + |$isNull = $leftRow.isNullAt($i); + |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$value = $valueCode;", "false", value) + } + } + } + + /** + * Creates the variables for right part of result row, using BoundReference, since the right + * part are accessed inside the loop. + */ + private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = rightRow + right.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + } + + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + private def genStreamSideJoinKey(ctx: CodegenContext, leftRow: String): (ExprCode, String) = { + ctx.INPUT_ROW = leftRow + if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { + // generate the join key as Long + val ev = streamedKeys.head.genCode(ctx) + (ev, ev.isNull) + } else { + // generate the join key as UnsafeRow + val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) + (ev, s"${ev.value}.anyNull()") + } + } + + private def createScoreVar(ctx: CodegenContext, row: String): ExprCode = { + ctx.INPUT_ROW = row + BindReferences.bindReference(scoreExpr, left.output ++ right.output).genCode(ctx) + } + + private def createResultVars(ctx: CodegenContext, resultRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = resultRow + output.zipWithIndex.map { case (a, i) => + val value = ctx.freshName("value") + val valueCode = ctx.getValue(resultRow, a.dataType, i.toString) + // declare it as class member, so we can access the column before or in the loop. + ctx.addMutableState(ctx.javaType(a.dataType), value, "") + if (a.nullable) { + val isNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", isNull, "") + val code = + s""" + |$isNull = $resultRow.isNullAt($i); + |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$value = $valueCode;", "false", value) + } + } + } + + /** + * Splits variables based on whether it's used by condition or not, returns the code to create + * these variables before the condition and after the condition. + * + * Only a few columns are used by condition, then we can skip the accessing of those columns + * that are not used by condition also filtered out by condition. + */ + private def splitVarsByCondition( + attributes: Seq[Attribute], + variables: Seq[ExprCode]): (String, String) = { + if (condition.isDefined) { + val condRefs = condition.get.references + val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => + condRefs.contains(a) + } + val beforeCond = evaluateVariables(used.map(_._2)) + val afterCond = evaluateVariables(notUsed.map(_._2)) + (beforeCond, afterCond) + } else { + (evaluateVariables(variables), "") + } + } + + override def doProduce(ctx: CodegenContext): String = { + ctx.copyResult = true + + val topKJoin = ctx.addReferenceObj("topKJoin", this) + + // Prepare a priority queue for top-K computing + val pQueue = ctx.freshName("queue") + ctx.addMutableState(classOf[PriorityQueueShim].getName, pQueue, + s"$pQueue = $topKJoin.priorityQueue();") + + // Prepare variables for a left side + val leftIter = ctx.freshName("leftIter") + ctx.addMutableState("scala.collection.Iterator", leftIter, s"$leftIter = inputs[0];") + val leftRow = ctx.freshName("leftRow") + ctx.addMutableState("InternalRow", leftRow, "") + val leftVars = createLeftVars(ctx, leftRow) + + // Prepare variables for a right side + val rightRow = ctx.freshName("rightRow") + val rightVars = createRightVar(ctx, rightRow) + + // Build a hashed relation from a right side + val buildRelation = prepareHashedRelation(ctx) + + // Project join keys from a left side + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, leftRow) + + // Prepare variables for joined rows + val joinedRow = ctx.freshName("joinedRow") + val joinedRowCls = classOf[JoinedRow].getName + ctx.addMutableState(joinedRowCls, joinedRow, s"$joinedRow = new $joinedRowCls();") + + // Project score values from joined rows + val scoreVar = createScoreVar(ctx, joinedRow) + + // Prepare variables for output rows + val resultRow = ctx.freshName("resultRow") + val resultVars = createResultVars(ctx, resultRow) + + val (beforeLoop, condCheck) = if (condition.isDefined) { + // Split the code of creating variables based on whether it's used by condition or not. + val loaded = ctx.freshName("loaded") + val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) + val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + // Generate code for condition + ctx.currentVars = leftVars ++ rightVars + val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) + // evaluate the columns those used by condition before loop + val before = s""" + |boolean $loaded = false; + |$leftBefore + """.stripMargin + + val checking = s""" + |$rightBefore + |${cond.code} + |if (${cond.isNull} || !${cond.value}) continue; + |if (!$loaded) { + | $loaded = true; + | $leftAfter + |} + |$rightAfter + """.stripMargin + (before, checking) + } else { + (evaluateVariables(leftVars), "") + } + + val numOutput = metricTerm(ctx, "numOutputRows") + + val matches = ctx.freshName("matches") + val topKRows = ctx.freshName("topKRows") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + + s""" + |$leftRow = null; + |while ($leftIter.hasNext()) { + | $leftRow = (InternalRow) $leftIter.next(); + | + | // Generate join key for stream side + | ${keyEv.code} + | + | // Find matches from HashedRelation + | $iteratorCls $matches = $anyNull? null : ($iteratorCls)$buildRelation.get(${keyEv.value}); + | if ($matches == null) continue; + | + | // Join top-K right rows + | while ($matches.hasNext()) { + | ${beforeLoop.trim} + | InternalRow $rightRow = (InternalRow) $matches.next(); + | ${condCheck.trim} + | InternalRow row = $joinedRow.apply($leftRow, $rightRow); + | // Compute a score for the `row` + | ${scoreVar.code} + | $pQueue.insert(${scoreVar.value}, row); + | } + | + | // Get top-K rows + | $iteratorCls $topKRows = $pQueue.get(); + | $pQueue.clear(); + | + | // Output top-K rows + | while ($topKRows.hasNext()) { + | InternalRow $resultRow = (InternalRow) $topKRows.next(); + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | + | if (shouldStop()) return; + |} + """.stripMargin + } +}
