viirya commented on code in PR #56334: URL: https://github.com/apache/spark/pull/56334#discussion_r3365618783
########## sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala: ########## @@ -0,0 +1,1358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format. + * + * This serializer: + * - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input + * - Stores data in Arrow IPC streaming format with optional compression (zstd/lz4) + * - Enables zero-copy columnar reads when output is ColumnarBatch + * - Uses off-heap memory via Arrow allocators + * - Collects per-column statistics for partition pruning + * - Provides efficient interoperability with Arrow ecosystem + * + * Configuration options: + * - spark.sql.cache.serializer: Set to this class name to enable + * - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch + * - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4) + * - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable columnar output + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = { + // Check if all data types in the schema are supported by Arrow + schema.forall(attr => ArrowUtils.isSupportedByArrow(attr.dataType)) + } + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { rowIterator => + new InternalRowToArrowCachedBatchIterator( + rowIterator, + schema, + sparkSchema, + maxRecordsPerBatch, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { batchIterator => + new ColumnarBatchToArrowCachedBatchIterator( + batchIterator, + schema, + sparkSchema, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def supportsColumnarOutput(schema: StructType): Boolean = { + // Always support columnar output with Arrow + true + } + + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = { + Option(Seq.fill(attributes.length)(classOf[ArrowColumnVector].getName)) + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val columnIndices = + selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray + // Capture config on driver + val timeZoneId = conf.sessionLocalTimeZone + val prefetchEnabled = conf.arrowCachePrefetchEnabled + + input.mapPartitionsInternal { batchIterator => + val baseIter = new ArrowCachedBatchToColumnarBatchIterator( + batchIterator, + cacheSchema, + selectedSchema, + columnIndices, + timeZoneId) + if (prefetchEnabled) { + new ArrowPrefetchColumnarBatchIterator(baseIter) + } else { + baseIter + } + } + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val timeZoneId = conf.sessionLocalTimeZone + + // Calculate column indices for projection + val selectedIndices = selectedAttributes.map { attr => + cacheAttributes.indexWhere(_.exprId == attr.exprId) + }.toArray + + // Check if all selected types can use the fast path. + // Types not handled by ArrowColumnReader must use the fallback path. + val needsFallback = selectedSchema.fields.exists { f => + f.dataType match { + case _: ArrayType | _: StructType | _: MapType => true + case CalendarIntervalType | VariantType | NullType => true + case _: UserDefinedType[_] => true + case _ => false + } + } + + if (needsFallback) { + // Fall back to columnar-to-row conversion via ColumnarBatch for complex types. + // Use UnsafeProjection to convert ColumnarBatchRow to UnsafeRow. + convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + .mapPartitionsInternal { batchIter => + val toUnsafe = org.apache.spark.sql.catalyst.expressions.UnsafeProjection.create( + selectedSchema) + batchIter.flatMap { batch => + val numRows = batch.numRows() + new Iterator[InternalRow] { + private var rowIdx = 0 + override def hasNext: Boolean = rowIdx < numRows + override def next(): InternalRow = { + val row = batch.getRow(rowIdx) + rowIdx += 1 + toUnsafe(row) + } + } + } + } + } else { + val prefetchEnabled = conf.arrowCachePrefetchEnabled + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToInternalRowIterator( + batchIterator, + cacheSchema, + selectedSchema, + selectedIndices, + timeZoneId, + prefetchEnabled) + } + } + } +} + +/** + * Companion object with shared utility methods for Arrow cache serialization. + */ +private object ArrowCachedBatchSerializer { + + // scalastyle:off caselocale + def createCompressionCodec( + codecName: String, + compressionLevel: Int): CompressionCodec = { + codecName.toLowerCase match { + case "none" => NoCompressionCodec.INSTANCE + case "zstd" => + val factory = CompressionCodec.Factory.INSTANCE + val codecType = new ZstdCompressionCodec(compressionLevel).getCodecType() + factory.createCodec(codecType) + case "lz4" => + val factory = CompressionCodec.Factory.INSTANCE + val codecType = new Lz4CompressionCodec().getCodecType() + factory.createCodec(codecType) + case other => + throw SparkException.internalError( + s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4") + } + } + // scalastyle:on caselocale + + def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, batch) + out.toByteArray + } + + def createColumnStats(dataType: DataType): ColumnStats = { + dataType match { + case BooleanType => new BooleanColumnStats + case ByteType => new ByteColumnStats + case ShortType => new ShortColumnStats + case IntegerType => new IntColumnStats + case DateType => new IntColumnStats // Date is stored as Int + case LongType => new LongColumnStats + case TimestampType => new LongColumnStats // Timestamp is stored as Long + case TimestampNTZType => new LongColumnStats // TimestampNTZ is stored as Long + case FloatType => new FloatColumnStats + case DoubleType => new DoubleColumnStats + case st: StringType => new StringColumnStats(st) + case BinaryType => new BinaryColumnStats + case dt: DecimalType => new DecimalColumnStats(dt) + case CalendarIntervalType => new IntervalColumnStats + case _: YearMonthIntervalType => new IntColumnStats // stored as Int + case _: DayTimeIntervalType => new LongColumnStats // stored as Long + case _: TimeType => new LongColumnStats // Time is stored as Long (nanoseconds) + case VariantType => new VariantColumnStats + case _ => new ObjectColumnStats(dataType) + } + } + + def buildStatisticsFromCollectors( + collectors: Array[ColumnStats], + schema: Seq[Attribute]): InternalRow = { + val stats = collectors.flatMap { collector => + val collected = collector.collectedStatistics + // ColumnStats returns: [lowerBound, upperBound, nullCount, count, sizeInBytes] + Seq(collected(0), collected(1), collected(2), collected(3), collected(4)) + } + InternalRow.fromSeq(stats.toSeq) + } + + def collectStatistics( + root: VectorSchemaRoot, + schema: Seq[Attribute]): InternalRow = { + val rowCount = root.getRowCount + val vectors = root.getFieldVectors.asScala.toSeq + + // Collect stats for each column: lowerBound, upperBound, nullCount, rowCount, sizeInBytes + val stats = schema.zip(vectors).flatMap { case (attr, vector) => + val nullCount = (0 until rowCount).count(i => vector.isNull(i)) + val sizeInBytes = vector.getBufferSize.toLong + + val (lower, upper) = attr.dataType match { + case BooleanType => calculateMinMaxBoolean(vector, rowCount) + case ByteType => calculateMinMaxByte(vector, rowCount) + case ShortType => calculateMinMaxShort(vector, rowCount) + case IntegerType => calculateMinMaxInt(vector, rowCount) + case DateType => calculateMinMaxDate(vector, rowCount) + case LongType => calculateMinMaxLong(vector, rowCount) + case TimestampType => calculateMinMaxTimestamp(vector, rowCount) + case TimestampNTZType => calculateMinMaxTimestampNTZ(vector, rowCount) + case FloatType => calculateMinMaxFloat(vector, rowCount) + case DoubleType => calculateMinMaxDouble(vector, rowCount) + case st: StringType => calculateMinMaxString(vector, rowCount, st.collationId) + case _: DecimalType => calculateMinMaxDecimal(vector, rowCount, attr.dataType) + case _: YearMonthIntervalType => calculateMinMaxYearMonthInterval(vector, rowCount) + case _: DayTimeIntervalType => calculateMinMaxDayTimeInterval(vector, rowCount) + case _: TimeType => calculateMinMaxTime(vector, rowCount) + case _ => (null, null) // Skip for binary, complex, and other unsupported types + } + + Seq(lower, upper, nullCount, rowCount, sizeInBytes) + } + + new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(stats.toArray) + } + + def calculateMinMaxBoolean( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = true + var max = false + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BitVector].get(i) != 0 + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxByte( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Byte.MaxValue + var max = Byte.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.TinyIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxShort( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Short.MaxValue + var max = Short.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.SmallIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxInt( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.IntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDate( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.DateDayVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxLong( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BigIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestamp( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroTZVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestampNTZ( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxFloat( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Float.MaxValue + var max = Float.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float4Vector].get(i) + // Skip NaN: IEEE 754 comparisons with NaN are always false, so NaN never + // updates min/max in the row-based path (FloatColumnStats.gatherValueStats). + if (!value.isNaN) { + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDouble( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Double.MaxValue + var max = Double.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float8Vector].get(i) + // Skip NaN to match DoubleColumnStats.gatherValueStats. + if (!value.isNaN) { Review Comment: Thanks for catching this and for the clear repro. You're right that the NaN bounds are not conservative under Spark SQL's ordering (where NaN sorts above everything), so a batch like `[1.0, NaN]` gets incorrectly pruned for `v > 100.0`. I confirmed the same wrong result with the default serializer: caching `[1.0, NaN]` and applying `filter("v > 100.0")` returns 0 rows instead of `NaN`. I'd prefer not to fix this in the Arrow path alone here. This serializer's row path reuses the existing `FloatColumnStats`/`DoubleColumnStats` collectors directly (via `createColumnStats`), and the columnar `calculateMinMax*` path was deliberately written to match them. Diverging only the columnar path would make the two paths within this serializer inconsistent, and the underlying defect lives in the shared, pre-existing statistics collectors used by `DefaultCachedBatchSerializer`. Since this PR is an additive feature rather than a bug fix, I think the right scope is a separate ticket that makes NaN bounds conservative across both the row and columnar paths (default and Arrow together), with mixed and all-NaN pruning tests. I'll keep the current behavior aligned with the default serializer here and file a follow-up. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
