sunchao commented on code in PR #56334:
URL: https://github.com/apache/spark/pull/56334#discussion_r3486643398


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala:
##########
@@ -0,0 +1,1371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.channels.Channels
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
+import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector.compression.{CompressionCodec, 
NoCompressionCodec}
+import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel}
+import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, 
MessageSerializer}
+
+import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.columnar.{CachedBatch, 
SimpleMetricsCachedBatchSerializer}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
+
+/**
+ * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format.
+ *
+ * This serializer:
+ *  - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input
+ *  - Stores data in Arrow IPC streaming format with optional compression 
(zstd/lz4)
+ *  - Enables zero-copy columnar reads when output is ColumnarBatch
+ *  - Uses off-heap memory via Arrow allocators
+ *  - Collects per-column statistics for partition pruning
+ *  - Provides efficient interoperability with Arrow ecosystem
+ *
+ * Configuration options:
+ *  - spark.sql.cache.serializer: Set to this class name to enable
+ *  - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch
+ *  - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4)
+ *  - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable 
columnar output
+ */
+class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer {
+
+  override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = {
+    // Check if all data types in the schema are supported by Arrow
+    schema.forall(attr => ArrowUtils.isSupportedByArrow(attr.dataType))
+  }
+
+  override def convertInternalRowToCachedBatch(
+      input: RDD[InternalRow],
+      schema: Seq[Attribute],
+      storageLevel: StorageLevel,
+      conf: SQLConf): RDD[CachedBatch] = {
+    // Capture config values on driver before RDD transformation
+    val sparkSchema = DataTypeUtils.fromAttributes(schema)
+    val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch
+    val timeZoneId = conf.sessionLocalTimeZone
+    val compressionCodecName = conf.arrowCompressionCodec
+    val compressionLevel = conf.arrowZstdCompressionLevel
+
+    input.mapPartitionsInternal { rowIterator =>
+      new InternalRowToArrowCachedBatchIterator(
+        rowIterator,
+        schema,
+        sparkSchema,
+        maxRecordsPerBatch,
+        timeZoneId,
+        compressionCodecName,
+        compressionLevel)
+    }
+  }
+
+  override def convertColumnarBatchToCachedBatch(
+      input: RDD[ColumnarBatch],
+      schema: Seq[Attribute],
+      storageLevel: StorageLevel,
+      conf: SQLConf): RDD[CachedBatch] = {
+    // Capture config values on driver before RDD transformation
+    val sparkSchema = DataTypeUtils.fromAttributes(schema)
+    val timeZoneId = conf.sessionLocalTimeZone
+    val compressionCodecName = conf.arrowCompressionCodec
+    val compressionLevel = conf.arrowZstdCompressionLevel
+
+    input.mapPartitionsInternal { batchIterator =>
+      new ColumnarBatchToArrowCachedBatchIterator(
+        batchIterator,
+        schema,
+        sparkSchema,
+        timeZoneId,
+        compressionCodecName,
+        compressionLevel)
+    }
+  }
+
+  override def supportsColumnarOutput(schema: StructType): Boolean = {
+    // Always support columnar output with Arrow
+    true
+  }
+
+  override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): 
Option[Seq[String]] = {
+    Option(Seq.fill(attributes.length)(classOf[ArrowColumnVector].getName))
+  }
+
+  override def convertCachedBatchToColumnarBatch(
+      input: RDD[CachedBatch],
+      cacheAttributes: Seq[Attribute],
+      selectedAttributes: Seq[Attribute],
+      conf: SQLConf): RDD[ColumnarBatch] = {
+    val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes)
+    val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes)
+    val columnIndices =
+      selectedAttributes.map(a => cacheAttributes.map(o => 
o.exprId).indexOf(a.exprId)).toArray
+    // Capture config on driver
+    val timeZoneId = conf.sessionLocalTimeZone
+    val prefetchEnabled = conf.arrowCachePrefetchEnabled
+
+    input.mapPartitionsInternal { batchIterator =>
+      new ArrowCachedBatchToColumnarBatchIterator(
+        batchIterator,
+        cacheSchema,
+        selectedSchema,
+        columnIndices,
+        timeZoneId,
+        prefetchEnabled)
+    }
+  }
+
+  override def convertCachedBatchToInternalRow(
+      input: RDD[CachedBatch],
+      cacheAttributes: Seq[Attribute],
+      selectedAttributes: Seq[Attribute],
+      conf: SQLConf): RDD[InternalRow] = {
+    val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes)
+    val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes)
+    val timeZoneId = conf.sessionLocalTimeZone
+
+    // Calculate column indices for projection
+    val selectedIndices = selectedAttributes.map { attr =>
+      cacheAttributes.indexWhere(_.exprId == attr.exprId)
+    }.toArray
+
+    // Check if all selected types can use the fast path.
+    // Types not handled by ArrowColumnReader must use the fallback path.
+    val needsFallback = selectedSchema.fields.exists { f =>
+      f.dataType match {
+        case _: ArrayType | _: StructType | _: MapType => true
+        case CalendarIntervalType | VariantType | NullType => true
+        case _: UserDefinedType[_] => true
+        // Geometry/Geography are represented as an Arrow struct (srid + wkb); 
the fast-path
+        // ArrowColumnReader does not handle them, so route them through the 
fallback.
+        case _: GeometryType | _: GeographyType => true
+        case _ => false
+      }
+    }
+
+    if (needsFallback) {
+      // Fall back to columnar-to-row conversion via ColumnarBatch for complex 
types.
+      // Use UnsafeProjection to convert ColumnarBatchRow to UnsafeRow.
+      convertCachedBatchToColumnarBatch(input, cacheAttributes, 
selectedAttributes, conf)
+        .mapPartitionsInternal { batchIter =>
+          val toUnsafe = 
org.apache.spark.sql.catalyst.expressions.UnsafeProjection.create(
+            selectedSchema)
+          batchIter.flatMap { batch =>
+            val numRows = batch.numRows()
+            new Iterator[InternalRow] {
+              private var rowIdx = 0
+              override def hasNext: Boolean = rowIdx < numRows
+              override def next(): InternalRow = {
+                val row = batch.getRow(rowIdx)
+                rowIdx += 1
+                toUnsafe(row)
+              }
+            }
+          }
+        }
+    } else {
+      val prefetchEnabled = conf.arrowCachePrefetchEnabled
+      input.mapPartitionsInternal { batchIterator =>
+        new ArrowCachedBatchToInternalRowIterator(
+          batchIterator,
+          cacheSchema,
+          selectedSchema,
+          selectedIndices,
+          timeZoneId,
+          prefetchEnabled)
+      }
+    }
+  }
+}
+
+/**
+ * Companion object with shared utility methods for Arrow cache serialization.
+ */
+private object ArrowCachedBatchSerializer {
+
+  // scalastyle:off caselocale
+  def createCompressionCodec(
+      codecName: String,
+      compressionLevel: Int): CompressionCodec = {
+    codecName.toLowerCase match {
+      case "none" => NoCompressionCodec.INSTANCE
+      // The codec instance must be constructed directly so that 
compressionLevel is honored:
+      // CompressionCodec.Factory.createCodec(codecType) ignores the level and 
builds a codec at
+      // the default level. The level only matters on the write side; the read 
side looks up the
+      // codec by the type recorded in the IPC message.
+      case "zstd" => new ZstdCompressionCodec(compressionLevel)
+      case "lz4" => new Lz4CompressionCodec()
+      case other =>
+        throw SparkException.internalError(
+          s"Unsupported Arrow compression codec: $other. Supported values: 
none, zstd, lz4")
+    }
+  }
+  // scalastyle:on caselocale
+
+  def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = {
+    val out = new ByteArrayOutputStream()
+    val writeChannel = new WriteChannel(Channels.newChannel(out))
+    MessageSerializer.serialize(writeChannel, batch)
+    out.toByteArray
+  }
+
+  def createColumnStats(dataType: DataType): ColumnStats = {
+    dataType match {
+      case BooleanType => new BooleanColumnStats
+      case ByteType => new ByteColumnStats
+      case ShortType => new ShortColumnStats
+      case IntegerType => new IntColumnStats
+      case DateType => new IntColumnStats  // Date is stored as Int
+      case LongType => new LongColumnStats
+      case TimestampType => new LongColumnStats  // Timestamp is stored as Long
+      case TimestampNTZType => new LongColumnStats  // TimestampNTZ is stored 
as Long
+      case FloatType => new FloatColumnStats
+      case DoubleType => new DoubleColumnStats
+      case st: StringType => new StringColumnStats(st)
+      case BinaryType => new BinaryColumnStats
+      case dt: DecimalType => new DecimalColumnStats(dt)
+      case CalendarIntervalType => new IntervalColumnStats
+      case _: YearMonthIntervalType => new IntColumnStats   // stored as Int
+      case _: DayTimeIntervalType => new LongColumnStats  // stored as Long
+      case _: TimeType => new LongColumnStats  // Time is stored as Long 
(nanoseconds)
+      case VariantType => new VariantColumnStats
+      // Geometry/Geography are stored as binary (WKB) internally, so reuse 
BinaryColumnStats
+      // to collect size/count without min/max bounds. They are AtomicTypes 
that ColumnType
+      // (used by ObjectColumnStats) does not handle, so they must be matched 
explicitly here.
+      case _: GeometryType | _: GeographyType => new BinaryColumnStats
+      case _ => new ObjectColumnStats(dataType)
+    }
+  }
+
+  def buildStatisticsFromCollectors(
+      collectors: Array[ColumnStats],
+      schema: Seq[Attribute]): InternalRow = {
+    val stats = collectors.flatMap { collector =>
+      val collected = collector.collectedStatistics
+      // ColumnStats returns: [lowerBound, upperBound, nullCount, count, 
sizeInBytes]
+      Seq(collected(0), collected(1), collected(2), collected(3), collected(4))
+    }
+    InternalRow.fromSeq(stats.toSeq)
+  }
+
+  def collectStatistics(
+      root: VectorSchemaRoot,
+      schema: Seq[Attribute]): InternalRow = {
+    val rowCount = root.getRowCount
+    val vectors = root.getFieldVectors.asScala.toSeq
+
+    // Collect stats for each column: lowerBound, upperBound, nullCount, 
rowCount, sizeInBytes
+    val stats = schema.zip(vectors).flatMap { case (attr, vector) =>
+      val nullCount = (0 until rowCount).count(i => vector.isNull(i))
+      val sizeInBytes = vector.getBufferSize.toLong
+
+      val (lower, upper) = attr.dataType match {
+        case BooleanType => calculateMinMaxBoolean(vector, rowCount)
+        case ByteType => calculateMinMaxByte(vector, rowCount)
+        case ShortType => calculateMinMaxShort(vector, rowCount)
+        case IntegerType => calculateMinMaxInt(vector, rowCount)
+        case DateType => calculateMinMaxDate(vector, rowCount)
+        case LongType => calculateMinMaxLong(vector, rowCount)
+        case TimestampType => calculateMinMaxTimestamp(vector, rowCount)
+        case TimestampNTZType => calculateMinMaxTimestampNTZ(vector, rowCount)
+        case FloatType => calculateMinMaxFloat(vector, rowCount)
+        case DoubleType => calculateMinMaxDouble(vector, rowCount)
+        case st: StringType => calculateMinMaxString(vector, rowCount, 
st.collationId)
+        case _: DecimalType => calculateMinMaxDecimal(vector, rowCount, 
attr.dataType)
+        case _: YearMonthIntervalType => 
calculateMinMaxYearMonthInterval(vector, rowCount)
+        case _: DayTimeIntervalType => calculateMinMaxDayTimeInterval(vector, 
rowCount)
+        case _: TimeType => calculateMinMaxTime(vector, rowCount)
+        case _ => (null, null) // Skip for binary, complex, and other 
unsupported types
+      }
+
+      Seq(lower, upper, nullCount, rowCount, sizeInBytes)
+    }
+
+    new 
org.apache.spark.sql.catalyst.expressions.GenericInternalRow(stats.toArray)
+  }
+
+  def calculateMinMaxBoolean(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = true
+    var max = false
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value = 
vector.asInstanceOf[org.apache.arrow.vector.BitVector].get(i) != 0
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value < min) min = value
+          if (value > max) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxByte(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Byte.MaxValue
+    var max = Byte.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value = 
vector.asInstanceOf[org.apache.arrow.vector.TinyIntVector].get(i)
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value < min) min = value
+          if (value > max) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxShort(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Short.MaxValue
+    var max = Short.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value = 
vector.asInstanceOf[org.apache.arrow.vector.SmallIntVector].get(i)
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value < min) min = value
+          if (value > max) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxInt(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Int.MaxValue
+    var max = Int.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value = 
vector.asInstanceOf[org.apache.arrow.vector.IntVector].get(i)
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value < min) min = value
+          if (value > max) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxDate(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Int.MaxValue
+    var max = Int.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value = 
vector.asInstanceOf[org.apache.arrow.vector.DateDayVector].get(i)
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value < min) min = value
+          if (value > max) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxLong(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Long.MaxValue
+    var max = Long.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value = 
vector.asInstanceOf[org.apache.arrow.vector.BigIntVector].get(i)
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value < min) min = value
+          if (value > max) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxTimestamp(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Long.MaxValue
+    var max = Long.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value =
+          
vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroTZVector].get(i)
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value < min) min = value
+          if (value > max) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxTimestampNTZ(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Long.MaxValue
+    var max = Long.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value =
+          
vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroVector].get(i)
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value < min) min = value
+          if (value > max) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxFloat(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Float.MaxValue
+    var max = Float.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value = 
vector.asInstanceOf[org.apache.arrow.vector.Float4Vector].get(i)
+        // Skip NaN: IEEE 754 comparisons with NaN are always false, so NaN 
never
+        // updates min/max in the row-based path 
(FloatColumnStats.gatherValueStats).
+        if (!value.isNaN) {
+          if (!hasValue) {
+            min = value
+            max = value
+            hasValue = true
+          } else {
+            if (value < min) min = value
+            if (value > max) max = value
+          }
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxDouble(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Double.MaxValue
+    var max = Double.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value = 
vector.asInstanceOf[org.apache.arrow.vector.Float8Vector].get(i)
+        // Skip NaN to match DoubleColumnStats.gatherValueStats.
+        if (!value.isNaN) {
+          if (!hasValue) {
+            min = value
+            max = value
+            hasValue = true
+          } else {
+            if (value < min) min = value
+            if (value > max) max = value
+          }
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxString(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int,
+      collationId: Int = StringType.collationId): (Any, Any) = {
+    var min: org.apache.spark.unsafe.types.UTF8String = null
+    var max: org.apache.spark.unsafe.types.UTF8String = null
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val bytes = 
vector.asInstanceOf[org.apache.arrow.vector.VarCharVector].get(i)
+        val value = org.apache.spark.unsafe.types.UTF8String.fromBytes(bytes)
+        if (!hasValue) {
+          min = value.clone()
+          max = value.clone()
+          hasValue = true
+        } else {
+          if (value.semanticCompare(min, collationId) < 0) min = value.clone()
+          if (value.semanticCompare(max, collationId) > 0) max = value.clone()
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxDecimal(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int,
+      dataType: org.apache.spark.sql.types.DataType): (Any, Any) = {
+    val decimalType = dataType.asInstanceOf[DecimalType]
+    var min: org.apache.spark.sql.types.Decimal = null
+    var max: org.apache.spark.sql.types.Decimal = null
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val bigDecimal = vector.asInstanceOf[
+          org.apache.arrow.vector.DecimalVector].getObject(i)
+        val value = org.apache.spark.sql.types.Decimal(
+          bigDecimal, decimalType.precision, decimalType.scale)
+
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value.compareTo(min) < 0) min = value
+          if (value.compareTo(max) > 0) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxYearMonthInterval(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Int.MaxValue
+    var max = Int.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value = 
vector.asInstanceOf[org.apache.arrow.vector.IntervalYearVector].get(i)
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value < min) min = value
+          if (value > max) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxDayTimeInterval(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Long.MaxValue
+    var max = Long.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value = org.apache.arrow.vector.DurationVector.get(
+          
vector.asInstanceOf[org.apache.arrow.vector.DurationVector].getDataBuffer, i)
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value < min) min = value
+          if (value > max) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+
+  def calculateMinMaxTime(
+      vector: org.apache.arrow.vector.FieldVector,
+      rowCount: Int): (Any, Any) = {
+    var min = Long.MaxValue
+    var max = Long.MinValue
+    var hasValue = false
+
+    (0 until rowCount).foreach { i =>
+      if (!vector.isNull(i)) {
+        val value = 
vector.asInstanceOf[org.apache.arrow.vector.TimeNanoVector].get(i)
+        if (!hasValue) {
+          min = value
+          max = value
+          hasValue = true
+        } else {
+          if (value < min) min = value
+          if (value > max) max = value
+        }
+      }
+    }
+
+    if (hasValue) (min, max) else (null, null)
+  }
+}
+
+/**
+ * Iterator that converts InternalRow to ArrowCachedBatch.
+ */
+private class InternalRowToArrowCachedBatchIterator(
+    rowIter: Iterator[InternalRow],
+    schema: Seq[Attribute],
+    sparkSchema: StructType,
+    maxRecordsPerBatch: Long,
+    timeZoneId: String,
+    compressionCodecName: String,
+    compressionLevel: Int) extends Iterator[ArrowCachedBatch] {
+
+  private val compressionCodec = 
ArrowCachedBatchSerializer.createCompressionCodec(
+    compressionCodecName,
+    compressionLevel)
+
+  private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+    
s"InternalRowToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}",
+    0,
+    Long.MaxValue)
+
+  private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, 
false, false)
+  private val root = VectorSchemaRoot.create(arrowSchema, allocator)
+  private val arrowWriter = ArrowWriter.create(root)
+  private val unloader = new VectorUnloader(root, true, compressionCodec, true)
+
+  // Create statistics collectors for each column
+  private val statsCollectors: Array[ColumnStats] = schema.map { attr =>
+    ArrowCachedBatchSerializer.createColumnStats(attr.dataType)
+  }.toArray
+
+  // Register cleanup
+  Option(TaskContext.get()).foreach { tc =>
+    tc.addTaskCompletionListener[Unit] { _ =>
+      close()
+    }
+  }
+
+  override def hasNext: Boolean = rowIter.hasNext || {
+    close()
+    false
+  }
+
+  override def next(): ArrowCachedBatch = {
+    var rowCount = 0
+
+    // Reset statistics collectors for new batch
+    var idx = 0
+    while (idx < statsCollectors.length) {
+      statsCollectors(idx) = 
ArrowCachedBatchSerializer.createColumnStats(schema(idx).dataType)
+      idx += 1
+    }
+
+    Utils.tryWithSafeFinally {
+      // Write rows to Arrow vectors and collect statistics incrementally.
+      // A nonpositive maxRecordsPerBatch means unlimited (one batch per 
partition), matching
+      // ArrowConverters; without the `<= 0` guard the loop would emit empty 
batches forever.
+      while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < 
maxRecordsPerBatch)) {
+        val row = rowIter.next()
+        arrowWriter.write(row)
+
+        // Collect statistics for this row
+        var i = 0
+        while (i < statsCollectors.length) {
+          statsCollectors(i).gatherStats(row, i)
+          i += 1
+        }
+
+        rowCount += 1
+      }
+      arrowWriter.finish()
+
+      // Get the Arrow RecordBatch with compression
+      val recordBatch = unloader.getRecordBatch()
+
+      Utils.tryWithSafeFinally {
+        // Serialize to Arrow IPC format
+        val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch)
+
+        // Build statistics InternalRow from collected stats
+        val stats = ArrowCachedBatchSerializer.buildStatisticsFromCollectors(
+          statsCollectors, schema)
+
+        ArrowCachedBatch(rowCount, arrowData, stats)
+      } {
+        recordBatch.close()
+      }
+    } {
+      arrowWriter.reset()
+    }
+  }
+
+  private def close(): Unit = {
+    root.close()
+    allocator.close()
+  }
+}
+
+/**
+ * Iterator that converts ColumnarBatch to ArrowCachedBatch.
+ */
+private class ColumnarBatchToArrowCachedBatchIterator(
+    batchIter: Iterator[ColumnarBatch],
+    schema: Seq[Attribute],
+    sparkSchema: StructType,
+    timeZoneId: String,
+    compressionCodecName: String,
+    compressionLevel: Int) extends Iterator[ArrowCachedBatch] {
+
+  private val compressionCodec = 
ArrowCachedBatchSerializer.createCompressionCodec(
+    compressionCodecName,
+    compressionLevel)
+
+  private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+    
s"ColumnarBatchToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}",
+    0,
+    Long.MaxValue)
+
+  private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, 
false, false)
+
+  // Register cleanup
+  Option(TaskContext.get()).foreach { tc =>
+    tc.addTaskCompletionListener[Unit] { _ =>
+      allocator.close()
+    }
+  }
+
+  override def hasNext: Boolean = batchIter.hasNext
+
+  override def next(): ArrowCachedBatch = {
+    val batch = batchIter.next()
+    val rowCount = batch.numRows()
+
+    // Check if batch is already Arrow-based for zero-copy path. The zero-copy 
path reuses the
+    // input vectors but serializes them under a schema built with 
largeVarTypes=false, and the
+    // read path reconstructs that same non-large schema. Large var-width 
vectors use 64-bit
+    // offsets, so reading them back under a 32-bit-offset schema would 
silently corrupt data.
+    // Fall back to the row-based conversion (which always produces standard 
var-width vectors)
+    // whenever any input vector is, or nests, a large var-width vector.
+    val vectors = (0 until batch.numCols()).map(batch.column)
+    val zeroCopyEligible = vectors.forall {
+      case acv: ArrowColumnVector =>
+        
!ColumnarBatchToArrowCachedBatchIterator.containsLargeVarType(acv.getValueVector)
+      case _ => false
+    }
+    if (zeroCopyEligible) {
+      // Fast path: zero-copy extraction of Arrow RecordBatch
+      convertArrowBatchZeroCopy(batch, rowCount, schema, vectors)
+    } else {
+      // Slow path: convert to Arrow via rows
+      convertToArrowBatch(batch, rowCount, schema)
+    }
+  }
+
+  private def convertArrowBatchZeroCopy(
+      batch: ColumnarBatch,
+      rowCount: Int,
+      schema: Seq[Attribute],
+      vectors: Seq[ColumnVector]): ArrowCachedBatch = {
+    // Zero-copy path: extract Arrow vectors directly from ArrowColumnVector
+    val arrowVectors = vectors.map(
+      _.asInstanceOf[ArrowColumnVector].getValueVector.asInstanceOf[
+        org.apache.arrow.vector.FieldVector])
+
+    // Create a VectorSchemaRoot from the existing vectors
+    val root = new VectorSchemaRoot(arrowSchema, arrowVectors.asJava, rowCount)
+
+    Utils.tryWithSafeFinally {
+      // Use VectorUnloader to create compressed RecordBatch
+      val unloader = new VectorUnloader(root, true, compressionCodec, true)
+      val recordBatch = unloader.getRecordBatch()
+
+      Utils.tryWithSafeFinally {
+        val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch)
+        val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema)
+        ArrowCachedBatch(rowCount, arrowData, stats)
+      } {
+        recordBatch.close()
+      }
+    } {
+      // Note: We don't close the root here because we don't own the vectors
+      // They are owned by the input ColumnarBatch
+    }
+  }
+
+  private def convertToArrowBatch(
+      batch: ColumnarBatch,
+      rowCount: Int,
+      schema: Seq[Attribute]): ArrowCachedBatch = {
+    // Convert columnar batch to rows, then to Arrow
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val arrowWriter = ArrowWriter.create(root)
+    val unloader = new VectorUnloader(root, true, compressionCodec, true)
+
+    // Collect statistics inline during row iteration, same as 
InternalRowToArrow path
+    val statsCollectors: Array[ColumnStats] = schema.map { attr =>
+      ArrowCachedBatchSerializer.createColumnStats(attr.dataType)
+    }.toArray
+
+    Utils.tryWithSafeFinally {
+      val rowIterator = batch.rowIterator().asScala
+      while (rowIterator.hasNext) {
+        val row = rowIterator.next()
+        arrowWriter.write(row)
+
+        // Collect statistics for this row inline
+        var i = 0
+        while (i < statsCollectors.length) {
+          statsCollectors(i).gatherStats(row, i)
+          i += 1
+        }
+      }
+      arrowWriter.finish()
+
+      val recordBatch = unloader.getRecordBatch()
+      Utils.tryWithSafeFinally {
+        val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch)
+        val stats = ArrowCachedBatchSerializer.buildStatisticsFromCollectors(
+          statsCollectors, schema)
+        ArrowCachedBatch(rowCount, arrowData, stats)
+      } {
+        recordBatch.close()
+      }
+    } {
+      arrowWriter.reset()
+      root.close()
+    }
+  }
+}
+
+private object ColumnarBatchToArrowCachedBatchIterator {
+  import org.apache.arrow.vector.{FieldVector, LargeVarBinaryVector, 
LargeVarCharVector}
+
+  /**
+   * Whether the vector is, or nests, a large var-width vector (64-bit 
offsets). These are not
+   * eligible for the zero-copy path because that path serializes and reloads 
under a schema built
+   * with largeVarTypes=false; reinterpreting 64-bit offset buffers as 32-bit 
would corrupt data.
+   */
+  def containsLargeVarType(vector: org.apache.arrow.vector.ValueVector): 
Boolean = vector match {
+    case _: LargeVarCharVector | _: LargeVarBinaryVector => true
+    case fv: FieldVector =>
+      fv.getChildrenFromFields.asScala.exists(containsLargeVarType)
+    case _ => false
+  }
+}
+
+/**
+ * Iterator that converts ArrowCachedBatch to ColumnarBatch.
+ */
+private class ArrowCachedBatchToColumnarBatchIterator(
+    batchIter: Iterator[CachedBatch],
+    cacheSchema: StructType,
+    selectedSchema: StructType,
+    columnIndices: Array[Int],
+    timeZoneId: String,
+    prefetchEnabled: Boolean = false) extends Iterator[ColumnarBatch] {
+
+  import java.util.concurrent.{Callable, ExecutionException, Executors, 
ExecutorService, Future}
+
+  private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+    
s"ArrowCachedBatchToColumnarBatchIterator-${TaskContext.get().taskAttemptId()}",
+    0,
+    Long.MaxValue)
+
+  private val arrowSchema = ArrowUtils.toArrowSchema(cacheSchema, timeZoneId, 
false, false)
+
+  // Track only the previous root to close it when next batch is produced
+  private var previousRoot: VectorSchemaRoot = null
+
+  // Prefetch support: deserialize the next batch into its own root in a 
background thread while
+  // the current batch is being consumed. Only the deserialization (IPC read + 
decompression +
+  // loading into a fresh root) happens off-thread; closing the previous root 
stays on the
+  // consumer thread in next(), so the vectors backing a returned 
ColumnarBatch are never released
+  // while the consumer may still read them.
+  private val prefetchExecutor: ExecutorService = if (prefetchEnabled) {
+    Executors.newSingleThreadExecutor(r => {
+      val t = new Thread(r, "arrow-cache-prefetch")
+      t.setDaemon(true)
+      t
+    })
+  } else {
+    null
+  }
+  private var prefetchFuture: Future[VectorSchemaRoot] = _
+
+  // Register cleanup - close remaining root and allocator when task completes
+  Option(TaskContext.get()).foreach { tc =>
+    tc.addTaskCompletionListener[Unit] { _ =>
+      if (prefetchFuture != null) {
+        prefetchFuture.cancel(true)

Review Comment:
   [P1] I rechecked the cleanup change on the current head, and an interrupted 
task can still race allocator shutdown. If the completion listener runs with 
the task thread interrupted, `awaitTermination` throws; the catch restores the 
interrupt and calls `shutdownNow()` without joining the worker. When the future 
is still in flight, the subsequent `future.get()` immediately throws 
`InterruptedException`, so the caller proceeds to `allocator.close()` while 
Arrow deserialization may still allocate or return an unclosed root. Please 
drain/join uninterruptibly, close any result, then restore the interrupt, with 
killed-task coverage for both row and columnar readers.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala:
##########
@@ -0,0 +1,1371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.channels.Channels
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
+import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector.compression.{CompressionCodec, 
NoCompressionCodec}
+import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel}
+import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, 
MessageSerializer}
+
+import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.columnar.{CachedBatch, 
SimpleMetricsCachedBatchSerializer}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
+
+/**
+ * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format.
+ *
+ * This serializer:
+ *  - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input
+ *  - Stores data in Arrow IPC streaming format with optional compression 
(zstd/lz4)
+ *  - Enables zero-copy columnar reads when output is ColumnarBatch
+ *  - Uses off-heap memory via Arrow allocators
+ *  - Collects per-column statistics for partition pruning
+ *  - Provides efficient interoperability with Arrow ecosystem
+ *
+ * Configuration options:
+ *  - spark.sql.cache.serializer: Set to this class name to enable
+ *  - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch
+ *  - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4)
+ *  - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable 
columnar output
+ */
+class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer {
+
+  override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = {
+    // Check if all data types in the schema are supported by Arrow
+    schema.forall(attr => ArrowUtils.isSupportedByArrow(attr.dataType))
+  }
+
+  override def convertInternalRowToCachedBatch(
+      input: RDD[InternalRow],
+      schema: Seq[Attribute],
+      storageLevel: StorageLevel,
+      conf: SQLConf): RDD[CachedBatch] = {
+    // Capture config values on driver before RDD transformation
+    val sparkSchema = DataTypeUtils.fromAttributes(schema)
+    val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch

Review Comment:
   [P1] This remains incomplete in two ways. The columnar path still captures 
neither limit and explicitly maps one input batch to one cached batch. A normal 
repro is to cache relation A with `maxRecordsPerBatch=10000`, change the 
session setting to `1`, then cache a projection of A: the Arrow-columnar 
recache preserves the 10,000-row input batches. On the row path, every 
non-`UnsafeRow` is estimated as only `numFields * 16`, so a one-field 
`GenericInternalRow` containing a 100 MiB string contributes 16 bytes to the 
nominal 64 MiB guard. Please split/slice columnar input and use the actual 
Arrow writer size (or equivalent type-aware sizing) for generic rows.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to