This is an automated email from the ASF dual-hosted git repository.
kejia pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 260e5857cc [GLUTEN-7548][VL] Optimize BHJ in velox backend (#8931)
260e5857cc is described below
commit 260e5857ccb7b68d4ce9cb09e41c5db051be0a33
Author: JiaKe <[email protected]>
AuthorDate: Mon Mar 9 11:20:41 2026 +0000
[GLUTEN-7548][VL] Optimize BHJ in velox backend (#8931)
* bhj optimization to ensure the hash table built once per executor
---
backends-velox/pom.xml | 4 +
.../apache/gluten/vectorized/HashJoinBuilder.java | 53 +++++
.../gluten/backendsapi/velox/VeloxBackend.scala | 11 +-
.../backendsapi/velox/VeloxListenerApi.scala | 12 +
.../gluten/backendsapi/velox/VeloxRuleApi.scala | 5 +
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 150 ++++++++++++-
.../backendsapi/velox/VeloxTransformerApi.scala | 5 +
.../org/apache/gluten/config/VeloxConfig.scala | 23 ++
.../gluten/execution/HashJoinExecTransformer.scala | 55 ++++-
.../execution/VeloxBroadcastBuildSideCache.scala | 112 ++++++++++
.../execution/VeloxBroadcastBuildSideRDD.scala | 29 ++-
...loxBroadcastNestedLoopJoinExecTransformer.scala | 2 +-
.../listener/VeloxGlutenSQLAppStatusListener.scala | 82 +++++++
.../apache/spark/rpc/GlutenDriverEndpoint.scala | 136 ++++++++++++
.../apache/spark/rpc/GlutenExecutorEndpoint.scala | 79 +++++++
.../org/apache/spark/rpc/GlutenRpcConstants.scala} | 21 +-
.../org/apache/spark/rpc/GlutenRpcMessages.scala | 53 +++++
.../spark/sql/execution/BroadcastUtils.scala | 4 +-
.../sql/execution/ColumnarBuildSideRelation.scala | 106 ++++++++-
.../sql/execution/joins/SparkHashJoinUtils.scala | 51 +++++
.../unsafe/UnsafeColumnarBuildSideRelation.scala | 116 +++++++++-
.../org/apache/gluten/test/MockVeloxBackend.java | 2 +-
.../apache/gluten/test/VeloxBackendTestBase.java | 17 ++
.../execution/DynamicOffHeapSizingSuite.scala | 4 +
.../gluten/execution/VeloxHashJoinSuite.scala | 159 +++++++-------
.../VeloxBroadcastBuildOnceBenchmark.scala | 85 +++++++
.../UnsafeColumnarBuildSideRelationTest.scala | 26 +++
cpp/velox/CMakeLists.txt | 2 +
cpp/velox/compute/VeloxBackend.cc | 1 +
cpp/velox/compute/VeloxBackend.h | 5 +
cpp/velox/jni/JniHashTable.cc | 145 ++++++++++++
cpp/velox/jni/JniHashTable.h | 54 +++++
cpp/velox/jni/VeloxJniWrapper.cc | 203 +++++++++++++++--
cpp/velox/operators/hashjoin/HashTableBuilder.cc | 244 +++++++++++++++++++++
cpp/velox/operators/hashjoin/HashTableBuilder.h | 123 +++++++++++
cpp/velox/substrait/SubstraitToVeloxPlan.cc | 40 ++++
docs/Configuration.md | 1 +
docs/velox-configuration.md | 1 +
.../gluten/extension/GlutenJoinKeysCapture.scala | 62 ++++++
.../org/apache/gluten/extension/JoinKeysTag.scala | 25 +--
.../apache/gluten/substrait/rel/JoinRelNode.java | 5 +
.../apache/gluten/substrait/rel/RelBuilder.java | 7 +-
.../substrait/proto/substrait/algebra.proto | 2 +
.../gluten/backendsapi/BackendSettingsApi.scala | 2 +
.../org/apache/gluten/config/GlutenConfig.scala | 10 +
.../gluten/execution/JoinExecTransformer.scala | 10 +-
.../org/apache/gluten/execution/JoinUtils.scala | 2 +
.../gluten/extension/columnar/FallbackRules.scala | 21 +-
.../execution/ColumnarBroadcastExchangeExec.scala | 4 +-
49 files changed, 2203 insertions(+), 168 deletions(-)
diff --git a/backends-velox/pom.xml b/backends-velox/pom.xml
index cd7d795861..ddf4916633 100644
--- a/backends-velox/pom.xml
+++ b/backends-velox/pom.xml
@@ -86,6 +86,10 @@
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
+ <dependency>
+ <groupId>com.github.ben-manes.caffeine</groupId>
+ <artifactId>caffeine</artifactId>
+ </dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
diff --git
a/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
new file mode 100644
index 0000000000..e54909054c
--- /dev/null
+++
b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized;
+
+import org.apache.gluten.runtime.Runtime;
+import org.apache.gluten.runtime.RuntimeAware;
+
+public class HashJoinBuilder implements RuntimeAware {
+ private final Runtime runtime;
+
+ private HashJoinBuilder(Runtime runtime) {
+ this.runtime = runtime;
+ }
+
+ public static HashJoinBuilder create(Runtime runtime) {
+ return new HashJoinBuilder(runtime);
+ }
+
+ @Override
+ public long rtHandle() {
+ return runtime.getHandle();
+ }
+
+ public static native void clearHashTable(long hashTableData);
+
+ public static native long cloneHashTable(long hashTableData);
+
+ public static native long nativeBuild(
+ String buildHashTableId,
+ long[] batchHandlers,
+ String joinKeys,
+ int joinType,
+ boolean hasMixedFiltCondition,
+ boolean isExistenceJoin,
+ byte[] namedStruct,
+ boolean isNullAwareAntiJoin,
+ long bloomFilterPushdownSize,
+ int broadcastHashTableBuildThreads);
+}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index 24d08a5792..2ab3af7cea 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -97,6 +97,11 @@ object VeloxBackendSettings extends BackendSettingsApi {
val GLUTEN_VELOX_INTERNAL_UDF_LIB_PATHS = VeloxBackend.CONF_PREFIX +
".internal.udfLibraryPaths"
val GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION = VeloxBackend.CONF_PREFIX +
".udfAllowTypeConversion"
+ val GLUTEN_VELOX_BROADCAST_CACHE_EXPIRED_TIME: String =
+ VeloxBackend.CONF_PREFIX + ("broadcast.cache.expired.time")
+ // unit: SECONDS, default 1 day
+ val GLUTEN_VELOX_BROADCAST_CACHE_EXPIRED_TIME_DEFAULT: Int = 86400
+
override def primaryBatchType: Convention.BatchType = VeloxBatchType
override def validateScanExec(
@@ -495,13 +500,17 @@ object VeloxBackendSettings extends BackendSettingsApi {
allSupported
}
+ override def enableJoinKeysRewrite(): Boolean = false
+
override def supportColumnarShuffleExec(): Boolean = {
val conf = GlutenConfig.get
conf.enableColumnarShuffle &&
(conf.isUseGlutenShuffleManager ||
conf.shuffleManagerSupportsColumnarShuffle)
}
- override def enableJoinKeysRewrite(): Boolean = false
+ override def enableHashTableBuildOncePerExecutor(): Boolean = {
+ VeloxConfig.get.enableBroadcastBuildOncePerExecutor
+ }
override def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = {
t =>
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
index 585f6d736d..8722ae8616 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
@@ -20,6 +20,7 @@ import org.apache.gluten.backendsapi.ListenerApi
import
org.apache.gluten.backendsapi.arrow.ArrowBatchTypes.{ArrowJavaBatchType,
ArrowNativeBatchType}
import org.apache.gluten.config.{GlutenConfig, GlutenCoreConfig, VeloxConfig}
import org.apache.gluten.config.VeloxConfig._
+import org.apache.gluten.execution.VeloxBroadcastBuildSideCache
import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.gluten.expression.UDFMappings
import org.apache.gluten.extension.columnar.transition.Convention
@@ -35,8 +36,10 @@ import org.apache.gluten.utils._
import org.apache.spark.{HdfsConfGenerator, ShuffleDependency, SparkConf,
SparkContext}
import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.internal.Logging
+import org.apache.spark.listener.VeloxGlutenSQLAppStatusListener
import org.apache.spark.memory.GlobalOffHeapMemory
import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.rpc.{GlutenDriverEndpoint, GlutenExecutorEndpoint}
import org.apache.spark.shuffle.{ColumnarShuffleDependency, LookupKey,
ShuffleManagerRegistry}
import org.apache.spark.shuffle.sort.ColumnarShuffleManager
import org.apache.spark.sql.execution.ColumnarCachedBatchSerializer
@@ -56,6 +59,8 @@ class VeloxListenerApi extends ListenerApi with Logging {
import VeloxListenerApi._
override def onDriverStart(sc: SparkContext, pc: PluginContext): Unit = {
+ GlutenDriverEndpoint.glutenDriverEndpointRef = (new
GlutenDriverEndpoint).self
+ VeloxGlutenSQLAppStatusListener.registerListener(sc)
val conf = pc.conf()
// When the Velox cache is enabled, the Velox file handle cache should
also be enabled.
@@ -138,6 +143,8 @@ class VeloxListenerApi extends ListenerApi with Logging {
override def onDriverShutdown(): Unit = shutdown()
override def onExecutorStart(pc: PluginContext): Unit = {
+ GlutenExecutorEndpoint.executorEndpoint = new
GlutenExecutorEndpoint(pc.executorID, pc.conf)
+
val conf = pc.conf()
// Static initializers for executor.
@@ -250,6 +257,11 @@ class VeloxListenerApi extends ListenerApi with Logging {
private def shutdown(): Unit = {
// TODO shutdown implementation in velox to release resources
+ VeloxBroadcastBuildSideCache.cleanAll()
+ val executorEndpoint = GlutenExecutorEndpoint.executorEndpoint
+ if (executorEndpoint != null) {
+ executorEndpoint.stop()
+ }
}
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index 773868b0c4..1d80536290 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -66,6 +66,11 @@ object VeloxRuleApi {
injector.injectOptimizerRule(CollapseGetJsonObjectExpressionRule.apply)
injector.injectOptimizerRule(RewriteCastFromArray.apply)
injector.injectOptimizerRule(RewriteUnboundedWindow.apply)
+
+ if (!BackendsApiManager.getSettings.enableJoinKeysRewrite()) {
+ injector.injectPlannerStrategy(_ =>
org.apache.gluten.extension.GlutenJoinKeysCapture())
+ }
+
if (BackendsApiManager.getSettings.supportAppendDataExec()) {
injector.injectPlannerStrategy(SparkShimLoader.getSparkShims.getRewriteCreateTableAsSelect(_))
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 69419deb1a..338bef20df 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -22,6 +22,7 @@ import org.apache.gluten.exception.{GlutenExceptionUtil,
GlutenNotSupportExcepti
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.expression.aggregate.{HLLAdapter,
VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet}
+import org.apache.gluten.extension.JoinKeysTag
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.shuffle.NeedCustomColumnarBatchSerializer
import org.apache.gluten.sql.shims.SparkShimLoader
@@ -29,6 +30,7 @@ import org.apache.gluten.vectorized.{ColumnarBatchSerializer,
ColumnarBatchSeria
import org.apache.spark.{ShuffleDependency, SparkEnv, SparkException}
import org.apache.spark.api.python.{ColumnarArrowEvalPythonExec,
PullOutArrowEvalPythonPreProjectHelper}
+import org.apache.spark.internal.Logging
import org.apache.spark.memory.SparkMemoryUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
@@ -43,9 +45,10 @@ import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
-import org.apache.spark.sql.execution.joins.{BuildSideRelation,
HashedRelationBroadcastMode}
+import org.apache.spark.sql.execution.joins.{BuildSideRelation,
HashedRelationBroadcastMode, SparkHashJoinUtils}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
@@ -64,8 +67,9 @@ import javax.ws.rs.core.UriBuilder
import java.util.Locale
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
-class VeloxSparkPlanExecApi extends SparkPlanExecApi {
+class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
/** Transform GetArrayItem to Substrait. */
override def genGetArrayItemTransformer(
@@ -678,9 +682,136 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
child: SparkPlan,
numOutputRows: SQLMetric,
dataSize: SQLMetric): BuildSideRelation = {
+
+ val buildKeys = mode match {
+ case mode1: HashedRelationBroadcastMode =>
+ mode1.key
+ case _ =>
+ // IdentityBroadcastMode
+ Seq.empty
+ }
+ var offload = true
+ val (newChild, newOutput, newBuildKeys) =
+ if (VeloxConfig.get.enableBroadcastBuildOncePerExecutor) {
+
+ // Try to lookup from TreeNodeTag using child's logical plan
+ // Need to recursively find logicalLink in case of AQE or other
wrappers
+ @scala.annotation.tailrec
+ def findLogicalLink(
+ plan: SparkPlan):
Option[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan] = {
+ plan.logicalLink match {
+ case some @ Some(_) => some
+ case None =>
+ plan.children match {
+ case Seq(child) => findLogicalLink(child)
+ case _ => None
+ }
+ }
+ }
+
+ val newBuildKeys = findLogicalLink(child)
+ .flatMap(_.getTagValue(JoinKeysTag.ORIGINAL_JOIN_KEYS))
+ .getOrElse {
+ if (SparkHashJoinUtils.canRewriteAsLongType(buildKeys) &&
buildKeys.nonEmpty) {
+ SparkHashJoinUtils.getOriginalKeysFromPacked(buildKeys.head)
+ } else {
+ buildKeys
+ }
+ }
+
+ val noNeedPreOp = newBuildKeys.forall {
+ case _: AttributeReference | _: BoundReference => true
+ case _ => false
+ }
+
+ if (noNeedPreOp) {
+ (child, child.output, Seq.empty[Expression])
+ } else {
+ // pre projection in case of expression join keys
+ val appendedProjections = new ArrayBuffer[NamedExpression]()
+ val preProjectionBuildKeys = newBuildKeys.zipWithIndex.map {
+ case (e, idx) =>
+ e match {
+ case b: BoundReference => child.output(b.ordinal)
+ case a: AttributeReference => a
+ case o: Expression =>
+ val newExpr = Alias(o, "col_" + idx)()
+ appendedProjections += newExpr
+ newExpr
+ }
+ }
+
+ def wrapChild(child: SparkPlan): SparkPlan = {
+ val childWithAdapter =
+
ColumnarCollapseTransformStages.wrapInputIteratorTransformer(child)
+ val projectExecTransformer =
+ ProjectExecTransformer(child.output ++ appendedProjections,
childWithAdapter)
+ val validationResult = projectExecTransformer.doValidate()
+ if (validationResult.ok()) {
+ WholeStageTransformer(
+ ProjectExecTransformer(child.output ++ appendedProjections,
childWithAdapter))(
+ ColumnarCollapseTransformStages
+ .getTransformStageCounter(childWithAdapter)
+ .incrementAndGet()
+ )
+ } else {
+ offload = false
+ child
+ }
+ }
+
+ val newChild = child match {
+ case wt: WholeStageTransformer =>
+ val projectTransformer =
+ ProjectExecTransformer(child.output ++ appendedProjections,
wt.child)
+ if (projectTransformer.doValidate().ok()) {
+ wt.withNewChildren(
+ Seq(ProjectExecTransformer(child.output ++
appendedProjections, wt.child)))
+
+ } else {
+ offload = false
+ child
+ }
+ case w: WholeStageCodegenExec =>
+ w.withNewChildren(Seq(ProjectExec(child.output ++
appendedProjections, w.child)))
+ case r: AQEShuffleReadExec if r.supportsColumnar =>
+ // when aqe is open
+ // TODO: remove this after pushdowning preprojection
+ wrapChild(r)
+ case r2c: RowToVeloxColumnarExec =>
+ wrapChild(r2c)
+ case union: ColumnarUnionExec =>
+ wrapChild(union)
+ case ordered: TakeOrderedAndProjectExecTransformer =>
+ wrapChild(ordered)
+ case a2v: ArrowColumnarToVeloxColumnarExec =>
+ wrapChild(a2v)
+ case other =>
+ offload = false
+ logWarning(
+ "Not supported operator " + other.nodeName +
+ " for BroadcastRelation and fallback to shuffle hash join")
+ child
+ }
+
+ if (offload) {
+ (
+ newChild,
+ (child.output ++ appendedProjections).map(_.toAttribute),
+ preProjectionBuildKeys)
+ } else {
+ (child, child.output, Seq.empty[Expression])
+ }
+ }
+ } else {
+ offload = false
+ (child, child.output, buildKeys)
+ }
+
val useOffheapBroadcastBuildRelation =
VeloxConfig.get.enableBroadcastBuildRelationInOffheap
- val serialized: Seq[ColumnarBatchSerializeResult] = child
+
+ val serialized: Seq[ColumnarBatchSerializeResult] = newChild
.executeColumnar()
.mapPartitions(itr => Iterator(BroadcastUtils.serializeStream(itr)))
.filter(_.numRows != 0)
@@ -694,18 +825,23 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
}
numOutputRows += serialized.map(_.numRows).sum
dataSize += rawSize
+
if (useOffheapBroadcastBuildRelation) {
TaskResources.runUnsafe {
UnsafeColumnarBuildSideRelation(
- child.output,
+ newOutput,
serialized.flatMap(_.offHeapData().asScala),
- mode)
+ mode,
+ newBuildKeys,
+ offload)
}
} else {
ColumnarBuildSideRelation(
- child.output,
+ newOutput,
serialized.flatMap(_.onHeapData().asScala).toArray,
- mode)
+ mode,
+ newBuildKeys,
+ offload)
}
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
index a40e9ca6e4..3a1d53154f 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
@@ -30,6 +30,7 @@ import org.apache.gluten.vectorized.PlanEvaluatorJniWrapper
import org.apache.spark.Partition
import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation,
PartitionDirectory}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -120,6 +121,10 @@ class VeloxTransformerApi extends TransformerApi with
Logging {
override def packPBMessage(message: Message): Any = Any.pack(message, "")
+ override def invalidateSQLExecutionResource(executionId: String): Unit = {
+ GlutenDriverEndpoint.invalidateResourceRelation(executionId)
+ }
+
override def genWriteParameters(write: WriteFilesExecTransformer): Any = {
write.fileFormat match {
case _ @(_: ParquetFileFormat | _: HiveFileFormat) =>
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala
b/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala
index ee0866391c..071d75d6cf 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala
@@ -61,6 +61,12 @@ class VeloxConfig(conf: SQLConf) extends GlutenConfig(conf) {
def enableBroadcastBuildRelationInOffheap: Boolean =
getConf(VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP)
+ def enableBroadcastBuildOncePerExecutor: Boolean =
+ getConf(VELOX_BROADCAST_BUILD_HASHTABLE_ONCE_PER_EXECUTOR)
+
+ def veloxBroadcastHashTableBuildThreads: Int =
+ getConf(COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_THREADS)
+
def veloxOrcScanEnabled: Boolean =
getConf(VELOX_ORC_SCAN_ENABLED)
@@ -195,6 +201,13 @@ object VeloxConfig extends ConfigRegistry {
.intConf
.createOptional
+ val COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_THREADS =
+
buildStaticConf("spark.gluten.sql.columnar.backend.velox.broadcastHashTableBuildThreads")
+ .doc("The number of threads used to build the broadcast hash table. " +
+ "If not set or set to 0, it will use the default number of threads
(available processors).")
+ .intConf
+ .createWithDefault(1)
+
val COLUMNAR_VELOX_ASYNC_TIMEOUT =
buildStaticConf("spark.gluten.sql.columnar.backend.velox.asyncTimeoutOnTaskStopping")
.doc(
@@ -586,6 +599,16 @@ object VeloxConfig extends ConfigRegistry {
.intConf
.createWithDefault(0)
+ val VELOX_BROADCAST_BUILD_HASHTABLE_ONCE_PER_EXECUTOR =
+ buildConf("spark.gluten.velox.buildHashTableOncePerExecutor.enabled")
+ .internal()
+ .doc(
+ "When enabled, the hash table is " +
+ "constructed once per executor. If not enabled, " +
+ "the hash table is rebuilt for each task.")
+ .booleanConf
+ .createWithDefault(true)
+
val QUERY_TRACE_ENABLED =
buildConf("spark.gluten.sql.columnar.backend.velox.queryTraceEnabled")
.doc("Enable query tracing flag.")
.booleanConf
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashJoinExecTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashJoinExecTransformer.scala
index e3c93848dc..d79a3cae04 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashJoinExecTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashJoinExecTransformer.scala
@@ -16,11 +16,14 @@
*/
package org.apache.gluten.execution
+import org.apache.gluten.config.VeloxConfig
+
import org.apache.spark.rdd.RDD
+import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.optimizer.BuildSide
+import org.apache.spark.sql.catalyst.optimizer.{BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -99,6 +102,9 @@ case class BroadcastHashJoinExecTransformer(
right,
isNullAwareAntiJoin) {
+ // Unique ID for built table
+ lazy val buildBroadcastTableId: String = buildPlan.id.toString
+
override protected lazy val substraitJoinType: JoinRel.JoinType = joinType
match {
case _: InnerLike =>
JoinRel.JoinType.JOIN_TYPE_INNER
@@ -125,9 +131,52 @@ case class BroadcastHashJoinExecTransformer(
override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
val streamedRDD = getColumnarInputRDDs(streamedPlan)
+ val executionId =
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ if (executionId != null) {
+ logWarning(
+ s"Trace broadcast table data $buildBroadcastTableId" + " " +
+ "and the execution id is " + executionId)
+ GlutenDriverEndpoint.collectResources(executionId, buildBroadcastTableId)
+ } else {
+ logWarning(
+ s"Can not trace broadcast table data $buildBroadcastTableId" +
+ s" because execution id is null." +
+ s" Will clean up until expire time.")
+ }
+
val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
- val broadcastRDD = VeloxBroadcastBuildSideRDD(sparkContext, broadcast)
+ val bloomFilterPushdownSize = if
(VeloxConfig.get.hashProbeDynamicFilterPushdownEnabled) {
+ VeloxConfig.get.hashProbeBloomFilterPushdownMaxSize
+ } else {
+ -1
+ }
+ val context =
+ BroadcastHashJoinContext(
+ buildKeyExprs,
+ substraitJoinType,
+ buildSide == BuildRight,
+ condition.isDefined,
+ joinType.isInstanceOf[ExistenceJoin],
+ buildPlan.output,
+ buildBroadcastTableId,
+ isNullAwareAntiJoin,
+ bloomFilterPushdownSize,
+ VeloxConfig.get.veloxBroadcastHashTableBuildThreads
+ )
+ val broadcastRDD = VeloxBroadcastBuildSideRDD(sparkContext, broadcast,
context)
// FIXME: Do we have to make build side a RDD?
streamedRDD :+ broadcastRDD
}
}
+
+case class BroadcastHashJoinContext(
+ buildSideJoinKeys: Seq[Expression],
+ substraitJoinType: JoinRel.JoinType,
+ buildRight: Boolean,
+ hasMixedFiltCondition: Boolean,
+ isExistenceJoin: Boolean,
+ buildSideStructure: Seq[Attribute],
+ buildHashTableId: String,
+ isNullAwareAntiJoin: Boolean = false,
+ bloomFilterPushdownSize: Long,
+ broadcastHashTableBuildThreads: Int)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideCache.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideCache.scala
new file mode 100644
index 0000000000..2705f3b34c
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideCache.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
+import org.apache.gluten.vectorized.HashJoinBuilder
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.ColumnarBuildSideRelation
+import org.apache.spark.sql.execution.joins.BuildSideRelation
+import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
+
+import com.github.benmanes.caffeine.cache.{Cache, Caffeine, RemovalCause,
RemovalListener}
+
+import java.util.concurrent.TimeUnit
+
+case class BroadcastHashTable(pointer: Long, relation: BuildSideRelation)
+
+/**
+ * `VeloxBroadcastBuildSideCache` is used for controlling to build bhj hash
table once.
+ *
+ * The complicated part is due to reuse exchange, where multiple BHJ IDs
correspond to a
+ * `BuildSideRelation`.
+ */
+object VeloxBroadcastBuildSideCache
+ extends Logging
+ with RemovalListener[String, BroadcastHashTable] {
+
+ private lazy val expiredTime = SparkEnv.get.conf.getLong(
+ VeloxBackendSettings.GLUTEN_VELOX_BROADCAST_CACHE_EXPIRED_TIME,
+ VeloxBackendSettings.GLUTEN_VELOX_BROADCAST_CACHE_EXPIRED_TIME_DEFAULT
+ )
+
+ // Use for controlling to build bhj hash table once.
+ // key: hashtable id, value is hashtable backend pointer(long to string).
+ private val buildSideRelationCache: Cache[String, BroadcastHashTable] =
+ Caffeine.newBuilder
+ .expireAfterAccess(expiredTime, TimeUnit.SECONDS)
+ .removalListener(this)
+ .build[String, BroadcastHashTable]()
+
+ def getOrBuildBroadcastHashTable(
+ broadcast: Broadcast[BuildSideRelation],
+ broadcastContext: BroadcastHashJoinContext): BroadcastHashTable =
synchronized {
+
+ buildSideRelationCache
+ .get(
+ broadcastContext.buildHashTableId,
+ (broadcast_id: String) => {
+ val (pointer, relation) = broadcast.value match {
+ case columnar: ColumnarBuildSideRelation =>
+ columnar.buildHashTable(broadcastContext)
+ case unsafe: UnsafeColumnarBuildSideRelation =>
+ unsafe.buildHashTable(broadcastContext)
+ }
+
+ logWarning(s"Create bhj $broadcast_id = $pointer")
+ BroadcastHashTable(pointer, relation)
+ }
+ )
+ }
+
+ /** This is callback from c++ backend. */
+ def get(broadcastHashtableId: String): Long =
+ synchronized {
+ Option(buildSideRelationCache.getIfPresent(broadcastHashtableId))
+ .map(_.pointer)
+ .getOrElse(0)
+ }
+
+ def invalidateBroadcastHashtable(broadcastHashtableId: String): Unit =
synchronized {
+ // Cleanup operations on the backend are idempotent.
+ buildSideRelationCache.invalidate(broadcastHashtableId)
+ }
+
+ /** Only used in UT. */
+ def size(): Long = buildSideRelationCache.estimatedSize()
+
+ def cleanAll(): Unit = buildSideRelationCache.invalidateAll()
+
+ override def onRemoval(key: String, value: BroadcastHashTable, cause:
RemovalCause): Unit = {
+ synchronized {
+ logWarning(s"Remove bhj $key = ${value.pointer}")
+ if (value.relation != null) {
+ value.relation match {
+ case columnar: ColumnarBuildSideRelation =>
+ columnar.reset()
+ case unsafe: UnsafeColumnarBuildSideRelation =>
+ unsafe.reset()
+ }
+ }
+
+ HashJoinBuilder.clearHashTable(value.pointer)
+ }
+ }
+}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala
index 0163178e59..2d4b157056 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala
@@ -19,19 +19,36 @@ package org.apache.gluten.execution
import org.apache.gluten.iterator.Iterators
import org.apache.spark.{broadcast, SparkContext}
+import org.apache.spark.sql.execution.ColumnarBuildSideRelation
import org.apache.spark.sql.execution.joins.BuildSideRelation
+import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch
case class VeloxBroadcastBuildSideRDD(
@transient private val sc: SparkContext,
- broadcasted: broadcast.Broadcast[BuildSideRelation])
+ broadcasted: broadcast.Broadcast[BuildSideRelation],
+ broadcastContext: BroadcastHashJoinContext,
+ isBNL: Boolean = false)
extends BroadcastBuildSideRDD(sc, broadcasted) {
override def genBroadcastBuildSideIterator(): Iterator[ColumnarBatch] = {
- val relation = broadcasted.value.asReadOnlyCopy()
- Iterators
- .wrap(relation.deserialized)
- .recyclePayload(batch => batch.close())
- .create()
+ val offload = broadcasted.value.asReadOnlyCopy() match {
+ case columnar: ColumnarBuildSideRelation =>
+ columnar.offload
+ case unsafe: UnsafeColumnarBuildSideRelation =>
+ unsafe.isOffload
+ }
+ val output = if (isBNL || !offload) {
+ val relation = broadcasted.value.asReadOnlyCopy()
+ Iterators
+ .wrap(relation.deserialized)
+ .recyclePayload(batch => batch.close())
+ .create()
+ } else {
+ VeloxBroadcastBuildSideCache.getOrBuildBroadcastHashTable(broadcasted,
broadcastContext)
+ Iterator.empty
+ }
+
+ output
}
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala
index 2a920c3ab9..6e0aaa27c6 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala
@@ -45,7 +45,7 @@ case class VeloxBroadcastNestedLoopJoinExecTransformer(
override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
val streamedRDD = getColumnarInputRDDs(streamedPlan)
val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
- val broadcastRDD = VeloxBroadcastBuildSideRDD(sparkContext, broadcast)
+ val broadcastRDD = VeloxBroadcastBuildSideRDD(sparkContext, broadcast,
null, true)
// FIXME: Do we have to make build side a RDD?
streamedRDD :+ broadcastRDD
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/listener/VeloxGlutenSQLAppStatusListener.scala
b/backends-velox/src/main/scala/org/apache/spark/listener/VeloxGlutenSQLAppStatusListener.scala
new file mode 100644
index 0000000000..7e4ecc9a84
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/spark/listener/VeloxGlutenSQLAppStatusListener.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.listener
+
+import org.apache.spark.SparkContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{GlutenDriverEndpoint, RpcEndpointRef}
+import org.apache.spark.rpc.GlutenRpcMessages._
+import org.apache.spark.scheduler._
+import org.apache.spark.sql.execution.ui._
+
+/** Gluten SQL listener. Used for monitor sql on whole life cycle.Create and
release resource. */
+class VeloxGlutenSQLAppStatusListener(val driverEndpointRef: RpcEndpointRef)
+ extends SparkListener
+ with Logging {
+
+ /**
+ * If executor was removed, driver endpoint need to remove executor endpoint
ref.\n When execution
+ * was end, Can't call executor ref again.
+ * @param executorRemoved
+ * execution eemoved event
+ */
+ override def onExecutorRemoved(executorRemoved:
SparkListenerExecutorRemoved): Unit = {
+ driverEndpointRef.send(GlutenExecutorRemoved(executorRemoved.executorId))
+ logTrace(s"Execution ${executorRemoved.executorId} Removed.")
+ }
+
+ override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
+ case e: SparkListenerSQLExecutionStart => onExecutionStart(e)
+ case e: SparkListenerSQLExecutionEnd => onExecutionEnd(e)
+ case _ => // Ignore
+ }
+
+ /**
+ * If execution is start, notice gluten executor with some prepare.
execution.
+ *
+ * @param event
+ * execution start event
+ */
+ private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = {
+ val executionId = event.executionId.toString
+ driverEndpointRef.send(GlutenOnExecutionStart(executionId))
+ logTrace(s"Execution $executionId start.")
+ }
+
+ /**
+ * If execution was end, some backend like CH need to clean resource which
is relation to this
+ * execution.
+ * @param event
+ * execution end event
+ */
+ private def onExecutionEnd(event: SparkListenerSQLExecutionEnd): Unit = {
+ // val stackTraceElements = Thread.currentThread().getStackTrace()
+
+ // for (element <- stackTraceElements) {
+ // logWarning(element.toString);
+ // }
+ val executionId = event.executionId.toString
+ driverEndpointRef.send(GlutenOnExecutionEnd(executionId))
+ logTrace(s"Execution $executionId end.")
+ }
+}
+object VeloxGlutenSQLAppStatusListener {
+ def registerListener(sc: SparkContext): Unit = {
+ sc.listenerBus.addToStatusQueue(
+ new
VeloxGlutenSQLAppStatusListener(GlutenDriverEndpoint.glutenDriverEndpointRef))
+ }
+}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenDriverEndpoint.scala
b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenDriverEndpoint.scala
new file mode 100644
index 0000000000..af635addf3
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenDriverEndpoint.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc
+
+import org.apache.gluten.config.GlutenConfig
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.GlutenRpcMessages._
+
+import com.github.benmanes.caffeine.cache.{Cache, Caffeine, RemovalCause,
RemovalListener}
+
+import java.util
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
+
+/**
+ * The gluten driver endpoint is responsible for communicating with the
executor. Executor will
+ * register with the driver when it starts.
+ */
+class GlutenDriverEndpoint extends IsolatedRpcEndpoint with Logging {
+ override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv
+
+ protected val totalRegisteredExecutors = new AtomicInteger(0)
+
+ private val driverEndpoint: RpcEndpointRef =
+ rpcEnv.setupEndpoint(GlutenRpcConstants.GLUTEN_DRIVER_ENDPOINT_NAME, this)
+
+ // TODO(yuan): get thread cnt from spark context
+ override def threadCount(): Int = 1
+ override def receive: PartialFunction[Any, Unit] = {
+ case GlutenOnExecutionStart(executionId) =>
+ if (executionId == null) {
+ logWarning(s"Execution Id is null. Resources maybe not clean after
execution end.")
+ }
+
+ case GlutenOnExecutionEnd(executionId) =>
+ logWarning(s"Execution Id is $executionId end.")
+
+ GlutenDriverEndpoint.executionResourceRelation.invalidate(executionId)
+
+ case GlutenExecutorRemoved(executorId) =>
+ GlutenDriverEndpoint.executorDataMap.remove(executorId)
+ totalRegisteredExecutors.addAndGet(-1)
+ logTrace(s"Executor endpoint ref $executorId is removed.")
+
+ case e =>
+ logError(s"Received unexpected message. $e")
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any,
Unit] = {
+
+ case GlutenRegisterExecutor(executorId, executorRef) =>
+ if (GlutenDriverEndpoint.executorDataMap.contains(executorId)) {
+ context.sendFailure(new IllegalStateException(s"Duplicate executor ID:
$executorId"))
+ } else {
+ // If the executor's rpc env is not listening for incoming
connections, `hostPort`
+ // will be null, and the client connection should be used to contact
the executor.
+ val executorAddress = if (executorRef.address != null) {
+ executorRef.address
+ } else {
+ context.senderAddress
+ }
+ logInfo(s"Registered executor $executorRef ($executorAddress) with ID
$executorId")
+
+ totalRegisteredExecutors.addAndGet(1)
+ val data = new ExecutorData(executorRef)
+ // This must be synchronized because variables mutated
+ // in this block are read when requesting executors
+ GlutenDriverEndpoint.this.synchronized {
+ GlutenDriverEndpoint.executorDataMap.put(executorId, data)
+ }
+ logTrace(s"Executor size
${GlutenDriverEndpoint.executorDataMap.size()}")
+ // Note: some tests expect the reply to come after we put the executor
in the map
+ context.reply(true)
+ }
+
+ }
+
+ override def onStart(): Unit = {
+ logInfo(s"Initialized GlutenDriverEndpoint, address:
${driverEndpoint.address.toString()}.")
+ }
+}
+
+object GlutenDriverEndpoint extends Logging with RemovalListener[String,
util.Set[String]] {
+ private lazy val executionResourceExpiredTime = SparkEnv.get.conf.getLong(
+ GlutenConfig.GLUTEN_RESOURCE_RELATION_EXPIRED_TIME.key,
+ GlutenConfig.GLUTEN_RESOURCE_RELATION_EXPIRED_TIME.defaultValue.get
+ )
+
+ var glutenDriverEndpointRef: RpcEndpointRef = _
+
+ // keep executorRef on memory
+ val executorDataMap = new ConcurrentHashMap[String, ExecutorData]
+
+ // If spark.scheduler.listenerbus.eventqueue.capacity is set too small,
+ // the listener may lose messages.
+ // We set a maximum expiration time of 1 day by default
+ // key: executionId, value: resourceIds
+ private val executionResourceRelation: Cache[String, util.Set[String]] =
+ Caffeine.newBuilder
+ .expireAfterAccess(executionResourceExpiredTime, TimeUnit.SECONDS)
+ .removalListener(this)
+ .build[String, util.Set[String]]()
+
+ def collectResources(executionId: String, resourceId: String): Unit = {
+ val resources = executionResourceRelation
+ .get(executionId, (_: String) => new util.HashSet[String]())
+ resources.add(resourceId)
+ }
+
+ def invalidateResourceRelation(executionId: String): Unit = {
+ executionResourceRelation.invalidate(executionId)
+ }
+
+ override def onRemoval(key: String, value: util.Set[String], cause:
RemovalCause): Unit = {
+ executorDataMap.forEach(
+ (_, executor) =>
executor.executorEndpointRef.send(GlutenCleanExecutionResource(key, value)))
+ }
+}
+
+class ExecutorData(val executorEndpointRef: RpcEndpointRef) {}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala
b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala
new file mode 100644
index 0000000000..49ecef20b3
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc
+
+import org.apache.gluten.execution.VeloxBroadcastBuildSideCache
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.rpc.GlutenRpcMessages._
+import org.apache.spark.util.ThreadUtils
+
+import scala.util.{Failure, Success}
+
+/** Gluten executor endpoint. */
+class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf)
+ extends IsolatedRpcEndpoint
+ with Logging {
+ override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv
+
+ private val driverHost = conf.get(config.DRIVER_HOST_ADDRESS.key,
"localhost")
+ private val driverPort = conf.getInt(config.DRIVER_PORT.key, 7077)
+ private val rpcAddress = RpcAddress(driverHost, driverPort)
+ private val driverUrl =
+ RpcEndpointAddress(rpcAddress,
GlutenRpcConstants.GLUTEN_DRIVER_ENDPOINT_NAME).toString
+
+ @volatile var driverEndpointRef: RpcEndpointRef = null
+
+ rpcEnv.setupEndpoint(GlutenRpcConstants.GLUTEN_EXECUTOR_ENDPOINT_NAME, this)
+ // TODO(yuan): get thread cnt from spark context
+ override def threadCount(): Int = 1
+ override def onStart(): Unit = {
+ rpcEnv
+ .asyncSetupEndpointRefByURI(driverUrl)
+ .flatMap {
+ ref =>
+ // This is a very fast action so we can use "ThreadUtils.sameThread"
+ driverEndpointRef = ref
+ ref.ask[Boolean](GlutenRegisterExecutor(executorId, self))
+ }(ThreadUtils.sameThread)
+ .onComplete {
+ case Success(_) => logTrace("Register GlutenExecutor listener
success.")
+ case Failure(e) => logError("Register GlutenExecutor listener error.",
e)
+ }(ThreadUtils.sameThread)
+ logInfo("Initialized GlutenExecutorEndpoint.")
+ }
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case GlutenCleanExecutionResource(executionId, hashIds) =>
+ if (executionId != null) {
+ hashIds.forEach(
+ resource_id =>
VeloxBroadcastBuildSideCache.invalidateBroadcastHashtable(resource_id))
+ }
+
+ case e =>
+ logError(s"Received unexpected message. $e")
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any,
Unit] = {
+ case e =>
+ logInfo(s"Received message. $e")
+ }
+}
+object GlutenExecutorEndpoint {
+ var executorEndpoint: GlutenExecutorEndpoint = _
+}
diff --git
a/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcConstants.scala
similarity index 61%
copy from
backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
copy to
backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcConstants.scala
index 2759613793..4fbb0722a2 100644
---
a/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
+++
b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcConstants.scala
@@ -14,24 +14,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.test;
+package org.apache.spark.rpc
-import org.apache.gluten.backendsapi.ListenerApi;
-import org.apache.gluten.backendsapi.velox.VeloxListenerApi;
+object GlutenRpcConstants {
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
+ val GLUTEN_DRIVER_ENDPOINT_NAME = "GlutenDriverEndpoint"
-public abstract class VeloxBackendTestBase {
- private static final ListenerApi API = new VeloxListenerApi();
-
- @BeforeClass
- public static void setup() {
- API.onExecutorStart(MockVeloxBackend.mockPluginContext());
- }
-
- @AfterClass
- public static void tearDown() {
- API.onExecutorShutdown();
- }
+ val GLUTEN_EXECUTOR_ENDPOINT_NAME = "GlutenExecutorEndpoint"
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
new file mode 100644
index 0000000000..8127c324b7
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc
+
+import java.util
+
+trait GlutenRpcMessage extends Serializable
+
+object GlutenRpcMessages {
+ case class GlutenRegisterExecutor(
+ executorId: String,
+ executorRef: RpcEndpointRef
+ ) extends GlutenRpcMessage
+
+ case class GlutenOnExecutionStart(executionId: String) extends
GlutenRpcMessage
+
+ case class GlutenOnExecutionEnd(executionId: String) extends GlutenRpcMessage
+
+ case class GlutenExecutorRemoved(executorId: String) extends GlutenRpcMessage
+
+ case class GlutenCleanExecutionResource(executionId: String,
broadcastHashIds: util.Set[String])
+ extends GlutenRpcMessage
+
+ // for mergetree cache
+ case class GlutenMergeTreeCacheLoad(
+ mergeTreeTable: String,
+ columns: util.Set[String],
+ onlyMetaCache: Boolean)
+ extends GlutenRpcMessage
+
+ case class GlutenCacheLoadStatus(jobId: String)
+
+ case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "")
+ extends GlutenRpcMessage
+
+ case class GlutenFilesCacheLoad(files: Array[Byte]) extends GlutenRpcMessage
+
+ case class GlutenFilesCacheLoadStatus(jobId: String)
+}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
index ad066d47f9..cf3f9ccca4 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
@@ -108,7 +108,9 @@ object BroadcastUtils {
UnsafeColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
result.offHeapData().asScala.toSeq,
- mode)
+ mode,
+ Seq.empty,
+ result.isOffHeap)
} else {
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
index d542fd92b9..6429f8bb3f 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
@@ -18,13 +18,16 @@ package org.apache.spark.sql.execution
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.columnarbatch.ColumnarBatches
+import org.apache.gluten.execution.BroadcastHashJoinContext
+import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.iterator.Iterators
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.gluten.utils.ArrowAbiUtil
-import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper,
NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper}
+import org.apache.gluten.utils.{ArrowAbiUtil, SubstraitUtil}
+import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper,
HashJoinBuilder, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper}
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq,
BindReferences, BoundReference, Expression, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
@@ -37,14 +40,18 @@ import org.apache.spark.util.KnownSizeEstimation
import org.apache.arrow.c.ArrowSchema
+import scala.collection.JavaConverters._
import scala.collection.JavaConverters.asScalaIteratorConverter
+import scala.collection.mutable.ArrayBuffer
object ColumnarBuildSideRelation {
// Keep constructor with BroadcastMode for compatibility
def apply(
output: Seq[Attribute],
batches: Array[Array[Byte]],
- mode: BroadcastMode): ColumnarBuildSideRelation = {
+ mode: BroadcastMode,
+ newBuildKeys: Seq[Expression] = Seq.empty,
+ offload: Boolean = false): ColumnarBuildSideRelation = {
val boundMode = mode match {
case HashedRelationBroadcastMode(keys, isNullAware) =>
// Bind each key to the build-side output so simple cols become
BoundReference
@@ -54,15 +61,23 @@ object ColumnarBuildSideRelation {
case m =>
m // IdentityBroadcastMode, etc.
}
- new ColumnarBuildSideRelation(output, batches,
BroadcastModeUtils.toSafe(boundMode))
+ new ColumnarBuildSideRelation(
+ output,
+ batches,
+ BroadcastModeUtils.toSafe(boundMode),
+ newBuildKeys,
+ offload)
}
}
case class ColumnarBuildSideRelation(
output: Seq[Attribute],
batches: Array[Array[Byte]],
- safeBroadcastMode: SafeBroadcastMode)
+ safeBroadcastMode: SafeBroadcastMode,
+ newBuildKeys: Seq[Expression],
+ offload: Boolean)
extends BuildSideRelation
+ with Logging
with KnownSizeEstimation {
// Rebuild the real BroadcastMode on demand; never serialize it.
@@ -135,6 +150,87 @@ case class ColumnarBuildSideRelation(
override def asReadOnlyCopy(): ColumnarBuildSideRelation = this
+ private var hashTableData: Long = 0L
+
+ def buildHashTable(
+ broadcastContext: BroadcastHashJoinContext): (Long,
ColumnarBuildSideRelation) =
+ synchronized {
+ if (hashTableData == 0) {
+ val runtime = Runtimes.contextInstance(
+ BackendsApiManager.getBackendName,
+ "ColumnarBuildSideRelation#buildHashTable")
+ val jniWrapper = ColumnarBatchSerializerJniWrapper.create(runtime)
+ val serializeHandle: Long = {
+ val allocator = ArrowBufferAllocators.contextInstance()
+ val cSchema = ArrowSchema.allocateNew(allocator)
+ val arrowSchema = SparkArrowUtil.toArrowSchema(
+ SparkShimLoader.getSparkShims.structFromAttributes(output),
+ SQLConf.get.sessionLocalTimeZone)
+ ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema)
+ val handle = jniWrapper
+ .init(cSchema.memoryAddress())
+ cSchema.close()
+ handle
+ }
+
+ val batchArray = new ArrayBuffer[Long]
+
+ var batchId = 0
+ while (batchId < batches.size) {
+ batchArray.append(jniWrapper.deserialize(serializeHandle,
batches(batchId)))
+ batchId += 1
+ }
+
+ logDebug(
+ s"BHJ value size: " +
+ s"${broadcastContext.buildHashTableId} = ${batches.length}")
+
+ val (keys, newOutput) = if (newBuildKeys.isEmpty) {
+ (
+ broadcastContext.buildSideJoinKeys.asJava,
+ broadcastContext.buildSideStructure.asJava
+ )
+ } else {
+ (
+ newBuildKeys.asJava,
+ output.asJava
+ )
+ }
+
+ val joinKey = keys.asScala
+ .map {
+ key =>
+ val attr = ConverterUtils.getAttrFromExpr(key)
+ ConverterUtils.genColumnNameWithExprId(attr)
+ }
+ .mkString(",")
+
+ // Build the hash table
+ hashTableData = HashJoinBuilder
+ .nativeBuild(
+ broadcastContext.buildHashTableId,
+ batchArray.toArray,
+ joinKey,
+ broadcastContext.substraitJoinType.ordinal(),
+ broadcastContext.hasMixedFiltCondition,
+ broadcastContext.isExistenceJoin,
+ SubstraitUtil.toNameStruct(newOutput).toByteArray,
+ broadcastContext.isNullAwareAntiJoin,
+ broadcastContext.bloomFilterPushdownSize,
+ broadcastContext.broadcastHashTableBuildThreads
+ )
+
+ jniWrapper.close(serializeHandle)
+ (hashTableData, this)
+ } else {
+ (HashJoinBuilder.cloneHashTable(hashTableData), null)
+ }
+ }
+
+ def reset(): Unit = synchronized {
+ hashTableData = 0
+ }
+
/**
* Transform columnar broadcast value to Array[InternalRow] by key.
*
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/joins/SparkHashJoinUtils.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/joins/SparkHashJoinUtils.scala
new file mode 100644
index 0000000000..1e6b677253
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/joins/SparkHashJoinUtils.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, BitwiseAnd,
BitwiseOr, Cast, Expression, ShiftLeft}
+import org.apache.spark.sql.types.IntegralType
+
+object SparkHashJoinUtils {
+
+ // Copy from
org.apache.spark.sql.execution.joins.HashJoin#canRewriteAsLongType
+ // we should keep consistent with it to identify the LongHashRelation.
+ def canRewriteAsLongType(keys: Seq[Expression]): Boolean = {
+ // TODO: support BooleanType, DateType and TimestampType
+ keys.forall(_.dataType.isInstanceOf[IntegralType]) &&
+ keys.map(_.dataType.defaultSize).sum <= 8
+ }
+
+ def getOriginalKeysFromPacked(expr: Expression): Seq[Expression] = {
+
+ def unwrap(e: Expression): Expression = e match {
+ case Cast(child, _, _, _) => unwrap(child)
+ case Alias(child, _) => unwrap(child)
+ case BitwiseAnd(child, _) => unwrap(child)
+ case other => other
+ }
+
+ expr match {
+ case BitwiseOr(ShiftLeft(left, _), rightPart) =>
+ getOriginalKeysFromPacked(left) :+ unwrap(rightPart)
+ case BitwiseOr(left, rightPart) =>
+ getOriginalKeysFromPacked(left) :+ unwrap(rightPart)
+ case other =>
+ Seq(unwrap(other))
+ }
+ }
+
+}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
index ba307415c5..fc7516c4b3 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
@@ -18,12 +18,14 @@ package org.apache.spark.sql.execution.unsafe
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.columnarbatch.ColumnarBatches
+import org.apache.gluten.execution.BroadcastHashJoinContext
+import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.iterator.Iterators
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.gluten.utils.ArrowAbiUtil
-import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper,
NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper}
+import org.apache.gluten.utils.{ArrowAbiUtil, SubstraitUtil}
+import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper,
HashJoinBuilder, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper}
import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
@@ -44,13 +46,17 @@ import org.apache.arrow.c.ArrowSchema
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import scala.collection.JavaConverters._
import scala.collection.JavaConverters.asScalaIteratorConverter
+import scala.collection.mutable.ArrayBuffer
object UnsafeColumnarBuildSideRelation {
def apply(
output: Seq[Attribute],
batches: Seq[UnsafeByteArray],
- mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
+ mode: BroadcastMode,
+ newBuildKeys: Seq[Expression] = Seq.empty,
+ offload: Boolean = false): UnsafeColumnarBuildSideRelation = {
val boundMode = mode match {
case HashedRelationBroadcastMode(keys, isNullAware) =>
// Bind each key to the build-side output so simple cols become
BoundReference
@@ -60,7 +66,12 @@ object UnsafeColumnarBuildSideRelation {
case m =>
m // IdentityBroadcastMode, etc.
}
- new UnsafeColumnarBuildSideRelation(output, batches,
BroadcastModeUtils.toSafe(boundMode))
+ new UnsafeColumnarBuildSideRelation(
+ output,
+ batches,
+ BroadcastModeUtils.toSafe(boundMode),
+ newBuildKeys,
+ offload)
}
}
@@ -78,7 +89,9 @@ object UnsafeColumnarBuildSideRelation {
class UnsafeColumnarBuildSideRelation(
private var output: Seq[Attribute],
private var batches: Seq[UnsafeByteArray],
- private var safeBroadcastMode: SafeBroadcastMode)
+ private var safeBroadcastMode: SafeBroadcastMode,
+ private var newBuildKeys: Seq[Expression],
+ private var offload: Boolean)
extends BuildSideRelation
with Externalizable
with Logging
@@ -96,37 +109,128 @@ class UnsafeColumnarBuildSideRelation(
case _ => None
}
+ def isOffload: Boolean = offload
+
/** needed for serialization. */
def this() = {
- this(null, null, null)
+ this(null, null, null, Seq.empty, false)
}
private[unsafe] def getBatches(): Seq[UnsafeByteArray] = {
batches
}
+ private var hashTableData: Long = 0L
+
+ def buildHashTable(broadcastContext: BroadcastHashJoinContext): (Long,
BuildSideRelation) =
+ synchronized {
+ if (hashTableData == 0) {
+ val runtime = Runtimes.contextInstance(
+ BackendsApiManager.getBackendName,
+ "UnsafeColumnarBuildSideRelation#buildHashTable")
+ val jniWrapper = ColumnarBatchSerializerJniWrapper.create(runtime)
+ val serializeHandle: Long = {
+ val allocator = ArrowBufferAllocators.contextInstance()
+ val cSchema = ArrowSchema.allocateNew(allocator)
+ val arrowSchema = SparkArrowUtil.toArrowSchema(
+ SparkShimLoader.getSparkShims.structFromAttributes(output),
+ SQLConf.get.sessionLocalTimeZone)
+ ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema)
+ val handle = jniWrapper
+ .init(cSchema.memoryAddress())
+ cSchema.close()
+ handle
+ }
+
+ val batchArray = new ArrayBuffer[Long]
+
+ var batchId = 0
+ while (batchId < batches.size) {
+ val (offset, length) = (batches(batchId).address(),
batches(batchId).size())
+ batchArray.append(jniWrapper.deserializeDirect(serializeHandle,
offset, length.toInt))
+ batchId += 1
+ }
+
+ logDebug(
+ s"BHJ value size: " +
+ s"${broadcastContext.buildHashTableId} = ${batches.size}")
+
+ val (keys, newOutput) = if (newBuildKeys.isEmpty) {
+ (
+ broadcastContext.buildSideJoinKeys.asJava,
+ broadcastContext.buildSideStructure.asJava
+ )
+ } else {
+ (
+ newBuildKeys.asJava,
+ output.asJava
+ )
+ }
+
+ val joinKey = keys.asScala
+ .map {
+ key =>
+ val attr = ConverterUtils.getAttrFromExpr(key)
+ ConverterUtils.genColumnNameWithExprId(attr)
+ }
+ .mkString(",")
+
+ // Build the hash table
+ hashTableData = HashJoinBuilder
+ .nativeBuild(
+ broadcastContext.buildHashTableId,
+ batchArray.toArray,
+ joinKey,
+ broadcastContext.substraitJoinType.ordinal(),
+ broadcastContext.hasMixedFiltCondition,
+ broadcastContext.isExistenceJoin,
+ SubstraitUtil.toNameStruct(newOutput).toByteArray,
+ broadcastContext.isNullAwareAntiJoin,
+ broadcastContext.bloomFilterPushdownSize,
+ broadcastContext.broadcastHashTableBuildThreads
+ )
+
+ jniWrapper.close(serializeHandle)
+ (hashTableData, this)
+ } else {
+ (HashJoinBuilder.cloneHashTable(hashTableData), null)
+ }
+ }
+
+ def reset(): Unit = synchronized {
+ hashTableData = 0
+ }
+
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException
{
out.writeObject(output)
out.writeObject(safeBroadcastMode)
out.writeObject(batches.toArray)
+ out.writeObject(newBuildKeys)
+ out.writeBoolean(offload)
}
override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
kryo.writeObject(out, output.toList)
kryo.writeClassAndObject(out, safeBroadcastMode)
kryo.writeClassAndObject(out, batches.toArray)
+ kryo.writeClassAndObject(out, newBuildKeys)
+ out.writeBoolean(offload)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
output = in.readObject().asInstanceOf[Seq[Attribute]]
safeBroadcastMode = in.readObject().asInstanceOf[SafeBroadcastMode]
batches = in.readObject().asInstanceOf[Array[UnsafeByteArray]].toSeq
+ newBuildKeys = in.readObject().asInstanceOf[Seq[Expression]]
+ offload = in.readBoolean()
}
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]]
safeBroadcastMode =
kryo.readClassAndObject(in).asInstanceOf[SafeBroadcastMode]
batches =
kryo.readClassAndObject(in).asInstanceOf[Array[UnsafeByteArray]].toSeq
+ newBuildKeys = kryo.readClassAndObject(in).asInstanceOf[Seq[Expression]]
+ offload = in.readBoolean()
}
private def transformProjection: UnsafeProjection = safeBroadcastMode match {
diff --git
a/backends-velox/src/test/java/org/apache/gluten/test/MockVeloxBackend.java
b/backends-velox/src/test/java/org/apache/gluten/test/MockVeloxBackend.java
index 06fe3d28ca..2c4b813f30 100644
--- a/backends-velox/src/test/java/org/apache/gluten/test/MockVeloxBackend.java
+++ b/backends-velox/src/test/java/org/apache/gluten/test/MockVeloxBackend.java
@@ -43,7 +43,7 @@ public final class MockVeloxBackend {
@Override
public String executorID() {
- throw new UnsupportedOperationException();
+ return "MockVeloxBackend ID";
}
@Override
diff --git
a/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
b/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
index 2759613793..c66c67fe9e 100644
---
a/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
+++
b/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
@@ -19,19 +19,36 @@ package org.apache.gluten.test;
import org.apache.gluten.backendsapi.ListenerApi;
import org.apache.gluten.backendsapi.velox.VeloxListenerApi;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.test.TestSparkSession;
import org.junit.AfterClass;
import org.junit.BeforeClass;
public abstract class VeloxBackendTestBase {
private static final ListenerApi API = new VeloxListenerApi();
+ private static SparkSession sparkSession = null;
@BeforeClass
public static void setup() {
+ if (sparkSession == null) {
+ sparkSession =
+ TestSparkSession.builder()
+ .appName("VeloxBackendTest")
+ .master("local[1]")
+ .config(MockVeloxBackend.mockPluginContext().conf())
+ .getOrCreate();
+ }
+
API.onExecutorStart(MockVeloxBackend.mockPluginContext());
}
@AfterClass
public static void tearDown() {
API.onExecutorShutdown();
+
+ if (sparkSession != null) {
+ sparkSession.stop();
+ sparkSession = null;
+ }
}
}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/DynamicOffHeapSizingSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/DynamicOffHeapSizingSuite.scala
index 0afbc2fa19..ddd76f917d 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/DynamicOffHeapSizingSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/DynamicOffHeapSizingSuite.scala
@@ -35,6 +35,10 @@ class DynamicOffHeapSizingSuite extends
VeloxWholeStageTransformerSuite {
.set("spark.shuffle.manager",
"org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.executor.memory", "2GB")
.set("spark.memory.offHeap.enabled", "false")
+ .set(
+ "spark.gluten.velox.buildHashTableOncePerExecutor.enabled",
+ "false"
+ ) // build native hash table need use off heap memory.
.set(GlutenCoreConfig.DYNAMIC_OFFHEAP_SIZING_MEMORY_FRACTION.key, "0.95")
.set(GlutenCoreConfig.DYNAMIC_OFFHEAP_SIZING_ENABLED.key, "true")
}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala
index 4fca03fa85..86565aa42b 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala
@@ -22,8 +22,7 @@ import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.SparkConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec,
ColumnarSubqueryBroadcastExec, InputIteratorTransformer}
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec}
+import org.apache.spark.sql.execution.{ColumnarSubqueryBroadcastExec,
InputIteratorTransformer}
class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite {
override protected val resourcePath: String = "/tpch-data-parquet"
@@ -114,85 +113,12 @@ class VeloxHashJoinSuite extends
VeloxWholeStageTransformerSuite {
}
}
- test("Reuse broadcast exchange for different build keys with same table") {
- Seq("true", "false").foreach(
- enabledOffheapBroadcast =>
- withSQLConf(
- VeloxConfig.VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP.key ->
enabledOffheapBroadcast) {
- withTable("t1", "t2") {
- spark.sql("""
- |CREATE TABLE t1 USING PARQUET
- |AS SELECT id as c1, id as c2 FROM range(10)
- |""".stripMargin)
-
- spark.sql("""
- |CREATE TABLE t2 USING PARQUET
- |AS SELECT id as c1, id as c2 FROM range(3)
- |""".stripMargin)
-
- val df = spark.sql("""
- |SELECT * FROM t1
- |JOIN t2 as tmp1 ON t1.c1 = tmp1.c1 and
tmp1.c1 = tmp1.c2
- |JOIN t2 as tmp2 on t1.c2 = tmp2.c2 and
tmp2.c1 = tmp2.c2
- |""".stripMargin)
-
- assert(collect(df.queryExecution.executedPlan) {
- case b: BroadcastExchangeExec => b
- }.size == 2)
-
- checkAnswer(
- df,
- Row(2, 2, 2, 2, 2, 2) :: Row(1, 1, 1, 1, 1, 1) :: Row(0, 0, 0,
0, 0, 0) :: Nil)
-
- assert(collect(df.queryExecution.executedPlan) {
- case b: ColumnarBroadcastExchangeExec => b
- }.size == 1)
- assert(collect(df.queryExecution.executedPlan) {
- case r @ ReusedExchangeExec(_, _: ColumnarBroadcastExchangeExec)
=> r
- }.size == 1)
- }
- })
- }
-
- test("ColumnarBuildSideRelation with small columnar to row memory") {
- Seq("true", "false").foreach(
- enabledOffheapBroadcast =>
- withSQLConf(
- GlutenConfig.GLUTEN_COLUMNAR_TO_ROW_MEM_THRESHOLD.key -> "16",
- VeloxConfig.VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP.key ->
enabledOffheapBroadcast) {
- withTable("t1", "t2") {
- spark.sql("""
- |CREATE TABLE t1 USING PARQUET
- |AS SELECT id as c1, id as c2 FROM range(10)
- |""".stripMargin)
-
- spark.sql("""
- |CREATE TABLE t2 USING PARQUET PARTITIONED BY (c1)
- |AS SELECT id as c1, id as c2 FROM range(30)
- |""".stripMargin)
-
- val df = spark.sql("""
- |SELECT t1.c2
- |FROM t1, t2
- |WHERE t1.c1 = t2.c1
- |AND t1.c2 < 4
- |""".stripMargin)
-
- checkAnswer(df, Row(0) :: Row(1) :: Row(2) :: Row(3) :: Nil)
-
- val subqueryBroadcastExecs =
collectWithSubqueries(df.queryExecution.executedPlan) {
- case subqueryBroadcast: ColumnarSubqueryBroadcastExec =>
subqueryBroadcast
- }
- assert(subqueryBroadcastExecs.size == 1)
- }
- })
- }
-
test("ColumnarBuildSideRelation transform support multiple key columns") {
Seq("true", "false").foreach(
enabledOffheapBroadcast =>
withSQLConf(
- VeloxConfig.VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP.key ->
enabledOffheapBroadcast) {
+ VeloxConfig.VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP.key ->
+ enabledOffheapBroadcast) {
withTable("t1", "t2") {
val df1 =
(0 until 50)
@@ -317,4 +243,83 @@ class VeloxHashJoinSuite extends
VeloxWholeStageTransformerSuite {
}
}
}
+
+ test("Broadcast join preserves original cast expression in join keys") {
+ withSQLConf(
+ ("spark.sql.autoBroadcastJoinThreshold", "10MB"),
+ ("spark.sql.adaptive.enabled", "false")
+ ) {
+ withTable("t1_int", "t2_long") {
+ // Create table with INT column
+ spark
+ .range(100)
+ .selectExpr("cast(id as int) as key", "id as value")
+ .write
+ .saveAsTable("t1_int")
+
+ // Create table with LONG column
+ spark.range(50).selectExpr("id as key", "id * 2 as
value").write.saveAsTable("t2_long")
+
+ // Join INT with LONG - Spark will insert cast(int to long) in join
keys
+ val query = """
+ SELECT t1.key, t1.value, t2.value as value2
+ FROM t1_int t1
+ JOIN t2_long t2 ON t1.key = t2.key
+ ORDER BY t1.key
+ """
+
+ runQueryAndCompare(query) {
+ df =>
+ // Check that broadcast join is used in Gluten execution
+ val plan = df.queryExecution.executedPlan
+ val broadcastJoins = plan.collect { case bhj:
BroadcastHashJoinExecTransformer => bhj }
+ assert(broadcastJoins.nonEmpty, "Should use broadcast hash join")
+ }
+ }
+ }
+ }
+
+ test("Broadcast join with multiple cast expressions in join keys") {
+ withSQLConf(
+ ("spark.sql.autoBroadcastJoinThreshold", "10MB"),
+ ("spark.sql.adaptive.enabled", "false")
+ ) {
+ withTable("t1_mixed", "t2_mixed") {
+ // Create table with mixed types
+ spark
+ .range(100)
+ .selectExpr("cast(id as int) as key1", "cast(id as short) as key2",
"id as value")
+ .write
+ .saveAsTable("t1_mixed")
+
+ // Create table with different types requiring casts
+ spark
+ .range(50)
+ .selectExpr("id as key1", "cast(id as int) as key2", "id * 2 as
value")
+ .write
+ .saveAsTable("t2_mixed")
+
+ // Join with multiple keys requiring casts
+ // key1: cast(int to long), key2: cast(short to int)
+ val query = """
+ SELECT t1.key1, t1.key2, t1.value, t2.value as value2
+ FROM t1_mixed t1
+ JOIN t2_mixed t2 ON t1.key1 = t2.key1 AND t1.key2 = t2.key2
+ ORDER BY t1.key1, t1.key2
+ """
+
+ runQueryAndCompare(query) {
+ df =>
+ // Check that broadcast join is used in Gluten execution
+ val plan = df.queryExecution.executedPlan
+ val broadcastJoins = plan.collect { case bhj:
BroadcastHashJoinExecTransformer => bhj }
+ assert(broadcastJoins.nonEmpty, "Should use broadcast hash join")
+
+ // Verify multiple join keys are handled correctly
+ assert(broadcastJoins.head.leftKeys.length == 2)
+ assert(broadcastJoins.head.rightKeys.length == 2)
+ }
+ }
+ }
+ }
}
diff --git
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/benchmark/VeloxBroadcastBuildOnceBenchmark.scala
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/benchmark/VeloxBroadcastBuildOnceBenchmark.scala
new file mode 100644
index 0000000000..6e06cc35a7
--- /dev/null
+++
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/benchmark/VeloxBroadcastBuildOnceBenchmark.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.benchmark
+
+import org.apache.gluten.config.VeloxConfig
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.internal.SQLConf
+
+/** Benchmark to measure performance for BHJ build once per executor. */
+object VeloxBroadcastBuildOnceBenchmark extends SqlBasedBenchmark {
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ val numRows = 5 * 1000 * 1000
+ val broadcastRows = 1000 * 1000
+
+ withTempPath {
+ f =>
+ val path = f.getCanonicalPath
+ val probePath = s"$path/probe"
+ val buildPath = s"$path/build"
+
+ // Generate probe table with many partitions to simulate many tasks
+ spark
+ .range(numRows)
+ .repartition(100)
+ .selectExpr("id as k1", "id as v1")
+ .write
+ .parquet(probePath)
+
+ // Generate build table
+ spark
+ .range(broadcastRows)
+ .selectExpr("id as k2", "id as v2")
+ .write
+ .parquet(buildPath)
+
+ spark.read.parquet(probePath).createOrReplaceTempView("probe")
+ spark.read.parquet(buildPath).createOrReplaceTempView("build")
+
+ val query = "SELECT /*+ BROADCAST(build) */ count(*) FROM probe JOIN
build ON k1 = k2"
+
+ val benchmark = new Benchmark("BHJ Build Once Benchmark", numRows,
output = output)
+
+ // Warm up
+ spark.sql(query).collect()
+
+ benchmark.addCase("Build once per executor enabled=false", 3) {
+ _ =>
+ withSQLConf(
+
VeloxConfig.VELOX_BROADCAST_BUILD_HASHTABLE_ONCE_PER_EXECUTOR.key -> "false",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200MB"
+ ) {
+ spark.sql(query).collect()
+ }
+ }
+
+ benchmark.addCase("Build once per executor enabled=true", 3) {
+ _ =>
+ withSQLConf(
+
VeloxConfig.VELOX_BROADCAST_BUILD_HASHTABLE_ONCE_PER_EXECUTOR.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200MB"
+ ) {
+ spark.sql(query).collect()
+ }
+ }
+
+ benchmark.run()
+ }
+ }
+}
diff --git
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
index 41400f613f..c881d77ed1 100644
---
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
+++
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
@@ -188,4 +188,30 @@ class UnsafeColumnarBuildSideRelationTest extends
SharedSparkSession {
newUnsafeRelationWithHashMode(ByteUnit.MiB.toKiB(50).toInt)
}
}
+
+ test("Verify offload field serialization") {
+ val relation = UnsafeColumnarBuildSideRelation(
+ output,
+ Seq(sampleUnsafeByteArrayInKb(1)),
+ IdentityBroadcastMode,
+ Seq.empty,
+ offload = true
+ )
+
+ // Java Serialization
+ val javaSerializer = new JavaSerializer(SparkEnv.get.conf).newInstance()
+ val javaBuffer = javaSerializer.serialize(relation)
+ val javaObj =
javaSerializer.deserialize[UnsafeColumnarBuildSideRelation](javaBuffer)
+ assert(javaObj.isOffload, "Java deserialization failed to restore
offload=true")
+
+ // Kryo Serialization
+ val kryoSerializer = new KryoSerializer(SparkEnv.get.conf).newInstance()
+ val kryoBuffer = kryoSerializer.serialize(relation)
+ val kryoObj =
kryoSerializer.deserialize[UnsafeColumnarBuildSideRelation](kryoBuffer)
+ assert(kryoObj.isOffload, "Kryo deserialization failed to restore
offload=true")
+
+ // Create another relation with offload=false to compare byte size if
possible,
+ // but boolean only takes 1 byte, might be hard to distinguish from
metadata noise.
+ // Instead, trust the assertion above.
+ }
}
diff --git a/cpp/velox/CMakeLists.txt b/cpp/velox/CMakeLists.txt
index be31f18206..fc6391b7f3 100644
--- a/cpp/velox/CMakeLists.txt
+++ b/cpp/velox/CMakeLists.txt
@@ -157,6 +157,7 @@ set(VELOX_SRCS
jni/JniFileSystem.cc
jni/JniUdf.cc
jni/VeloxJniWrapper.cc
+ jni/JniHashTable.cc
memory/BufferOutputStream.cc
memory/VeloxColumnarBatch.cc
memory/VeloxMemoryManager.cc
@@ -164,6 +165,7 @@ set(VELOX_SRCS
operators/functions/RowConstructorWithNull.cc
operators/functions/SparkExprToSubfieldFilterParser.cc
operators/plannodes/RowVectorStream.cc
+ operators/hashjoin/HashTableBuilder.cc
operators/reader/FileReaderIterator.cc
operators/reader/ParquetReaderIterator.cc
operators/serializer/VeloxColumnarBatchSerializer.cc
diff --git a/cpp/velox/compute/VeloxBackend.cc
b/cpp/velox/compute/VeloxBackend.cc
index de9e9385f8..0232da48da 100644
--- a/cpp/velox/compute/VeloxBackend.cc
+++ b/cpp/velox/compute/VeloxBackend.cc
@@ -362,6 +362,7 @@ void VeloxBackend::tearDown() {
filesystem->close();
}
#endif
+ gluten::hashTableObjStore.reset();
// Destruct IOThreadPoolExecutor will join all threads.
// On threads exit, thread local variables can be constructed with
referencing global variables.
diff --git a/cpp/velox/compute/VeloxBackend.h b/cpp/velox/compute/VeloxBackend.h
index 94e7ec93fb..d73787063f 100644
--- a/cpp/velox/compute/VeloxBackend.h
+++ b/cpp/velox/compute/VeloxBackend.h
@@ -28,6 +28,7 @@
#include "velox/common/config/Config.h"
#include "velox/common/memory/MmapAllocator.h"
+#include "jni/JniHashTable.h"
#include "memory/VeloxMemoryManager.h"
namespace gluten {
@@ -56,6 +57,10 @@ class VeloxBackend {
return globalMemoryManager_.get();
}
+ folly::Executor* executor() const {
+ return ioExecutor_.get();
+ }
+
void tearDown();
private:
diff --git a/cpp/velox/jni/JniHashTable.cc b/cpp/velox/jni/JniHashTable.cc
new file mode 100644
index 0000000000..77cd78ff6a
--- /dev/null
+++ b/cpp/velox/jni/JniHashTable.cc
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <arrow/c/abi.h>
+
+#include <jni/JniCommon.h>
+#include "JniHashTable.h"
+#include "folly/String.h"
+#include "memory/ColumnarBatch.h"
+#include "memory/VeloxColumnarBatch.h"
+#include "substrait/algebra.pb.h"
+#include "substrait/type.pb.h"
+#include "velox/core/PlanNode.h"
+#include "velox/type/Type.h"
+
+namespace gluten {
+
+static jclass jniVeloxBroadcastBuildSideCache = nullptr;
+static jmethodID jniGet = nullptr;
+
+jlong callJavaGet(const std::string& id) {
+ JNIEnv* env;
+ if (vm->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
+ throw gluten::GlutenException("JNIEnv was not attached to current thread");
+ }
+
+ const jstring s = env->NewStringUTF(id.c_str());
+
+ auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache,
jniGet, s);
+ return result;
+}
+
+// Return the velox's hash table.
+std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
+ const std::string& joinKeys,
+ std::vector<std::string> names,
+ std::vector<facebook::velox::TypePtr> veloxTypeList,
+ int joinType,
+ bool hasMixedJoinCondition,
+ bool isExistenceJoin,
+ bool isNullAwareAntiJoin,
+ int64_t bloomFilterPushdownSize,
+ std::vector<std::shared_ptr<ColumnarBatch>>& batches,
+ std::shared_ptr<facebook::velox::memory::MemoryPool> memoryPool) {
+ auto rowType = std::make_shared<facebook::velox::RowType>(std::move(names),
std::move(veloxTypeList));
+
+ auto sJoin = static_cast<substrait::JoinRel_JoinType>(joinType);
+ facebook::velox::core::JoinType vJoin;
+ switch (sJoin) {
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
+ vJoin = facebook::velox::core::JoinType::kInner;
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER:
+ vJoin = facebook::velox::core::JoinType::kFull;
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT:
+ vJoin = facebook::velox::core::JoinType::kLeft;
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT:
+ vJoin = facebook::velox::core::JoinType::kRight;
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
+ // Determine the semi join type based on extracted information.
+ if (isExistenceJoin) {
+ vJoin = facebook::velox::core::JoinType::kLeftSemiProject;
+ } else {
+ vJoin = facebook::velox::core::JoinType::kLeftSemiFilter;
+ }
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
+ // Determine the semi join type based on extracted information.
+ if (isExistenceJoin) {
+ vJoin = facebook::velox::core::JoinType::kRightSemiProject;
+ } else {
+ vJoin = facebook::velox::core::JoinType::kRightSemiFilter;
+ }
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_ANTI: {
+ // Determine the anti join type based on extracted information.
+ vJoin = facebook::velox::core::JoinType::kAnti;
+ break;
+ }
+ default:
+ VELOX_NYI("Unsupported Join type: {}", std::to_string(sJoin));
+ }
+
+ std::vector<std::string> joinKeyNames;
+ folly::split(',', joinKeys, joinKeyNames);
+
+ std::vector<std::shared_ptr<const
facebook::velox::core::FieldAccessTypedExpr>> joinKeyTypes;
+ joinKeyTypes.reserve(joinKeyNames.size());
+ for (const auto& name : joinKeyNames) {
+ joinKeyTypes.emplace_back(
+
std::make_shared<facebook::velox::core::FieldAccessTypedExpr>(rowType->findChild(name),
name));
+ }
+
+ auto hashTableBuilder = std::make_shared<HashTableBuilder>(
+ vJoin,
+ isNullAwareAntiJoin,
+ hasMixedJoinCondition,
+ bloomFilterPushdownSize,
+ joinKeyTypes,
+ rowType,
+ memoryPool.get());
+
+ for (auto i = 0; i < batches.size(); i++) {
+ auto rowVector = VeloxColumnarBatch::from(memoryPool.get(),
batches[i])->getRowVector();
+ hashTableBuilder->addInput(rowVector);
+ }
+
+ return hashTableBuilder;
+}
+
+long getJoin(std::string hashTableId) {
+ return callJavaGet(hashTableId);
+}
+
+void initVeloxJniHashTable(JNIEnv* env) {
+ if (env->GetJavaVM(&vm) != JNI_OK) {
+ throw gluten::GlutenException("Unable to get JavaVM instance");
+ }
+ const char* classSig =
"Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;";
+ jniVeloxBroadcastBuildSideCache = createGlobalClassReferenceOrError(env,
classSig);
+ jniGet = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache, "get",
"(Ljava/lang/String;)J");
+}
+
+void finalizeVeloxJniHashTable(JNIEnv* env) {
+ env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache);
+}
+
+} // namespace gluten
diff --git a/cpp/velox/jni/JniHashTable.h b/cpp/velox/jni/JniHashTable.h
new file mode 100644
index 0000000000..c0d9227840
--- /dev/null
+++ b/cpp/velox/jni/JniHashTable.h
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <jni.h>
+#include "memory/ColumnarBatch.h"
+#include "memory/VeloxMemoryManager.h"
+#include "operators/hashjoin/HashTableBuilder.h"
+#include "utils/ObjectStore.h"
+#include "velox/exec/HashTable.h"
+
+namespace gluten {
+
+inline static JavaVM* vm = nullptr;
+
+inline static std::unique_ptr<ObjectStore> hashTableObjStore =
ObjectStore::create();
+
+// Return the hash table builder address.
+std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
+ const std::string& joinKeys,
+ std::vector<std::string> names,
+ std::vector<facebook::velox::TypePtr> veloxTypeList,
+ int joinType,
+ bool hasMixedJoinCondition,
+ bool isExistenceJoin,
+ bool isNullAwareAntiJoin,
+ int64_t bloomFilterPushdownSize,
+ std::vector<std::shared_ptr<ColumnarBatch>>& batches,
+ std::shared_ptr<facebook::velox::memory::MemoryPool> memoryPool);
+
+long getJoin(std::string hashTableId);
+
+void initVeloxJniHashTable(JNIEnv* env);
+
+void finalizeVeloxJniHashTable(JNIEnv* env);
+
+jlong callJavaGet(const std::string& id);
+
+} // namespace gluten
diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc
index ad6f8947eb..e488274e97 100644
--- a/cpp/velox/jni/VeloxJniWrapper.cc
+++ b/cpp/velox/jni/VeloxJniWrapper.cc
@@ -30,14 +30,18 @@
#include "config/GlutenConfig.h"
#include "jni/JniError.h"
#include "jni/JniFileSystem.h"
+#include "jni/JniHashTable.h"
+#include "memory/AllocationListener.h"
#include "memory/VeloxColumnarBatch.h"
#include "memory/VeloxMemoryManager.h"
+#include "operators/hashjoin/HashTableBuilder.h"
#include "shuffle/rss/RssPartitionWriter.h"
#include "substrait/SubstraitToVeloxPlanValidator.h"
#include "utils/ObjectStore.h"
#include "utils/VeloxBatchResizer.h"
#include "velox/common/base/BloomFilter.h"
#include "velox/common/file/FileSystems.h"
+#include "velox/exec/HashTable.h"
#ifdef GLUTEN_ENABLE_GPU
#include "cudf/CudfPlanValidator.h"
@@ -76,6 +80,7 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
getJniErrorState()->ensureInitialized(env);
initVeloxJniFileSystem(env);
initVeloxJniUDF(env);
+ initVeloxJniHashTable(env);
infoCls = createGlobalClassReferenceOrError(env,
"Lorg/apache/gluten/validate/NativePlanValidationInfo;");
infoClsInitMethod = getMethodIdOrError(env, infoCls, "<init>",
"(ILjava/lang/String;)V");
@@ -84,12 +89,13 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
createGlobalClassReferenceOrError(env,
"Lorg/apache/spark/sql/execution/datasources/BlockStripes;");
blockStripesConstructor = getMethodIdOrError(env, blockStripesClass,
"<init>", "(J[J[II[[B)V");
- batchWriteMetricsClass =
- createGlobalClassReferenceOrError(env,
"Lorg/apache/gluten/metrics/BatchWriteMetrics;");
+ batchWriteMetricsClass = createGlobalClassReferenceOrError(env,
"Lorg/apache/gluten/metrics/BatchWriteMetrics;");
batchWriteMetricsConstructor = getMethodIdOrError(env,
batchWriteMetricsClass, "<init>", "(JIJJ)V");
DLOG(INFO) << "Loaded Velox backend.";
+ gluten::vm = vm;
+
return jniVersion;
}
@@ -183,8 +189,7 @@
Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateWithFail
JNI_METHOD_END(nullptr)
}
-JNIEXPORT jboolean JNICALL
-Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateExpression(
// NOLINT
+JNIEXPORT jboolean JNICALL
Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateExpression(
// NOLINT
JNIEnv* env,
jobject wrapper,
jbyteArray exprArray,
@@ -439,8 +444,8 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_utils_VeloxBatchResizerJniWrapper
auto ctx = getRuntime(env, wrapper);
auto pool =
dynamic_cast<VeloxMemoryManager*>(ctx->memoryManager())->getLeafMemoryPool();
auto iter = makeJniColumnarBatchIterator(env, jIter, ctx);
- auto appender = std::make_shared<ResultIterator>(
- std::make_unique<VeloxBatchResizer>(pool.get(), minOutputBatchSize,
maxOutputBatchSize, preferredBatchBytes, std::move(iter)));
+ auto appender =
std::make_shared<ResultIterator>(std::make_unique<VeloxBatchResizer>(
+ pool.get(), minOutputBatchSize, maxOutputBatchSize, preferredBatchBytes,
std::move(iter)));
return ctx->saveObject(appender);
JNI_METHOD_END(kInvalidObjectHandle)
}
@@ -583,12 +588,15 @@
Java_org_apache_gluten_datasource_VeloxDataSourceJniWrapper_splitBlockByPartitio
const auto numRows = inputRowVector->size();
connector::hive::PartitionIdGenerator idGen(
- asRowType(inputRowVector->type()), partitionColIndicesVec, 65536,
pool.get()
+ asRowType(inputRowVector->type()),
+ partitionColIndicesVec,
+ 65536,
+ pool.get()
#ifdef GLUTEN_ENABLE_ENHANCED_FEATURES
- ,
+ ,
true
-#endif
- );
+#endif
+ );
raw_vector<uint64_t> partitionIds{};
idGen.run(inputRowVector, partitionIds);
GLUTEN_CHECK(partitionIds.size() == numRows, "Mismatched number of partition
ids");
@@ -914,18 +922,181 @@ JNIEXPORT jobject JNICALL
Java_org_apache_gluten_execution_IcebergWriteJniWrappe
auto writer = ObjectStore::retrieve<IcebergWriter>(writerHandle);
auto writeStats = writer->writeStats();
jobject writeMetrics = env->NewObject(
- batchWriteMetricsClass,
- batchWriteMetricsConstructor,
- writeStats.numWrittenBytes,
- writeStats.numWrittenFiles,
- writeStats.writeIOTimeNs,
- writeStats.writeWallNs);
+ batchWriteMetricsClass,
+ batchWriteMetricsConstructor,
+ writeStats.numWrittenBytes,
+ writeStats.numWrittenFiles,
+ writeStats.writeIOTimeNs,
+ writeStats.writeWallNs);
return writeMetrics;
JNI_METHOD_END(nullptr)
}
#endif
+JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_HashJoinBuilder_nativeBuild( // NOLINT
+ JNIEnv* env,
+ jclass,
+ jstring tableId,
+ jlongArray batchHandles,
+ jstring joinKey,
+ jint joinType,
+ jboolean hasMixedJoinCondition,
+ jboolean isExistenceJoin,
+ jbyteArray namedStruct,
+ jboolean isNullAwareAntiJoin,
+ jlong bloomFilterPushdownSize,
+ jint broadcastHashTableBuildThreads) {
+ JNI_METHOD_START
+ const auto hashTableId = jStringToCString(env, tableId);
+ const auto hashJoinKey = jStringToCString(env, joinKey);
+ const auto inputType = gluten::getByteArrayElementsSafe(env, namedStruct);
+ std::string structString{
+ reinterpret_cast<const char*>(inputType.elems()),
static_cast<std::string::size_type>(inputType.length())};
+
+ substrait::NamedStruct substraitStruct;
+ substraitStruct.ParseFromString(structString);
+
+ std::vector<facebook::velox::TypePtr> veloxTypeList;
+ veloxTypeList = SubstraitParser::parseNamedStruct(substraitStruct);
+
+ const auto& substraitNames = substraitStruct.names();
+
+ std::vector<std::string> names;
+ names.reserve(substraitNames.size());
+ for (const auto& name : substraitNames) {
+ names.emplace_back(name);
+ }
+
+ std::vector<std::shared_ptr<ColumnarBatch>> cb;
+ int handleCount = env->GetArrayLength(batchHandles);
+ auto safeArray = getLongArrayElementsSafe(env, batchHandles);
+ for (int i = 0; i < handleCount; ++i) {
+ int64_t handle = safeArray.elems()[i];
+ cb.push_back(ObjectStore::retrieve<ColumnarBatch>(handle));
+ }
+
+ size_t maxThreads = broadcastHashTableBuildThreads > 0
+ ? std::min((size_t)broadcastHashTableBuildThreads, (size_t)32)
+ : std::min((size_t)std::thread::hardware_concurrency(), (size_t)32);
+
+ // Heuristic: Each thread should process at least a certain number of
batches to justify parallelism overhead.
+ // 32 batches is roughly 128k rows, which is a reasonable granularity for a
single thread.
+ constexpr size_t kMinBatchesPerThread = 32;
+ size_t numThreads = std::min(maxThreads, (handleCount + kMinBatchesPerThread
- 1) / kMinBatchesPerThread);
+ numThreads = std::max((size_t)1, numThreads);
+
+ if (numThreads <= 1) {
+ auto builder = nativeHashTableBuild(
+ hashJoinKey,
+ names,
+ veloxTypeList,
+ joinType,
+ hasMixedJoinCondition,
+ isExistenceJoin,
+ isNullAwareAntiJoin,
+ bloomFilterPushdownSize,
+ cb,
+ defaultLeafVeloxMemoryPool());
+
+ auto mainTable = builder->uniqueTable();
+ mainTable->prepareJoinTable(
+ {},
+ facebook::velox::exec::BaseHashTable::kNoSpillInputStartPartitionBit,
+ 1'000'000,
+ builder->dropDuplicates(),
+ nullptr);
+ builder->setHashTable(std::move(mainTable));
+
+ return gluten::hashTableObjStore->save(builder);
+ }
+
+ std::vector<std::thread> threads;
+
+ std::vector<std::shared_ptr<gluten::HashTableBuilder>>
hashTableBuilders(numThreads);
+ std::vector<std::unique_ptr<facebook::velox::exec::BaseHashTable>>
otherTables(numThreads);
+
+ for (size_t t = 0; t < numThreads; ++t) {
+ size_t start = (handleCount * t) / numThreads;
+ size_t end = (handleCount * (t + 1)) / numThreads;
+
+ threads.emplace_back([&, t, start, end]() {
+ std::vector<std::shared_ptr<gluten::ColumnarBatch>> threadBatches;
+ for (size_t i = start; i < end; ++i) {
+ threadBatches.push_back(cb[i]);
+ }
+
+ auto builder = nativeHashTableBuild(
+ hashJoinKey,
+ names,
+ veloxTypeList,
+ joinType,
+ hasMixedJoinCondition,
+ isExistenceJoin,
+ isNullAwareAntiJoin,
+ bloomFilterPushdownSize,
+ threadBatches,
+ defaultLeafVeloxMemoryPool());
+
+ hashTableBuilders[t] = std::move(builder);
+ otherTables[t] = std::move(hashTableBuilders[t]->uniqueTable());
+ });
+ }
+
+ for (auto& thread : threads) {
+ thread.join();
+ }
+
+ auto mainTable = std::move(otherTables[0]);
+ std::vector<std::unique_ptr<facebook::velox::exec::BaseHashTable>> tables;
+ for (int i = 1; i < numThreads; ++i) {
+ tables.push_back(std::move(otherTables[i]));
+ }
+
+ // TODO: Get accurate signal if parallel join build is going to be applied
+ // from hash table. Currently there is still a chance inside hash table that
+ // it might decide it is not going to trigger parallel join build.
+ const bool allowParallelJoinBuild = !tables.empty();
+
+ mainTable->prepareJoinTable(
+ std::move(tables),
+ facebook::velox::exec::BaseHashTable::kNoSpillInputStartPartitionBit,
+ 1'000'000,
+ hashTableBuilders[0]->dropDuplicates(),
+ allowParallelJoinBuild ? VeloxBackend::get()->executor() : nullptr);
+
+ for (int i = 1; i < numThreads; ++i) {
+ if (hashTableBuilders[i]->joinHasNullKeys()) {
+ hashTableBuilders[0]->setJoinHasNullKeys(true);
+ break;
+ }
+ }
+
+ hashTableBuilders[0]->setHashTable(std::move(mainTable));
+ return gluten::hashTableObjStore->save(hashTableBuilders[0]);
+ JNI_METHOD_END(kInvalidObjectHandle)
+}
+
+JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_HashJoinBuilder_cloneHashTable( // NOLINT
+ JNIEnv* env,
+ jclass,
+ jlong tableHandler) {
+ JNI_METHOD_START
+ auto hashTableHandler =
ObjectStore::retrieve<gluten::HashTableBuilder>(tableHandler);
+ return gluten::hashTableObjStore->save(hashTableHandler);
+ JNI_METHOD_END(kInvalidObjectHandle)
+}
+
+JNIEXPORT void JNICALL
Java_org_apache_gluten_vectorized_HashJoinBuilder_clearHashTable( // NOLINT
+ JNIEnv* env,
+ jclass,
+ jlong tableHandler) {
+ JNI_METHOD_START
+ auto hashTableHandler =
ObjectStore::retrieve<gluten::HashTableBuilder>(tableHandler);
+ hashTableHandler->hashTable()->clear(true);
+ ObjectStore::release(tableHandler);
+ JNI_METHOD_END()
+}
#ifdef __cplusplus
}
#endif
diff --git a/cpp/velox/operators/hashjoin/HashTableBuilder.cc
b/cpp/velox/operators/hashjoin/HashTableBuilder.cc
new file mode 100644
index 0000000000..7c42cf5b49
--- /dev/null
+++ b/cpp/velox/operators/hashjoin/HashTableBuilder.cc
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "operators/hashjoin/HashTableBuilder.h"
+#include "velox/exec/OperatorUtils.h"
+
+namespace gluten {
+namespace {
+facebook::velox::RowTypePtr hashJoinTableType(
+ const std::vector<facebook::velox::core::FieldAccessTypedExprPtr>&
joinKeys,
+ const facebook::velox::RowTypePtr& inputType) {
+ const auto numKeys = joinKeys.size();
+
+ std::vector<std::string> names;
+ names.reserve(inputType->size());
+ std::vector<facebook::velox::TypePtr> types;
+ types.reserve(inputType->size());
+ std::unordered_set<uint32_t> keyChannelSet;
+ keyChannelSet.reserve(inputType->size());
+
+ for (int i = 0; i < numKeys; ++i) {
+ auto& key = joinKeys[i];
+ auto channel = facebook::velox::exec::exprToChannel(key.get(), inputType);
+ keyChannelSet.insert(channel);
+ names.emplace_back(inputType->nameOf(channel));
+ types.emplace_back(inputType->childAt(channel));
+ }
+
+ for (auto i = 0; i < inputType->size(); ++i) {
+ if (keyChannelSet.find(i) == keyChannelSet.end()) {
+ names.emplace_back(inputType->nameOf(i));
+ types.emplace_back(inputType->childAt(i));
+ }
+ }
+
+ return ROW(std::move(names), std::move(types));
+}
+
+bool isLeftNullAwareJoinWithFilter(facebook::velox::core::JoinType joinType,
bool nullAware, bool withFilter) {
+ return (isAntiJoin(joinType) || isLeftSemiProjectJoin(joinType) ||
isLeftSemiFilterJoin(joinType)) && nullAware &&
+ withFilter;
+}
+} // namespace
+
+HashTableBuilder::HashTableBuilder(
+ facebook::velox::core::JoinType joinType,
+ bool nullAware,
+ bool withFilter,
+ int64_t bloomFilterPushdownSize,
+ const std::vector<facebook::velox::core::FieldAccessTypedExprPtr>&
joinKeys,
+ const facebook::velox::RowTypePtr& inputType,
+ facebook::velox::memory::MemoryPool* pool)
+ : joinType_{joinType},
+ nullAware_{nullAware},
+ withFilter_(withFilter),
+ keyChannelMap_(joinKeys.size()),
+ inputType_(inputType),
+ bloomFilterPushdownSize_(bloomFilterPushdownSize),
+ pool_(pool) {
+ const auto numKeys = joinKeys.size();
+ keyChannels_.reserve(numKeys);
+
+ for (int i = 0; i < numKeys; ++i) {
+ auto& key = joinKeys[i];
+ auto channel = facebook::velox::exec::exprToChannel(key.get(), inputType_);
+ keyChannelMap_[channel] = i;
+ keyChannels_.emplace_back(channel);
+ }
+
+ // Identify the non-key build side columns and make a decoder for each.
+ const int32_t numDependents = inputType_->size() - numKeys;
+ if (numDependents > 0) {
+ // Number of join keys (numKeys) may be less then number of input columns
+ // (inputType->size()). In this case numDependents is negative and cannot
be
+ // used to call 'reserve'. This happens when we join different probe side
+ // keys with the same build side key: SELECT * FROM t LEFT JOIN u ON t.k1 =
+ // u.k AND t.k2 = u.k.
+ dependentChannels_.reserve(numDependents);
+ decoders_.reserve(numDependents);
+ }
+ for (auto i = 0; i < inputType->size(); ++i) {
+ if (keyChannelMap_.find(i) == keyChannelMap_.end()) {
+ dependentChannels_.emplace_back(i);
+
decoders_.emplace_back(std::make_unique<facebook::velox::DecodedVector>());
+ }
+ }
+
+ tableType_ = hashJoinTableType(joinKeys, inputType);
+ setupTable();
+}
+
+// Invoked to set up hash table to build.
+void HashTableBuilder::setupTable() {
+ VELOX_CHECK_NULL(uniqueTable_);
+
+ const auto numKeys = keyChannels_.size();
+ std::vector<std::unique_ptr<facebook::velox::exec::VectorHasher>> keyHashers;
+ keyHashers.reserve(numKeys);
+ for (vector_size_t i = 0; i < numKeys; ++i) {
+
keyHashers.emplace_back(facebook::velox::exec::VectorHasher::create(tableType_->childAt(i),
keyChannels_[i]));
+ }
+
+ const auto numDependents = tableType_->size() - numKeys;
+ std::vector<facebook::velox::TypePtr> dependentTypes;
+ dependentTypes.reserve(numDependents);
+ for (int i = numKeys; i < tableType_->size(); ++i) {
+ dependentTypes.emplace_back(tableType_->childAt(i));
+ }
+ if (isRightJoin(joinType_) || isFullJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
+ // Do not ignore null keys.
+ uniqueTable_ = facebook::velox::exec::HashTable<false>::createForJoin(
+ std::move(keyHashers),
+ dependentTypes,
+ true, // allowDuplicates
+ true, // hasProbedFlag
+ 1'000, //
operatorCtx_->driverCtx()->queryConfig().minTableRowsForParallelJoinBuild()
+ pool_,
+ true);
+ } else {
+ // (Left) semi and anti join with no extra filter only needs to know
whether
+ // there is a match. Hence, no need to store entries with duplicate keys.
+ dropDuplicates_ =
+ !withFilter_ && (isLeftSemiFilterJoin(joinType_) ||
isLeftSemiProjectJoin(joinType_) || isAntiJoin(joinType_));
+ // Right semi join needs to tag build rows that were probed.
+ const bool needProbedFlag = isRightSemiFilterJoin(joinType_);
+ if (isLeftNullAwareJoinWithFilter(joinType_, nullAware_, withFilter_)) {
+ // We need to check null key rows in build side in case of null-aware
anti
+ // or left semi project join with filter set.
+ uniqueTable_ = facebook::velox::exec::HashTable<false>::createForJoin(
+ std::move(keyHashers),
+ dependentTypes,
+ !dropDuplicates_, // allowDuplicates
+ needProbedFlag, // hasProbedFlag
+ 1'000, //
operatorCtx_->driverCtx()->queryConfig().minTableRowsForParallelJoinBuild()
+ pool_,
+ true);
+ } else {
+ // Ignore null keys
+ uniqueTable_ = facebook::velox::exec::HashTable<true>::createForJoin(
+ std::move(keyHashers),
+ dependentTypes,
+ !dropDuplicates_, // allowDuplicates
+ needProbedFlag, // hasProbedFlag
+ 1'000, //
operatorCtx_->driverCtx()->queryConfig().minTableRowsForParallelJoinBuild()
+ pool_,
+ bloomFilterPushdownSize_);
+ }
+ }
+ analyzeKeys_ = uniqueTable_->hashMode() !=
facebook::velox::exec::BaseHashTable::HashMode::kHash;
+}
+
+void HashTableBuilder::addInput(facebook::velox::RowVectorPtr input) {
+ activeRows_.resize(input->size());
+ activeRows_.setAll();
+
+ auto& hashers = uniqueTable_->hashers();
+
+ for (auto i = 0; i < hashers.size(); ++i) {
+ auto key = input->childAt(hashers[i]->channel())->loadedVector();
+ hashers[i]->decode(*key, activeRows_);
+ }
+
+ deselectRowsWithNulls(hashers, activeRows_);
+ activeRows_.setAll();
+
+ if (!isRightJoin(joinType_) && !isFullJoin(joinType_) &&
!isRightSemiProjectJoin(joinType_) &&
+ !isLeftNullAwareJoinWithFilter(joinType_, nullAware_, withFilter_)) {
+ deselectRowsWithNulls(hashers, activeRows_);
+ if (nullAware_ && !joinHasNullKeys_ && activeRows_.countSelected() <
input->size()) {
+ joinHasNullKeys_ = true;
+ }
+ } else if (nullAware_ && !joinHasNullKeys_) {
+ for (auto& hasher : hashers) {
+ auto& decoded = hasher->decodedVector();
+ if (decoded.mayHaveNulls()) {
+ auto* nulls = decoded.nulls(&activeRows_);
+ if (nulls && facebook::velox::bits::countNulls(nulls, 0,
activeRows_.end()) > 0) {
+ joinHasNullKeys_ = true;
+ break;
+ }
+ }
+ }
+ }
+
+ for (auto i = 0; i < dependentChannels_.size(); ++i) {
+
decoders_[i]->decode(*input->childAt(dependentChannels_[i])->loadedVector(),
activeRows_);
+ }
+
+ if (!activeRows_.hasSelections()) {
+ return;
+ }
+
+ if (analyzeKeys_ && hashes_.size() < activeRows_.end()) {
+ hashes_.resize(activeRows_.end());
+ }
+
+ // As long as analyzeKeys is true, we keep running the keys through
+ // the Vectorhashers so that we get a possible mapping of the keys
+ // to small ints for array or normalized key. When mayUseValueIds is
+ // false for the first time we stop. We do not retain the value ids
+ // since the final ones will only be known after all data is
+ // received.
+ for (auto& hasher : hashers) {
+ // TODO: Load only for active rows, except if right/full outer join.
+ if (analyzeKeys_) {
+ hasher->computeValueIds(activeRows_, hashes_);
+ analyzeKeys_ = hasher->mayUseValueIds();
+ }
+ }
+ auto rows = uniqueTable_->rows();
+ auto nextOffset = rows->nextOffset();
+
+ activeRows_.applyToSelected([&](auto rowIndex) {
+ char* newRow = rows->newRow();
+ if (nextOffset) {
+ *reinterpret_cast<char**>(newRow + nextOffset) = nullptr;
+ }
+ // Store the columns for each row in sequence. At probe time
+ // strings of the row will probably be in consecutive places, so
+ // reading one will prime the cache for the next.
+ for (auto i = 0; i < hashers.size(); ++i) {
+ rows->store(hashers[i]->decodedVector(), rowIndex, newRow, i);
+ }
+ for (auto i = 0; i < dependentChannels_.size(); ++i) {
+ rows->store(*decoders_[i], rowIndex, newRow, i + hashers.size());
+ }
+ });
+}
+
+} // namespace gluten
diff --git a/cpp/velox/operators/hashjoin/HashTableBuilder.h
b/cpp/velox/operators/hashjoin/HashTableBuilder.h
new file mode 100644
index 0000000000..83c90b4110
--- /dev/null
+++ b/cpp/velox/operators/hashjoin/HashTableBuilder.h
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <thread>
+#include "velox/exec/HashJoinBridge.h"
+#include "velox/exec/HashTable.h"
+#include "velox/exec/RowContainer.h"
+#include "velox/exec/VectorHasher.h"
+
+namespace gluten {
+using column_index_t = uint32_t;
+using vector_size_t = int32_t;
+
+class HashTableBuilder {
+ public:
+ HashTableBuilder(
+ facebook::velox::core::JoinType joinType,
+ bool nullAware,
+ bool withFilter,
+ int64_t bloomFilterPushdownSize,
+ const std::vector<facebook::velox::core::FieldAccessTypedExprPtr>&
joinKeys,
+ const facebook::velox::RowTypePtr& inputType,
+ facebook::velox::memory::MemoryPool* pool);
+
+ void addInput(facebook::velox::RowVectorPtr input);
+
+ void setHashTable(std::unique_ptr<facebook::velox::exec::BaseHashTable>
uniqueHashTable) {
+ table_ = std::move(uniqueHashTable);
+ }
+
+ std::unique_ptr<facebook::velox::exec::BaseHashTable> uniqueTable() {
+ return std::move(uniqueTable_);
+ }
+
+ std::shared_ptr<facebook::velox::exec::BaseHashTable> hashTable() {
+ return table_;
+ }
+ void setJoinHasNullKeys(bool joinHasNullKeys) {
+ joinHasNullKeys_ = joinHasNullKeys;
+ }
+
+ bool joinHasNullKeys() {
+ return joinHasNullKeys_;
+ }
+
+ bool dropDuplicates() {
+ return dropDuplicates_;
+ }
+
+ private:
+ // Invoked to set up hash table to build.
+ void setupTable();
+
+ const facebook::velox::core::JoinType joinType_;
+
+ const bool nullAware_;
+ const bool withFilter_;
+
+ // The row type used for hash table build and disk spilling.
+ facebook::velox::RowTypePtr tableType_;
+
+ // Container for the rows being accumulated.
+ std::shared_ptr<facebook::velox::exec::BaseHashTable> table_;
+
+ std::unique_ptr<facebook::velox::exec::BaseHashTable> uniqueTable_;
+
+ // Key channels in 'input_'
+ std::vector<column_index_t> keyChannels_;
+
+ // Non-key channels in 'input_'.
+ std::vector<column_index_t> dependentChannels_;
+
+ // Corresponds 1:1 to 'dependentChannels_'.
+ std::vector<std::unique_ptr<facebook::velox::DecodedVector>> decoders_;
+
+ // True if we are considering use of normalized keys or array hash tables.
+ // Set to false when the dataset is no longer suitable.
+ bool analyzeKeys_;
+
+ // Temporary space for hash numbers.
+ facebook::velox::raw_vector<uint64_t> hashes_;
+
+ // Set of active rows during addInput().
+ facebook::velox::SelectivityVector activeRows_;
+
+ // True if this is a build side of an anti or left semi project join and has
+ // at least one entry with null join keys.
+ bool joinHasNullKeys_{false};
+
+ // Indices of key columns used by the filter in build side table.
+ std::vector<column_index_t> keyFilterChannels_;
+ // Indices of dependent columns used by the filter in 'decoders_'.
+ std::vector<column_index_t> dependentFilterChannels_;
+
+ // Maps key channel in 'input_' to channel in key.
+ folly::F14FastMap<column_index_t, column_index_t> keyChannelMap_;
+
+ const facebook::velox::RowTypePtr& inputType_;
+
+ int64_t bloomFilterPushdownSize_;
+
+ facebook::velox::memory::MemoryPool* pool_;
+
+ bool dropDuplicates_{false};
+};
+
+} // namespace gluten
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index d71ab12528..834127e20c 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -19,12 +19,15 @@
#include "TypeUtils.h"
#include "VariantToVectorConverter.h"
+#include "jni/JniHashTable.h"
+#include "operators/hashjoin/HashTableBuilder.h"
#include "operators/plannodes/RowVectorStream.h"
#include "velox/connectors/hive/HiveDataSink.h"
#include "velox/exec/TableWriter.h"
#include "velox/type/Type.h"
#include "utils/ConfigExtractor.h"
+#include "utils/ObjectStore.h"
#include "utils/VeloxWriterUtils.h"
#include "config.pb.h"
@@ -393,6 +396,43 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
rightNode,
getJoinOutputType(leftNode, rightNode, joinType));
+ } else if (
+ sJoin.has_advanced_extension() &&
+ SubstraitParser::configSetInOptimization(sJoin.advanced_extension(),
"isBHJ=")) {
+ std::string hashTableId = sJoin.hashtableid();
+
+ std::shared_ptr<core::OpaqueHashTable> opaqueSharedHashTable = nullptr;
+ bool joinHasNullKeys = false;
+
+ try {
+ auto hashTableBuilder =
ObjectStore::retrieve<gluten::HashTableBuilder>(getJoin(hashTableId));
+ joinHasNullKeys = hashTableBuilder->joinHasNullKeys();
+ auto originalShared = hashTableBuilder->hashTable();
+ opaqueSharedHashTable = std::shared_ptr<core::OpaqueHashTable>(
+ originalShared,
reinterpret_cast<core::OpaqueHashTable*>(originalShared.get()));
+
+ LOG(INFO) << "Successfully retrieved and aliased HashTable for reuse.
ID: " << hashTableId;
+ } catch (const std::exception& e) {
+ LOG(WARNING)
+ << "Error retrieving HashTable from ObjectStore: " << e.what()
+ << ". Falling back to building new table. To ensure correct results,
please verify that spark.gluten.velox.buildHashTableOncePerExecutor.enabled is
set to false.";
+ opaqueSharedHashTable = nullptr;
+ }
+
+ // Create HashJoinNode node
+ return std::make_shared<core::HashJoinNode>(
+ nextPlanNodeId(),
+ joinType,
+ isNullAwareAntiJoin,
+ leftKeys,
+ rightKeys,
+ filter,
+ leftNode,
+ rightNode,
+ getJoinOutputType(leftNode, rightNode, joinType),
+ false,
+ joinHasNullKeys,
+ opaqueSharedHashTable);
} else {
// Create HashJoinNode node
return std::make_shared<core::HashJoinNode>(
diff --git a/docs/Configuration.md b/docs/Configuration.md
index 1372d98243..066d664436 100644
--- a/docs/Configuration.md
+++ b/docs/Configuration.md
@@ -79,6 +79,7 @@ nav_order: 15
| spark.gluten.sql.columnar.partial.generate | true
| Evaluates the non-offload-able HiveUDTF using vanilla Spark
generator
|
| spark.gluten.sql.columnar.partial.project | true
| Break up one project node into 2 phases when some of the
expressions are non offload-able. Phase one is a regular offloaded project
transformer that evaluates the offload-able expressions in native, phase two
preserves the output from phase one and evaluates the remaining
non-offload-able expressions using vanilla Spark projections
|
| spark.gluten.sql.columnar.physicalJoinOptimizationLevel | 12
| Fallback to row operators if there are several continuous joins.
|
+| spark.gluten.sql.columnar.physicalJoinOptimizationOutputSize | 52
| Fallback to row operators if there are several continuous joins and
matched output size.
|
| spark.gluten.sql.columnar.physicalJoinOptimizeEnable | false
| Enable or disable columnar physicalJoinOptimize.
|
| spark.gluten.sql.columnar.preferStreamingAggregate | true
| Velox backend supports `StreamingAggregate`. `StreamingAggregate`
uses the less memory as it does not need to hold all groups in memory, so it
could avoid spill. When true and the child output ordering satisfies the
grouping key then Gluten will choose `StreamingAggregate` as the native
operator. |
| spark.gluten.sql.columnar.project | true
| Enable or disable columnar project.
|
diff --git a/docs/velox-configuration.md b/docs/velox-configuration.md
index f4a79c4652..1a4a1fb7e6 100644
--- a/docs/velox-configuration.md
+++ b/docs/velox-configuration.md
@@ -19,6 +19,7 @@ nav_order: 16
| spark.gluten.sql.columnar.backend.velox.bloomFilter.expectedNumItems
| 1000000 | The default number of expected items for the velox
bloomfilter: 'spark.bloom_filter.expected_num_items'
[...]
| spark.gluten.sql.columnar.backend.velox.bloomFilter.maxNumBits
| 4194304 | The max number of bits to use for the velox bloom
filter: 'spark.bloom_filter.max_num_bits'
[...]
| spark.gluten.sql.columnar.backend.velox.bloomFilter.numBits
| 8388608 | The default number of bits to use for the velox bloom
filter: 'spark.bloom_filter.num_bits'
[...]
+| spark.gluten.sql.columnar.backend.velox.broadcastHashTableBuildThreads
| 1 | The number of threads used to build the broadcast
hash table. If not set or set to 0, it will use the default number of threads
(available processors).
[...]
| spark.gluten.sql.columnar.backend.velox.cacheEnabled
| false | Enable Velox cache, default off. It's recommended to
enablesoft-affinity as well when enable velox cache.
[...]
| spark.gluten.sql.columnar.backend.velox.cachePrefetchMinPct
| 0 | Set prefetch cache min pct for velox file scan
[...]
| spark.gluten.sql.columnar.backend.velox.checkUsageLeak
| true | Enable check memory usage leak.
[...]
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenJoinKeysCapture.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenJoinKeysCapture.scala
new file mode 100644
index 0000000000..5d1cb8d90a
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenJoinKeysCapture.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys,
ExtractSingleColumnNullAwareAntiJoin}
+import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
+import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
+
+/**
+ * Strategy to capture join keys from logical plan before Spark's
JoinSelection transforms them.
+ * This strategy runs early in the planning phase to preserve the original
join keys before any
+ * transformations like rewriteKeyExpr.
+ */
+case class GlutenJoinKeysCapture() extends SparkStrategy {
+
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = {
+
+ if (!plan.isInstanceOf[Join]) {
+ return Nil
+ }
+
+ plan match {
+
+ case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, left, right, _) =>
+ if (leftKeys.nonEmpty) {
+ left.setTagValue(JoinKeysTag.ORIGINAL_JOIN_KEYS, leftKeys)
+ }
+ if (rightKeys.nonEmpty) {
+ right.setTagValue(JoinKeysTag.ORIGINAL_JOIN_KEYS, rightKeys)
+ }
+
+ Nil
+
+ case j @ ExtractSingleColumnNullAwareAntiJoin(leftKeys, rightKeys) =>
+ if (leftKeys.nonEmpty) {
+ j.left.setTagValue(JoinKeysTag.ORIGINAL_JOIN_KEYS, leftKeys)
+ }
+ if (rightKeys.nonEmpty) {
+ j.right.setTagValue(JoinKeysTag.ORIGINAL_JOIN_KEYS, rightKeys)
+ }
+
+ Nil
+
+ // For non-equi-join or other plan nodes, return Nil.
+ case _ => Nil
+ }
+ }
+}
diff --git
a/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
b/gluten-core/src/main/scala/org/apache/gluten/extension/JoinKeysTag.scala
similarity index 61%
copy from
backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
copy to gluten-core/src/main/scala/org/apache/gluten/extension/JoinKeysTag.scala
index 2759613793..646b0df7d0 100644
---
a/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
+++ b/gluten-core/src/main/scala/org/apache/gluten/extension/JoinKeysTag.scala
@@ -14,24 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.test;
+package org.apache.gluten.extension
-import org.apache.gluten.backendsapi.ListenerApi;
-import org.apache.gluten.backendsapi.velox.VeloxListenerApi;
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
+/** TreeNodeTag for storing original join keys before Spark's transformations.
*/
+object JoinKeysTag {
-public abstract class VeloxBackendTestBase {
- private static final ListenerApi API = new VeloxListenerApi();
-
- @BeforeClass
- public static void setup() {
- API.onExecutorStart(MockVeloxBackend.mockPluginContext());
- }
-
- @AfterClass
- public static void tearDown() {
- API.onExecutorShutdown();
- }
+ /** Tag to store original join keys on logical plan nodes. */
+ val ORIGINAL_JOIN_KEYS: TreeNodeTag[Seq[Expression]] =
+ TreeNodeTag[Seq[Expression]]("gluten.originalJoinKeys")
}
diff --git
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/JoinRelNode.java
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/JoinRelNode.java
index 714340cdf6..2bd98500fe 100644
---
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/JoinRelNode.java
+++
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/JoinRelNode.java
@@ -32,6 +32,7 @@ public class JoinRelNode implements RelNode, Serializable {
private final ExpressionNode expression;
private final ExpressionNode postJoinFilter;
private final AdvancedExtensionNode extensionNode;
+ private final String hashTableId;
JoinRelNode(
RelNode left,
@@ -39,12 +40,14 @@ public class JoinRelNode implements RelNode, Serializable {
JoinRel.JoinType joinType,
ExpressionNode expression,
ExpressionNode postJoinFilter,
+ String hashTableId,
AdvancedExtensionNode extensionNode) {
this.left = left;
this.right = right;
this.joinType = joinType;
this.expression = expression;
this.postJoinFilter = postJoinFilter;
+ this.hashTableId = hashTableId;
this.extensionNode = extensionNode;
}
@@ -72,6 +75,8 @@ public class JoinRelNode implements RelNode, Serializable {
joinBuilder.setAdvancedExtension(extensionNode.toProtobuf());
}
+ joinBuilder.setHashTableId(hashTableId);
+
return Rel.newBuilder().setJoin(joinBuilder.build()).build();
}
}
diff --git
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
index 20ca9d36f1..4072394624 100644
---
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
+++
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
@@ -184,11 +184,12 @@ public class RelBuilder {
JoinRel.JoinType joinType,
ExpressionNode expression,
ExpressionNode postJoinFilter,
+ String hashTableId,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return makeJoinRel(
- left, right, joinType, expression, postJoinFilter, null, context,
operatorId);
+ left, right, joinType, expression, postJoinFilter, null, hashTableId,
context, operatorId);
}
public static RelNode makeJoinRel(
@@ -198,10 +199,12 @@ public class RelBuilder {
ExpressionNode expression,
ExpressionNode postJoinFilter,
AdvancedExtensionNode extensionNode,
+ String hashTableId,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
- return new JoinRelNode(left, right, joinType, expression, postJoinFilter,
extensionNode);
+ return new JoinRelNode(
+ left, right, joinType, expression, postJoinFilter, hashTableId,
extensionNode);
}
public static RelNode makeCrossRel(
diff --git
a/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto
b/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto
index 7d72332baa..2bfb68e097 100644
---
a/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto
+++
b/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto
@@ -258,6 +258,8 @@ message JoinRel {
JoinType type = 6;
+ string hashTableId = 7;
+
enum JoinType {
JOIN_TYPE_UNSPECIFIED = 0;
JOIN_TYPE_INNER = 1;
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
index 671a29709e..8dd3156099 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
@@ -85,6 +85,8 @@ trait BackendSettingsApi {
def enableJoinKeysRewrite(): Boolean = true
+ def enableHashTableBuildOncePerExecutor(): Boolean = true
+
def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = {
case _: InnerLike | RightOuter | FullOuter => true
case _ => false
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
index ed2d549366..8d71e15964 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
@@ -202,6 +202,9 @@ class GlutenConfig(conf: SQLConf) extends
GlutenCoreConfig(conf) {
def physicalJoinOptimizationThrottle: Integer =
getConf(COLUMNAR_PHYSICAL_JOIN_OPTIMIZATION_THROTTLE)
+ def physicalJoinOptimizationOutputSize: Integer =
+ getConf(COLUMNAR_PHYSICAL_JOIN_OPTIMIZATION_OUTPUT_SIZE)
+
def enablePhysicalJoinOptimize: Boolean =
getConf(COLUMNAR_PHYSICAL_JOIN_OPTIMIZATION_ENABLED)
@@ -998,6 +1001,13 @@ object GlutenConfig extends ConfigRegistry {
.intConf
.createWithDefault(12)
+ val COLUMNAR_PHYSICAL_JOIN_OPTIMIZATION_OUTPUT_SIZE =
+ buildConf("spark.gluten.sql.columnar.physicalJoinOptimizationOutputSize")
+ .doc(
+ "Fallback to row operators if there are several continuous joins and
matched output size.")
+ .intConf
+ .createWithDefault(52)
+
val COLUMNAR_PHYSICAL_JOIN_OPTIMIZATION_ENABLED =
buildConf("spark.gluten.sql.columnar.physicalJoinOptimizeEnable")
.doc("Enable or disable columnar physicalJoinOptimize.")
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
index e5db338515..b4fa188f44 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
@@ -186,9 +186,14 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec
with TransformSupport {
// https://issues.apache.org/jira/browse/SPARK-31869
private def expandPartitioning(partitioning: Partitioning): Partitioning = {
val expandLimit = conf.broadcastHashJoinOutputPartitioningExpandLimit
+ val (buildKeys, streamedKeys) = if (needSwitchChildren) {
+ (leftKeys, rightKeys)
+ } else {
+ (rightKeys, leftKeys)
+ }
joinType match {
case _: InnerLike if expandLimit > 0 =>
- new ExpandOutputPartitioningShim(streamedKeyExprs, buildKeyExprs,
expandLimit)
+ new ExpandOutputPartitioningShim(streamedKeys, buildKeys, expandLimit)
.expandPartitioning(partitioning)
case _ => partitioning
}
@@ -262,7 +267,8 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with
TransformSupport {
inputStreamedOutput,
inputBuildOutput,
context,
- operatorId
+ operatorId,
+ buildPlan.id.toString
)
context.registerJoinParam(operatorId, joinParams)
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
index a7a31cf471..eeb6069890 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
@@ -184,6 +184,7 @@ object JoinUtils {
inputBuildOutput: Seq[Attribute],
substraitContext: SubstraitContext,
operatorId: java.lang.Long,
+ hashTableId: String = "",
validation: Boolean = false): RelNode = {
// scalastyle:on argcount
// Create pre-projection for build/streamed plan. Append projected keys to
each side.
@@ -233,6 +234,7 @@ object JoinUtils {
joinExpressionNode,
postJoinFilter.orNull,
createJoinExtensionNode(joinParameters, streamedOutput ++ buildOutput),
+ hashTableId,
substraitContext,
operatorId
)
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
index 926708ee33..5e6c777922 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
@@ -38,17 +38,32 @@ case class FallbackMultiCodegens(session: SparkSession)
extends Rule[SparkPlan]
lazy val glutenConf: GlutenConfig = GlutenConfig.get
lazy val physicalJoinOptimize = glutenConf.enablePhysicalJoinOptimize
lazy val optimizeLevel: Integer = glutenConf.physicalJoinOptimizationThrottle
+ lazy val outputSize: Integer = glutenConf.physicalJoinOptimizationOutputSize
def existsMultiCodegens(plan: SparkPlan, count: Int = 0): Boolean =
plan match {
case plan: CodegenSupport if plan.supportCodegen =>
- if ((count + 1) >= optimizeLevel) return true
+ if (
+ (count + 1) >= optimizeLevel &&
plan.output.map(_.dataType.defaultSize).sum == outputSize
+ ) {
+ return true
+ }
plan.children.exists(existsMultiCodegens(_, count + 1))
case plan: ShuffledHashJoinExec =>
- if ((count + 1) >= optimizeLevel) return true
+ if (
+ (count + 1) >= optimizeLevel &&
plan.output.map(_.dataType.defaultSize).sum == outputSize
+ ) {
+ return true
+ }
+
plan.children.exists(existsMultiCodegens(_, count + 1))
case plan: SortMergeJoinExec if GlutenConfig.get.forceShuffledHashJoin =>
- if ((count + 1) >= optimizeLevel) return true
+ if (
+ (count + 1) >= optimizeLevel &&
plan.output.map(_.dataType.defaultSize).sum == outputSize
+ ) {
+ return true
+ }
+
plan.children.exists(existsMultiCodegens(_, count + 1))
case _ => false
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
index 1de490ad61..371f9948b7 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
@@ -131,9 +131,7 @@ case class ColumnarBroadcastExchangeExec(mode:
BroadcastMode, child: SparkPlan)
override def rowType0(): Convention.RowType = Convention.RowType.None
override def doCanonicalize(): SparkPlan = {
- val canonicalized =
-
BackendsApiManager.getSparkPlanExecApiInstance.doCanonicalizeForBroadcastMode(mode)
- ColumnarBroadcastExchangeExec(canonicalized, child.canonicalized)
+ ColumnarBroadcastExchangeExec(mode.canonicalized, child.canonicalized)
}
override def doPrepare(): Unit = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]