viirya commented on code in PR #56334: URL: https://github.com/apache/spark/pull/56334#discussion_r3472346656
########## sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala: ########## @@ -0,0 +1,1371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format. + * + * This serializer: + * - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input + * - Stores data in Arrow IPC streaming format with optional compression (zstd/lz4) + * - Enables zero-copy columnar reads when output is ColumnarBatch + * - Uses off-heap memory via Arrow allocators + * - Collects per-column statistics for partition pruning + * - Provides efficient interoperability with Arrow ecosystem + * + * Configuration options: + * - spark.sql.cache.serializer: Set to this class name to enable + * - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch + * - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4) + * - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable columnar output + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = { + // Check if all data types in the schema are supported by Arrow + schema.forall(attr => ArrowUtils.isSupportedByArrow(attr.dataType)) + } + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { rowIterator => + new InternalRowToArrowCachedBatchIterator( + rowIterator, + schema, + sparkSchema, + maxRecordsPerBatch, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { batchIterator => + new ColumnarBatchToArrowCachedBatchIterator( + batchIterator, + schema, + sparkSchema, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def supportsColumnarOutput(schema: StructType): Boolean = { + // Always support columnar output with Arrow + true + } + + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = { + Option(Seq.fill(attributes.length)(classOf[ArrowColumnVector].getName)) + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val columnIndices = + selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray + // Capture config on driver + val timeZoneId = conf.sessionLocalTimeZone + val prefetchEnabled = conf.arrowCachePrefetchEnabled + + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToColumnarBatchIterator( + batchIterator, + cacheSchema, + selectedSchema, + columnIndices, + timeZoneId, + prefetchEnabled) + } + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val timeZoneId = conf.sessionLocalTimeZone + + // Calculate column indices for projection + val selectedIndices = selectedAttributes.map { attr => + cacheAttributes.indexWhere(_.exprId == attr.exprId) + }.toArray + + // Check if all selected types can use the fast path. + // Types not handled by ArrowColumnReader must use the fallback path. + val needsFallback = selectedSchema.fields.exists { f => + f.dataType match { + case _: ArrayType | _: StructType | _: MapType => true + case CalendarIntervalType | VariantType | NullType => true + case _: UserDefinedType[_] => true + // Geometry/Geography are represented as an Arrow struct (srid + wkb); the fast-path + // ArrowColumnReader does not handle them, so route them through the fallback. + case _: GeometryType | _: GeographyType => true + case _ => false + } + } + + if (needsFallback) { + // Fall back to columnar-to-row conversion via ColumnarBatch for complex types. + // Use UnsafeProjection to convert ColumnarBatchRow to UnsafeRow. + convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + .mapPartitionsInternal { batchIter => + val toUnsafe = org.apache.spark.sql.catalyst.expressions.UnsafeProjection.create( + selectedSchema) + batchIter.flatMap { batch => + val numRows = batch.numRows() + new Iterator[InternalRow] { + private var rowIdx = 0 + override def hasNext: Boolean = rowIdx < numRows + override def next(): InternalRow = { + val row = batch.getRow(rowIdx) + rowIdx += 1 + toUnsafe(row) + } + } + } + } + } else { + val prefetchEnabled = conf.arrowCachePrefetchEnabled + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToInternalRowIterator( + batchIterator, + cacheSchema, + selectedSchema, + selectedIndices, + timeZoneId, + prefetchEnabled) + } + } + } +} + +/** + * Companion object with shared utility methods for Arrow cache serialization. + */ +private object ArrowCachedBatchSerializer { + + // scalastyle:off caselocale + def createCompressionCodec( + codecName: String, + compressionLevel: Int): CompressionCodec = { + codecName.toLowerCase match { + case "none" => NoCompressionCodec.INSTANCE + // The codec instance must be constructed directly so that compressionLevel is honored: + // CompressionCodec.Factory.createCodec(codecType) ignores the level and builds a codec at + // the default level. The level only matters on the write side; the read side looks up the + // codec by the type recorded in the IPC message. + case "zstd" => new ZstdCompressionCodec(compressionLevel) + case "lz4" => new Lz4CompressionCodec() + case other => + throw SparkException.internalError( + s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4") + } + } + // scalastyle:on caselocale + + def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, batch) + out.toByteArray + } + + def createColumnStats(dataType: DataType): ColumnStats = { + dataType match { + case BooleanType => new BooleanColumnStats + case ByteType => new ByteColumnStats + case ShortType => new ShortColumnStats + case IntegerType => new IntColumnStats + case DateType => new IntColumnStats // Date is stored as Int + case LongType => new LongColumnStats + case TimestampType => new LongColumnStats // Timestamp is stored as Long + case TimestampNTZType => new LongColumnStats // TimestampNTZ is stored as Long + case FloatType => new FloatColumnStats + case DoubleType => new DoubleColumnStats + case st: StringType => new StringColumnStats(st) + case BinaryType => new BinaryColumnStats + case dt: DecimalType => new DecimalColumnStats(dt) + case CalendarIntervalType => new IntervalColumnStats + case _: YearMonthIntervalType => new IntColumnStats // stored as Int + case _: DayTimeIntervalType => new LongColumnStats // stored as Long + case _: TimeType => new LongColumnStats // Time is stored as Long (nanoseconds) + case VariantType => new VariantColumnStats + // Geometry/Geography are stored as binary (WKB) internally, so reuse BinaryColumnStats + // to collect size/count without min/max bounds. They are AtomicTypes that ColumnType + // (used by ObjectColumnStats) does not handle, so they must be matched explicitly here. + case _: GeometryType | _: GeographyType => new BinaryColumnStats + case _ => new ObjectColumnStats(dataType) + } + } + + def buildStatisticsFromCollectors( + collectors: Array[ColumnStats], + schema: Seq[Attribute]): InternalRow = { + val stats = collectors.flatMap { collector => + val collected = collector.collectedStatistics + // ColumnStats returns: [lowerBound, upperBound, nullCount, count, sizeInBytes] + Seq(collected(0), collected(1), collected(2), collected(3), collected(4)) + } + InternalRow.fromSeq(stats.toSeq) + } + + def collectStatistics( + root: VectorSchemaRoot, + schema: Seq[Attribute]): InternalRow = { + val rowCount = root.getRowCount + val vectors = root.getFieldVectors.asScala.toSeq + + // Collect stats for each column: lowerBound, upperBound, nullCount, rowCount, sizeInBytes + val stats = schema.zip(vectors).flatMap { case (attr, vector) => + val nullCount = (0 until rowCount).count(i => vector.isNull(i)) + val sizeInBytes = vector.getBufferSize.toLong + + val (lower, upper) = attr.dataType match { + case BooleanType => calculateMinMaxBoolean(vector, rowCount) + case ByteType => calculateMinMaxByte(vector, rowCount) + case ShortType => calculateMinMaxShort(vector, rowCount) + case IntegerType => calculateMinMaxInt(vector, rowCount) + case DateType => calculateMinMaxDate(vector, rowCount) + case LongType => calculateMinMaxLong(vector, rowCount) + case TimestampType => calculateMinMaxTimestamp(vector, rowCount) + case TimestampNTZType => calculateMinMaxTimestampNTZ(vector, rowCount) + case FloatType => calculateMinMaxFloat(vector, rowCount) + case DoubleType => calculateMinMaxDouble(vector, rowCount) + case st: StringType => calculateMinMaxString(vector, rowCount, st.collationId) + case _: DecimalType => calculateMinMaxDecimal(vector, rowCount, attr.dataType) + case _: YearMonthIntervalType => calculateMinMaxYearMonthInterval(vector, rowCount) + case _: DayTimeIntervalType => calculateMinMaxDayTimeInterval(vector, rowCount) + case _: TimeType => calculateMinMaxTime(vector, rowCount) + case _ => (null, null) // Skip for binary, complex, and other unsupported types + } + + Seq(lower, upper, nullCount, rowCount, sizeInBytes) + } + + new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(stats.toArray) + } + + def calculateMinMaxBoolean( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = true + var max = false + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BitVector].get(i) != 0 + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxByte( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Byte.MaxValue + var max = Byte.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.TinyIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxShort( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Short.MaxValue + var max = Short.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.SmallIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxInt( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.IntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDate( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.DateDayVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxLong( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BigIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestamp( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroTZVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestampNTZ( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxFloat( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Float.MaxValue + var max = Float.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float4Vector].get(i) + // Skip NaN: IEEE 754 comparisons with NaN are always false, so NaN never + // updates min/max in the row-based path (FloatColumnStats.gatherValueStats). + if (!value.isNaN) { + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDouble( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Double.MaxValue + var max = Double.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float8Vector].get(i) + // Skip NaN to match DoubleColumnStats.gatherValueStats. + if (!value.isNaN) { + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxString( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int, + collationId: Int = StringType.collationId): (Any, Any) = { + var min: org.apache.spark.unsafe.types.UTF8String = null + var max: org.apache.spark.unsafe.types.UTF8String = null + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val bytes = vector.asInstanceOf[org.apache.arrow.vector.VarCharVector].get(i) + val value = org.apache.spark.unsafe.types.UTF8String.fromBytes(bytes) + if (!hasValue) { + min = value.clone() + max = value.clone() + hasValue = true + } else { + if (value.semanticCompare(min, collationId) < 0) min = value.clone() + if (value.semanticCompare(max, collationId) > 0) max = value.clone() + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDecimal( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int, + dataType: org.apache.spark.sql.types.DataType): (Any, Any) = { + val decimalType = dataType.asInstanceOf[DecimalType] + var min: org.apache.spark.sql.types.Decimal = null + var max: org.apache.spark.sql.types.Decimal = null + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val bigDecimal = vector.asInstanceOf[ + org.apache.arrow.vector.DecimalVector].getObject(i) + val value = org.apache.spark.sql.types.Decimal( + bigDecimal, decimalType.precision, decimalType.scale) + + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value.compareTo(min) < 0) min = value + if (value.compareTo(max) > 0) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxYearMonthInterval( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.IntervalYearVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDayTimeInterval( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = org.apache.arrow.vector.DurationVector.get( + vector.asInstanceOf[org.apache.arrow.vector.DurationVector].getDataBuffer, i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTime( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.TimeNanoVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } +} + +/** + * Iterator that converts InternalRow to ArrowCachedBatch. + */ +private class InternalRowToArrowCachedBatchIterator( + rowIter: Iterator[InternalRow], + schema: Seq[Attribute], + sparkSchema: StructType, + maxRecordsPerBatch: Long, + timeZoneId: String, + compressionCodecName: String, + compressionLevel: Int) extends Iterator[ArrowCachedBatch] { + + private val compressionCodec = ArrowCachedBatchSerializer.createCompressionCodec( + compressionCodecName, + compressionLevel) + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"InternalRowToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, false, false) + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + private val arrowWriter = ArrowWriter.create(root) + private val unloader = new VectorUnloader(root, true, compressionCodec, true) + + // Create statistics collectors for each column + private val statsCollectors: Array[ColumnStats] = schema.map { attr => + ArrowCachedBatchSerializer.createColumnStats(attr.dataType) + }.toArray + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + close() + } + } + + override def hasNext: Boolean = rowIter.hasNext || { + close() + false + } + + override def next(): ArrowCachedBatch = { + var rowCount = 0 + + // Reset statistics collectors for new batch + var idx = 0 + while (idx < statsCollectors.length) { + statsCollectors(idx) = ArrowCachedBatchSerializer.createColumnStats(schema(idx).dataType) + idx += 1 + } + + Utils.tryWithSafeFinally { + // Write rows to Arrow vectors and collect statistics incrementally. + // A nonpositive maxRecordsPerBatch means unlimited (one batch per partition), matching + // ArrowConverters; without the `<= 0` guard the loop would emit empty batches forever. + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + + // Collect statistics for this row + var i = 0 + while (i < statsCollectors.length) { + statsCollectors(i).gatherStats(row, i) + i += 1 + } + + rowCount += 1 + } + arrowWriter.finish() + + // Get the Arrow RecordBatch with compression + val recordBatch = unloader.getRecordBatch() + + Utils.tryWithSafeFinally { + // Serialize to Arrow IPC format + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + + // Build statistics InternalRow from collected stats + val stats = ArrowCachedBatchSerializer.buildStatisticsFromCollectors( + statsCollectors, schema) + + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + arrowWriter.reset() + } + } + + private def close(): Unit = { + root.close() + allocator.close() + } +} + +/** + * Iterator that converts ColumnarBatch to ArrowCachedBatch. + */ +private class ColumnarBatchToArrowCachedBatchIterator( + batchIter: Iterator[ColumnarBatch], + schema: Seq[Attribute], + sparkSchema: StructType, + timeZoneId: String, + compressionCodecName: String, + compressionLevel: Int) extends Iterator[ArrowCachedBatch] { + + private val compressionCodec = ArrowCachedBatchSerializer.createCompressionCodec( + compressionCodecName, + compressionLevel) + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"ColumnarBatchToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, false, false) + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + allocator.close() + } + } + + override def hasNext: Boolean = batchIter.hasNext + + override def next(): ArrowCachedBatch = { + val batch = batchIter.next() Review Comment: Fixed. `next()` now wraps the conversion in `tryWithSafeFinally { ... } { batch.closeIfFreeable() }`, releasing the consumed input batch on both success and failure of either branch. closeIfFreeable() is a no-op for reusable writable/constant vectors, as you noted. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala: ########## @@ -0,0 +1,1371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format. + * + * This serializer: + * - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input + * - Stores data in Arrow IPC streaming format with optional compression (zstd/lz4) + * - Enables zero-copy columnar reads when output is ColumnarBatch + * - Uses off-heap memory via Arrow allocators + * - Collects per-column statistics for partition pruning + * - Provides efficient interoperability with Arrow ecosystem + * + * Configuration options: + * - spark.sql.cache.serializer: Set to this class name to enable + * - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch + * - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4) + * - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable columnar output + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = { + // Check if all data types in the schema are supported by Arrow + schema.forall(attr => ArrowUtils.isSupportedByArrow(attr.dataType)) + } + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { rowIterator => + new InternalRowToArrowCachedBatchIterator( + rowIterator, + schema, + sparkSchema, + maxRecordsPerBatch, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { batchIterator => + new ColumnarBatchToArrowCachedBatchIterator( + batchIterator, + schema, + sparkSchema, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def supportsColumnarOutput(schema: StructType): Boolean = { + // Always support columnar output with Arrow + true + } + + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = { + Option(Seq.fill(attributes.length)(classOf[ArrowColumnVector].getName)) + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val columnIndices = + selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray + // Capture config on driver + val timeZoneId = conf.sessionLocalTimeZone + val prefetchEnabled = conf.arrowCachePrefetchEnabled + + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToColumnarBatchIterator( + batchIterator, + cacheSchema, + selectedSchema, + columnIndices, + timeZoneId, + prefetchEnabled) + } + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val timeZoneId = conf.sessionLocalTimeZone + + // Calculate column indices for projection + val selectedIndices = selectedAttributes.map { attr => + cacheAttributes.indexWhere(_.exprId == attr.exprId) + }.toArray + + // Check if all selected types can use the fast path. + // Types not handled by ArrowColumnReader must use the fallback path. + val needsFallback = selectedSchema.fields.exists { f => + f.dataType match { + case _: ArrayType | _: StructType | _: MapType => true + case CalendarIntervalType | VariantType | NullType => true + case _: UserDefinedType[_] => true + // Geometry/Geography are represented as an Arrow struct (srid + wkb); the fast-path + // ArrowColumnReader does not handle them, so route them through the fallback. + case _: GeometryType | _: GeographyType => true + case _ => false + } + } + + if (needsFallback) { + // Fall back to columnar-to-row conversion via ColumnarBatch for complex types. + // Use UnsafeProjection to convert ColumnarBatchRow to UnsafeRow. + convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + .mapPartitionsInternal { batchIter => + val toUnsafe = org.apache.spark.sql.catalyst.expressions.UnsafeProjection.create( + selectedSchema) + batchIter.flatMap { batch => + val numRows = batch.numRows() + new Iterator[InternalRow] { + private var rowIdx = 0 + override def hasNext: Boolean = rowIdx < numRows + override def next(): InternalRow = { + val row = batch.getRow(rowIdx) + rowIdx += 1 + toUnsafe(row) + } + } + } + } + } else { + val prefetchEnabled = conf.arrowCachePrefetchEnabled + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToInternalRowIterator( + batchIterator, + cacheSchema, + selectedSchema, + selectedIndices, + timeZoneId, + prefetchEnabled) + } + } + } +} + +/** + * Companion object with shared utility methods for Arrow cache serialization. + */ +private object ArrowCachedBatchSerializer { + + // scalastyle:off caselocale + def createCompressionCodec( + codecName: String, + compressionLevel: Int): CompressionCodec = { + codecName.toLowerCase match { + case "none" => NoCompressionCodec.INSTANCE + // The codec instance must be constructed directly so that compressionLevel is honored: + // CompressionCodec.Factory.createCodec(codecType) ignores the level and builds a codec at + // the default level. The level only matters on the write side; the read side looks up the + // codec by the type recorded in the IPC message. + case "zstd" => new ZstdCompressionCodec(compressionLevel) + case "lz4" => new Lz4CompressionCodec() + case other => + throw SparkException.internalError( + s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4") + } + } + // scalastyle:on caselocale + + def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, batch) + out.toByteArray + } + + def createColumnStats(dataType: DataType): ColumnStats = { + dataType match { + case BooleanType => new BooleanColumnStats + case ByteType => new ByteColumnStats + case ShortType => new ShortColumnStats + case IntegerType => new IntColumnStats + case DateType => new IntColumnStats // Date is stored as Int + case LongType => new LongColumnStats + case TimestampType => new LongColumnStats // Timestamp is stored as Long + case TimestampNTZType => new LongColumnStats // TimestampNTZ is stored as Long + case FloatType => new FloatColumnStats + case DoubleType => new DoubleColumnStats + case st: StringType => new StringColumnStats(st) + case BinaryType => new BinaryColumnStats + case dt: DecimalType => new DecimalColumnStats(dt) + case CalendarIntervalType => new IntervalColumnStats + case _: YearMonthIntervalType => new IntColumnStats // stored as Int + case _: DayTimeIntervalType => new LongColumnStats // stored as Long + case _: TimeType => new LongColumnStats // Time is stored as Long (nanoseconds) + case VariantType => new VariantColumnStats + // Geometry/Geography are stored as binary (WKB) internally, so reuse BinaryColumnStats + // to collect size/count without min/max bounds. They are AtomicTypes that ColumnType + // (used by ObjectColumnStats) does not handle, so they must be matched explicitly here. + case _: GeometryType | _: GeographyType => new BinaryColumnStats + case _ => new ObjectColumnStats(dataType) + } + } + + def buildStatisticsFromCollectors( + collectors: Array[ColumnStats], + schema: Seq[Attribute]): InternalRow = { + val stats = collectors.flatMap { collector => + val collected = collector.collectedStatistics + // ColumnStats returns: [lowerBound, upperBound, nullCount, count, sizeInBytes] + Seq(collected(0), collected(1), collected(2), collected(3), collected(4)) + } + InternalRow.fromSeq(stats.toSeq) + } + + def collectStatistics( + root: VectorSchemaRoot, + schema: Seq[Attribute]): InternalRow = { + val rowCount = root.getRowCount + val vectors = root.getFieldVectors.asScala.toSeq + + // Collect stats for each column: lowerBound, upperBound, nullCount, rowCount, sizeInBytes + val stats = schema.zip(vectors).flatMap { case (attr, vector) => + val nullCount = (0 until rowCount).count(i => vector.isNull(i)) + val sizeInBytes = vector.getBufferSize.toLong + + val (lower, upper) = attr.dataType match { + case BooleanType => calculateMinMaxBoolean(vector, rowCount) + case ByteType => calculateMinMaxByte(vector, rowCount) + case ShortType => calculateMinMaxShort(vector, rowCount) + case IntegerType => calculateMinMaxInt(vector, rowCount) + case DateType => calculateMinMaxDate(vector, rowCount) + case LongType => calculateMinMaxLong(vector, rowCount) + case TimestampType => calculateMinMaxTimestamp(vector, rowCount) + case TimestampNTZType => calculateMinMaxTimestampNTZ(vector, rowCount) + case FloatType => calculateMinMaxFloat(vector, rowCount) + case DoubleType => calculateMinMaxDouble(vector, rowCount) + case st: StringType => calculateMinMaxString(vector, rowCount, st.collationId) + case _: DecimalType => calculateMinMaxDecimal(vector, rowCount, attr.dataType) + case _: YearMonthIntervalType => calculateMinMaxYearMonthInterval(vector, rowCount) + case _: DayTimeIntervalType => calculateMinMaxDayTimeInterval(vector, rowCount) + case _: TimeType => calculateMinMaxTime(vector, rowCount) + case _ => (null, null) // Skip for binary, complex, and other unsupported types + } + + Seq(lower, upper, nullCount, rowCount, sizeInBytes) + } + + new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(stats.toArray) + } + + def calculateMinMaxBoolean( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = true + var max = false + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BitVector].get(i) != 0 + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxByte( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Byte.MaxValue + var max = Byte.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.TinyIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxShort( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Short.MaxValue + var max = Short.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.SmallIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxInt( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.IntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDate( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.DateDayVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxLong( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BigIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestamp( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroTZVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestampNTZ( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxFloat( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Float.MaxValue + var max = Float.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float4Vector].get(i) + // Skip NaN: IEEE 754 comparisons with NaN are always false, so NaN never + // updates min/max in the row-based path (FloatColumnStats.gatherValueStats). + if (!value.isNaN) { + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDouble( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Double.MaxValue + var max = Double.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float8Vector].get(i) + // Skip NaN to match DoubleColumnStats.gatherValueStats. + if (!value.isNaN) { + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxString( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int, + collationId: Int = StringType.collationId): (Any, Any) = { + var min: org.apache.spark.unsafe.types.UTF8String = null + var max: org.apache.spark.unsafe.types.UTF8String = null + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val bytes = vector.asInstanceOf[org.apache.arrow.vector.VarCharVector].get(i) + val value = org.apache.spark.unsafe.types.UTF8String.fromBytes(bytes) + if (!hasValue) { + min = value.clone() + max = value.clone() + hasValue = true + } else { + if (value.semanticCompare(min, collationId) < 0) min = value.clone() + if (value.semanticCompare(max, collationId) > 0) max = value.clone() + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDecimal( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int, + dataType: org.apache.spark.sql.types.DataType): (Any, Any) = { + val decimalType = dataType.asInstanceOf[DecimalType] + var min: org.apache.spark.sql.types.Decimal = null + var max: org.apache.spark.sql.types.Decimal = null + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val bigDecimal = vector.asInstanceOf[ + org.apache.arrow.vector.DecimalVector].getObject(i) + val value = org.apache.spark.sql.types.Decimal( + bigDecimal, decimalType.precision, decimalType.scale) + + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value.compareTo(min) < 0) min = value + if (value.compareTo(max) > 0) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxYearMonthInterval( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.IntervalYearVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDayTimeInterval( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = org.apache.arrow.vector.DurationVector.get( + vector.asInstanceOf[org.apache.arrow.vector.DurationVector].getDataBuffer, i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTime( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.TimeNanoVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } +} + +/** + * Iterator that converts InternalRow to ArrowCachedBatch. + */ +private class InternalRowToArrowCachedBatchIterator( + rowIter: Iterator[InternalRow], + schema: Seq[Attribute], + sparkSchema: StructType, + maxRecordsPerBatch: Long, + timeZoneId: String, + compressionCodecName: String, + compressionLevel: Int) extends Iterator[ArrowCachedBatch] { + + private val compressionCodec = ArrowCachedBatchSerializer.createCompressionCodec( + compressionCodecName, + compressionLevel) + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"InternalRowToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, false, false) + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + private val arrowWriter = ArrowWriter.create(root) + private val unloader = new VectorUnloader(root, true, compressionCodec, true) + + // Create statistics collectors for each column + private val statsCollectors: Array[ColumnStats] = schema.map { attr => + ArrowCachedBatchSerializer.createColumnStats(attr.dataType) + }.toArray + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + close() + } + } + + override def hasNext: Boolean = rowIter.hasNext || { + close() + false + } + + override def next(): ArrowCachedBatch = { + var rowCount = 0 + + // Reset statistics collectors for new batch + var idx = 0 + while (idx < statsCollectors.length) { + statsCollectors(idx) = ArrowCachedBatchSerializer.createColumnStats(schema(idx).dataType) + idx += 1 + } + + Utils.tryWithSafeFinally { + // Write rows to Arrow vectors and collect statistics incrementally. + // A nonpositive maxRecordsPerBatch means unlimited (one batch per partition), matching + // ArrowConverters; without the `<= 0` guard the loop would emit empty batches forever. + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + + // Collect statistics for this row + var i = 0 + while (i < statsCollectors.length) { + statsCollectors(i).gatherStats(row, i) + i += 1 + } + + rowCount += 1 + } + arrowWriter.finish() + + // Get the Arrow RecordBatch with compression + val recordBatch = unloader.getRecordBatch() + + Utils.tryWithSafeFinally { + // Serialize to Arrow IPC format + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + + // Build statistics InternalRow from collected stats + val stats = ArrowCachedBatchSerializer.buildStatisticsFromCollectors( + statsCollectors, schema) + + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + arrowWriter.reset() + } + } + + private def close(): Unit = { + root.close() + allocator.close() + } +} + +/** + * Iterator that converts ColumnarBatch to ArrowCachedBatch. + */ +private class ColumnarBatchToArrowCachedBatchIterator( + batchIter: Iterator[ColumnarBatch], + schema: Seq[Attribute], + sparkSchema: StructType, + timeZoneId: String, + compressionCodecName: String, + compressionLevel: Int) extends Iterator[ArrowCachedBatch] { + + private val compressionCodec = ArrowCachedBatchSerializer.createCompressionCodec( + compressionCodecName, + compressionLevel) + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"ColumnarBatchToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, false, false) + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + allocator.close() + } + } + + override def hasNext: Boolean = batchIter.hasNext + + override def next(): ArrowCachedBatch = { + val batch = batchIter.next() + val rowCount = batch.numRows() + + // Check if batch is already Arrow-based for zero-copy path. The zero-copy path reuses the + // input vectors but serializes them under a schema built with largeVarTypes=false, and the + // read path reconstructs that same non-large schema. Large var-width vectors use 64-bit + // offsets, so reading them back under a 32-bit-offset schema would silently corrupt data. + // Fall back to the row-based conversion (which always produces standard var-width vectors) + // whenever any input vector is, or nests, a large var-width vector. + val vectors = (0 until batch.numCols()).map(batch.column) + val zeroCopyEligible = vectors.forall { + case acv: ArrowColumnVector => + !ColumnarBatchToArrowCachedBatchIterator.containsLargeVarType(acv.getValueVector) + case _ => false + } + if (zeroCopyEligible) { + // Fast path: zero-copy extraction of Arrow RecordBatch + convertArrowBatchZeroCopy(batch, rowCount, schema, vectors) + } else { + // Slow path: convert to Arrow via rows + convertToArrowBatch(batch, rowCount, schema) + } + } + + private def convertArrowBatchZeroCopy( + batch: ColumnarBatch, + rowCount: Int, + schema: Seq[Attribute], + vectors: Seq[ColumnVector]): ArrowCachedBatch = { + // Zero-copy path: extract Arrow vectors directly from ArrowColumnVector + val arrowVectors = vectors.map( + _.asInstanceOf[ArrowColumnVector].getValueVector.asInstanceOf[ + org.apache.arrow.vector.FieldVector]) + + // Create a VectorSchemaRoot from the existing vectors + val root = new VectorSchemaRoot(arrowSchema, arrowVectors.asJava, rowCount) + + Utils.tryWithSafeFinally { + // Use VectorUnloader to create compressed RecordBatch + val unloader = new VectorUnloader(root, true, compressionCodec, true) + val recordBatch = unloader.getRecordBatch() + + Utils.tryWithSafeFinally { + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + // Note: We don't close the root here because we don't own the vectors + // They are owned by the input ColumnarBatch + } + } + + private def convertToArrowBatch( + batch: ColumnarBatch, + rowCount: Int, + schema: Seq[Attribute]): ArrowCachedBatch = { + // Convert columnar batch to rows, then to Arrow + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) + val unloader = new VectorUnloader(root, true, compressionCodec, true) + + // Collect statistics inline during row iteration, same as InternalRowToArrow path + val statsCollectors: Array[ColumnStats] = schema.map { attr => + ArrowCachedBatchSerializer.createColumnStats(attr.dataType) + }.toArray + + Utils.tryWithSafeFinally { + val rowIterator = batch.rowIterator().asScala + while (rowIterator.hasNext) { + val row = rowIterator.next() + arrowWriter.write(row) + + // Collect statistics for this row inline + var i = 0 + while (i < statsCollectors.length) { + statsCollectors(i).gatherStats(row, i) + i += 1 + } + } + arrowWriter.finish() + + val recordBatch = unloader.getRecordBatch() + Utils.tryWithSafeFinally { + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + val stats = ArrowCachedBatchSerializer.buildStatisticsFromCollectors( + statsCollectors, schema) + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + arrowWriter.reset() + root.close() + } + } +} + +private object ColumnarBatchToArrowCachedBatchIterator { + import org.apache.arrow.vector.{FieldVector, LargeVarBinaryVector, LargeVarCharVector} + + /** + * Whether the vector is, or nests, a large var-width vector (64-bit offsets). These are not + * eligible for the zero-copy path because that path serializes and reloads under a schema built + * with largeVarTypes=false; reinterpreting 64-bit offset buffers as 32-bit would corrupt data. + */ + def containsLargeVarType(vector: org.apache.arrow.vector.ValueVector): Boolean = vector match { + case _: LargeVarCharVector | _: LargeVarBinaryVector => true + case fv: FieldVector => + fv.getChildrenFromFields.asScala.exists(containsLargeVarType) + case _ => false + } +} + +/** + * Iterator that converts ArrowCachedBatch to ColumnarBatch. + */ +private class ArrowCachedBatchToColumnarBatchIterator( + batchIter: Iterator[CachedBatch], + cacheSchema: StructType, + selectedSchema: StructType, + columnIndices: Array[Int], + timeZoneId: String, + prefetchEnabled: Boolean = false) extends Iterator[ColumnarBatch] { + + import java.util.concurrent.{Callable, ExecutionException, Executors, ExecutorService, Future} + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"ArrowCachedBatchToColumnarBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(cacheSchema, timeZoneId, false, false) + + // Track only the previous root to close it when next batch is produced + private var previousRoot: VectorSchemaRoot = null + + // Prefetch support: deserialize the next batch into its own root in a background thread while + // the current batch is being consumed. Only the deserialization (IPC read + decompression + + // loading into a fresh root) happens off-thread; closing the previous root stays on the + // consumer thread in next(), so the vectors backing a returned ColumnarBatch are never released + // while the consumer may still read them. + private val prefetchExecutor: ExecutorService = if (prefetchEnabled) { + Executors.newSingleThreadExecutor(r => { + val t = new Thread(r, "arrow-cache-prefetch") + t.setDaemon(true) + t + }) + } else { + null + } + private var prefetchFuture: Future[VectorSchemaRoot] = _ + + // Register cleanup - close remaining root and allocator when task completes + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + if (prefetchFuture != null) { + prefetchFuture.cancel(true) Review Comment: Fixed. Cleanup now goes through a shared helper that shuts the prefetch executor down and awaits termination (rather than `cancel(true)`, which could race an in-flight allocation), then retrieves and closes any root the worker already produced before closing the allocator. Applied to both the columnar-reader and row-reader listeners. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala: ########## @@ -0,0 +1,1371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format. + * + * This serializer: + * - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input + * - Stores data in Arrow IPC streaming format with optional compression (zstd/lz4) + * - Enables zero-copy columnar reads when output is ColumnarBatch + * - Uses off-heap memory via Arrow allocators + * - Collects per-column statistics for partition pruning + * - Provides efficient interoperability with Arrow ecosystem + * + * Configuration options: + * - spark.sql.cache.serializer: Set to this class name to enable + * - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch + * - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4) + * - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable columnar output + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = { + // Check if all data types in the schema are supported by Arrow + schema.forall(attr => ArrowUtils.isSupportedByArrow(attr.dataType)) + } + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch Review Comment: Partially addressed, with a narrower scope after looking into it. The row path now honors `maxBytesPerBatch` in addition to `maxRecordsPerBatch` (stopping at whichever is hit first), matching `ArrowConverters`. For the columnar path I kept the one-batch-in/one-batch-out behavior: the upstream `ColumnarBatch` row count is already bounded by the source's batch-size config (e.g. `spark.sql.parquet.columnarReaderBatchSize`, default 4096), and the zero-copy path reuses already-formed Arrow vectors that cannot exceed Arrow's offset limits or balloon memory beyond what the upstream batch already holds. Let me know if you'd still prefer per-row byte enforcement on the columnar/zero-copy paths. ########## sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala: ########## @@ -38,6 +38,50 @@ private[sql] object ArrowUtils { // todo: support more types. + /** + * Check if a Spark DataType is supported by Arrow. This recursively checks complex types + * (Array, Struct, Map). + * + * Note: This checks compatibility with toArrowField(), not toArrowType(). Types like + * GeometryType, GeographyType, and VariantType are not supported by toArrowType() (which only + * handles primitive Arrow types), but ARE supported by toArrowField() which converts them to + * Arrow Struct representations with metadata. Since Arrow cache uses toArrowField() via + * toArrowSchema() to create the schema, these types are supported. + */ + def isSupportedByArrow(dt: DataType): Boolean = { + dt match { + // Primitive types + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: StringType | BinaryType | NullType => + true + + // Decimal + case _: DecimalType => true + + // Temporal types + case DateType | TimestampType | TimestampNTZType | _: TimeType => true + + // Interval types + case _: YearMonthIntervalType | _: DayTimeIntervalType | CalendarIntervalType => true + + // Complex types - recursively check element types + case ArrayType(elementType, _) => isSupportedByArrow(elementType) + case StructType(fields) => fields.forall(f => isSupportedByArrow(f.dataType)) + case MapType(keyType, valueType, _) => + isSupportedByArrow(keyType) && isSupportedByArrow(valueType) + + // Special types + // Note: These are not in toArrowType(), but are handled by toArrowField() + case udt: UserDefinedType[_] => isSupportedByArrow(udt.sqlType) Review Comment: Fixed. `createColumnStats` now unwraps UDTs (`case udt: UserDefinedType[_] => createColumnStats(udt.sqlType)`), so a Variant- or Geometry-backed UDT gets the right collector instead of falling through to `ObjectColumnStats` and throwing `UNSUPPORTED_DATATYPE`. This keeps the capability check and the statistics path in agreement. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala: ########## @@ -358,8 +359,28 @@ private[columnar] final class ObjectColumnStats(dataType: DataType) extends Colu override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { - val size = columnType.actualSize(row, ordinal) - sizeInBytes += size + // Check if this is a columnar complex type that doesn't support getSizeInBytes + val isColumnarComplexType = columnType match { + case _: ARRAY => + row.getArray(ordinal).isInstanceOf[ColumnarArray] + case _: MAP => + row.getMap(ordinal).isInstanceOf[ColumnarMap] + case struct: STRUCT => + row.getStruct(ordinal, struct.dataType.fields.length).isInstanceOf[ColumnarRow] + case _ => + false + } + + if (!isColumnarComplexType) { Review Comment: Fixed at the source: the columnar slow path now derives statistics from the built Arrow root via `collectStatistics(root)` (reading `vector.getBufferSize`) instead of the row collectors, matching the zero-copy path. Columnar complex values therefore contribute their actual byte size rather than zero, so a complex-only relation no longer reports `sizeInBytes=0`. The `ColumnStats` columnar-complex guard is left in place defensively. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala: ########## @@ -0,0 +1,1371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format. + * + * This serializer: + * - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input + * - Stores data in Arrow IPC streaming format with optional compression (zstd/lz4) + * - Enables zero-copy columnar reads when output is ColumnarBatch + * - Uses off-heap memory via Arrow allocators + * - Collects per-column statistics for partition pruning + * - Provides efficient interoperability with Arrow ecosystem + * + * Configuration options: + * - spark.sql.cache.serializer: Set to this class name to enable + * - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch + * - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4) + * - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable columnar output + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = { + // Check if all data types in the schema are supported by Arrow + schema.forall(attr => ArrowUtils.isSupportedByArrow(attr.dataType)) + } + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { rowIterator => + new InternalRowToArrowCachedBatchIterator( + rowIterator, + schema, + sparkSchema, + maxRecordsPerBatch, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { batchIterator => + new ColumnarBatchToArrowCachedBatchIterator( + batchIterator, + schema, + sparkSchema, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def supportsColumnarOutput(schema: StructType): Boolean = { + // Always support columnar output with Arrow + true + } + + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = { + Option(Seq.fill(attributes.length)(classOf[ArrowColumnVector].getName)) + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val columnIndices = + selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray + // Capture config on driver + val timeZoneId = conf.sessionLocalTimeZone + val prefetchEnabled = conf.arrowCachePrefetchEnabled + + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToColumnarBatchIterator( + batchIterator, + cacheSchema, + selectedSchema, + columnIndices, + timeZoneId, + prefetchEnabled) + } + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val timeZoneId = conf.sessionLocalTimeZone + + // Calculate column indices for projection + val selectedIndices = selectedAttributes.map { attr => + cacheAttributes.indexWhere(_.exprId == attr.exprId) + }.toArray + + // Check if all selected types can use the fast path. + // Types not handled by ArrowColumnReader must use the fallback path. + val needsFallback = selectedSchema.fields.exists { f => + f.dataType match { + case _: ArrayType | _: StructType | _: MapType => true + case CalendarIntervalType | VariantType | NullType => true + case _: UserDefinedType[_] => true + // Geometry/Geography are represented as an Arrow struct (srid + wkb); the fast-path + // ArrowColumnReader does not handle them, so route them through the fallback. + case _: GeometryType | _: GeographyType => true + case _ => false + } + } + + if (needsFallback) { + // Fall back to columnar-to-row conversion via ColumnarBatch for complex types. + // Use UnsafeProjection to convert ColumnarBatchRow to UnsafeRow. + convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + .mapPartitionsInternal { batchIter => + val toUnsafe = org.apache.spark.sql.catalyst.expressions.UnsafeProjection.create( + selectedSchema) + batchIter.flatMap { batch => + val numRows = batch.numRows() + new Iterator[InternalRow] { + private var rowIdx = 0 + override def hasNext: Boolean = rowIdx < numRows + override def next(): InternalRow = { + val row = batch.getRow(rowIdx) + rowIdx += 1 + toUnsafe(row) + } + } + } + } + } else { + val prefetchEnabled = conf.arrowCachePrefetchEnabled + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToInternalRowIterator( + batchIterator, + cacheSchema, + selectedSchema, + selectedIndices, + timeZoneId, + prefetchEnabled) + } + } + } +} + +/** + * Companion object with shared utility methods for Arrow cache serialization. + */ +private object ArrowCachedBatchSerializer { + + // scalastyle:off caselocale + def createCompressionCodec( + codecName: String, + compressionLevel: Int): CompressionCodec = { + codecName.toLowerCase match { + case "none" => NoCompressionCodec.INSTANCE + // The codec instance must be constructed directly so that compressionLevel is honored: + // CompressionCodec.Factory.createCodec(codecType) ignores the level and builds a codec at + // the default level. The level only matters on the write side; the read side looks up the + // codec by the type recorded in the IPC message. + case "zstd" => new ZstdCompressionCodec(compressionLevel) + case "lz4" => new Lz4CompressionCodec() + case other => + throw SparkException.internalError( + s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4") + } + } + // scalastyle:on caselocale + + def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, batch) + out.toByteArray + } + + def createColumnStats(dataType: DataType): ColumnStats = { + dataType match { + case BooleanType => new BooleanColumnStats + case ByteType => new ByteColumnStats + case ShortType => new ShortColumnStats + case IntegerType => new IntColumnStats + case DateType => new IntColumnStats // Date is stored as Int + case LongType => new LongColumnStats + case TimestampType => new LongColumnStats // Timestamp is stored as Long + case TimestampNTZType => new LongColumnStats // TimestampNTZ is stored as Long + case FloatType => new FloatColumnStats + case DoubleType => new DoubleColumnStats + case st: StringType => new StringColumnStats(st) + case BinaryType => new BinaryColumnStats + case dt: DecimalType => new DecimalColumnStats(dt) + case CalendarIntervalType => new IntervalColumnStats + case _: YearMonthIntervalType => new IntColumnStats // stored as Int + case _: DayTimeIntervalType => new LongColumnStats // stored as Long + case _: TimeType => new LongColumnStats // Time is stored as Long (nanoseconds) + case VariantType => new VariantColumnStats + // Geometry/Geography are stored as binary (WKB) internally, so reuse BinaryColumnStats + // to collect size/count without min/max bounds. They are AtomicTypes that ColumnType + // (used by ObjectColumnStats) does not handle, so they must be matched explicitly here. + case _: GeometryType | _: GeographyType => new BinaryColumnStats + case _ => new ObjectColumnStats(dataType) + } + } + + def buildStatisticsFromCollectors( + collectors: Array[ColumnStats], + schema: Seq[Attribute]): InternalRow = { + val stats = collectors.flatMap { collector => + val collected = collector.collectedStatistics + // ColumnStats returns: [lowerBound, upperBound, nullCount, count, sizeInBytes] + Seq(collected(0), collected(1), collected(2), collected(3), collected(4)) + } + InternalRow.fromSeq(stats.toSeq) + } + + def collectStatistics( + root: VectorSchemaRoot, + schema: Seq[Attribute]): InternalRow = { + val rowCount = root.getRowCount + val vectors = root.getFieldVectors.asScala.toSeq + + // Collect stats for each column: lowerBound, upperBound, nullCount, rowCount, sizeInBytes + val stats = schema.zip(vectors).flatMap { case (attr, vector) => + val nullCount = (0 until rowCount).count(i => vector.isNull(i)) + val sizeInBytes = vector.getBufferSize.toLong + + val (lower, upper) = attr.dataType match { + case BooleanType => calculateMinMaxBoolean(vector, rowCount) + case ByteType => calculateMinMaxByte(vector, rowCount) + case ShortType => calculateMinMaxShort(vector, rowCount) + case IntegerType => calculateMinMaxInt(vector, rowCount) + case DateType => calculateMinMaxDate(vector, rowCount) + case LongType => calculateMinMaxLong(vector, rowCount) + case TimestampType => calculateMinMaxTimestamp(vector, rowCount) + case TimestampNTZType => calculateMinMaxTimestampNTZ(vector, rowCount) + case FloatType => calculateMinMaxFloat(vector, rowCount) + case DoubleType => calculateMinMaxDouble(vector, rowCount) + case st: StringType => calculateMinMaxString(vector, rowCount, st.collationId) + case _: DecimalType => calculateMinMaxDecimal(vector, rowCount, attr.dataType) + case _: YearMonthIntervalType => calculateMinMaxYearMonthInterval(vector, rowCount) + case _: DayTimeIntervalType => calculateMinMaxDayTimeInterval(vector, rowCount) + case _: TimeType => calculateMinMaxTime(vector, rowCount) + case _ => (null, null) // Skip for binary, complex, and other unsupported types + } + + Seq(lower, upper, nullCount, rowCount, sizeInBytes) + } + + new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(stats.toArray) + } + + def calculateMinMaxBoolean( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = true + var max = false + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BitVector].get(i) != 0 + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxByte( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Byte.MaxValue + var max = Byte.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.TinyIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxShort( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Short.MaxValue + var max = Short.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.SmallIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxInt( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.IntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDate( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.DateDayVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxLong( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BigIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestamp( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroTZVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestampNTZ( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxFloat( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Float.MaxValue + var max = Float.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float4Vector].get(i) + // Skip NaN: IEEE 754 comparisons with NaN are always false, so NaN never + // updates min/max in the row-based path (FloatColumnStats.gatherValueStats). + if (!value.isNaN) { + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDouble( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Double.MaxValue + var max = Double.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float8Vector].get(i) + // Skip NaN to match DoubleColumnStats.gatherValueStats. + if (!value.isNaN) { + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxString( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int, + collationId: Int = StringType.collationId): (Any, Any) = { + var min: org.apache.spark.unsafe.types.UTF8String = null + var max: org.apache.spark.unsafe.types.UTF8String = null + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val bytes = vector.asInstanceOf[org.apache.arrow.vector.VarCharVector].get(i) + val value = org.apache.spark.unsafe.types.UTF8String.fromBytes(bytes) + if (!hasValue) { + min = value.clone() + max = value.clone() + hasValue = true + } else { + if (value.semanticCompare(min, collationId) < 0) min = value.clone() + if (value.semanticCompare(max, collationId) > 0) max = value.clone() + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDecimal( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int, + dataType: org.apache.spark.sql.types.DataType): (Any, Any) = { + val decimalType = dataType.asInstanceOf[DecimalType] + var min: org.apache.spark.sql.types.Decimal = null + var max: org.apache.spark.sql.types.Decimal = null + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val bigDecimal = vector.asInstanceOf[ + org.apache.arrow.vector.DecimalVector].getObject(i) + val value = org.apache.spark.sql.types.Decimal( + bigDecimal, decimalType.precision, decimalType.scale) + + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value.compareTo(min) < 0) min = value + if (value.compareTo(max) > 0) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxYearMonthInterval( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.IntervalYearVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDayTimeInterval( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = org.apache.arrow.vector.DurationVector.get( + vector.asInstanceOf[org.apache.arrow.vector.DurationVector].getDataBuffer, i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTime( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.TimeNanoVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } +} + +/** + * Iterator that converts InternalRow to ArrowCachedBatch. + */ +private class InternalRowToArrowCachedBatchIterator( + rowIter: Iterator[InternalRow], + schema: Seq[Attribute], + sparkSchema: StructType, + maxRecordsPerBatch: Long, + timeZoneId: String, + compressionCodecName: String, + compressionLevel: Int) extends Iterator[ArrowCachedBatch] { + + private val compressionCodec = ArrowCachedBatchSerializer.createCompressionCodec( + compressionCodecName, + compressionLevel) + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"InternalRowToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, false, false) + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + private val arrowWriter = ArrowWriter.create(root) + private val unloader = new VectorUnloader(root, true, compressionCodec, true) + + // Create statistics collectors for each column + private val statsCollectors: Array[ColumnStats] = schema.map { attr => + ArrowCachedBatchSerializer.createColumnStats(attr.dataType) + }.toArray + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + close() + } + } + + override def hasNext: Boolean = rowIter.hasNext || { + close() + false + } + + override def next(): ArrowCachedBatch = { + var rowCount = 0 + + // Reset statistics collectors for new batch + var idx = 0 + while (idx < statsCollectors.length) { + statsCollectors(idx) = ArrowCachedBatchSerializer.createColumnStats(schema(idx).dataType) + idx += 1 + } + + Utils.tryWithSafeFinally { + // Write rows to Arrow vectors and collect statistics incrementally. + // A nonpositive maxRecordsPerBatch means unlimited (one batch per partition), matching + // ArrowConverters; without the `<= 0` guard the loop would emit empty batches forever. + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + + // Collect statistics for this row + var i = 0 + while (i < statsCollectors.length) { + statsCollectors(i).gatherStats(row, i) + i += 1 + } + + rowCount += 1 + } + arrowWriter.finish() + + // Get the Arrow RecordBatch with compression + val recordBatch = unloader.getRecordBatch() + + Utils.tryWithSafeFinally { + // Serialize to Arrow IPC format + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + + // Build statistics InternalRow from collected stats + val stats = ArrowCachedBatchSerializer.buildStatisticsFromCollectors( + statsCollectors, schema) + + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + arrowWriter.reset() + } + } + + private def close(): Unit = { + root.close() + allocator.close() + } +} + +/** + * Iterator that converts ColumnarBatch to ArrowCachedBatch. + */ +private class ColumnarBatchToArrowCachedBatchIterator( + batchIter: Iterator[ColumnarBatch], + schema: Seq[Attribute], + sparkSchema: StructType, + timeZoneId: String, + compressionCodecName: String, + compressionLevel: Int) extends Iterator[ArrowCachedBatch] { + + private val compressionCodec = ArrowCachedBatchSerializer.createCompressionCodec( + compressionCodecName, + compressionLevel) + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"ColumnarBatchToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, false, false) + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + allocator.close() + } + } + + override def hasNext: Boolean = batchIter.hasNext + + override def next(): ArrowCachedBatch = { + val batch = batchIter.next() + val rowCount = batch.numRows() + + // Check if batch is already Arrow-based for zero-copy path. The zero-copy path reuses the + // input vectors but serializes them under a schema built with largeVarTypes=false, and the + // read path reconstructs that same non-large schema. Large var-width vectors use 64-bit + // offsets, so reading them back under a 32-bit-offset schema would silently corrupt data. + // Fall back to the row-based conversion (which always produces standard var-width vectors) + // whenever any input vector is, or nests, a large var-width vector. + val vectors = (0 until batch.numCols()).map(batch.column) + val zeroCopyEligible = vectors.forall { + case acv: ArrowColumnVector => + !ColumnarBatchToArrowCachedBatchIterator.containsLargeVarType(acv.getValueVector) + case _ => false + } + if (zeroCopyEligible) { + // Fast path: zero-copy extraction of Arrow RecordBatch + convertArrowBatchZeroCopy(batch, rowCount, schema, vectors) + } else { + // Slow path: convert to Arrow via rows + convertToArrowBatch(batch, rowCount, schema) + } + } + + private def convertArrowBatchZeroCopy( + batch: ColumnarBatch, + rowCount: Int, + schema: Seq[Attribute], + vectors: Seq[ColumnVector]): ArrowCachedBatch = { + // Zero-copy path: extract Arrow vectors directly from ArrowColumnVector + val arrowVectors = vectors.map( + _.asInstanceOf[ArrowColumnVector].getValueVector.asInstanceOf[ + org.apache.arrow.vector.FieldVector]) + + // Create a VectorSchemaRoot from the existing vectors + val root = new VectorSchemaRoot(arrowSchema, arrowVectors.asJava, rowCount) + + Utils.tryWithSafeFinally { + // Use VectorUnloader to create compressed RecordBatch + val unloader = new VectorUnloader(root, true, compressionCodec, true) + val recordBatch = unloader.getRecordBatch() + + Utils.tryWithSafeFinally { + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + // Note: We don't close the root here because we don't own the vectors + // They are owned by the input ColumnarBatch + } + } + + private def convertToArrowBatch( + batch: ColumnarBatch, + rowCount: Int, + schema: Seq[Attribute]): ArrowCachedBatch = { + // Convert columnar batch to rows, then to Arrow + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) + val unloader = new VectorUnloader(root, true, compressionCodec, true) + + // Collect statistics inline during row iteration, same as InternalRowToArrow path + val statsCollectors: Array[ColumnStats] = schema.map { attr => + ArrowCachedBatchSerializer.createColumnStats(attr.dataType) + }.toArray + + Utils.tryWithSafeFinally { + val rowIterator = batch.rowIterator().asScala + while (rowIterator.hasNext) { + val row = rowIterator.next() + arrowWriter.write(row) + + // Collect statistics for this row inline + var i = 0 + while (i < statsCollectors.length) { + statsCollectors(i).gatherStats(row, i) + i += 1 + } + } + arrowWriter.finish() + + val recordBatch = unloader.getRecordBatch() + Utils.tryWithSafeFinally { + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + val stats = ArrowCachedBatchSerializer.buildStatisticsFromCollectors( + statsCollectors, schema) + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + arrowWriter.reset() + root.close() + } + } +} + +private object ColumnarBatchToArrowCachedBatchIterator { + import org.apache.arrow.vector.{FieldVector, LargeVarBinaryVector, LargeVarCharVector} + + /** + * Whether the vector is, or nests, a large var-width vector (64-bit offsets). These are not + * eligible for the zero-copy path because that path serializes and reloads under a schema built + * with largeVarTypes=false; reinterpreting 64-bit offset buffers as 32-bit would corrupt data. + */ + def containsLargeVarType(vector: org.apache.arrow.vector.ValueVector): Boolean = vector match { + case _: LargeVarCharVector | _: LargeVarBinaryVector => true + case fv: FieldVector => + fv.getChildrenFromFields.asScala.exists(containsLargeVarType) + case _ => false + } +} + +/** + * Iterator that converts ArrowCachedBatch to ColumnarBatch. + */ +private class ArrowCachedBatchToColumnarBatchIterator( + batchIter: Iterator[CachedBatch], + cacheSchema: StructType, + selectedSchema: StructType, + columnIndices: Array[Int], + timeZoneId: String, + prefetchEnabled: Boolean = false) extends Iterator[ColumnarBatch] { + + import java.util.concurrent.{Callable, ExecutionException, Executors, ExecutorService, Future} + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"ArrowCachedBatchToColumnarBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(cacheSchema, timeZoneId, false, false) + + // Track only the previous root to close it when next batch is produced + private var previousRoot: VectorSchemaRoot = null + + // Prefetch support: deserialize the next batch into its own root in a background thread while + // the current batch is being consumed. Only the deserialization (IPC read + decompression + + // loading into a fresh root) happens off-thread; closing the previous root stays on the + // consumer thread in next(), so the vectors backing a returned ColumnarBatch are never released + // while the consumer may still read them. + private val prefetchExecutor: ExecutorService = if (prefetchEnabled) { + Executors.newSingleThreadExecutor(r => { + val t = new Thread(r, "arrow-cache-prefetch") + t.setDaemon(true) + t + }) + } else { + null + } + private var prefetchFuture: Future[VectorSchemaRoot] = _ + + // Register cleanup - close remaining root and allocator when task completes + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + if (prefetchFuture != null) { + prefetchFuture.cancel(true) + prefetchFuture = null + } + if (prefetchExecutor != null) { + prefetchExecutor.shutdownNow() + } + if (previousRoot != null) { + previousRoot.close() + previousRoot = null + } + allocator.close() + } + } + + override def hasNext: Boolean = prefetchFuture != null || batchIter.hasNext + + override def next(): ColumnarBatch = { + // Close the previous root since the consumer has moved on from the batch it backed. + if (previousRoot != null) { + previousRoot.close() + previousRoot = null + } + + val root = if (prefetchFuture != null) { + val r = try { + prefetchFuture.get() + } catch { + case e: ExecutionException => throw e.getCause + } + prefetchFuture = null + r + } else { + deserializeToRoot(batchIter.next().asInstanceOf[ArrowCachedBatch]) + } + + previousRoot = root + + // Wrap vectors in ArrowColumnVector and project to selected columns. + val allColumns = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + val selectedColumns = columnIndices.map(allColumns(_)) + val batch = new ColumnarBatch(selectedColumns, root.getRowCount) + + // Start prefetching the next batch while this one is being consumed. + submitPrefetch() + + batch + } + + /** Deserialize a cached batch into its own freshly-created root. Does not touch other roots. */ + private def deserializeToRoot(cachedBatch: ArrowCachedBatch): VectorSchemaRoot = { + val in = new ByteArrayInputStream(cachedBatch.arrowData) + val readChannel = new ReadChannel(Channels.newChannel(in)) + val recordBatch = MessageSerializer.deserializeRecordBatch(readChannel, allocator) + Utils.tryWithSafeFinally { + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val loader = new VectorLoader(root) + loader.load(recordBatch) Review Comment: Fixed. Both `deserializeToRoot` and `deserializeBatch` now close the partially-loaded root if `VectorLoader.load` fails, then rethrow, so a mid-load failure no longer leaves an unreachable root that makes the later `allocator.close()` fail and mask the original error. ########## docs/sql-arrow-cache-format.md: ########## @@ -0,0 +1,339 @@ +# Apache Arrow Cache Format for Spark + +## Overview + +Apache Spark supports using Apache Arrow as an alternative cache format for in-memory Dataset caching. This format provides improved performance for certain workloads, especially when working with columnar data sources like Parquet and ORC. + +## Benefits + +The Arrow cache format offers several advantages over the default cache format: + +- **Zero-copy reads** when input is already in Arrow format (e.g., Arrow-based data sources, re-caching Arrow cached data) +- **Better filter pushdown** with min/max statistics for partition pruning +- **Off-heap memory management** via Arrow allocators +- **Efficient compression** with zstd and lz4 codecs +- **Arrow ecosystem interoperability** for data sharing + +**Note**: Spark's built-in Parquet/ORC readers use internal column vectors (`OnHeapColumnVector`/`OffHeapColumnVector`), not Arrow format, so they don't benefit from zero-copy optimization. + +## Configuration + +To enable Arrow cache format, set the static configuration: + +```scala +spark.conf.set("spark.sql.cache.serializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") +``` + +**Note**: This is a static configuration that must be set before the SparkSession is created. +It selects the cache serializer for the whole session; once set, this serializer handles every +cached relation. There is no automatic per-relation fallback to another cache serializer based on +the data types involved (see [Supported Data Types](#supported-data-types) for how unsupported +types are handled). + +```scala +val spark = SparkSession.builder() + .appName("MyApp") + .config("spark.sql.cache.serializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + .getOrCreate() +``` + +## Usage + +Once configured, use cache operations as normal: + +```scala +// Cache a DataFrame +val df = spark.read.parquet("data.parquet") +df.cache() + +// Use cached data +df.filter("age > 30").count() + +// Uncache when done +df.unpersist() +``` + +## Compression + +Arrow cache supports multiple compression codecs. Configure compression with: + +```scala +spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") +``` + +Available options: +- `none` - No compression (fastest, largest size, **default**) +- `lz4` - LZ4 compression (fast, good compression) +- `zstd` - Zstandard compression (slower, best compression) + +For zstd, you can also configure the compression level. Positive values (up to 22) give better +compression but slower speed; negative values give ultra-fast compression with lower ratios: + +```scala +spark.conf.set("spark.sql.execution.arrow.compression.zstd.level", "3") // Default: 3 +``` + +## Vectorized Reader + +Enable vectorized reading for better performance with primitive types: + +```scala +spark.conf.set("spark.sql.inMemoryColumnarStorage.enableVectorizedReader", "true") +``` + +When enabled, cached data is read as columnar batches instead of rows, which can significantly improve performance for columnar operations. + +## Performance Characteristics + +In our benchmarks, the Arrow cache format performs best on the following workloads. Actual +results depend on data types, compression settings, and hardware, and the default cache format +can be faster in some cases (for example, with higher compression levels): + +1. **Filter-Heavy Workloads**: Queries with selective filters benefit from min/max statistics. +2. **Columnar Operations**: Aggregations and projections on cached data benefit from the Arrow format. +3. **Parquet/ORC Caching**: Arrow's batch processing helps even without the zero-copy path. +4. **Re-caching with Column Projection**: Dropping columns from Arrow-cached data preserves the + `ArrowColumnVector` format, enabling true zero-copy extraction and the largest gains. + +### Benchmark Results + +The numbers below are illustrative results from one run on an Apple M4 Max (OpenJDK 21.0.8) and +will vary with hardware, JDK, and compression settings. They are not a guarantee. For the +authoritative, regularly regenerated numbers, see +`sql/core/benchmarks/ArrowCacheBenchmark-jdk21-results.txt` and the `ArrowCacheBenchmark` suite. + +| Workload | Default Cache | Arrow Cache | Speedup | +|----------|--------------|-------------|---------| +| Write + Read (5M rows, 3 primitive columns) | 153.7 ns/row | 74.2 ns/row | **~2X faster** | +| Filter with stats (5M rows) | 100.1 ns/row | 70.8 ns/row | **~1.4X faster** | +| Columnar input from Parquet (2M rows, 3 primitive columns) | 195.3 ns/row | 113.1 ns/row | **~1.7X faster** | +| Re-cache with zero-copy (2M rows, 2 columns) | 123.3 ns/row | 38.5 ns/row | **~3.2X faster** | + +**Notes**: +- **Write + Read**: Significant improvement from efficient Arrow serialization and vectorized operations +- **Filter improvement**: Comes from min/max statistics enabling batch skipping during partition pruning +- **Parquet caching**: Shows improvement despite Spark's Parquet reader producing `OnHeapColumnVector`/`OffHeapColumnVector` rather than `ArrowColumnVector`, due to Arrow's efficient batch processing +- **Re-cache with zero-copy**: When caching a subset of columns from Arrow-cached data (e.g., `df.drop("column")`), the remaining columns preserve their `ArrowColumnVector` format, enabling true zero-copy extraction and achieving the best performance +- **Zero-copy benefits** only apply when input is already `ArrowColumnVector` (e.g., Python Arrow sources, re-caching Arrow cached data with column projection) + +## Supported Data Types + +Arrow cache supports the following data types: + +### Primitive Types +- BooleanType +- ByteType, ShortType, IntegerType, LongType +- FloatType, DoubleType +- DecimalType (all precision/scale combinations) +- NullType + +### Temporal Types +- DateType +- TimestampType +- TimestampNTZType +- TimeType + +### Interval Types +- YearMonthIntervalType +- DayTimeIntervalType +- CalendarIntervalType + +### String and Binary +- StringType (including collated strings) +- BinaryType + +### Complex Types +- ArrayType +- StructType +- MapType +- Nested combinations of the above + +### Other Types +- VariantType +- GeometryType, GeographyType +- User-defined types (UDTs) whose underlying representation is itself supported + +### Unsupported Types + +Arrow cache covers every type the default cache serializer supports, plus some it +does not (for example geometry and geography). Types that Arrow cannot represent +(such as `ObjectType`) are not silently dropped or routed to a different cache +serializer: there is no per-type fallback, because the cache serializer is chosen +once via the static `spark.sql.cache.serializer` configuration and then handles +every cached relation. Attempting to cache an unsupported type fails with an +`UNSUPPORTED_DATATYPE` error when the cache is materialized. + +## Statistics and Filter Pushdown + +Arrow cache automatically collects min/max statistics for the following types: +- Boolean +- Numeric types (Byte, Short, Int, Long, Float, Double) +- Decimal +- Date, Timestamp, and Timestamp without time zone (TIMESTAMP_NTZ) +- Time +- Year-month and day-time intervals +- String (using collation-aware comparison for collated strings) + +Other types (Binary, Variant, calendar intervals, and complex types such as +Array/Struct/Map) are cached but do not contribute min/max bounds, so they only +record null counts and sizes. + +These statistics enable partition pruning when filtering: + +```scala +val df = spark.range(10000000).cache() + +// This filter can skip batches using min/max statistics +df.filter("id > 5000000").count() +``` + +## Memory Management + +Arrow cache uses off-heap memory managed by Apache Arrow allocators. This is a fundamental design choice in Apache Arrow and is not configurable for on-heap memory. + +**Memory Efficiency**: +- Despite requiring off-heap memory, Arrow cache is often **more memory-efficient** than default cache: + - Efficient compression with zstd/lz4 codecs + - Compact columnar format without Java object overhead + - Better compression ratios, especially for strings and complex types +- If you have limited off-heap memory, increase `spark.executor.memoryOverhead` to allocate more off-heap memory + +**Memory Cleanup**: +Arrow memory is automatically cleaned up when: +- Tasks complete +- DataFrames are unpersisted +- SparkSession is stopped + +You can monitor Arrow memory usage through Spark metrics and the Spark UI. + +## Limitations and Considerations + +1. **Static Configuration**: Cache serializer must be set before SparkSession creation +2. **Memory Overhead**: Arrow format has small per-batch overhead +3. **Compatibility**: Cannot mix cache formats - recache needed when switching +4. **Compression Trade-off**: Higher compression = lower memory but slower reads + +## Migration from Default Cache + +To migrate from default cache to Arrow cache: + +1. **Stop your SparkSession** Review Comment: Fixed. The migration section now states that the serializer is resolved once and held process-wide, so switching cache formats requires a fresh JVM; the in-process "stop and reconfigure" steps were removed. ########## docs/sql-arrow-cache-format.md: ########## @@ -0,0 +1,339 @@ +# Apache Arrow Cache Format for Spark + +## Overview + +Apache Spark supports using Apache Arrow as an alternative cache format for in-memory Dataset caching. This format provides improved performance for certain workloads, especially when working with columnar data sources like Parquet and ORC. + +## Benefits + +The Arrow cache format offers several advantages over the default cache format: + +- **Zero-copy reads** when input is already in Arrow format (e.g., Arrow-based data sources, re-caching Arrow cached data) +- **Better filter pushdown** with min/max statistics for partition pruning +- **Off-heap memory management** via Arrow allocators +- **Efficient compression** with zstd and lz4 codecs +- **Arrow ecosystem interoperability** for data sharing + +**Note**: Spark's built-in Parquet/ORC readers use internal column vectors (`OnHeapColumnVector`/`OffHeapColumnVector`), not Arrow format, so they don't benefit from zero-copy optimization. + +## Configuration + +To enable Arrow cache format, set the static configuration: + +```scala +spark.conf.set("spark.sql.cache.serializer", Review Comment: Fixed. Removed the `spark.conf.set` snippet (it throws `CANNOT_MODIFY_CONFIG` on the static key) and kept only the `SparkSession.builder.config(...)` example. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
