This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 76ea2ddef perf: Coalesce broadcast exchange batches before 
broadcasting (#3703)
76ea2ddef is described below

commit 76ea2ddef267ebc7c40bde02a7fa0d146f94ff6b
Author: Matt Butrovich <[email protected]>
AuthorDate: Mon Mar 16 18:40:12 2026 -0400

    perf: Coalesce broadcast exchange batches before broadcasting (#3703)
---
 .../org/apache/spark/sql/comet/util/Utils.scala    | 111 ++++++++++++++++++++-
 .../sql/comet/CometBroadcastExchangeExec.scala     |  17 +++-
 .../org/apache/comet/exec/CometExecSuite.scala     |  12 ++-
 .../org/apache/comet/exec/CometJoinSuite.scala     |  84 +++++++++++++---
 4 files changed, 205 insertions(+), 19 deletions(-)

diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala 
b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
index 6eaa9cad4..78f2e81c7 100644
--- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
+++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
@@ -26,13 +26,15 @@ import java.nio.channels.Channels
 import scala.jdk.CollectionConverters._
 
 import org.apache.arrow.c.CDataDictionaryProvider
-import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, 
DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, 
IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, 
TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, 
VarCharVector, VectorSchemaRoot}
+import org.apache.arrow.vector._
 import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
 import org.apache.arrow.vector.dictionary.DictionaryProvider
-import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
 import org.apache.arrow.vector.types._
 import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+import org.apache.arrow.vector.util.VectorSchemaRootAppender
 import org.apache.spark.{SparkEnv, SparkException}
+import org.apache.spark.internal.Logging
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator
 import org.apache.spark.sql.types._
@@ -43,7 +45,7 @@ import org.apache.comet.Constants.COMET_CONF_DIR_ENV
 import org.apache.comet.shims.CometTypeShim
 import org.apache.comet.vector.CometVector
 
-object Utils extends CometTypeShim {
+object Utils extends CometTypeShim with Logging {
   def getConfPath(confFileName: String): String = {
     sys.env
       .get(COMET_CONF_DIR_ENV)
@@ -232,6 +234,7 @@ object Utils extends CometTypeShim {
 
   /**
    * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
+   *
    * @param bytes
    *   the serialized batches
    * @param source
@@ -252,6 +255,108 @@ object Utils extends CometTypeShim {
     new ArrowReaderIterator(Channels.newChannel(ins), source)
   }
 
+  /**
+   * Coalesces many small Arrow IPC batches into a single batch for 
broadcasting.
+   *
+   * Why this is necessary: The broadcast exchange collects shuffle output by 
calling
+   * getByteArrayRdd, which serializes each ColumnarBatch independently into 
its own
+   * ChunkedByteBuffer. The shuffle reader (CometBlockStoreShuffleReader) 
produces one
+   * ColumnarBatch per shuffle block, and there is one block per writer task 
per output partition.
+   * So with W writer tasks and P output partitions, the broadcast collects up 
to W * P tiny
+   * batches. For example, with 400 writer tasks and 500 partitions, 1M rows 
would arrive as ~200K
+   * batches of ~5 rows each.
+   *
+   * Without coalescing, every consumer task in the broadcast join would 
independently deserialize
+   * all of these tiny Arrow IPC streams, paying per-stream overhead (schema 
parsing, buffer
+   * allocation) for each one. With coalescing, we decode and append all 
batches into one
+   * VectorSchemaRoot on the driver, then re-serialize once. Each consumer 
task then deserializes
+   * a single Arrow IPC stream.
+   */
+  def coalesceBroadcastBatches(
+      input: Iterator[ChunkedByteBuffer]): (Array[ChunkedByteBuffer], Long, 
Long) = {
+    val buffers = input.filterNot(_.size == 0).toArray
+    if (buffers.isEmpty) {
+      return (Array.empty, 0L, 0L)
+    }
+
+    val allocator = org.apache.comet.CometArrowAllocator
+      .newChildAllocator("broadcast-coalesce", 0, Long.MaxValue)
+    try {
+      var targetRoot: VectorSchemaRoot = null
+      var totalRows = 0L
+      var batchCount = 0
+
+      val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+      try {
+        for (bytes <- buffers) {
+          val compressedInputStream =
+            new 
DataInputStream(codec.compressedInputStream(bytes.toInputStream()))
+          val reader =
+            new ArrowStreamReader(Channels.newChannel(compressedInputStream), 
allocator)
+          try {
+            // Comet decodes dictionaries during execution, so this shouldn't 
happen.
+            // If it does, fall back to the original uncoalesced buffers 
because each
+            // partition can have a different dictionary, and appending index 
vectors
+            // would silently mix indices from incompatible dictionaries.
+            if (!reader.getDictionaryVectors.isEmpty) {
+              logWarning(
+                "Unexpected dictionary-encoded column during BroadcastExchange 
coalescing; " +
+                  "skipping coalesce")
+              reader.close()
+              if (targetRoot != null) {
+                targetRoot.close()
+                targetRoot = null
+              }
+              return (buffers, 0L, 0L)
+            }
+            while (reader.loadNextBatch()) {
+              val sourceRoot = reader.getVectorSchemaRoot
+              if (targetRoot == null) {
+                targetRoot = VectorSchemaRoot.create(sourceRoot.getSchema, 
allocator)
+                targetRoot.allocateNew()
+              }
+              VectorSchemaRootAppender.append(targetRoot, sourceRoot)
+              totalRows += sourceRoot.getRowCount
+              batchCount += 1
+            }
+          } finally {
+            reader.close()
+          }
+        }
+
+        if (targetRoot == null) {
+          return (Array.empty, 0L, 0L)
+        }
+
+        assert(
+          targetRoot.getRowCount.toLong == totalRows,
+          s"Row count mismatch after coalesce: ${targetRoot.getRowCount} != 
$totalRows")
+
+        logInfo(s"Coalesced $batchCount broadcast batches into 1 ($totalRows 
rows)")
+
+        val outputStream = new ChunkedByteBufferOutputStream(1024 * 1024, 
ByteBuffer.allocate)
+        val compressedOutputStream =
+          new DataOutputStream(codec.compressedOutputStream(outputStream))
+        val writer =
+          new ArrowStreamWriter(targetRoot, null, 
Channels.newChannel(compressedOutputStream))
+        try {
+          writer.start()
+          writer.writeBatch()
+        } finally {
+          writer.close()
+        }
+
+        (Array(outputStream.toChunkedByteBuffer), batchCount.toLong, totalRows)
+      } finally {
+        if (targetRoot != null) {
+          targetRoot.close()
+        }
+      }
+    } finally {
+      allocator.close()
+    }
+  }
+
   def getBatchFieldVectors(
       batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
     var provider: Option[DictionaryProvider] = None
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index f40e05ea0..4a323e575 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -77,7 +77,13 @@ case class CometBroadcastExchangeExec(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"),
     "collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to 
collect"),
     "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to 
build"),
-    "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to 
broadcast"))
+    "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to 
broadcast"),
+    "numCoalescedBatches" -> SQLMetrics.createMetric(
+      sparkContext,
+      "number of coalesced batches for broadcast"),
+    "numCoalescedRows" -> SQLMetrics.createMetric(
+      sparkContext,
+      "number of coalesced rows for broadcast"))
 
   override def doCanonicalize(): SparkPlan = {
     CometBroadcastExchangeExec(null, null, mode, child.canonicalized)
@@ -155,7 +161,14 @@ case class CometBroadcastExchangeExec(
         val beforeBuild = System.nanoTime()
         longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - 
beforeCollect)
 
-        val batches = input.toArray
+        // Coalesce the many small per-shuffle-block buffers into a single 
buffer.
+        // Without this, each consumer task deserializes one Arrow IPC stream 
per
+        // shuffle block (one per writer task per partition), which is very 
expensive
+        // when there are hundreds of writer tasks and partitions. See the 
scaladoc
+        // on coalesceBroadcastBatches for details.
+        val (batches, coalescedBatches, coalescedRows) = 
Utils.coalesceBroadcastBatches(input)
+        longMetric("numCoalescedBatches") += coalescedBatches
+        longMetric("numCoalescedRows") += coalescedRows
 
         val dataSize = batches.map(_.size).sum
 
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 0bf9bbc95..aff181626 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, 
ExpressionInfo, He
 import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, 
BloomFilterAggregate}
 import org.apache.spark.sql.comet._
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometShuffleExchangeExec}
-import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, 
SparkPlan, SQLExecution, UnionExec}
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
BroadcastQueryStageExec}
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
@@ -474,6 +474,10 @@ class CometExecSuite extends CometTestBase {
           val expected = (0 until numParts).flatMap(_ => (0 until 5).map(i => 
i + 1)).sorted
 
           assert(rowContents === expected)
+
+          val metrics = nativeBroadcast.metrics
+          assert(metrics("numCoalescedBatches").value == 5L)
+          assert(metrics("numCoalescedRows").value == 5L)
         }
       }
     }
@@ -493,6 +497,10 @@ class CometExecSuite extends CometTestBase {
           }.get.asInstanceOf[CometBroadcastExchangeExec]
           val rows = nativeBroadcast.executeCollect()
           assert(rows.isEmpty)
+
+          val metrics = nativeBroadcast.metrics
+          assert(metrics("numCoalescedBatches").value == 0L)
+          assert(metrics("numCoalescedRows").value == 0L)
         }
       }
     }
