This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 60fe431 feat: Add CometRowToColumnar operator (#206)
60fe431 is described below
commit 60fe4315e793d7efb6a8de9246257c1b3c623766
Author: advancedxy <[email protected]>
AuthorDate: Wed Apr 10 11:59:58 2024 +0800
feat: Add CometRowToColumnar operator (#206)
---
.../main/scala/org/apache/comet/CometConf.scala | 20 +
.../scala/org/apache/comet/vector/NativeUtil.scala | 17 +
.../org/apache/comet/vector/StreamReader.scala | 13 +-
.../sql/comet/execution/arrow/ArrowWriters.scala | 472 +++++++++++++++++++++
.../execution/arrow/CometArrowConverters.scala | 131 ++++++
.../org/apache/spark/sql/comet/util/Utils.scala | 15 +
.../apache/comet/CometSparkSessionExtensions.scala | 56 ++-
.../org/apache/comet/serde/QueryPlanSerde.scala | 3 +-
.../spark/sql/comet/CometRowToColumnarExec.scala | 84 ++++
.../org/apache/spark/sql/comet/operators.scala | 5 +-
.../org/apache/comet/exec/CometExecSuite.scala | 54 ++-
.../scala/org/apache/spark/sql/CometTestBase.scala | 15 +-
12 files changed, 856 insertions(+), 29 deletions(-)
diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala
b/common/src/main/scala/org/apache/comet/CometConf.scala
index bd2e04d..341ec98 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -337,6 +337,26 @@ object CometConf {
"enabled when reading from Iceberg tables.")
.booleanConf
.createWithDefault(false)
+
+ val COMET_ROW_TO_COLUMNAR_ENABLED: ConfigEntry[Boolean] =
+ conf("spark.comet.rowToColumnar.enabled")
+ .internal()
+ .doc("""
+ |Whether to enable row to columnar conversion in Comet. When this is
turned on, Comet will
+ |convert row-based operators in
`spark.comet.rowToColumnar.supportedOperatorList` into
+ |columnar based before processing.""".stripMargin)
+ .booleanConf
+ .createWithDefault(false)
+
+ val COMET_ROW_TO_COLUMNAR_SUPPORTED_OPERATOR_LIST: ConfigEntry[Seq[String]] =
+ conf("spark.comet.rowToColumnar.supportedOperatorList")
+ .doc(
+ "A comma-separated list of row-based operators that will be converted
to columnar " +
+ "format when 'spark.comet.rowToColumnar.enabled' is true")
+ .stringConf
+ .toSequence
+ .createWithDefault(Seq("Range,InMemoryTableScan"))
+
}
object ConfigHelpers {
diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
index 3756da9..763ccff 100644
--- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
+++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
@@ -23,6 +23,8 @@ import scala.collection.mutable
import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema,
CDataDictionaryProvider, Data}
import org.apache.arrow.memory.RootAllocator
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.spark.SparkException
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -132,3 +134,18 @@ class NativeUtil {
new ColumnarBatch(arrayVectors.toArray, maxNumRows)
}
}
+
+object NativeUtil {
+ def rootAsBatch(arrowRoot: VectorSchemaRoot): ColumnarBatch = {
+ rootAsBatch(arrowRoot, null)
+ }
+
+ def rootAsBatch(arrowRoot: VectorSchemaRoot, provider: DictionaryProvider):
ColumnarBatch = {
+ val vectors = (0 until arrowRoot.getFieldVectors.size()).map { i =>
+ val vector = arrowRoot.getFieldVectors.get(i)
+ // Native shuffle always uses decimal128.
+ CometVector.getVector(vector, true, provider)
+ }
+ new ColumnarBatch(vectors.toArray, arrowRoot.getRowCount)
+ }
+}
diff --git a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
index da72383..61d800b 100644
--- a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
+++ b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
@@ -21,13 +21,11 @@ package org.apache.comet.vector
import java.nio.channels.ReadableByteChannel
-import scala.collection.JavaConverters.collectionAsScalaIterableConverter
-
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ReadChannel}
import org.apache.arrow.vector.ipc.message.MessageChannelReader
-import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+import org.apache.spark.sql.vectorized.ColumnarBatch
/**
* A reader that consumes Arrow data from an input channel, and produces Comet
batches.
@@ -47,14 +45,7 @@ case class StreamReader(channel: ReadableByteChannel)
extends AutoCloseable {
}
private def rootAsBatch(root: VectorSchemaRoot): ColumnarBatch = {
- val columns = root.getFieldVectors.asScala.map { vector =>
- // Native shuffle always uses decimal128.
- CometVector.getVector(vector, true,
arrowReader).asInstanceOf[ColumnVector]
- }.toArray
-
- val batch = new ColumnarBatch(columns)
- batch.setNumRows(root.getRowCount)
- batch
+ NativeUtil.rootAsBatch(root, arrowReader)
}
override def close(): Unit = {
diff --git
a/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala
new file mode 100644
index 0000000..8d9f373
--- /dev/null
+++
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala
@@ -0,0 +1,472 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.arrow
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.complex._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.comet.util.Utils
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+
+/**
+ * This file is mostly copied from Spark SQL's
+ * org.apache.spark.sql.execution.arrow.ArrowWriter.scala. Comet shadows Arrow
classes to avoid
+ * potential conflicts with Spark's Arrow dependencies, hence we cannot reuse
Spark's ArrowWriter
+ * directly.
+ */
+private[arrow] object ArrowWriter {
+ def create(root: VectorSchemaRoot): ArrowWriter = {
+ val children = root.getFieldVectors().asScala.map { vector =>
+ vector.allocateNew()
+ createFieldWriter(vector)
+ }
+ new ArrowWriter(root, children.toArray)
+ }
+
+ private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
+ val field = vector.getField()
+ (Utils.fromArrowField(field), vector) match {
+ case (BooleanType, vector: BitVector) => new BooleanWriter(vector)
+ case (ByteType, vector: TinyIntVector) => new ByteWriter(vector)
+ case (ShortType, vector: SmallIntVector) => new ShortWriter(vector)
+ case (IntegerType, vector: IntVector) => new IntegerWriter(vector)
+ case (LongType, vector: BigIntVector) => new LongWriter(vector)
+ case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
+ case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
+ case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
+ new DecimalWriter(vector, precision, scale)
+ case (StringType, vector: VarCharVector) => new StringWriter(vector)
+ case (StringType, vector: LargeVarCharVector) => new
LargeStringWriter(vector)
+ case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
+ case (BinaryType, vector: LargeVarBinaryVector) => new
LargeBinaryWriter(vector)
+ case (DateType, vector: DateDayVector) => new DateWriter(vector)
+ case (TimestampType, vector: TimeStampMicroTZVector) => new
TimestampWriter(vector)
+ case (TimestampNTZType, vector: TimeStampMicroVector) => new
TimestampNTZWriter(vector)
+ case (ArrayType(_, _), vector: ListVector) =>
+ val elementVector = createFieldWriter(vector.getDataVector())
+ new ArrayWriter(vector, elementVector)
+ case (MapType(_, _, _), vector: MapVector) =>
+ val structVector = vector.getDataVector.asInstanceOf[StructVector]
+ val keyWriter =
createFieldWriter(structVector.getChild(MapVector.KEY_NAME))
+ val valueWriter =
createFieldWriter(structVector.getChild(MapVector.VALUE_NAME))
+ new MapWriter(vector, structVector, keyWriter, valueWriter)
+ case (StructType(_), vector: StructVector) =>
+ val children = (0 until vector.size()).map { ordinal =>
+ createFieldWriter(vector.getChildByOrdinal(ordinal))
+ }
+ new StructWriter(vector, children.toArray)
+ case (NullType, vector: NullVector) => new NullWriter(vector)
+ case (_: YearMonthIntervalType, vector: IntervalYearVector) =>
+ new IntervalYearWriter(vector)
+ case (_: DayTimeIntervalType, vector: DurationVector) => new
DurationWriter(vector)
+// case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) =>
+// new IntervalMonthDayNanoWriter(vector)
+ case (dt, _) =>
+ throw QueryExecutionErrors.notSupportTypeError(dt)
+ }
+ }
+}
+
+class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter])
{
+
+ def schema: StructType = Utils.fromArrowSchema(root.getSchema())
+
+ private var count: Int = 0
+
+ def write(row: InternalRow): Unit = {
+ var i = 0
+ while (i < fields.length) {
+ fields(i).write(row, i)
+ i += 1
+ }
+ count += 1
+ }
+
+ def finish(): Unit = {
+ root.setRowCount(count)
+ fields.foreach(_.finish())
+ }
+
+ def reset(): Unit = {
+ root.setRowCount(0)
+ count = 0
+ fields.foreach(_.reset())
+ }
+}
+
+private[arrow] abstract class ArrowFieldWriter {
+
+ def valueVector: ValueVector
+
+ def name: String = valueVector.getField().getName()
+ def dataType: DataType = Utils.fromArrowField(valueVector.getField())
+ def nullable: Boolean = valueVector.getField().isNullable()
+
+ def setNull(): Unit
+ def setValue(input: SpecializedGetters, ordinal: Int): Unit
+
+ private[arrow] var count: Int = 0
+
+ def write(input: SpecializedGetters, ordinal: Int): Unit = {
+ if (input.isNullAt(ordinal)) {
+ setNull()
+ } else {
+ setValue(input, ordinal)
+ }
+ count += 1
+ }
+
+ def finish(): Unit = {
+ valueVector.setValueCount(count)
+ }
+
+ def reset(): Unit = {
+ valueVector.reset()
+ count = 0
+ }
+}
+
+private[arrow] class BooleanWriter(val valueVector: BitVector) extends
ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0)
+ }
+}
+
+private[arrow] class ByteWriter(val valueVector: TinyIntVector) extends
ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getByte(ordinal))
+ }
+}
+
+private[arrow] class ShortWriter(val valueVector: SmallIntVector) extends
ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getShort(ordinal))
+ }
+}
+
+private[arrow] class IntegerWriter(val valueVector: IntVector) extends
ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getInt(ordinal))
+ }
+}
+
+private[arrow] class LongWriter(val valueVector: BigIntVector) extends
ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getLong(ordinal))
+ }
+}
+
+private[arrow] class FloatWriter(val valueVector: Float4Vector) extends
ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getFloat(ordinal))
+ }
+}
+
+private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends
ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getDouble(ordinal))
+ }
+}
+
+private[arrow] class DecimalWriter(val valueVector: DecimalVector, precision:
Int, scale: Int)
+ extends ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val decimal = input.getDecimal(ordinal, precision, scale)
+ if (decimal.changePrecision(precision, scale)) {
+ valueVector.setSafe(count, decimal.toJavaBigDecimal)
+ } else {
+ setNull()
+ }
+ }
+}
+
+private[arrow] class StringWriter(val valueVector: VarCharVector) extends
ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val utf8 = input.getUTF8String(ordinal)
+ val utf8ByteBuffer = utf8.getByteBuffer
+ // todo: for off-heap UTF8String, how to pass in to arrow without copy?
+ valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(),
utf8.numBytes())
+ }
+}
+
+private[arrow] class LargeStringWriter(val valueVector: LargeVarCharVector)
+ extends ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val utf8 = input.getUTF8String(ordinal)
+ val utf8ByteBuffer = utf8.getByteBuffer
+ // todo: for off-heap UTF8String, how to pass in to arrow without copy?
+ valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(),
utf8.numBytes())
+ }
+}
+
+private[arrow] class BinaryWriter(val valueVector: VarBinaryVector) extends
ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val bytes = input.getBinary(ordinal)
+ valueVector.setSafe(count, bytes, 0, bytes.length)
+ }
+}
+
+private[arrow] class LargeBinaryWriter(val valueVector: LargeVarBinaryVector)
+ extends ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val bytes = input.getBinary(ordinal)
+ valueVector.setSafe(count, bytes, 0, bytes.length)
+ }
+}
+
+private[arrow] class DateWriter(val valueVector: DateDayVector) extends
ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getInt(ordinal))
+ }
+}
+
+private[arrow] class TimestampWriter(val valueVector: TimeStampMicroTZVector)
+ extends ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getLong(ordinal))
+ }
+}
+
+private[arrow] class TimestampNTZWriter(val valueVector: TimeStampMicroVector)
+ extends ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getLong(ordinal))
+ }
+}
+
+private[arrow] class ArrayWriter(val valueVector: ListVector, val
elementWriter: ArrowFieldWriter)
+ extends ArrowFieldWriter {
+
+ override def setNull(): Unit = {}
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val array = input.getArray(ordinal)
+ var i = 0
+ valueVector.startNewValue(count)
+ while (i < array.numElements()) {
+ elementWriter.write(array, i)
+ i += 1
+ }
+ valueVector.endValue(count, array.numElements())
+ }
+
+ override def finish(): Unit = {
+ super.finish()
+ elementWriter.finish()
+ }
+
+ override def reset(): Unit = {
+ super.reset()
+ elementWriter.reset()
+ }
+}
+
+private[arrow] class StructWriter(
+ val valueVector: StructVector,
+ children: Array[ArrowFieldWriter])
+ extends ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ var i = 0
+ while (i < children.length) {
+ children(i).setNull()
+ children(i).count += 1
+ i += 1
+ }
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val struct = input.getStruct(ordinal, children.length)
+ var i = 0
+ valueVector.setIndexDefined(count)
+ while (i < struct.numFields) {
+ children(i).write(struct, i)
+ i += 1
+ }
+ }
+
+ override def finish(): Unit = {
+ super.finish()
+ children.foreach(_.finish())
+ }
+
+ override def reset(): Unit = {
+ super.reset()
+ children.foreach(_.reset())
+ }
+}
+
+private[arrow] class MapWriter(
+ val valueVector: MapVector,
+ val structVector: StructVector,
+ val keyWriter: ArrowFieldWriter,
+ val valueWriter: ArrowFieldWriter)
+ extends ArrowFieldWriter {
+
+ override def setNull(): Unit = {}
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val map = input.getMap(ordinal)
+ valueVector.startNewValue(count)
+ val keys = map.keyArray()
+ val values = map.valueArray()
+ var i = 0
+ while (i < map.numElements()) {
+ structVector.setIndexDefined(keyWriter.count)
+ keyWriter.write(keys, i)
+ valueWriter.write(values, i)
+ i += 1
+ }
+
+ valueVector.endValue(count, map.numElements())
+ }
+
+ override def finish(): Unit = {
+ super.finish()
+ keyWriter.finish()
+ valueWriter.finish()
+ }
+
+ override def reset(): Unit = {
+ super.reset()
+ keyWriter.reset()
+ valueWriter.reset()
+ }
+}
+
+private[arrow] class NullWriter(val valueVector: NullVector) extends
ArrowFieldWriter {
+
+ override def setNull(): Unit = {}
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {}
+}
+
+private[arrow] class IntervalYearWriter(val valueVector: IntervalYearVector)
+ extends ArrowFieldWriter {
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getInt(ordinal));
+ }
+}
+
+private[arrow] class DurationWriter(val valueVector: DurationVector) extends
ArrowFieldWriter {
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getLong(ordinal))
+ }
+}
+
+private[arrow] class IntervalMonthDayNanoWriter(val valueVector:
IntervalMonthDayNanoVector)
+ extends ArrowFieldWriter {
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val ci = input.getInterval(ordinal)
+ valueVector.setSafe(count, ci.months, ci.days,
Math.multiplyExact(ci.microseconds, 1000L))
+ }
+}
diff --git
a/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala
new file mode 100644
index 0000000..9dbd8dc
--- /dev/null
+++
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.arrow
+
+import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.comet.util.Utils
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import org.apache.comet.vector.NativeUtil
+
+object CometArrowConverters extends Logging {
+ // TODO: we should reuse the same root allocator in the comet code base?
+ val rootAllocator: BufferAllocator = new RootAllocator(Long.MaxValue)
+
+ // This is similar how Spark converts internal row to Arrow format except
that it is transforming
+ // the result batch to Comet's ColumnarBatch instead of serialized bytes.
+ // There's another big difference that Comet may consume the ColumnarBatch
by exporting it to
+ // the native side. Hence, we need to:
+ // 1. reset the Arrow writer after the ColumnarBatch is consumed
+ // 2. close the allocator when the task is finished but not when the
iterator is all consumed
+ // The reason for the second point is that when ColumnarBatch is exported to
the native side, the
+ // exported process increases the reference count of the Arrow vectors. The
reference count is
+ // only decreased when the native plan is done with the vectors, which is
usually longer than
+ // all the ColumnarBatches are consumed.
+ private[sql] class ArrowBatchIterator(
+ rowIter: Iterator[InternalRow],
+ schema: StructType,
+ maxRecordsPerBatch: Long,
+ timeZoneId: String,
+ context: TaskContext)
+ extends Iterator[ColumnarBatch]
+ with AutoCloseable {
+
+ private val arrowSchema = Utils.toArrowSchema(schema, timeZoneId)
+ // Reuse the same root allocator here.
+ private val allocator =
+ rootAllocator.newChildAllocator(s"to${this.getClass.getSimpleName}", 0,
Long.MaxValue)
+ private val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ private val arrowWriter = ArrowWriter.create(root)
+
+ private var currentBatch: ColumnarBatch = null
+ private var closed: Boolean = false
+
+ Option(context).foreach {
+ _.addTaskCompletionListener[Unit] { _ =>
+ close(true)
+ }
+ }
+
+ override def hasNext: Boolean = rowIter.hasNext || {
+ close(false)
+ false
+ }
+
+ override def next(): ColumnarBatch = {
+ currentBatch = nextBatch()
+ currentBatch
+ }
+
+ override def close(): Unit = {
+ close(false)
+ }
+
+ private def nextBatch(): ColumnarBatch = {
+ if (rowIter.hasNext) {
+ // the arrow writer shall be reset before writing the next batch
+ arrowWriter.reset()
+ var rowCount = 0L
+ while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount <
maxRecordsPerBatch)) {
+ val row = rowIter.next()
+ arrowWriter.write(row)
+ rowCount += 1
+ }
+ arrowWriter.finish()
+ NativeUtil.rootAsBatch(root)
+ } else {
+ null
+ }
+ }
+
+ private def close(closeAllocator: Boolean): Unit = {
+ try {
+ if (!closed) {
+ if (currentBatch != null) {
+ arrowWriter.reset()
+ currentBatch.close()
+ currentBatch = null
+ }
+ root.close()
+ closed = true
+ }
+ } finally {
+ // the allocator shall be closed when the task is finished
+ if (closeAllocator) {
+ allocator.close()
+ }
+ }
+ }
+ }
+
+ def toArrowBatchIterator(
+ rowIter: Iterator[InternalRow],
+ schema: StructType,
+ maxRecordsPerBatch: Long,
+ timeZoneId: String,
+ context: TaskContext): Iterator[ColumnarBatch] = {
+ new ArrowBatchIterator(rowIter, schema, maxRecordsPerBatch, timeZoneId,
context)
+ }
+}
diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
index 684d778..7d920e1 100644
--- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
+++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
@@ -54,6 +54,11 @@ object Utils {
str.split(",").map(_.trim()).filter(_.nonEmpty)
}
+ /** bridges the function call to Spark's Util */
+ def getSimpleName(cls: Class[_]): String = {
+ org.apache.spark.util.Utils.getSimpleName(cls)
+ }
+
def fromArrowField(field: Field): DataType = {
field.getType match {
case _: ArrowType.Map =>
@@ -90,6 +95,9 @@ object Utils {
case _: ArrowType.FixedSizeBinary => BinaryType
case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
+ case ts: ArrowType.Timestamp
+ if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null =>
+ TimestampNTZType
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND =>
TimestampType
case ArrowType.Null.INSTANCE => NullType
case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH =>
@@ -98,6 +106,13 @@ object Utils {
case _ => throw new UnsupportedOperationException(s"Unsupported data type:
${dt.toString}")
}
+ def fromArrowSchema(schema: Schema): StructType = {
+ StructType(schema.getFields.asScala.map { field =>
+ val dt = fromArrowField(field)
+ StructField(field.getName, dt, field.isNullable)
+ }.toArray)
+ }
+
/** Maps data type from Spark to Arrow. NOTE: timeZoneId required for
TimestampTypes */
def toArrowType(dt: DataType, timeZoneId: String): ArrowType =
dt match {
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 7795194..a10ac57 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometNativeShuffle}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
+import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec,
ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -43,7 +44,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.comet.CometConf._
-import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled,
isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled,
isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled,
isCometShuffleEnabled, isSchemaSupported}
+import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled,
isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled,
isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled,
isCometShuffleEnabled, isSchemaSupported, shouldApplyRowToColumnar}
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde
@@ -68,7 +69,7 @@ class CometSparkSessionExtensions
override def preColumnarTransitions: Rule[SparkPlan] =
CometExecRule(session)
override def postColumnarTransitions: Rule[SparkPlan] =
- EliminateRedundantColumnarToRow(session)
+ EliminateRedundantTransitions(session)
}
case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] {
@@ -238,6 +239,11 @@ class CometSparkSessionExtensions
val nativeOp = QueryPlanSerde.operator2Proto(op).get
CometScanWrapper(nativeOp, op)
+ case op if shouldApplyRowToColumnar(conf, op) =>
+ val cometOp = CometRowToColumnarExec(op)
+ val nativeOp = QueryPlanSerde.operator2Proto(cometOp).get
+ CometScanWrapper(nativeOp, cometOp)
+
case op: ProjectExec =>
val newOp = transform1(op)
newOp match {
@@ -592,18 +598,32 @@ class CometSparkSessionExtensions
}
}
- // CometExec already wraps a `ColumnarToRowExec` for row-based operators.
Therefore,
- // `ColumnarToRowExec` is redundant and can be eliminated.
+ // This rule is responsible for eliminating redundant transitions between
row-based and
+ // columnar-based operators for Comet. Currently, two potential redundant
transitions are:
+ // 1. `ColumnarToRowExec` on top of an ending `CometCollectLimitExec`
operator, which is
+ // redundant as `CometCollectLimitExec` already wraps a
`ColumnarToRowExec` for row-based
+ // output.
+ // 2. Consecutive operators of `CometRowToColumnarExec` and
`ColumnarToRowExec`.
+ //
+ // Note about the first case: The `ColumnarToRowExec` was added during
+ // ApplyColumnarRulesAndInsertTransitions' insertTransitions phase when
Spark requests row-based
+ // output such as a `collect` call. It's correct to add a redundant
`ColumnarToRowExec` for
+ // `CometExec`. However, for certain operators such as
`CometCollectLimitExec` which overrides
+ // `executeCollect`, the redundant `ColumnarToRowExec` makes the override
ineffective.
//
- // It was added during ApplyColumnarRulesAndInsertTransitions'
insertTransitions phase when Spark
- // requests row-based output such as `collect` call. It's correct to add a
redundant
- // `ColumnarToRowExec` for `CometExec`. However, for certain operators such
as
- // `CometCollectLimitExec` which overrides `executeCollect`, the redundant
`ColumnarToRowExec`
- // makes the override ineffective. The purpose of this rule is to eliminate
the redundant
- // `ColumnarToRowExec` for such operators.
- case class EliminateRedundantColumnarToRow(session: SparkSession) extends
Rule[SparkPlan] {
+ // Note about the second case: When `spark.comet.rowToColumnar.enabled` is
set, Comet will add
+ // `CometRowToColumnarExec` on top of row-based operators first, but the
downstream operator
+ // only takes row-based input as it's a vanilla Spark operator(as Comet
cannot convert it for
+ // various reasons) or Spark requests row-based output such as a `collect`
call. Spark will adds
+ // another `ColumnarToRowExec` on top of `CometRowToColumnarExec`. In this
case, the pair could
+ // be removed.
+ case class EliminateRedundantTransitions(session: SparkSession) extends
Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
- plan match {
+ val eliminatedPlan = plan transformUp {
+ case ColumnarToRowExec(rowToColumnar: CometRowToColumnarExec) =>
rowToColumnar.child
+ }
+
+ eliminatedPlan match {
case ColumnarToRowExec(child: CometCollectLimitExec) =>
child
case other =>
@@ -716,6 +736,18 @@ object CometSparkSessionExtensions extends Logging {
op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
}
+ private def shouldApplyRowToColumnar(conf: SQLConf, op: SparkPlan): Boolean
= {
+ // Only consider converting leaf nodes to columnar currently, so that all
the following
+ // operators can have a chance to be converted to columnar.
+ // TODO: consider converting other intermediate operators to columnar.
+ op.isInstanceOf[LeafExecNode] && !op.supportsColumnar &&
isSchemaSupported(op.schema) &&
+ COMET_ROW_TO_COLUMNAR_ENABLED.get(conf) && {
+ val simpleClassName = Utils.getSimpleName(op.getClass)
+ val nodeName = simpleClassName.replaceAll("Exec$", "")
+
COMET_ROW_TO_COLUMNAR_SUPPORTED_OPERATOR_LIST.get(conf).contains(nodeName)
+ }
+ }
+
/** Used for operations that weren't available in Spark 3.2 */
def isSpark32: Boolean = {
org.apache.spark.SPARK_VERSION.matches("3\\.2\\..*")
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index b98c438..26fc708 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight,
NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
-import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometSinkPlaceHolder, DecimalPrecision}
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometRowToColumnarExec, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
@@ -2064,6 +2064,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
private def isCometSink(op: SparkPlan): Boolean = {
op match {
case s if isCometScan(s) => true
+ case _: CometRowToColumnarExec => true
case _: CometSinkPlaceHolder => true
case _: CoalesceExec => true
case _: CollectLimitExec => true
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometRowToColumnarExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometRowToColumnarExec.scala
new file mode 100644
index 0000000..5679e86
--- /dev/null
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/CometRowToColumnarExec.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet
+
+import org.apache.spark.TaskContext
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters
+import org.apache.spark.sql.execution.{RowToColumnarTransition, SparkPlan}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+case class CometRowToColumnarExec(child: SparkPlan)
+ extends RowToColumnarTransition
+ with CometPlan {
+ override def output: Seq[Attribute] = child.output
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute()
+ }
+
+ override def doExecuteBroadcast[T](): Broadcast[T] = {
+ child.executeBroadcast()
+ }
+
+ override def supportsColumnar: Boolean = true
+
+ override lazy val metrics: Map[String, SQLMetric] = Map(
+ "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input
rows"),
+ "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of
output batches"))
+
+ override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ val numInputRows = longMetric("numInputRows")
+ val numOutputBatches = longMetric("numOutputBatches")
+ val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch
+ val timeZoneId = conf.sessionLocalTimeZone
+ val schema = child.schema
+
+ child
+ .execute()
+ .mapPartitionsInternal { iter =>
+ val context = TaskContext.get()
+ CometArrowConverters.toArrowBatchIterator(
+ iter,
+ schema,
+ maxRecordsPerBatch,
+ timeZoneId,
+ context)
+ }
+ .map { batch =>
+ numInputRows += batch.numRows()
+ numOutputBatches += 1
+ batch
+ }
+ }
+
+ override protected def withNewChildInternal(newChild: SparkPlan):
CometRowToColumnarExec =
+ copy(child = newChild)
+
+}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 520f239..8545eee 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -270,7 +270,7 @@ abstract class CometNativeExec extends CometExec {
}
if (inputs.isEmpty) {
- throw new CometRuntimeException(s"No input for CometNativeExec:
$this")
+ throw new CometRuntimeException(s"No input for CometNativeExec:\n
$this")
}
ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_))
@@ -300,7 +300,8 @@ abstract class CometNativeExec extends CometExec {
case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec
|
_: AQEShuffleReadExec | _: CometShuffleExchangeExec | _:
CometUnionExec |
_: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _:
ReusedExchangeExec |
- _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec =>
+ _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec |
+ _: CometRowToColumnarExec =>
func(plan)
case _: CometPlan =>
// Other Comet operators, continue to traverse the tree.
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index b2c4fd6..0bb21ab 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics,
CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
-import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometCollectLimitExec, CometFilterExec, CometHashAggregateExec,
CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec}
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometCollectLimitExec, CometFilterExec, CometHashAggregateExec,
CometProjectExec, CometRowToColumnarExec, CometScanExec,
CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec,
SQLExecution, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
@@ -1118,6 +1118,58 @@ class CometExecSuite extends CometTestBase {
}
})
}
+
+ test("RowToColumnar over RangeExec") {
+ Seq("true", "false").foreach(aqe => {
+ Seq(500, 900).foreach { batchSize =>
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe,
+ SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key ->
batchSize.toString) {
+ val df = spark.range(1000).selectExpr("id", "id % 8 as
k").groupBy("k").sum("id")
+ checkSparkAnswerAndOperator(df)
+ // empty record batch should also be handled
+ val df2 = spark.range(0).selectExpr("id", "id % 8 as
k").groupBy("k").sum("id")
+ checkSparkAnswerAndOperator(df2, includeClasses =
Seq(classOf[CometRowToColumnarExec]))
+ }
+ }
+ })
+ }
+
+ test("RowToColumnar over RangeExec directly is eliminated for row output") {
+ Seq("true", "false").foreach(aqe => {
+ Seq(500, 900).foreach { batchSize =>
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe,
+ SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key ->
batchSize.toString) {
+ val df = spark.range(1000)
+ val qe = df.queryExecution
+ qe.executedPlan.collectFirst({ case r: CometRowToColumnarExec => r
}) match {
+ case Some(_) => fail("CometRowToColumnarExec should be eliminated")
+ case _ =>
+ }
+ }
+ }
+ })
+ }
+
+ test("RowToColumnar over InMemoryTableScanExec") {
+ Seq("true", "false").foreach(aqe => {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe,
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true",
+ SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark
+ .range(1000)
+ .selectExpr("id as key", "id % 8 as value")
+ .toDF("key", "value")
+ .selectExpr("key", "value", "key+1")
+ .createOrReplaceTempView("abc")
+ spark.catalog.cacheTable("abc")
+ val df = spark.sql("SELECT * FROM abc").groupBy("key").count()
+ checkSparkAnswerAndOperator(df, includeClasses =
Seq(classOf[CometRowToColumnarExec]))
+ }
+ })
+ }
}
case class BucketedTableTestSpec(
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 6fb81bc..de58665 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -34,7 +34,7 @@ import org.apache.parquet.hadoop.example.ExampleParquetWriter
import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark._
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED,
MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER}
-import org.apache.spark.sql.comet.{CometBatchScanExec,
CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper,
CometSinkPlaceHolder}
+import org.apache.spark.sql.comet.{CometBatchScanExec,
CometBroadcastExchangeExec, CometExec, CometRowToColumnarExec, CometScanExec,
CometScanWrapper, CometSinkPlaceHolder}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometNativeShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter,
SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -75,6 +75,7 @@ abstract class CometTestBase
conf.set(CometConf.COMET_EXEC_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true")
+ conf.set(CometConf.COMET_ROW_TO_COLUMNAR_ENABLED.key, "true")
conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g")
conf
}
@@ -155,9 +156,11 @@ abstract class CometTestBase
}
protected def checkCometOperators(plan: SparkPlan, excludedClasses:
Class[_]*): Unit = {
- plan.foreach {
+ val wrapped = wrapCometRowToColumnar(plan)
+ wrapped.foreach {
case _: CometScanExec | _: CometBatchScanExec => true
case _: CometSinkPlaceHolder | _: CometScanWrapper => false
+ case _: CometRowToColumnarExec => false
case _: CometExec | _: CometShuffleExchangeExec => true
case _: CometBroadcastExchangeExec => true
case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter
=> true
@@ -184,6 +187,14 @@ abstract class CometTestBase
}
}
+ /** Wraps the CometRowToColumn as ScanWrapper, so the child operators will
not be checked */
+ private def wrapCometRowToColumnar(plan: SparkPlan): SparkPlan = {
+ plan.transformDown {
+ // don't care the native operators
+ case p: CometRowToColumnarExec => CometScanWrapper(null, p)
+ }
+ }
+
/**
* Check the answer of a Comet SQL query with Spark result using absolute
tolerance.
*/