sunchao commented on code in PR #56334: URL: https://github.com/apache/spark/pull/56334#discussion_r3503140394
########## sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala: ########## @@ -0,0 +1,1533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format. + * + * This serializer: + * - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input + * - Stores data in Arrow IPC streaming format with optional compression (zstd/lz4) + * - Enables zero-copy columnar reads when output is ColumnarBatch + * - Uses off-heap memory via Arrow allocators + * - Collects per-column statistics for partition pruning + * - Provides efficient interoperability with Arrow ecosystem + * + * Configuration options: + * - spark.sql.cache.serializer: Set to this class name to enable + * - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch + * - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4) + * - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable columnar output + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + // supportsColumnarInput selects the columnar-vs-row input path; it does not gate which schemas + // this serializer accepts. The cache framework has no per-type fallback to another serializer + // (whatever spark.sql.cache.serializer selects handles every cached relation), so returning + // false here only routes input through convertInternalRowToCachedBatch, which is still this + // serializer. Type support is enforced once per partition by checkSupportedSchema below; the + // only real precondition for columnar input is that the plan can produce columnar output, which + // InMemoryRelation already checks via cachedPlan.supportsColumnar before calling this. + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = true + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + ArrowCachedBatchSerializer.checkSupportedSchema(schema) + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch + val maxBytesPerBatch = conf.arrowMaxBytesPerBatch + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { rowIterator => + new InternalRowToArrowCachedBatchIterator( + rowIterator, + schema, + sparkSchema, + maxRecordsPerBatch, + maxBytesPerBatch, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + ArrowCachedBatchSerializer.checkSupportedSchema(schema) + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { batchIterator => + new ColumnarBatchToArrowCachedBatchIterator( + batchIterator, + schema, + sparkSchema, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def supportsColumnarOutput(schema: StructType): Boolean = { + // Always support columnar output with Arrow + true + } + + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = { + Option(Seq.fill(attributes.length)(classOf[ArrowColumnVector].getName)) + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val columnIndices = + selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray + // Capture config on driver + val timeZoneId = conf.sessionLocalTimeZone + val prefetchEnabled = conf.arrowCachePrefetchEnabled + + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToColumnarBatchIterator( + batchIterator, + cacheSchema, + selectedSchema, + columnIndices, + timeZoneId, + prefetchEnabled) + } + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val timeZoneId = conf.sessionLocalTimeZone + + // Calculate column indices for projection + val selectedIndices = selectedAttributes.map { attr => + cacheAttributes.indexWhere(_.exprId == attr.exprId) + }.toArray + + // Check if all selected types can use the fast path. + // Types not handled by ArrowColumnReader must use the fallback path. + val needsFallback = selectedSchema.fields.exists { f => + f.dataType match { + case _: ArrayType | _: StructType | _: MapType => true + case CalendarIntervalType | VariantType | NullType => true + case _: UserDefinedType[_] => true + // Geometry/Geography are represented as an Arrow struct (srid + wkb); the fast-path + // ArrowColumnReader does not handle them, so route them through the fallback. + case _: GeometryType | _: GeographyType => true + case _ => false + } + } + + if (needsFallback) { + // Fall back to columnar-to-row conversion via ColumnarBatch for complex types. + // Use UnsafeProjection to convert ColumnarBatchRow to UnsafeRow. + convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + .mapPartitionsInternal { batchIter => + val toUnsafe = org.apache.spark.sql.catalyst.expressions.UnsafeProjection.create( + selectedSchema) + batchIter.flatMap { batch => + val numRows = batch.numRows() + new Iterator[InternalRow] { + private var rowIdx = 0 + override def hasNext: Boolean = rowIdx < numRows + override def next(): InternalRow = { + val row = batch.getRow(rowIdx) + rowIdx += 1 + toUnsafe(row) + } + } + } + } + } else { + val prefetchEnabled = conf.arrowCachePrefetchEnabled + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToInternalRowIterator( + batchIterator, + cacheSchema, + selectedSchema, + selectedIndices, + timeZoneId, + prefetchEnabled) + } + } + } +} + +/** + * Companion object with shared utility methods for Arrow cache serialization. + */ +private object ArrowCachedBatchSerializer { + + /** + * Run an Arrow write block, translating a CalendarInterval microsecond overflow into a clear + * error. Arrow's IntervalMonthDayNano representation is nanosecond-based, so writing a + * CalendarInterval multiplies its microseconds by 1000 with Math.multiplyExact; Spark allows the + * full Long microsecond domain, so values beyond Long.MaxValue/1000 overflow and otherwise abort + * with an opaque "long overflow" ArithmeticException. The catch is only installed when the schema + * actually contains a CalendarInterval column (hasInterval), so there is no per-row cost and no + * effect on schemas without intervals; the try is entered once per batch, not per row. + */ + def withIntervalOverflowTranslation[T](hasInterval: Boolean)(block: => T): T = { + if (!hasInterval) { + block + } else { + try { + block + } catch { + case e: ArithmeticException => + throw SparkException.internalError( + "Arrow cache cannot represent a CalendarInterval whose microseconds exceed " + + "+/-(Long.MaxValue / 1000): Arrow stores intervals in nanoseconds and the " + + s"conversion overflows. Original error: ${e.getMessage}") + } + } + } + + /** Whether the schema has a top-level CalendarInterval column (the only overflow-prone type). */ + def hasCalendarInterval(schema: Seq[Attribute]): Boolean = + schema.exists(_.dataType == CalendarIntervalType) + + /** + * Fail fast, once per partition on the driver-facing entry points, if any column type cannot be + * represented by the Arrow cache. This is the actual capability gate (supportsColumnarInput only + * chooses the input path). Without it, an unsupported type would otherwise surface as a less + * obvious failure deeper in schema conversion or statistics collection. + */ + def checkSupportedSchema(schema: Seq[Attribute]): Unit = { + schema.find(attr => !ArrowUtils.isSupportedByArrow(attr.dataType)).foreach { attr => + throw SparkException.internalError( Review Comment: [P2] Please surface unsupported schemas with the documented user-facing condition rather than `INTERNAL_ERROR`. `checkSupportedSchema` explicitly handles unsupported types such as `ObjectType`, but `SparkException.internalError` sets condition `INTERNAL_ERROR`, while the guide promises `UNSUPPORTED_DATATYPE`. A focused invocation confirms `getCondition == INTERNAL_ERROR`, and the current test asserts only message text. Please use the existing structured unsupported-datatype error (or an equivalent structured error preserving column context) and assert the condition. _\[ :robot: posted by Codex on behalf of sunchao using the code-review-for-me skill :robot: \]_ ########## sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala: ########## @@ -0,0 +1,1533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format. + * + * This serializer: + * - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input + * - Stores data in Arrow IPC streaming format with optional compression (zstd/lz4) + * - Enables zero-copy columnar reads when output is ColumnarBatch + * - Uses off-heap memory via Arrow allocators + * - Collects per-column statistics for partition pruning + * - Provides efficient interoperability with Arrow ecosystem + * + * Configuration options: + * - spark.sql.cache.serializer: Set to this class name to enable + * - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch + * - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4) + * - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable columnar output + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + // supportsColumnarInput selects the columnar-vs-row input path; it does not gate which schemas + // this serializer accepts. The cache framework has no per-type fallback to another serializer + // (whatever spark.sql.cache.serializer selects handles every cached relation), so returning + // false here only routes input through convertInternalRowToCachedBatch, which is still this + // serializer. Type support is enforced once per partition by checkSupportedSchema below; the + // only real precondition for columnar input is that the plan can produce columnar output, which + // InMemoryRelation already checks via cachedPlan.supportsColumnar before calling this. + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = true + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + ArrowCachedBatchSerializer.checkSupportedSchema(schema) + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch + val maxBytesPerBatch = conf.arrowMaxBytesPerBatch + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { rowIterator => + new InternalRowToArrowCachedBatchIterator( + rowIterator, + schema, + sparkSchema, + maxRecordsPerBatch, + maxBytesPerBatch, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + ArrowCachedBatchSerializer.checkSupportedSchema(schema) + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { batchIterator => + new ColumnarBatchToArrowCachedBatchIterator( + batchIterator, + schema, + sparkSchema, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def supportsColumnarOutput(schema: StructType): Boolean = { + // Always support columnar output with Arrow + true + } + + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = { + Option(Seq.fill(attributes.length)(classOf[ArrowColumnVector].getName)) + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val columnIndices = + selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray + // Capture config on driver + val timeZoneId = conf.sessionLocalTimeZone + val prefetchEnabled = conf.arrowCachePrefetchEnabled + + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToColumnarBatchIterator( + batchIterator, + cacheSchema, + selectedSchema, + columnIndices, + timeZoneId, + prefetchEnabled) + } + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val timeZoneId = conf.sessionLocalTimeZone + + // Calculate column indices for projection + val selectedIndices = selectedAttributes.map { attr => + cacheAttributes.indexWhere(_.exprId == attr.exprId) + }.toArray + + // Check if all selected types can use the fast path. + // Types not handled by ArrowColumnReader must use the fallback path. + val needsFallback = selectedSchema.fields.exists { f => + f.dataType match { + case _: ArrayType | _: StructType | _: MapType => true + case CalendarIntervalType | VariantType | NullType => true + case _: UserDefinedType[_] => true + // Geometry/Geography are represented as an Arrow struct (srid + wkb); the fast-path + // ArrowColumnReader does not handle them, so route them through the fallback. + case _: GeometryType | _: GeographyType => true + case _ => false + } + } + + if (needsFallback) { + // Fall back to columnar-to-row conversion via ColumnarBatch for complex types. + // Use UnsafeProjection to convert ColumnarBatchRow to UnsafeRow. + convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + .mapPartitionsInternal { batchIter => + val toUnsafe = org.apache.spark.sql.catalyst.expressions.UnsafeProjection.create( + selectedSchema) + batchIter.flatMap { batch => + val numRows = batch.numRows() + new Iterator[InternalRow] { + private var rowIdx = 0 + override def hasNext: Boolean = rowIdx < numRows + override def next(): InternalRow = { + val row = batch.getRow(rowIdx) + rowIdx += 1 + toUnsafe(row) + } + } + } + } + } else { + val prefetchEnabled = conf.arrowCachePrefetchEnabled + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToInternalRowIterator( + batchIterator, + cacheSchema, + selectedSchema, + selectedIndices, + timeZoneId, + prefetchEnabled) + } + } + } +} + +/** + * Companion object with shared utility methods for Arrow cache serialization. + */ +private object ArrowCachedBatchSerializer { + + /** + * Run an Arrow write block, translating a CalendarInterval microsecond overflow into a clear + * error. Arrow's IntervalMonthDayNano representation is nanosecond-based, so writing a + * CalendarInterval multiplies its microseconds by 1000 with Math.multiplyExact; Spark allows the + * full Long microsecond domain, so values beyond Long.MaxValue/1000 overflow and otherwise abort + * with an opaque "long overflow" ArithmeticException. The catch is only installed when the schema + * actually contains a CalendarInterval column (hasInterval), so there is no per-row cost and no + * effect on schemas without intervals; the try is entered once per batch, not per row. + */ + def withIntervalOverflowTranslation[T](hasInterval: Boolean)(block: => T): T = { + if (!hasInterval) { + block + } else { + try { + block + } catch { + case e: ArithmeticException => + throw SparkException.internalError( + "Arrow cache cannot represent a CalendarInterval whose microseconds exceed " + + "+/-(Long.MaxValue / 1000): Arrow stores intervals in nanoseconds and the " + + s"conversion overflows. Original error: ${e.getMessage}") + } + } + } + + /** Whether the schema has a top-level CalendarInterval column (the only overflow-prone type). */ + def hasCalendarInterval(schema: Seq[Attribute]): Boolean = + schema.exists(_.dataType == CalendarIntervalType) + + /** + * Fail fast, once per partition on the driver-facing entry points, if any column type cannot be + * represented by the Arrow cache. This is the actual capability gate (supportsColumnarInput only + * chooses the input path). Without it, an unsupported type would otherwise surface as a less + * obvious failure deeper in schema conversion or statistics collection. + */ + def checkSupportedSchema(schema: Seq[Attribute]): Unit = { + schema.find(attr => !ArrowUtils.isSupportedByArrow(attr.dataType)).foreach { attr => + throw SparkException.internalError( + s"Arrow cache does not support column '${attr.name}' of type ${attr.dataType.sql}. " + + "Use the default cache serializer for this data, or cast the column to a supported type.") + } + } + + // scalastyle:off caselocale + def createCompressionCodec( + codecName: String, + compressionLevel: Int): CompressionCodec = { + codecName.toLowerCase match { + case "none" => NoCompressionCodec.INSTANCE + // The codec instance must be constructed directly so that compressionLevel is honored: + // CompressionCodec.Factory.createCodec(codecType) ignores the level and builds a codec at + // the default level. The level only matters on the write side; the read side looks up the + // codec by the type recorded in the IPC message. + case "zstd" => new ZstdCompressionCodec(compressionLevel) + case "lz4" => new Lz4CompressionCodec() + case other => + throw SparkException.internalError( + s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4") + } + } + // scalastyle:on caselocale + + def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, batch) Review Comment: [P2] This serializes only a schema-less IPC `RecordBatch` message, not the IPC stream claimed by `ArrowCachedBatch`, this serializer's class comment, and the guide. I verified the exact payload starts with `MessageHeader.RecordBatch` (3), and Arrow 19's standard `ArrowStreamReader` rejects it with `Expected schema but header was 3`. Spark's paired reader works only because it reconstructs the schema out of band and calls `deserializeRecordBatch` directly. Please either emit a complete stream (Schema, batch, and end marker) or describe this narrowly as an internal schema-less RecordBatch payload and remove the ecosystem data-sharing/interoperability claim. _\[ :robot: posted by Codex on behalf of sunchao using the code-review-for-me skill :robot: \]_ ########## docs/sql-arrow-cache-format.md: ########## @@ -0,0 +1,378 @@ +--- +layout: global +title: Apache Arrow Cache Format +displayTitle: Apache Arrow Cache Format +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +## Overview + +Apache Spark supports using Apache Arrow as an alternative cache format for in-memory Dataset caching. This format provides improved performance for certain workloads, especially when working with columnar data sources like Parquet and ORC. + +## Benefits + +The Arrow cache format offers several advantages over the default cache format: + +- **Zero-copy reads** when input is already in Arrow format (e.g., Arrow-based data sources, re-caching Arrow cached data) +- **Better filter pushdown** with min/max statistics for partition pruning +- **Off-heap memory management** via Arrow allocators +- **Efficient compression** with zstd and lz4 codecs +- **Arrow ecosystem interoperability** for data sharing + +**Note**: Spark's built-in Parquet/ORC readers use internal column vectors (`OnHeapColumnVector`/`OffHeapColumnVector`), not Arrow format, so they don't benefit from zero-copy optimization. + +## Configuration + +`spark.sql.cache.serializer` is a static SQL configuration, so it must be set when the +SparkSession is built and cannot be changed on a running session (`spark.conf.set` rejects static +keys with `CANNOT_MODIFY_CONFIG`): + +```scala +val spark = SparkSession.builder() + .appName("MyApp") + .config("spark.sql.cache.serializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + .getOrCreate() +``` + +**Note**: This config selects the cache serializer for the whole session; once set, this +serializer handles every cached relation. There is no automatic per-relation fallback to another +cache serializer based on the data types involved (see +[Supported Data Types](#supported-data-types) for how unsupported types are handled). The chosen +serializer is also cached process-wide on first use, so switching cache formats within a JVM that +has already materialized a cache requires a fresh JVM (see +[Migration from Default Cache](#migration-from-default-cache)). + +## Usage + +Once configured, use cache operations as normal: + +```scala +// Cache a DataFrame +val df = spark.read.parquet("data.parquet") +df.cache() + +// Use cached data +df.filter("age > 30").count() + +// Uncache when done +df.unpersist() +``` + +## Compression + +Arrow cache supports multiple compression codecs. Configure compression with: + +```scala +spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") +``` + +Available options: +- `none` - No compression (fastest, largest size, **default**) +- `lz4` - LZ4 compression (fast, good compression) Review Comment: [P2] This is still not fixed, and Arrow 19 makes the native-library qualification inaccurate. The primary codec list still calls LZ4 `fast`, but Arrow 19's `Lz4CompressionCodec` unconditionally uses Commons Compress's `FramedLZ4Compressor` streams; it does not detect or switch to `lz4-java`, and this PR instantiates that codec directly. Spark already carries `at.yawk.lz4`, which does not change this path. Please remove the speed claim and the `unless the native LZ4 library is on the classpath` guidance (including the benchmark-source note), or replace and benchmark the codec implementation that actually provides a fast path. _\[ :robot: posted by Codex on behalf of sunchao using the code-review-for-me skill :robot: \]_ ########## sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala: ########## @@ -0,0 +1,1533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format. + * + * This serializer: + * - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input + * - Stores data in Arrow IPC streaming format with optional compression (zstd/lz4) + * - Enables zero-copy columnar reads when output is ColumnarBatch + * - Uses off-heap memory via Arrow allocators + * - Collects per-column statistics for partition pruning + * - Provides efficient interoperability with Arrow ecosystem + * + * Configuration options: + * - spark.sql.cache.serializer: Set to this class name to enable + * - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch + * - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4) + * - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable columnar output + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + // supportsColumnarInput selects the columnar-vs-row input path; it does not gate which schemas + // this serializer accepts. The cache framework has no per-type fallback to another serializer + // (whatever spark.sql.cache.serializer selects handles every cached relation), so returning + // false here only routes input through convertInternalRowToCachedBatch, which is still this + // serializer. Type support is enforced once per partition by checkSupportedSchema below; the + // only real precondition for columnar input is that the plan can produce columnar output, which + // InMemoryRelation already checks via cachedPlan.supportsColumnar before calling this. + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = true + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + ArrowCachedBatchSerializer.checkSupportedSchema(schema) + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch + val maxBytesPerBatch = conf.arrowMaxBytesPerBatch + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { rowIterator => + new InternalRowToArrowCachedBatchIterator( + rowIterator, + schema, + sparkSchema, + maxRecordsPerBatch, + maxBytesPerBatch, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + ArrowCachedBatchSerializer.checkSupportedSchema(schema) + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { batchIterator => + new ColumnarBatchToArrowCachedBatchIterator( + batchIterator, + schema, + sparkSchema, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def supportsColumnarOutput(schema: StructType): Boolean = { + // Always support columnar output with Arrow + true + } + + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = { + Option(Seq.fill(attributes.length)(classOf[ArrowColumnVector].getName)) + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val columnIndices = + selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray + // Capture config on driver + val timeZoneId = conf.sessionLocalTimeZone + val prefetchEnabled = conf.arrowCachePrefetchEnabled + + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToColumnarBatchIterator( + batchIterator, + cacheSchema, + selectedSchema, + columnIndices, + timeZoneId, + prefetchEnabled) + } + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val timeZoneId = conf.sessionLocalTimeZone + + // Calculate column indices for projection + val selectedIndices = selectedAttributes.map { attr => + cacheAttributes.indexWhere(_.exprId == attr.exprId) + }.toArray + + // Check if all selected types can use the fast path. + // Types not handled by ArrowColumnReader must use the fallback path. + val needsFallback = selectedSchema.fields.exists { f => + f.dataType match { + case _: ArrayType | _: StructType | _: MapType => true + case CalendarIntervalType | VariantType | NullType => true + case _: UserDefinedType[_] => true + // Geometry/Geography are represented as an Arrow struct (srid + wkb); the fast-path + // ArrowColumnReader does not handle them, so route them through the fallback. + case _: GeometryType | _: GeographyType => true + case _ => false + } + } + + if (needsFallback) { + // Fall back to columnar-to-row conversion via ColumnarBatch for complex types. + // Use UnsafeProjection to convert ColumnarBatchRow to UnsafeRow. + convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + .mapPartitionsInternal { batchIter => + val toUnsafe = org.apache.spark.sql.catalyst.expressions.UnsafeProjection.create( + selectedSchema) + batchIter.flatMap { batch => + val numRows = batch.numRows() + new Iterator[InternalRow] { + private var rowIdx = 0 + override def hasNext: Boolean = rowIdx < numRows + override def next(): InternalRow = { + val row = batch.getRow(rowIdx) + rowIdx += 1 + toUnsafe(row) + } + } + } + } + } else { + val prefetchEnabled = conf.arrowCachePrefetchEnabled + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToInternalRowIterator( + batchIterator, + cacheSchema, + selectedSchema, + selectedIndices, + timeZoneId, + prefetchEnabled) + } + } + } +} + +/** + * Companion object with shared utility methods for Arrow cache serialization. + */ +private object ArrowCachedBatchSerializer { + + /** + * Run an Arrow write block, translating a CalendarInterval microsecond overflow into a clear + * error. Arrow's IntervalMonthDayNano representation is nanosecond-based, so writing a + * CalendarInterval multiplies its microseconds by 1000 with Math.multiplyExact; Spark allows the + * full Long microsecond domain, so values beyond Long.MaxValue/1000 overflow and otherwise abort + * with an opaque "long overflow" ArithmeticException. The catch is only installed when the schema + * actually contains a CalendarInterval column (hasInterval), so there is no per-row cost and no + * effect on schemas without intervals; the try is entered once per batch, not per row. + */ + def withIntervalOverflowTranslation[T](hasInterval: Boolean)(block: => T): T = { + if (!hasInterval) { + block + } else { + try { + block + } catch { + case e: ArithmeticException => + throw SparkException.internalError( + "Arrow cache cannot represent a CalendarInterval whose microseconds exceed " + + "+/-(Long.MaxValue / 1000): Arrow stores intervals in nanoseconds and the " + + s"conversion overflows. Original error: ${e.getMessage}") + } + } + } + + /** Whether the schema has a top-level CalendarInterval column (the only overflow-prone type). */ + def hasCalendarInterval(schema: Seq[Attribute]): Boolean = + schema.exists(_.dataType == CalendarIntervalType) + + /** + * Fail fast, once per partition on the driver-facing entry points, if any column type cannot be + * represented by the Arrow cache. This is the actual capability gate (supportsColumnarInput only + * chooses the input path). Without it, an unsupported type would otherwise surface as a less + * obvious failure deeper in schema conversion or statistics collection. + */ + def checkSupportedSchema(schema: Seq[Attribute]): Unit = { + schema.find(attr => !ArrowUtils.isSupportedByArrow(attr.dataType)).foreach { attr => + throw SparkException.internalError( + s"Arrow cache does not support column '${attr.name}' of type ${attr.dataType.sql}. " + + "Use the default cache serializer for this data, or cast the column to a supported type.") + } + } + + // scalastyle:off caselocale + def createCompressionCodec( + codecName: String, + compressionLevel: Int): CompressionCodec = { + codecName.toLowerCase match { + case "none" => NoCompressionCodec.INSTANCE + // The codec instance must be constructed directly so that compressionLevel is honored: + // CompressionCodec.Factory.createCodec(codecType) ignores the level and builds a codec at + // the default level. The level only matters on the write side; the read side looks up the + // codec by the type recorded in the IPC message. + case "zstd" => new ZstdCompressionCodec(compressionLevel) + case "lz4" => new Lz4CompressionCodec() + case other => + throw SparkException.internalError( + s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4") + } + } + // scalastyle:on caselocale + + def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, batch) + out.toByteArray + } + + /** + * Shut down a prefetch worker during task cleanup without leaking the root it may have produced. + * + * The prefetch worker deserializes the next batch into a fresh [[VectorSchemaRoot]] off-thread. + * If task completion runs while a result is in flight (e.g. a LIMIT consumer stops early), + * cancelling and discarding the future would drop a root that was already (or is about to be) + * produced, and the subsequent `allocator.close()` would fail with "Memory was leaked by query". + * + * This stops accepting new work, waits for the worker to finish so no root is produced after we + * stop looking, then closes any completed result. Always returns null so the caller can null out + * its future reference. Safe to call with a null executor or future. + */ + def drainAndClosePrefetch( + executor: java.util.concurrent.ExecutorService, + future: java.util.concurrent.Future[VectorSchemaRoot]): java.util.concurrent.Future[ + VectorSchemaRoot] = { + // Drain and join the worker uninterruptibly, then close any root it produced, before the + // caller closes the allocator. This runs from a task-completion listener, which can fire with + // the task thread already interrupted (e.g. a killed task). If we let awaitTermination or + // future.get observe the interrupt and bail early, the worker could still be allocating into, + // or have already returned, a root that we then neither join nor close -- and the subsequent + // allocator.close() would race the worker or fail with "Memory was leaked by query". So we + // clear the interrupt for the duration and restore it only at the end. + val wasInterrupted = Thread.interrupted() + try { + if (executor != null) { + executor.shutdown() + var terminated = false + while (!terminated) { + try { + terminated = + executor.awaitTermination(Long.MaxValue, java.util.concurrent.TimeUnit.NANOSECONDS) + } catch { + // Re-clear and keep waiting: we must not leave the worker running. + case _: InterruptedException => Thread.interrupted() Review Comment: [P2] An interrupt delivered while this await is blocked is still lost. `InterruptedException` clears the status before entering this catch, so `Thread.interrupted()` observes false, and `wasInterrupted` records only the entry state. I reproduced this against the exact method by interrupting the caller after `awaitTermination` began; it returned with the interrupt flag clear. The current test pre-interrupts before entry, so it cannot cover this path. Please accumulate every caught interruption in a mutable flag, finish joining/closing, and restore that flag in `finally`. _\[ :robot: posted by Codex on behalf of sunchao using the code-review-for-me skill :robot: \]_ -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
