This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 639fa2fb minor: refactor decodeBatches to make private in broadcast 
exchange (#1195)
639fa2fb is described below

commit 639fa2fb450a760728bc6921ca58da64df0c64a7
Author: Andy Grove <[email protected]>
AuthorDate: Sun Dec 22 12:25:21 2024 -0700

    minor: refactor decodeBatches to make private in broadcast exchange (#1195)
---
 .../sql/comet/CometBroadcastExchangeExec.scala     | 24 +++++++++++++--
 .../org/apache/spark/sql/comet/operators.scala     | 35 ++--------------------
 .../org/apache/comet/exec/CometExecSuite.scala     | 33 --------------------
 3 files changed, 25 insertions(+), 67 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index ccf218cf..6bc519ab 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -19,6 +19,8 @@
 
 package org.apache.spark.sql.comet
 
+import java.io.DataInputStream
+import java.nio.channels.Channels
 import java.util.UUID
 import java.util.concurrent.{Future, TimeoutException, TimeUnit}
 
@@ -26,13 +28,15 @@ import scala.concurrent.{ExecutionContext, Promise}
 import scala.concurrent.duration.NANOSECONDS
 import scala.util.control.NonFatal
 
-import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
+import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, 
TaskContext}
 import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec
+import org.apache.spark.io.CompressionCodec
 import org.apache.spark.launcher.SparkLauncher
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, 
SQLExecution}
 import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, 
ShuffleQueryStageExec}
@@ -299,7 +303,23 @@ class CometBatchRDD(
   override def compute(split: Partition, context: TaskContext): 
Iterator[ColumnarBatch] = {
     val partition = split.asInstanceOf[CometBatchPartition]
     partition.value.value.toIterator
-      .flatMap(CometExec.decodeBatches(_, this.getClass.getSimpleName))
+      .flatMap(decodeBatches(_, this.getClass.getSimpleName))
+  }
+
+  /**
+   * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
+   */
+  private def decodeBatches(bytes: ChunkedByteBuffer, source: String): 
Iterator[ColumnarBatch] = {
+    if (bytes.size == 0) {
+      return Iterator.empty
+    }
+
+    // use Spark's compression codec (LZ4 by default) and not Comet's 
compression
+    val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+    val cbbis = bytes.toInputStream()
+    val ins = new DataInputStream(codec.compressedInputStream(cbbis))
+    // batches are in Arrow IPC format
+    new ArrowReaderIterator(Channels.newChannel(ins), source)
   }
 }
 
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 77188312..c70f7464 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -19,14 +19,12 @@
 
 package org.apache.spark.sql.comet
 
-import java.io.{ByteArrayOutputStream, DataInputStream}
-import java.nio.channels.Channels
+import java.io.ByteArrayOutputStream
 
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.io.CompressionCodec
+import org.apache.spark.TaskContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, 
Expression, NamedExpression, SortOrder}
@@ -34,7 +32,7 @@ import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, PartitioningCollection, UnknownPartitioning}
-import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, 
CometShuffleExchangeExec}
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
 import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode
 import org.apache.spark.sql.comet.util.Utils
 import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, 
ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, 
UnaryExecNode}
@@ -78,18 +76,6 @@ abstract class CometExec extends CometPlan {
   // outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec.
   override def outputPartitioning: Partitioning = 
originalPlan.outputPartitioning
 
-  /**
-   * Executes the Comet operator and returns the result as an iterator of 
ColumnarBatch.
-   */
-  def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = {
-    val countsAndBytes = CometExec.getByteArrayRdd(this).collect()
-    val total = countsAndBytes.map(_._1).sum
-    val rows = countsAndBytes.iterator
-      .flatMap(countAndBytes =>
-        CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName))
-    (total, rows)
-  }
-
   protected def setSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = {
     sparkPlan.children.foreach(setSubqueries(planId, _))
 
@@ -161,21 +147,6 @@ object CometExec {
       Utils.serializeBatches(iter)
     }
   }
-
-  /**
-   * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
-   */
-  def decodeBatches(bytes: ChunkedByteBuffer, source: String): 
Iterator[ColumnarBatch] = {
-    if (bytes.size == 0) {
-      return Iterator.empty
-    }
-
-    val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
-    val cbbis = bytes.toInputStream()
-    val ins = new DataInputStream(codec.compressedInputStream(cbbis))
-
-    new ArrowReaderIterator(Channels.newChannel(ins), source)
-  }
 }
 
 /**
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 10276953..90c3221e 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -22,8 +22,6 @@ package org.apache.comet.exec
 import java.sql.Date
 import java.time.{Duration, Period}
 
-import scala.collection.JavaConverters._
-import scala.collection.mutable
 import scala.util.Random
 
 import org.scalactic.source.Position
@@ -462,37 +460,6 @@ class CometExecSuite extends CometTestBase {
     }
   }
 
-  test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch 
results") {
-    assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 
3.4+")
-    withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "true") {
-      withParquetTable((0 until 50).map(i => (i, i + 1)), "tbl") {
-        val df = sql("SELECT _1 + 1, _2 + 2 FROM tbl WHERE _1 > 3")
-
-        val nativeProject = find(df.queryExecution.executedPlan) {
-          case _: CometProjectExec => true
-          case _ => false
-        }.get.asInstanceOf[CometProjectExec]
-
-        val (rows, batches) = nativeProject.executeColumnarCollectIterator()
-        assert(rows == 46)
-
-        val column1 = mutable.ArrayBuffer.empty[Int]
-        val column2 = mutable.ArrayBuffer.empty[Int]
-
-        batches.foreach(batch => {
-          batch.rowIterator().asScala.foreach { row =>
-            assert(row.numFields == 2)
-            column1 += row.getInt(0)
-            column2 += row.getInt(1)
-          }
-        })
-
-        assert(column1.toArray.sorted === (4 until 50).map(_ + 1).toArray)
-        assert(column2.toArray.sorted === (5 until 51).map(_ + 2).toArray)
-      }
-    }
-  }
-
   test("scalar subquery") {
     val dataTypes =
       Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to