This is an automated email from the ASF dual-hosted git repository.
felixybw pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 91c52e15f1 [GLUTEN-9671][VL] Fix broadcast exchange stackoverflow due
to Kryo serialization (#10541)
91c52e15f1 is described below
commit 91c52e15f16593747e918145258ebe1408cb8ea2
Author: Felix Loesing <[email protected]>
AuthorDate: Thu Aug 28 18:47:04 2025 -0700
[GLUTEN-9671][VL] Fix broadcast exchange stackoverflow due to Kryo
serialization (#10541)
This pull request introduces a safer and more robust approach for handling
Spark's BroadcastMode during serialization. The main improvement is the
introduction of a new SafeBroadcastMode abstraction and related utilities,
which help avoid serialization issues that caused a Stackoverflow exception
during broadcast exchanges. BroadcastMode was introduced in this PR that caused
the issue we observed. HashedRelationBroadcastMode embeds Catalyst expression
trees, which are not safe to Kryo [...]
With this change, the broadcast payload now contains only primitives and
byte arrays (no Catalyst trees). For bound keys, we serialize just column
ordinals (+ null-aware flag) and for computed keys (e.g., upper(col)), we
serialize the key expressions once as Java bytes and deserialize only where
needed to build projections.
---
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 2 +-
.../spark/sql/execution/BroadcastModeUtils.scala | 134 +++++++++++++++++++++
.../spark/sql/execution/BroadcastUtils.scala | 4 +-
.../sql/execution/ColumnarBuildSideRelation.scala | 57 +++++++--
.../unsafe/UnsafeColumnarBuildSideRelation.scala | 91 +++++++++++---
5 files changed, 261 insertions(+), 27 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index c3ac63f767..af46ff2673 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -678,7 +678,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
dataSize += rawSize
if (useOffheapBroadcastBuildRelation) {
TaskResources.runUnsafe {
- new UnsafeColumnarBuildSideRelation(child.output,
serialized.flatMap(_.getSerialized), mode)
+ UnsafeColumnarBuildSideRelation(child.output,
serialized.flatMap(_.getSerialized), mode)
}
} else {
ColumnarBuildSideRelation(child.output,
serialized.flatMap(_.getSerialized), mode)
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastModeUtils.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastModeUtils.scala
new file mode 100644
index 0000000000..d0ae9a6832
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastModeUtils.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference,
Expression}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode,
IdentityBroadcastMode}
+import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, IOException,
ObjectInputStream, ObjectOutputStream}
+
+/**
+ * Provides serialization-safe representations of BroadcastMode to avoid
issues with circular
+ * references in complex expression trees during Kryo serialization.
+ */
+sealed trait SafeBroadcastMode extends Serializable
+
+/** Safe representation of IdentityBroadcastMode */
+case object IdentitySafeBroadcastMode extends SafeBroadcastMode
+
+/**
+ * Safe wrapper for HashedRelationBroadcastMode. Stores only column ordinals
instead of full
+ * BoundReference expressions.
+ */
+final case class HashSafeBroadcastMode(ordinals: Array[Int], isNullAware:
Boolean)
+ extends SafeBroadcastMode
+
+/**
+ * Safe wrapper for HashedRelationBroadcastMode when keys are not simple
BoundReferences. Stores key
+ * expressions as serialized Java bytes.
+ */
+final case class HashExprSafeBroadcastMode(exprBytes: Array[Byte],
isNullAware: Boolean)
+ extends SafeBroadcastMode
+
+object BroadcastModeUtils extends Logging {
+
+ /**
+ * Converts a BroadcastMode to its SafeBroadcastMode equivalent. Uses
ordinals for simple
+ * BoundReferences, otherwise serializes the expressions.
+ */
+ private[execution] def toSafe(mode: BroadcastMode): SafeBroadcastMode = mode
match {
+ case IdentityBroadcastMode =>
+ IdentitySafeBroadcastMode
+ case HashedRelationBroadcastMode(keys, isNullAware) =>
+ // Fast path: all keys are already BoundReference(i, ..,..).
+ val ords = keys.collect { case BoundReference(ord, _, _) => ord }
+ if (ords.size == keys.size) {
+ HashSafeBroadcastMode(ords.toArray, isNullAware)
+ } else {
+ // Fallback: store the key expressions as Java-serialized bytes.
+ HashExprSafeBroadcastMode(serializeExpressions(keys), isNullAware)
+ }
+
+ case other =>
+ throw new IllegalArgumentException(s"Unsupported BroadcastMode: $other")
+ }
+
+ /** Converts a SafeBroadcastMode to its BroadcastMode equivalent. */
+ private[execution] def fromSafe(safe: SafeBroadcastMode, output:
Seq[Attribute]): BroadcastMode =
+ safe match {
+ case IdentitySafeBroadcastMode =>
+ IdentityBroadcastMode
+
+ case HashSafeBroadcastMode(ords, isNullAware) =>
+ val bound = ords.map(i => BoundReference(i, output(i).dataType,
output(i).nullable)).toSeq
+ HashedRelationBroadcastMode(bound, isNullAware)
+
+ case HashExprSafeBroadcastMode(bytes, isNullAware) =>
+ HashedRelationBroadcastMode(deserializeExpressions(bytes), isNullAware)
+ }
+
+ // Helpers for expression serialization (used in HashExprSafeBroadcastMode)
+ private[execution] def serializeExpressions(keys: Seq[Expression]):
Array[Byte] = {
+ val bos = new ByteArrayOutputStream()
+ var oos: ObjectOutputStream = null
+ try {
+ oos = new ObjectOutputStream(bos)
+ oos.writeObject(keys)
+ oos.flush()
+ bos.toByteArray
+ } catch {
+ case e @ (_: IOException | _: ClassNotFoundException | _:
ClassCastException) =>
+ logError(
+ s"Failed to serialize expressions for BroadcastMode. Expression
count: ${keys.length}",
+ e)
+ throw new RuntimeException("Failed to serialize expressions for
BroadcastMode", e)
+ case e: Exception =>
+ logError(
+ s"Unexpected error during expression serialization. Expression
count: ${keys.length}",
+ e)
+ throw e
+ } finally {
+ if (oos != null) oos.close()
+ bos.close()
+ }
+ }
+
+ private[execution] def deserializeExpressions(bytes: Array[Byte]):
Seq[Expression] = {
+ val bis = new ByteArrayInputStream(bytes)
+ var ois: ObjectInputStream = null
+ try {
+ ois = new ObjectInputStream(bis)
+ ois.readObject().asInstanceOf[Seq[Expression]]
+ } catch {
+ case e @ (_: IOException | _: ClassNotFoundException | _:
ClassCastException) =>
+ logError(
+ s"Failed to deserialize expressions for BroadcastMode. Data size:
${bytes.length} bytes",
+ e)
+ throw new RuntimeException("Failed to deserialize expressions for
BroadcastMode", e)
+ case e: Exception =>
+ logError(
+ s"Unexpected error during expression deserialization. Data size:
${bytes.length} bytes",
+ e)
+ throw e
+ } finally {
+ if (ois != null) ois.close()
+ bis.close()
+ }
+ }
+}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
index 155227db41..342c9694f0 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
@@ -108,7 +108,7 @@ object BroadcastUtils {
result.getSerialized
}
if (useOffheapBuildRelation) {
- new UnsafeColumnarBuildSideRelation(
+ UnsafeColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
@@ -134,7 +134,7 @@ object BroadcastUtils {
result.getSerialized
}
if (useOffheapBuildRelation) {
- new UnsafeColumnarBuildSideRelation(
+ UnsafeColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
index cd49ed30ea..59a9cb2b00 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
@@ -26,11 +26,9 @@ import org.apache.gluten.utils.ArrowAbiUtil
import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper,
NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq,
BindReferences, BoundReference, Expression, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
-import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode
-import org.apache.spark.sql.execution.joins.BuildSideRelation
-import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
+import org.apache.spark.sql.execution.joins.{BuildSideRelation,
HashedRelationBroadcastMode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.utils.SparkArrowUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -40,17 +38,56 @@ import org.apache.arrow.c.ArrowSchema
import scala.collection.JavaConverters.asScalaIteratorConverter
+object ColumnarBuildSideRelation {
+ // Keep constructor with BroadcastMode for compatibility
+ def apply(
+ output: Seq[Attribute],
+ batches: Array[Array[Byte]],
+ mode: BroadcastMode): ColumnarBuildSideRelation = {
+ val boundMode = mode match {
+ case HashedRelationBroadcastMode(keys, isNullAware) =>
+ // Bind each key to the build-side output so simple cols become
BoundReference
+ val boundKeys: Seq[Expression] =
+ keys.map(k => BindReferences.bindReference(k, AttributeSeq(output)))
+ HashedRelationBroadcastMode(boundKeys, isNullAware)
+ case m =>
+ m // IdentityBroadcastMode, etc.
+ }
+ new ColumnarBuildSideRelation(output, batches,
BroadcastModeUtils.toSafe(boundMode))
+ }
+}
+
case class ColumnarBuildSideRelation(
output: Seq[Attribute],
batches: Array[Array[Byte]],
- mode: BroadcastMode)
+ safeBroadcastMode: SafeBroadcastMode)
extends BuildSideRelation {
- private def transformProjection: UnsafeProjection = {
- mode match {
- case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k)
- case IdentityBroadcastMode => UnsafeProjection.create(output, output)
- }
+ // Rebuild the real BroadcastMode on demand; never serialize it.
+ @transient override lazy val mode: BroadcastMode =
+ BroadcastModeUtils.fromSafe(safeBroadcastMode, output)
+
+ // If we stored expression bytes, deserialize once and cache locally (not
serialized).
+ @transient private lazy val exprKeysFromBytes: Option[Seq[Expression]] =
safeBroadcastMode match {
+ case HashExprSafeBroadcastMode(bytes, _) =>
+ Some(BroadcastModeUtils.deserializeExpressions(bytes))
+ case _ => None
+ }
+
+ private def transformProjection: UnsafeProjection = safeBroadcastMode match {
+ case IdentitySafeBroadcastMode =>
+ UnsafeProjection.create(output, output)
+ case HashSafeBroadcastMode(ords, _) =>
+ val bound = ords.map(i => BoundReference(i, output(i).dataType,
output(i).nullable))
+ UnsafeProjection.create(bound)
+ case HashExprSafeBroadcastMode(_, _) =>
+ exprKeysFromBytes match {
+ case Some(keys) => UnsafeProjection.create(keys)
+ case None =>
+ throw new IllegalStateException(
+ "Failed to deserialize expressions for HashExprSafeBroadcastMode"
+ )
+ }
}
override def deserialized: Iterator[ColumnarBatch] = {
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
index c0ef884e73..80e92a1537 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
@@ -28,8 +28,9 @@ import
org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeCo
import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
UnsafeProjection, UnsafeRow}
-import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode,
IdentityBroadcastMode}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq,
BindReferences, BoundReference, Expression, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
+import org.apache.spark.sql.execution.{BroadcastModeUtils,
HashExprSafeBroadcastMode, HashSafeBroadcastMode, IdentitySafeBroadcastMode,
SafeBroadcastMode}
import org.apache.spark.sql.execution.joins.{BuildSideRelation,
HashedRelationBroadcastMode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.utils.SparkArrowUtil
@@ -45,6 +46,44 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import scala.collection.JavaConverters.asScalaIteratorConverter
+object UnsafeColumnarBuildSideRelation {
+ // Keep constructors with BroadcastMode for compatibility
+ def apply(
+ output: Seq[Attribute],
+ batches: UnsafeBytesBufferArray,
+ mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
+ val boundMode = mode match {
+ case HashedRelationBroadcastMode(keys, isNullAware) =>
+ // Bind each key to the build-side output so simple cols become
BoundReference
+ val boundKeys: Seq[Expression] =
+ keys.map(k => BindReferences.bindReference(k, AttributeSeq(output)))
+ HashedRelationBroadcastMode(boundKeys, isNullAware)
+ case m =>
+ m // IdentityBroadcastMode, etc.
+ }
+ new UnsafeColumnarBuildSideRelation(output, batches,
BroadcastModeUtils.toSafe(boundMode))
+ }
+ def apply(
+ output: Seq[Attribute],
+ bytesBufferArray: Array[Array[Byte]],
+ mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
+ val boundMode = mode match {
+ case HashedRelationBroadcastMode(keys, isNullAware) =>
+ // Bind each key to the build-side output so simple cols become
BoundReference
+ val boundKeys: Seq[Expression] =
+ keys.map(k => BindReferences.bindReference(k, AttributeSeq(output)))
+ HashedRelationBroadcastMode(boundKeys, isNullAware)
+ case m =>
+ m // IdentityBroadcastMode, etc.
+ }
+ new UnsafeColumnarBuildSideRelation(
+ output,
+ bytesBufferArray,
+ BroadcastModeUtils.toSafe(boundMode)
+ )
+ }
+}
+
/**
* A broadcast relation that is built using off-heap memory. It will avoid the
on-heap memory OOM.
*
@@ -59,18 +98,33 @@ import
scala.collection.JavaConverters.asScalaIteratorConverter
case class UnsafeColumnarBuildSideRelation(
private var output: Seq[Attribute],
private var batches: UnsafeBytesBufferArray,
- var mode: BroadcastMode)
+ var safeBroadcastMode: SafeBroadcastMode)
extends BuildSideRelation
with Externalizable
with Logging
with KryoSerializable {
+ // Rebuild the real BroadcastMode on demand; never serialize it.
+ @transient override lazy val mode: BroadcastMode =
+ BroadcastModeUtils.fromSafe(safeBroadcastMode, output)
+
+ // If we stored expression bytes, deserialize once and cache locally (not
serialized).
+ @transient private lazy val exprKeysFromBytes: Option[Seq[Expression]] =
safeBroadcastMode match {
+ case HashExprSafeBroadcastMode(bytes, _) =>
+ Some(BroadcastModeUtils.deserializeExpressions(bytes))
+ case _ => None
+ }
+
/** needed for serialization. */
def this() = {
this(null, null.asInstanceOf[UnsafeBytesBufferArray], null)
}
- def this(output: Seq[Attribute], bytesBufferArray: Array[Array[Byte]], mode:
BroadcastMode) = {
+ def this(
+ output: Seq[Attribute],
+ bytesBufferArray: Array[Array[Byte]],
+ safeMode: SafeBroadcastMode
+ ) = {
this(
output,
UnsafeBytesBufferArray(
@@ -78,7 +132,7 @@ case class UnsafeColumnarBuildSideRelation(
bytesBufferArray.map(_.length),
bytesBufferArray.map(_.length.toLong).sum
),
- mode
+ safeMode
)
val batchesSize = bytesBufferArray.length
for (i <- 0 until batchesSize) {
@@ -89,7 +143,7 @@ case class UnsafeColumnarBuildSideRelation(
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException
{
out.writeObject(output)
- out.writeObject(mode)
+ out.writeObject(safeBroadcastMode)
out.writeInt(batches.arraySize)
out.writeObject(batches.bytesBufferLengths)
out.writeLong(batches.totalBytes)
@@ -101,7 +155,7 @@ case class UnsafeColumnarBuildSideRelation(
override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
kryo.writeObject(out, output.toList)
- kryo.writeClassAndObject(out, mode)
+ kryo.writeClassAndObject(out, safeBroadcastMode)
out.writeInt(batches.arraySize)
kryo.writeObject(out, batches.bytesBufferLengths)
out.writeLong(batches.totalBytes)
@@ -113,7 +167,7 @@ case class UnsafeColumnarBuildSideRelation(
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
output = in.readObject().asInstanceOf[Seq[Attribute]]
- mode = in.readObject().asInstanceOf[BroadcastMode]
+ safeBroadcastMode = in.readObject().asInstanceOf[SafeBroadcastMode]
val totalArraySize = in.readInt()
val bytesBufferLengths = in.readObject().asInstanceOf[Array[Int]]
val totalBytes = in.readLong()
@@ -137,7 +191,7 @@ case class UnsafeColumnarBuildSideRelation(
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]]
- mode = kryo.readClassAndObject(in).asInstanceOf[BroadcastMode]
+ safeBroadcastMode =
kryo.readClassAndObject(in).asInstanceOf[SafeBroadcastMode]
val totalArraySize = in.readInt()
val bytesBufferLengths = kryo.readObject(in, classOf[Array[Int]])
val totalBytes = in.readLong()
@@ -152,11 +206,20 @@ case class UnsafeColumnarBuildSideRelation(
}
}
- private def transformProjection: UnsafeProjection = {
- mode match {
- case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k)
- case IdentityBroadcastMode => UnsafeProjection.create(output, output)
- }
+ private def transformProjection: UnsafeProjection = safeBroadcastMode match {
+ case IdentitySafeBroadcastMode =>
+ UnsafeProjection.create(output, output)
+ case HashSafeBroadcastMode(ords, _) =>
+ val bound = ords.map(i => BoundReference(i, output(i).dataType,
output(i).nullable))
+ UnsafeProjection.create(bound)
+ case HashExprSafeBroadcastMode(_, _) =>
+ exprKeysFromBytes match {
+ case Some(keys) => UnsafeProjection.create(keys)
+ case None =>
+ throw new IllegalStateException(
+ "Failed to deserialize expressions for HashExprSafeBroadcastMode"
+ )
+ }
}
override def deserialized: Iterator[ColumnarBatch] = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]