viirya commented on code in PR #56334:
URL: https://github.com/apache/spark/pull/56334#discussion_r3392575590


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala:
##########
@@ -0,0 +1,1373 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.channels.Channels
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
+import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector.compression.{CompressionCodec, 
NoCompressionCodec}
+import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel}
+import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, 
MessageSerializer}
+
+import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.columnar.{CachedBatch, 
SimpleMetricsCachedBatchSerializer}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
+
+/**
+ * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format.
+ *
+ * This serializer:
+ *  - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input
+ *  - Stores data in Arrow IPC streaming format with optional compression 
(zstd/lz4)
+ *  - Enables zero-copy columnar reads when output is ColumnarBatch
+ *  - Uses off-heap memory via Arrow allocators
+ *  - Collects per-column statistics for partition pruning
+ *  - Provides efficient interoperability with Arrow ecosystem
+ *
+ * Configuration options:
+ *  - spark.sql.cache.serializer: Set to this class name to enable
+ *  - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch
+ *  - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4)
+ *  - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable 
columnar output
+ */
+class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer {
+
+  override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = {
+    // Check if all data types in the schema are supported by Arrow
+    schema.forall(attr => ArrowUtils.isSupportedByArrow(attr.dataType))
+  }
+
+  override def convertInternalRowToCachedBatch(
+      input: RDD[InternalRow],
+      schema: Seq[Attribute],
+      storageLevel: StorageLevel,
+      conf: SQLConf): RDD[CachedBatch] = {
+    // Capture config values on driver before RDD transformation
+    val sparkSchema = DataTypeUtils.fromAttributes(schema)
+    val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch
+    val timeZoneId = conf.sessionLocalTimeZone
+    val compressionCodecName = conf.arrowCompressionCodec
+    val compressionLevel = conf.arrowZstdCompressionLevel
+
+    input.mapPartitionsInternal { rowIterator =>
+      new InternalRowToArrowCachedBatchIterator(
+        rowIterator,
+        schema,
+        sparkSchema,
+        maxRecordsPerBatch,
+        timeZoneId,
+        compressionCodecName,
+        compressionLevel)
+    }
+  }
+
+  override def convertColumnarBatchToCachedBatch(
+      input: RDD[ColumnarBatch],
+      schema: Seq[Attribute],
+      storageLevel: StorageLevel,
+      conf: SQLConf): RDD[CachedBatch] = {
+    // Capture config values on driver before RDD transformation
+    val sparkSchema = DataTypeUtils.fromAttributes(schema)
+    val timeZoneId = conf.sessionLocalTimeZone
+    val compressionCodecName = conf.arrowCompressionCodec
+    val compressionLevel = conf.arrowZstdCompressionLevel
+
+    input.mapPartitionsInternal { batchIterator =>
+      new ColumnarBatchToArrowCachedBatchIterator(
+        batchIterator,
+        schema,
+        sparkSchema,
+        timeZoneId,
+        compressionCodecName,
+        compressionLevel)
+    }
+  }
+
+  override def supportsColumnarOutput(schema: StructType): Boolean = {
+    // Always support columnar output with Arrow
+    true
+  }
+
+  override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): 
Option[Seq[String]] = {
+    Option(Seq.fill(attributes.length)(classOf[ArrowColumnVector].getName))
+  }
+
+  override def convertCachedBatchToColumnarBatch(
+      input: RDD[CachedBatch],
+      cacheAttributes: Seq[Attribute],
+      selectedAttributes: Seq[Attribute],
+      conf: SQLConf): RDD[ColumnarBatch] = {
+    val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes)
+    val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes)
+    val columnIndices =
+      selectedAttributes.map(a => cacheAttributes.map(o => 
o.exprId).indexOf(a.exprId)).toArray
+    // Capture config on driver
+    val timeZoneId = conf.sessionLocalTimeZone
+    val prefetchEnabled = conf.arrowCachePrefetchEnabled
+
+    input.mapPartitionsInternal { batchIterator =>
+      new ArrowCachedBatchToColumnarBatchIterator(
+        batchIterator,
+        cacheSchema,
+        selectedSchema,
+        columnIndices,
+        timeZoneId,
+        prefetchEnabled)
+    }
+  }
+
+  override def convertCachedBatchToInternalRow(
+      input: RDD[CachedBatch],
+      cacheAttributes: Seq[Attribute],
+      selectedAttributes: Seq[Attribute],
+      conf: SQLConf): RDD[InternalRow] = {
+    val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes)
+    val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes)
+    val timeZoneId = conf.sessionLocalTimeZone
+
+    // Calculate column indices for projection
+    val selectedIndices = selectedAttributes.map { attr =>
+      cacheAttributes.indexWhere(_.exprId == attr.exprId)
+    }.toArray
+
+    // Check if all selected types can use the fast path.
+    // Types not handled by ArrowColumnReader must use the fallback path.
+    val needsFallback = selectedSchema.fields.exists { f =>
+      f.dataType match {
+        case _: ArrayType | _: StructType | _: MapType => true
+        case CalendarIntervalType | VariantType | NullType => true
+        case _: UserDefinedType[_] => true
+        // Geometry/Geography are represented as an Arrow struct (srid + wkb); 
the fast-path
+        // ArrowColumnReader does not handle them, so route them through the 
fallback.
+        case _: GeometryType | _: GeographyType => true
+        case _ => false
+      }
+    }
+
+    if (needsFallback) {
+      // Fall back to columnar-to-row conversion via ColumnarBatch for complex 
types.
+      // Use UnsafeProjection to convert ColumnarBatchRow to UnsafeRow.
+      convertCachedBatchToColumnarBatch(input, cacheAttributes, 
selectedAttributes, conf)
+        .mapPartitionsInternal { batchIter =>
+          val toUnsafe = 
org.apache.spark.sql.catalyst.expressions.UnsafeProjection.create(
+            selectedSchema)
+          batchIter.flatMap { batch =>
+            val numRows = batch.numRows()
+            new Iterator[InternalRow] {
+              private var rowIdx = 0
+              override def hasNext: Boolean = rowIdx < numRows
+              override def next(): InternalRow = {
+                val row = batch.getRow(rowIdx)
+                rowIdx += 1
+                toUnsafe(row)
+              }
+            }
+          }
+        }
+    } else {
+      val prefetchEnabled = conf.arrowCachePrefetchEnabled
+      input.mapPartitionsInternal { batchIterator =>
+        new ArrowCachedBatchToInternalRowIterator(
+          batchIterator,
+          cacheSchema,
+          selectedSchema,
+          selectedIndices,
+          timeZoneId,
+          prefetchEnabled)
+      }
+    }
+  }
+}
+
+/**
+ * Companion object with shared utility methods for Arrow cache serialization.
+ */
+private object ArrowCachedBatchSerializer {
+
+  // scalastyle:off caselocale
+  def createCompressionCodec(
+      codecName: String,
+      compressionLevel: Int): CompressionCodec = {
+    codecName.toLowerCase match {
+      case "none" => NoCompressionCodec.INSTANCE
+      case "zstd" =>
+        val factory = CompressionCodec.Factory.INSTANCE
+        val codecType = new 
ZstdCompressionCodec(compressionLevel).getCodecType()
+        factory.createCodec(codecType)

Review Comment:
   Fixed by constructing the codec directly as you suggested -- 
`createCompressionCodec` now returns `new 
ZstdCompressionCodec(compressionLevel)` instead of rebuilding through the 
single-argument factory overload (the lz4 arm is simplified the same way). A 
comment notes why the factory must not be used here: the level only matters on 
the write side, since reads look up the codec by the type recorded in the IPC 
message.
   
   Digging into why the committed benchmark results never caught this, I found 
a second, independent bug: `ArrowCacheBenchmark` was setting a nonexistent conf 
key (`spark.sql.execution.arrow.compression.level` instead of 
`spark.sql.execution.arrow.compression.zstd.level`), which `spark.conf.set` 
accepts silently. So the per-level benchmark rows were doubly broken -- even 
with the codec fixed, the level would never have reached it. The benchmark now 
references the `SQLConf` key constants so a typo can no longer bind to nothing. 
The doc page had the same wrong key plus a few other inaccuracies (codec 
default listed as `zstd` instead of `none`, level range stated as 1-22 though 
negative fast levels are supported, and a `maxRecordsPerBatch` key missing the 
`execution` segment); all fixed.
   
   Added a regression test that compresses the same batch at zstd level -5 and 
19 and asserts the higher level yields a strictly smaller payload; it fails 
against the previous codec construction. The per-level rows in the committed 
benchmark result files are stale (all three levels effectively measured the 
default level -- they are within noise of each other, which corroborates your 
finding); I will regenerate them with the benchmark GitHub Actions workflow.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to