This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 639fa2fb minor: refactor decodeBatches to make private in broadcast
exchange (#1195)
639fa2fb is described below
commit 639fa2fb450a760728bc6921ca58da64df0c64a7
Author: Andy Grove <[email protected]>
AuthorDate: Sun Dec 22 12:25:21 2024 -0700
minor: refactor decodeBatches to make private in broadcast exchange (#1195)
---
.../sql/comet/CometBroadcastExchangeExec.scala | 24 +++++++++++++--
.../org/apache/spark/sql/comet/operators.scala | 35 ++--------------------
.../org/apache/comet/exec/CometExecSuite.scala | 33 --------------------
3 files changed, 25 insertions(+), 67 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index ccf218cf..6bc519ab 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -19,6 +19,8 @@
package org.apache.spark.sql.comet
+import java.io.DataInputStream
+import java.nio.channels.Channels
import java.util.UUID
import java.util.concurrent.{Future, TimeoutException, TimeUnit}
@@ -26,13 +28,15 @@ import scala.concurrent.{ExecutionContext, Promise}
import scala.concurrent.duration.NANOSECONDS
import scala.util.control.NonFatal
-import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
+import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv,
TaskContext}
import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan,
SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
ShuffleQueryStageExec}
@@ -299,7 +303,23 @@ class CometBatchRDD(
override def compute(split: Partition, context: TaskContext):
Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometBatchPartition]
partition.value.value.toIterator
- .flatMap(CometExec.decodeBatches(_, this.getClass.getSimpleName))
+ .flatMap(decodeBatches(_, this.getClass.getSimpleName))
+ }
+
+ /**
+ * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
+ */
+ private def decodeBatches(bytes: ChunkedByteBuffer, source: String):
Iterator[ColumnarBatch] = {
+ if (bytes.size == 0) {
+ return Iterator.empty
+ }
+
+ // use Spark's compression codec (LZ4 by default) and not Comet's
compression
+ val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+ val cbbis = bytes.toInputStream()
+ val ins = new DataInputStream(codec.compressedInputStream(cbbis))
+ // batches are in Arrow IPC format
+ new ArrowReaderIterator(Channels.newChannel(ins), source)
}
}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 77188312..c70f7464 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -19,14 +19,12 @@
package org.apache.spark.sql.comet
-import java.io.{ByteArrayOutputStream, DataInputStream}
-import java.nio.channels.Channels
+import java.io.ByteArrayOutputStream
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.io.CompressionCodec
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet,
Expression, NamedExpression, SortOrder}
@@ -34,7 +32,7 @@ import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, PartitioningCollection, UnknownPartitioning}
-import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator,
CometShuffleExchangeExec}
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec,
ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan,
UnaryExecNode}
@@ -78,18 +76,6 @@ abstract class CometExec extends CometPlan {
// outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec.
override def outputPartitioning: Partitioning =
originalPlan.outputPartitioning
- /**
- * Executes the Comet operator and returns the result as an iterator of
ColumnarBatch.
- */
- def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = {
- val countsAndBytes = CometExec.getByteArrayRdd(this).collect()
- val total = countsAndBytes.map(_._1).sum
- val rows = countsAndBytes.iterator
- .flatMap(countAndBytes =>
- CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName))
- (total, rows)
- }
-
protected def setSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = {
sparkPlan.children.foreach(setSubqueries(planId, _))
@@ -161,21 +147,6 @@ object CometExec {
Utils.serializeBatches(iter)
}
}
-
- /**
- * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
- */
- def decodeBatches(bytes: ChunkedByteBuffer, source: String):
Iterator[ColumnarBatch] = {
- if (bytes.size == 0) {
- return Iterator.empty
- }
-
- val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
- val cbbis = bytes.toInputStream()
- val ins = new DataInputStream(codec.compressedInputStream(cbbis))
-
- new ArrowReaderIterator(Channels.newChannel(ins), source)
- }
}
/**
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 10276953..90c3221e 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -22,8 +22,6 @@ package org.apache.comet.exec
import java.sql.Date
import java.time.{Duration, Period}
-import scala.collection.JavaConverters._
-import scala.collection.mutable
import scala.util.Random
import org.scalactic.source.Position
@@ -462,37 +460,6 @@ class CometExecSuite extends CometTestBase {
}
}
- test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch
results") {
- assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark
3.4+")
- withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "true") {
- withParquetTable((0 until 50).map(i => (i, i + 1)), "tbl") {
- val df = sql("SELECT _1 + 1, _2 + 2 FROM tbl WHERE _1 > 3")
-
- val nativeProject = find(df.queryExecution.executedPlan) {
- case _: CometProjectExec => true
- case _ => false
- }.get.asInstanceOf[CometProjectExec]
-
- val (rows, batches) = nativeProject.executeColumnarCollectIterator()
- assert(rows == 46)
-
- val column1 = mutable.ArrayBuffer.empty[Int]
- val column2 = mutable.ArrayBuffer.empty[Int]
-
- batches.foreach(batch => {
- batch.rowIterator().asScala.foreach { row =>
- assert(row.numFields == 2)
- column1 += row.getInt(0)
- column2 += row.getInt(1)
- }
- })
-
- assert(column1.toArray.sorted === (4 until 50).map(_ + 1).toArray)
- assert(column2.toArray.sorted === (5 until 51).map(_ + 2).toArray)
- }
- }
- }
-
test("scalar subquery") {
val dataTypes =
Seq(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]