This is an automated email from the ASF dual-hosted git repository.

felixybw pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 91c52e15f1 [GLUTEN-9671][VL] Fix broadcast exchange stackoverflow due 
to Kryo serialization (#10541)
91c52e15f1 is described below

commit 91c52e15f16593747e918145258ebe1408cb8ea2
Author: Felix Loesing <[email protected]>
AuthorDate: Thu Aug 28 18:47:04 2025 -0700

    [GLUTEN-9671][VL] Fix broadcast exchange stackoverflow due to Kryo 
serialization (#10541)
    
    This pull request introduces a safer and more robust approach for handling 
Spark's BroadcastMode during serialization. The main improvement is the 
introduction of a new SafeBroadcastMode abstraction and related utilities, 
which help avoid serialization issues that caused a Stackoverflow exception 
during broadcast exchanges. BroadcastMode was introduced in this PR that caused 
the issue we observed. HashedRelationBroadcastMode embeds Catalyst expression 
trees, which are not safe to Kryo [...]
    
    With this change, the broadcast payload now contains only primitives and 
byte arrays (no Catalyst trees). For bound keys, we serialize just column 
ordinals (+ null-aware flag) and for computed keys (e.g., upper(col)), we 
serialize the key expressions once as Java bytes and deserialize only where 
needed to build projections.
---
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  |   2 +-
 .../spark/sql/execution/BroadcastModeUtils.scala   | 134 +++++++++++++++++++++
 .../spark/sql/execution/BroadcastUtils.scala       |   4 +-
 .../sql/execution/ColumnarBuildSideRelation.scala  |  57 +++++++--
 .../unsafe/UnsafeColumnarBuildSideRelation.scala   |  91 +++++++++++---
 5 files changed, 261 insertions(+), 27 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index c3ac63f767..af46ff2673 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -678,7 +678,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
     dataSize += rawSize
     if (useOffheapBroadcastBuildRelation) {
       TaskResources.runUnsafe {
-        new UnsafeColumnarBuildSideRelation(child.output, 
serialized.flatMap(_.getSerialized), mode)
+        UnsafeColumnarBuildSideRelation(child.output, 
serialized.flatMap(_.getSerialized), mode)
       }
     } else {
       ColumnarBuildSideRelation(child.output, 
serialized.flatMap(_.getSerialized), mode)
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastModeUtils.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastModeUtils.scala
new file mode 100644
index 0000000000..d0ae9a6832
--- /dev/null
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastModeUtils.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
Expression}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, 
IdentityBroadcastMode}
+import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, IOException, 
ObjectInputStream, ObjectOutputStream}
+
+/**
+ * Provides serialization-safe representations of BroadcastMode to avoid 
issues with circular
+ * references in complex expression trees during Kryo serialization.
+ */
+sealed trait SafeBroadcastMode extends Serializable
+
+/** Safe representation of IdentityBroadcastMode */
+case object IdentitySafeBroadcastMode extends SafeBroadcastMode
+
+/**
+ * Safe wrapper for HashedRelationBroadcastMode. Stores only column ordinals 
instead of full
+ * BoundReference expressions.
+ */
+final case class HashSafeBroadcastMode(ordinals: Array[Int], isNullAware: 
Boolean)
+  extends SafeBroadcastMode
+
+/**
+ * Safe wrapper for HashedRelationBroadcastMode when keys are not simple 
BoundReferences. Stores key
+ * expressions as serialized Java bytes.
+ */
+final case class HashExprSafeBroadcastMode(exprBytes: Array[Byte], 
isNullAware: Boolean)
+  extends SafeBroadcastMode
+
+object BroadcastModeUtils extends Logging {
+
+  /**
+   * Converts a BroadcastMode to its SafeBroadcastMode equivalent. Uses 
ordinals for simple
+   * BoundReferences, otherwise serializes the expressions.
+   */
+  private[execution] def toSafe(mode: BroadcastMode): SafeBroadcastMode = mode 
match {
+    case IdentityBroadcastMode =>
+      IdentitySafeBroadcastMode
+    case HashedRelationBroadcastMode(keys, isNullAware) =>
+      // Fast path: all keys are already BoundReference(i, ..,..).
+      val ords = keys.collect { case BoundReference(ord, _, _) => ord }
+      if (ords.size == keys.size) {
+        HashSafeBroadcastMode(ords.toArray, isNullAware)
+      } else {
+        // Fallback: store the key expressions as Java-serialized bytes.
+        HashExprSafeBroadcastMode(serializeExpressions(keys), isNullAware)
+      }
+
+    case other =>
+      throw new IllegalArgumentException(s"Unsupported BroadcastMode: $other")
+  }
+
+  /** Converts a SafeBroadcastMode to its BroadcastMode equivalent. */
+  private[execution] def fromSafe(safe: SafeBroadcastMode, output: 
Seq[Attribute]): BroadcastMode =
+    safe match {
+      case IdentitySafeBroadcastMode =>
+        IdentityBroadcastMode
+
+      case HashSafeBroadcastMode(ords, isNullAware) =>
+        val bound = ords.map(i => BoundReference(i, output(i).dataType, 
output(i).nullable)).toSeq
+        HashedRelationBroadcastMode(bound, isNullAware)
+
+      case HashExprSafeBroadcastMode(bytes, isNullAware) =>
+        HashedRelationBroadcastMode(deserializeExpressions(bytes), isNullAware)
+    }
+
+  // Helpers for expression serialization (used in HashExprSafeBroadcastMode)
+  private[execution] def serializeExpressions(keys: Seq[Expression]): 
Array[Byte] = {
+    val bos = new ByteArrayOutputStream()
+    var oos: ObjectOutputStream = null
+    try {
+      oos = new ObjectOutputStream(bos)
+      oos.writeObject(keys)
+      oos.flush()
+      bos.toByteArray
+    } catch {
+      case e @ (_: IOException | _: ClassNotFoundException | _: 
ClassCastException) =>
+        logError(
+          s"Failed to serialize expressions for BroadcastMode. Expression 
count: ${keys.length}",
+          e)
+        throw new RuntimeException("Failed to serialize expressions for 
BroadcastMode", e)
+      case e: Exception =>
+        logError(
+          s"Unexpected error during expression serialization. Expression 
count: ${keys.length}",
+          e)
+        throw e
+    } finally {
+      if (oos != null) oos.close()
+      bos.close()
+    }
+  }
+
+  private[execution] def deserializeExpressions(bytes: Array[Byte]): 
Seq[Expression] = {
+    val bis = new ByteArrayInputStream(bytes)
+    var ois: ObjectInputStream = null
+    try {
+      ois = new ObjectInputStream(bis)
+      ois.readObject().asInstanceOf[Seq[Expression]]
+    } catch {
+      case e @ (_: IOException | _: ClassNotFoundException | _: 
ClassCastException) =>
+        logError(
+          s"Failed to deserialize expressions for BroadcastMode. Data size: 
${bytes.length} bytes",
+          e)
+        throw new RuntimeException("Failed to deserialize expressions for 
BroadcastMode", e)
+      case e: Exception =>
+        logError(
+          s"Unexpected error during expression deserialization. Data size: 
${bytes.length} bytes",
+          e)
+        throw e
+    } finally {
+      if (ois != null) ois.close()
+      bis.close()
+    }
+  }
+}
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
index 155227db41..342c9694f0 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
@@ -108,7 +108,7 @@ object BroadcastUtils {
               result.getSerialized
           }
           if (useOffheapBuildRelation) {
-            new UnsafeColumnarBuildSideRelation(
+            UnsafeColumnarBuildSideRelation(
               SparkShimLoader.getSparkShims.attributesFromStruct(schema),
               serialized,
               mode)
@@ -134,7 +134,7 @@ object BroadcastUtils {
               result.getSerialized
           }
           if (useOffheapBuildRelation) {
-            new UnsafeColumnarBuildSideRelation(
+            UnsafeColumnarBuildSideRelation(
               SparkShimLoader.getSparkShims.attributesFromStruct(schema),
               serialized,
               mode)
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
index cd49ed30ea..59a9cb2b00 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
@@ -26,11 +26,9 @@ import org.apache.gluten.utils.ArrowAbiUtil
 import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, 
NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper}
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, 
BindReferences, BoundReference, Expression, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
-import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode
-import org.apache.spark.sql.execution.joins.BuildSideRelation
-import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
+import org.apache.spark.sql.execution.joins.{BuildSideRelation, 
HashedRelationBroadcastMode}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.utils.SparkArrowUtil
 import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -40,17 +38,56 @@ import org.apache.arrow.c.ArrowSchema
 
 import scala.collection.JavaConverters.asScalaIteratorConverter
 
+object ColumnarBuildSideRelation {
+  // Keep constructor with BroadcastMode for compatibility
+  def apply(
+      output: Seq[Attribute],
+      batches: Array[Array[Byte]],
+      mode: BroadcastMode): ColumnarBuildSideRelation = {
+    val boundMode = mode match {
+      case HashedRelationBroadcastMode(keys, isNullAware) =>
+        // Bind each key to the build-side output so simple cols become 
BoundReference
+        val boundKeys: Seq[Expression] =
+          keys.map(k => BindReferences.bindReference(k, AttributeSeq(output)))
+        HashedRelationBroadcastMode(boundKeys, isNullAware)
+      case m =>
+        m // IdentityBroadcastMode, etc.
+    }
+    new ColumnarBuildSideRelation(output, batches, 
BroadcastModeUtils.toSafe(boundMode))
+  }
+}
+
 case class ColumnarBuildSideRelation(
     output: Seq[Attribute],
     batches: Array[Array[Byte]],
-    mode: BroadcastMode)
+    safeBroadcastMode: SafeBroadcastMode)
   extends BuildSideRelation {
 
-  private def transformProjection: UnsafeProjection = {
-    mode match {
-      case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k)
-      case IdentityBroadcastMode => UnsafeProjection.create(output, output)
-    }
+  // Rebuild the real BroadcastMode on demand; never serialize it.
+  @transient override lazy val mode: BroadcastMode =
+    BroadcastModeUtils.fromSafe(safeBroadcastMode, output)
+
+  // If we stored expression bytes, deserialize once and cache locally (not 
serialized).
+  @transient private lazy val exprKeysFromBytes: Option[Seq[Expression]] = 
safeBroadcastMode match {
+    case HashExprSafeBroadcastMode(bytes, _) =>
+      Some(BroadcastModeUtils.deserializeExpressions(bytes))
+    case _ => None
+  }
+
+  private def transformProjection: UnsafeProjection = safeBroadcastMode match {
+    case IdentitySafeBroadcastMode =>
+      UnsafeProjection.create(output, output)
+    case HashSafeBroadcastMode(ords, _) =>
+      val bound = ords.map(i => BoundReference(i, output(i).dataType, 
output(i).nullable))
+      UnsafeProjection.create(bound)
+    case HashExprSafeBroadcastMode(_, _) =>
+      exprKeysFromBytes match {
+        case Some(keys) => UnsafeProjection.create(keys)
+        case None =>
+          throw new IllegalStateException(
+            "Failed to deserialize expressions for HashExprSafeBroadcastMode"
+          )
+      }
   }
 
   override def deserialized: Iterator[ColumnarBatch] = {
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
index c0ef884e73..80e92a1537 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
@@ -28,8 +28,9 @@ import 
org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeCo
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
UnsafeProjection, UnsafeRow}
-import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, 
IdentityBroadcastMode}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, 
BindReferences, BoundReference, Expression, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
+import org.apache.spark.sql.execution.{BroadcastModeUtils, 
HashExprSafeBroadcastMode, HashSafeBroadcastMode, IdentitySafeBroadcastMode, 
SafeBroadcastMode}
 import org.apache.spark.sql.execution.joins.{BuildSideRelation, 
HashedRelationBroadcastMode}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.utils.SparkArrowUtil
@@ -45,6 +46,44 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
 
 import scala.collection.JavaConverters.asScalaIteratorConverter
 
+object UnsafeColumnarBuildSideRelation {
+  // Keep constructors with BroadcastMode for compatibility
+  def apply(
+      output: Seq[Attribute],
+      batches: UnsafeBytesBufferArray,
+      mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
+    val boundMode = mode match {
+      case HashedRelationBroadcastMode(keys, isNullAware) =>
+        // Bind each key to the build-side output so simple cols become 
BoundReference
+        val boundKeys: Seq[Expression] =
+          keys.map(k => BindReferences.bindReference(k, AttributeSeq(output)))
+        HashedRelationBroadcastMode(boundKeys, isNullAware)
+      case m =>
+        m // IdentityBroadcastMode, etc.
+    }
+    new UnsafeColumnarBuildSideRelation(output, batches, 
BroadcastModeUtils.toSafe(boundMode))
+  }
+  def apply(
+      output: Seq[Attribute],
+      bytesBufferArray: Array[Array[Byte]],
+      mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
+    val boundMode = mode match {
+      case HashedRelationBroadcastMode(keys, isNullAware) =>
+        // Bind each key to the build-side output so simple cols become 
BoundReference
+        val boundKeys: Seq[Expression] =
+          keys.map(k => BindReferences.bindReference(k, AttributeSeq(output)))
+        HashedRelationBroadcastMode(boundKeys, isNullAware)
+      case m =>
+        m // IdentityBroadcastMode, etc.
+    }
+    new UnsafeColumnarBuildSideRelation(
+      output,
+      bytesBufferArray,
+      BroadcastModeUtils.toSafe(boundMode)
+    )
+  }
+}
+
 /**
  * A broadcast relation that is built using off-heap memory. It will avoid the 
on-heap memory OOM.
  *
@@ -59,18 +98,33 @@ import 
scala.collection.JavaConverters.asScalaIteratorConverter
 case class UnsafeColumnarBuildSideRelation(
     private var output: Seq[Attribute],
     private var batches: UnsafeBytesBufferArray,
-    var mode: BroadcastMode)
+    var safeBroadcastMode: SafeBroadcastMode)
   extends BuildSideRelation
   with Externalizable
   with Logging
   with KryoSerializable {
 
+  // Rebuild the real BroadcastMode on demand; never serialize it.
+  @transient override lazy val mode: BroadcastMode =
+    BroadcastModeUtils.fromSafe(safeBroadcastMode, output)
+
+  // If we stored expression bytes, deserialize once and cache locally (not 
serialized).
+  @transient private lazy val exprKeysFromBytes: Option[Seq[Expression]] = 
safeBroadcastMode match {
+    case HashExprSafeBroadcastMode(bytes, _) =>
+      Some(BroadcastModeUtils.deserializeExpressions(bytes))
+    case _ => None
+  }
+
   /** needed for serialization. */
   def this() = {
     this(null, null.asInstanceOf[UnsafeBytesBufferArray], null)
   }
 
-  def this(output: Seq[Attribute], bytesBufferArray: Array[Array[Byte]], mode: 
BroadcastMode) = {
+  def this(
+      output: Seq[Attribute],
+      bytesBufferArray: Array[Array[Byte]],
+      safeMode: SafeBroadcastMode
+  ) = {
     this(
       output,
       UnsafeBytesBufferArray(
@@ -78,7 +132,7 @@ case class UnsafeColumnarBuildSideRelation(
         bytesBufferArray.map(_.length),
         bytesBufferArray.map(_.length.toLong).sum
       ),
-      mode
+      safeMode
     )
     val batchesSize = bytesBufferArray.length
     for (i <- 0 until batchesSize) {
@@ -89,7 +143,7 @@ case class UnsafeColumnarBuildSideRelation(
 
   override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException 
{
     out.writeObject(output)
-    out.writeObject(mode)
+    out.writeObject(safeBroadcastMode)
     out.writeInt(batches.arraySize)
     out.writeObject(batches.bytesBufferLengths)
     out.writeLong(batches.totalBytes)
@@ -101,7 +155,7 @@ case class UnsafeColumnarBuildSideRelation(
 
   override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
     kryo.writeObject(out, output.toList)
-    kryo.writeClassAndObject(out, mode)
+    kryo.writeClassAndObject(out, safeBroadcastMode)
     out.writeInt(batches.arraySize)
     kryo.writeObject(out, batches.bytesBufferLengths)
     out.writeLong(batches.totalBytes)
@@ -113,7 +167,7 @@ case class UnsafeColumnarBuildSideRelation(
 
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
     output = in.readObject().asInstanceOf[Seq[Attribute]]
-    mode = in.readObject().asInstanceOf[BroadcastMode]
+    safeBroadcastMode = in.readObject().asInstanceOf[SafeBroadcastMode]
     val totalArraySize = in.readInt()
     val bytesBufferLengths = in.readObject().asInstanceOf[Array[Int]]
     val totalBytes = in.readLong()
@@ -137,7 +191,7 @@ case class UnsafeColumnarBuildSideRelation(
 
   override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
     output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]]
-    mode = kryo.readClassAndObject(in).asInstanceOf[BroadcastMode]
+    safeBroadcastMode = 
kryo.readClassAndObject(in).asInstanceOf[SafeBroadcastMode]
     val totalArraySize = in.readInt()
     val bytesBufferLengths = kryo.readObject(in, classOf[Array[Int]])
     val totalBytes = in.readLong()
@@ -152,11 +206,20 @@ case class UnsafeColumnarBuildSideRelation(
     }
   }
 
-  private def transformProjection: UnsafeProjection = {
-    mode match {
-      case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k)
-      case IdentityBroadcastMode => UnsafeProjection.create(output, output)
-    }
+  private def transformProjection: UnsafeProjection = safeBroadcastMode match {
+    case IdentitySafeBroadcastMode =>
+      UnsafeProjection.create(output, output)
+    case HashSafeBroadcastMode(ords, _) =>
+      val bound = ords.map(i => BoundReference(i, output(i).dataType, 
output(i).nullable))
+      UnsafeProjection.create(bound)
+    case HashExprSafeBroadcastMode(_, _) =>
+      exprKeysFromBytes match {
+        case Some(keys) => UnsafeProjection.create(keys)
+        case None =>
+          throw new IllegalStateException(
+            "Failed to deserialize expressions for HashExprSafeBroadcastMode"
+          )
+      }
   }
 
   override def deserialized: Iterator[ColumnarBatch] = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to