@@ -712,7 +720,7 @@ class CometExecSuite extends CometTestBase {
         assert(metrics.contains("build_time"))
         assert(metrics("build_time").value > 1L)
         assert(metrics.contains("build_input_batches"))
-        assert(metrics("build_input_batches").value == 25L)
+        assert(metrics("build_input_batches").value == 5L)
         assert(metrics.contains("build_mem_used"))
         assert(metrics("build_mem_used").value > 1L)
         assert(metrics.contains("build_input_rows"))
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
index d5a8387be..49fbe10c3 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.internal.SQLConf
 import org.apache.comet.CometConf
 
 class CometJoinSuite extends CometTestBase {
+
   import testImplicits._
 
   override protected def test(testName: String, testTags: Tag*)(testFun: => 
Any)(implicit
@@ -359,28 +360,87 @@ class CometJoinSuite extends CometTestBase {
       checkSparkAnswer(left.join(right, ($"left.N" === $"right.N") && 
($"right.N" =!= 3), "full"))
 
       checkSparkAnswer(sql("""
-            |SELECT l.a, count(*)
-            |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
-            |GROUP BY l.a
+          |SELECT l.a, count(*)
+          |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
+          |GROUP BY l.a
         """.stripMargin))
 
       checkSparkAnswer(sql("""
-            |SELECT r.N, count(*)
-            |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
-            |GROUP BY r.N
+          |SELECT r.N, count(*)
+          |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
+          |GROUP BY r.N
           """.stripMargin))
 
       checkSparkAnswer(sql("""
-            |SELECT l.N, count(*)
-            |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
-            |GROUP BY l.N
+          |SELECT l.N, count(*)
+          |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
+          |GROUP BY l.N
           """.stripMargin))
 
       checkSparkAnswer(sql("""
-            |SELECT r.a, count(*)
-            |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
-            |GROUP BY r.a
+          |SELECT r.a, count(*)
+          |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
+          |GROUP BY r.a
         """.stripMargin))
     }
   }
+
+  test("Broadcast hash join build-side batch coalescing") {
+    // Use many shuffle partitions to produce many small broadcast batches,
+    // then verify that coalescing reduces the build-side batch count to 1 per 
task.
+    val numPartitions = 512
+    withSQLConf(
+      CometConf.COMET_BATCH_SIZE.key -> "100",
+      SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
+      "spark.sql.join.forceApplyShuffledHashJoin" -> "true",
+      SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+      SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) {
+      withParquetTable((0 until 10000).map(i => (i, i % 5)), "tbl_a") {
+        withParquetTable((0 until 10000).map(i => (i % 10, i + 2)), "tbl_b") {
+          // Force a shuffle on tbl_a before broadcast so the broadcast source 
has
+          // numPartitions partitions, not just the number of parquet files.
+          val query =
+            s"""SELECT /*+ BROADCAST(a) */ *
+               |FROM (SELECT /*+ REPARTITION($numPartitions) */ * FROM tbl_a) a
+               |JOIN tbl_b ON a._2 = tbl_b._1""".stripMargin
+
+          val (_, cometPlan) = checkSparkAnswerAndOperator(
+            sql(query),
+            Seq(classOf[CometBroadcastExchangeExec], 
classOf[CometBroadcastHashJoinExec]))
+
+          val joins = collect(cometPlan) { case j: CometBroadcastHashJoinExec 
=>
+            j
+          }
+          assert(joins.nonEmpty, "Expected CometBroadcastHashJoinExec in plan")
+
+          val join = joins.head
+          val buildBatches = join.metrics("build_input_batches").value
+
+          // Without coalescing, build_input_batches would be ~numPartitions 
per task,
+          // totaling ~numPartitions * numPartitions across all tasks.
+          // With coalescing, each task gets 1 batch, so total ≈ numPartitions.
+          assert(
+            buildBatches <= numPartitions,
+            s"Expected at most $numPartitions build batches (1 per task), got 
$buildBatches. " +
+              "Broadcast batch coalescing may not be working.")
+
+          val broadcasts = collect(cometPlan) { case b: 
CometBroadcastExchangeExec =>
+            b
+          }
+          assert(broadcasts.nonEmpty, "Expected CometBroadcastExchangeExec in 
plan")
+
+          val broadcast = broadcasts.head
+          val coalescedBatches = broadcast.metrics("numCoalescedBatches").value
+          val coalescedRows = broadcast.metrics("numCoalescedRows").value
+
+          assert(
+            coalescedBatches >= numPartitions,
+            s"Expected at least $numPartitions coalesced batches, got 
$coalescedBatches")
+          assert(coalescedRows == 10000, s"Expected 10000 coalesced rows, got 
$coalescedRows")
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to