This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new def03c2b9583 [SPARK-54226][SQL] Extend Arrow compression to Pandas UDF
def03c2b9583 is described below
commit def03c2b9583c71d483d305b5bbba9ade39ed29a
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Nov 7 06:46:16 2025 -0800
[SPARK-54226][SQL] Extend Arrow compression to Pandas UDF
### What changes were proposed in this pull request?
This is an extension to https://github.com/apache/spark/pull/52747. In
https://github.com/apache/spark/pull/52747, we add the support of Arrow
compression to `toArrow` and `toPandas` to reduce memory usage. We would like
to extend the memory optimization feature to Pandas UDF case.
### Why are the changes needed?
To optimize memory usage for Pandas UDF case.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit test
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code v2.0.14
Closes #52925 from viirya/arrow_compress_udf.
Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
(cherry picked from commit 96ed48db3879fc9e2d250c4548eec4409061d2de)
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../sql/tests/pandas/test_pandas_grouped_map.py | 60 ++++++++++++++++
.../tests/pandas/test_pandas_udf_grouped_agg.py | 36 ++++++++++
.../sql/tests/pandas/test_pandas_udf_scalar.py | 56 +++++++++++++++
.../sql/execution/arrow/ArrowWriterWrapper.scala | 8 ++-
.../org/apache/spark/sql/internal/SQLConf.scala | 15 ++++
.../sql/execution/arrow/ArrowConverters.scala | 6 +-
.../python/CoGroupedArrowPythonRunner.scala | 39 ++++++++++-
.../sql/execution/python/PythonArrowInput.scala | 79 ++++++++++++++++++++--
8 files changed, 285 insertions(+), 14 deletions(-)
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index ef84673179dc..b60c5a187fbf 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -1406,6 +1406,66 @@ class ApplyInPandasTestsMixin:
actual = grouped_df.applyInPandas(func, "value long").collect()
self.assertEqual(actual, expected)
+ def test_grouped_map_pandas_udf_with_compression_codec(self):
+ # Test grouped map Pandas UDF with different compression codec settings
+ @pandas_udf("id long, v int, v1 double", PandasUDFType.GROUPED_MAP)
+ def foo(pdf):
+ return pdf.assign(v1=pdf.v * pdf.id * 1.0)
+
+ df = self.data
+ pdf = df.toPandas()
+ expected = pdf.groupby("id",
as_index=False).apply(foo.func).reset_index(drop=True)
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+ result = df.groupby("id").apply(foo).sort("id").toPandas()
+ assert_frame_equal(expected, result)
+
+ def test_apply_in_pandas_with_compression_codec(self):
+ # Test applyInPandas with different compression codec settings
+ def stats(key, pdf):
+ return pd.DataFrame([(key[0], pdf.v.mean())], columns=["id",
"mean"])
+
+ df = self.data
+ expected = df.select("id").distinct().withColumn("mean",
sf.lit(24.5)).toPandas()
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+ result = (
+ df.groupby("id")
+ .applyInPandas(stats, schema="id long, mean double")
+ .sort("id")
+ .toPandas()
+ )
+ assert_frame_equal(expected, result)
+
+ def test_apply_in_pandas_iterator_with_compression_codec(self):
+ # Test applyInPandas with iterator and compression
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ def sum_func(batches: Iterator[pd.DataFrame]) ->
Iterator[pd.DataFrame]:
+ total = 0
+ for batch in batches:
+ total += batch["v"].sum()
+ yield pd.DataFrame({"v": [total]})
+
+ expected = [Row(v=3.0), Row(v=18.0)]
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+ result = (
+ df.groupby("id")
+ .applyInPandas(sum_func, schema="v double")
+ .orderBy("v")
+ .collect()
+ )
+ self.assertEqual(result, expected)
+
class ApplyInPandasTests(ApplyInPandasTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index 2b3e42312df9..2958d0e67f1e 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -861,6 +861,42 @@ class GroupedAggPandasUDFTestsMixin:
],
)
+ def test_grouped_agg_pandas_udf_with_compression_codec(self):
+ # Test grouped agg Pandas UDF with different compression codec settings
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def sum_udf(v):
+ return v.sum()
+
+ df = self.data
+ expected = df.groupby("id").agg(sum_udf(df.v)).sort("id").toPandas()
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+ result =
df.groupby("id").agg(sum_udf(df.v)).sort("id").toPandas()
+ assert_frame_equal(expected, result)
+
+ def test_grouped_agg_pandas_udf_with_compression_codec_complex(self):
+ # Test grouped agg with multiple UDFs and compression
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def mean_udf(v):
+ return v.mean()
+
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def sum_udf(v):
+ return v.sum()
+
+ df = self.data
+ expected = df.groupby("id").agg(mean_udf(df.v),
sum_udf(df.v)).sort("id").toPandas()
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+ result = (
+ df.groupby("id").agg(mean_udf(df.v),
sum_udf(df.v)).sort("id").toPandas()
+ )
+ assert_frame_equal(expected, result)
+
class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin,
ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index fbfe1a226b5e..554c994afc1e 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -1988,6 +1988,62 @@ class ScalarPandasUDFTestsMixin:
],
)
+ def test_scalar_pandas_udf_with_compression_codec(self):
+ # Test scalar Pandas UDF with different compression codec settings
+ @pandas_udf("long")
+ def plus_one(v):
+ return v + 1
+
+ df = self.spark.range(100)
+ expected = [Row(result=i + 1) for i in range(100)]
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+ result =
df.select(plus_one("id").alias("result")).collect()
+ self.assertEqual(expected, result)
+
+ def test_scalar_pandas_udf_with_compression_codec_complex_types(self):
+ # Test scalar Pandas UDF with compression for complex types (strings,
arrays)
+ @pandas_udf("string")
+ def concat_string(v):
+ return v.apply(lambda x: "value_" + str(x))
+
+ @pandas_udf(ArrayType(IntegerType()))
+ def create_array(v):
+ return v.apply(lambda x: [x, x * 2, x * 3])
+
+ df = self.spark.range(50)
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+ # Test string UDF
+ result =
df.select(concat_string("id").alias("result")).collect()
+ expected = [Row(result=f"value_{i}") for i in range(50)]
+ self.assertEqual(expected, result)
+
+ # Test array UDF
+ result =
df.select(create_array("id").alias("result")).collect()
+ expected = [Row(result=[i, i * 2, i * 3]) for i in
range(50)]
+ self.assertEqual(expected, result)
+
+ def test_scalar_iter_pandas_udf_with_compression_codec(self):
+ # Test scalar iterator Pandas UDF with compression
+ @pandas_udf("long", PandasUDFType.SCALAR_ITER)
+ def plus_two(iterator):
+ for s in iterator:
+ yield s + 2
+
+ df = self.spark.range(100)
+ expected = [Row(result=i + 2) for i in range(100)]
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+ result =
df.select(plus_two("id").alias("result")).collect()
+ self.assertEqual(expected, result)
+
class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala
index 6c5799bd241b..c04bae07f67d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow
import java.io.DataOutputStream
import org.apache.arrow.memory.BufferAllocator
-import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.TaskContext
@@ -34,6 +34,7 @@ case class ArrowWriterWrapper(
var arrowWriter: SparkArrowWriter,
var root: VectorSchemaRoot,
var allocator: BufferAllocator,
+ var unloader: VectorUnloader,
context: TaskContext) {
@volatile var isClosed = false
@@ -58,6 +59,7 @@ case class ArrowWriterWrapper(
arrowWriter = null
root = null
allocator = null
+ unloader = null
}
}
}
@@ -77,8 +79,10 @@ object ArrowWriterWrapper {
s"stdout writer for $allocatorOwner", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val arrowWriter = SparkArrowWriter.create(root)
+
val streamWriter = new ArrowStreamWriter(root, null, dataOut)
streamWriter.start()
- ArrowWriterWrapper(streamWriter, arrowWriter, root, allocator, context)
+ // Unloader will be set by the caller after creation
+ ArrowWriterWrapper(streamWriter, arrowWriter, root, allocator, null,
context)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b8907629ad37..a367971a2fc8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4001,6 +4001,19 @@ object SQLConf {
.checkValues(Set("none", "zstd", "lz4"))
.createWithDefault("none")
+ val ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL =
+ buildConf("spark.sql.execution.arrow.zstd.compressionLevel")
+ .doc("Compression level for Zstandard (zstd) codec when compressing
Arrow IPC data. " +
+ "This config is only used when
spark.sql.execution.arrow.compressionCodec is set to " +
+ "'zstd'. Valid values are integers from 1 (fastest, lowest
compression) to 22 " +
+ "(slowest, highest compression). The default value 3 provides a good
balance between " +
+ "compression speed and compression ratio.")
+ .version("4.1.0")
+ .intConf
+ .checkValue(level => level >= 1 && level <= 22,
+ "Zstd compression level must be between 1 and 22")
+ .createWithDefault(3)
+
val ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH =
buildConf("spark.sql.execution.arrow.transformWithStateInPySpark.maxStateRecordsPerBatch")
.doc("When using TransformWithState in PySpark (both Python Row and
Pandas), limit " +
@@ -7348,6 +7361,8 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def arrowCompressionCodec: String =
getConf(ARROW_EXECUTION_COMPRESSION_CODEC)
+ def arrowZstdCompressionLevel: Int =
getConf(ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL)
+
def arrowTransformWithStateInPySparkMaxStateRecordsPerBatch: Int =
getConf(ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 7f260bd2efd0..8b031af14e8b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -31,6 +31,7 @@ import org.apache.arrow.vector.compression.{CompressionCodec,
NoCompressionCodec
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter,
ReadChannel, WriteChannel}
import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption,
MessageSerializer}
+import org.apache.spark.SparkException
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
@@ -102,15 +103,16 @@ private[sql] object ArrowConverters extends Logging {
private val codec = compressionCodecName match {
case "none" => NoCompressionCodec.INSTANCE
case "zstd" =>
+ val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
val factory = CompressionCodec.Factory.INSTANCE
- val codecType = new ZstdCompressionCodec().getCodecType()
+ val codecType = new
ZstdCompressionCodec(compressionLevel).getCodecType()
factory.createCodec(codecType)
case "lz4" =>
val factory = CompressionCodec.Factory.INSTANCE
val codecType = new Lz4CompressionCodec().getCodecType()
factory.createCodec(codecType)
case other =>
- throw new IllegalArgumentException(
+ throw SparkException.internalError(
s"Unsupported Arrow compression codec: $other. Supported values:
none, zstd, lz4")
}
protected val unloader = new VectorUnloader(root, true, codec, true)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 00eb9039d05c..50013e533819 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -20,7 +20,11 @@ package org.apache.spark.sql.execution.python
import java.io.DataOutputStream
import java.util
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
+import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector.compression.{CompressionCodec,
NoCompressionCodec}
+
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions,
PythonRDD, PythonWorker}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriterWrapper
@@ -76,6 +80,27 @@ class CoGroupedArrowPythonRunner(
if (v > 0) v else Int.MaxValue
}
private val maxBytesPerBatch: Long = SQLConf.get.arrowMaxBytesPerBatch
+ private val compressionCodecName: String = SQLConf.get.arrowCompressionCodec
+
+ // Helper method to create VectorUnloader with compression
+ private def createUnloader(root: VectorSchemaRoot): VectorUnloader = {
+ val codec = compressionCodecName match {
+ case "none" => NoCompressionCodec.INSTANCE
+ case "zstd" =>
+ val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new
ZstdCompressionCodec(compressionLevel).getCodecType()
+ factory.createCodec(codecType)
+ case "lz4" =>
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new Lz4CompressionCodec().getCodecType()
+ factory.createCodec(codecType)
+ case other =>
+ throw SparkException.internalError(
+ s"Unsupported Arrow compression codec: $other. Supported values:
none, zstd, lz4")
+ }
+ new VectorUnloader(root, true, codec, true)
+ }
protected def newWriter(
env: SparkEnv,
@@ -136,13 +161,17 @@ class CoGroupedArrowPythonRunner(
leftGroupArrowWriter =
ArrowWriterWrapper.createAndStartArrowWriter(leftSchema,
timeZoneId, pythonExec + " (left)", errorOnDuplicatedFieldNames
= true,
largeVarTypes, dataOut, context)
+ // Set the unloader with compression after creating the writer
+ leftGroupArrowWriter.unloader =
createUnloader(leftGroupArrowWriter.root)
}
numRowsInBatch = BatchedPythonArrowInput.writeSizedBatch(
leftGroupArrowWriter.arrowWriter,
leftGroupArrowWriter.streamWriter,
nextBatchInLeftGroup,
maxBytesPerBatch,
- maxRecordsPerBatch)
+ maxRecordsPerBatch,
+ leftGroupArrowWriter.unloader,
+ dataOut)
if (!nextBatchInLeftGroup.hasNext) {
leftGroupArrowWriter.streamWriter.end()
@@ -155,13 +184,17 @@ class CoGroupedArrowPythonRunner(
rightGroupArrowWriter =
ArrowWriterWrapper.createAndStartArrowWriter(rightSchema,
timeZoneId, pythonExec + " (right)", errorOnDuplicatedFieldNames
= true,
largeVarTypes, dataOut, context)
+ // Set the unloader with compression after creating the writer
+ rightGroupArrowWriter.unloader =
createUnloader(rightGroupArrowWriter.root)
}
numRowsInBatch = BatchedPythonArrowInput.writeSizedBatch(
rightGroupArrowWriter.arrowWriter,
rightGroupArrowWriter.streamWriter,
nextBatchInRightGroup,
maxBytesPerBatch,
- maxRecordsPerBatch)
+ maxRecordsPerBatch,
+ rightGroupArrowWriter.unloader,
+ dataOut)
if (!nextBatchInRightGroup.hasNext) {
rightGroupArrowWriter.streamWriter.end()
rightGroupArrowWriter.close()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index b2ec96c5b29f..f77b0a9342b0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -17,11 +17,16 @@
package org.apache.spark.sql.execution.python
import java.io.DataOutputStream
+import java.nio.channels.Channels
-import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
+import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector.compression.{CompressionCodec,
NoCompressionCodec}
import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.apache.arrow.vector.ipc.WriteChannel
+import org.apache.arrow.vector.ipc.message.MessageSerializer
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow
@@ -70,6 +75,26 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected val allocator =
ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for
$pythonExec", 0, Long.MaxValue)
protected val root = VectorSchemaRoot.create(arrowSchema, allocator)
+
+ // Create compression codec based on config
+ private val compressionCodecName = SQLConf.get.arrowCompressionCodec
+ private val codec = compressionCodecName match {
+ case "none" => NoCompressionCodec.INSTANCE
+ case "zstd" =>
+ val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new ZstdCompressionCodec(compressionLevel).getCodecType()
+ factory.createCodec(codecType)
+ case "lz4" =>
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new Lz4CompressionCodec().getCodecType()
+ factory.createCodec(codecType)
+ case other =>
+ throw SparkException.internalError(
+ s"Unsupported Arrow compression codec: $other. Supported values: none,
zstd, lz4")
+ }
+ protected val unloader = new VectorUnloader(root, true, codec, true)
+
protected var writer: ArrowStreamWriter = _
protected def close(): Unit = {
@@ -137,7 +162,14 @@ private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[In
}
arrowWriter.finish()
- writer.writeBatch()
+ // Use unloader to get compressed batch and write it manually
+ val batch = unloader.getRecordBatch()
+ try {
+ val writeChannel = new WriteChannel(Channels.newChannel(dataOut))
+ MessageSerializer.serialize(writeChannel, batch)
+ } finally {
+ batch.close()
+ }
arrowWriter.reset()
val deltaData = dataOut.size() - startData
pythonMetrics("pythonDataSent") += deltaData
@@ -169,7 +201,8 @@ private[python] trait BatchedPythonArrowInput extends
BasicPythonArrowInput {
val startData = dataOut.size()
val numRowsInBatch = BatchedPythonArrowInput.writeSizedBatch(
- arrowWriter, writer, nextBatchStart, maxBytesPerBatch,
maxRecordsPerBatch)
+ arrowWriter, writer, nextBatchStart, maxBytesPerBatch,
maxRecordsPerBatch, unloader,
+ dataOut)
assert(0 < numRowsInBatch && numRowsInBatch <= maxRecordsPerBatch,
numRowsInBatch)
val deltaData = dataOut.size() - startData
@@ -209,7 +242,9 @@ private[python] object BatchedPythonArrowInput {
writer: ArrowStreamWriter,
rowIter: Iterator[InternalRow],
maxBytesPerBatch: Long,
- maxRecordsPerBatch: Int): Int = {
+ maxRecordsPerBatch: Int,
+ unloader: VectorUnloader,
+ dataOut: DataOutputStream): Int = {
var numRowsInBatch: Int = 0
def underBatchSizeLimit: Boolean =
@@ -221,7 +256,14 @@ private[python] object BatchedPythonArrowInput {
numRowsInBatch += 1
}
arrowWriter.finish()
- writer.writeBatch()
+ // Use unloader to get compressed batch and write it manually
+ val batch = unloader.getRecordBatch()
+ try {
+ val writeChannel = new WriteChannel(Channels.newChannel(dataOut))
+ MessageSerializer.serialize(writeChannel, batch)
+ } finally {
+ batch.close()
+ }
arrowWriter.reset()
numRowsInBatch
}
@@ -231,6 +273,26 @@ private[python] object BatchedPythonArrowInput {
* Enables an optimization that splits each group into the sized batches.
*/
private[python] trait GroupedPythonArrowInput { self:
RowInputArrowPythonRunner =>
+
+ // Helper method to create VectorUnloader with compression for grouped
operations
+ private def createUnloaderForGroup(root: VectorSchemaRoot): VectorUnloader =
{
+ val codec = SQLConf.get.arrowCompressionCodec match {
+ case "none" => NoCompressionCodec.INSTANCE
+ case "zstd" =>
+ val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new
ZstdCompressionCodec(compressionLevel).getCodecType()
+ factory.createCodec(codecType)
+ case "lz4" =>
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new Lz4CompressionCodec().getCodecType()
+ factory.createCodec(codecType)
+ case other =>
+ throw SparkException.internalError(
+ s"Unsupported Arrow compression codec: $other. Supported values:
none, zstd, lz4")
+ }
+ new VectorUnloader(root, true, codec, true)
+ }
protected override def newWriter(
env: SparkEnv,
worker: PythonWorker,
@@ -255,13 +317,16 @@ private[python] trait GroupedPythonArrowInput { self:
RowInputArrowPythonRunner
writer = ArrowWriterWrapper.createAndStartArrowWriter(
schema, timeZoneId, pythonExec,
errorOnDuplicatedFieldNames, largeVarTypes, dataOut, context)
+ // Set the unloader with compression after creating the writer
+ writer.unloader = createUnloaderForGroup(writer.root)
nextBatchStart = inputIterator.next()
}
}
if (nextBatchStart.hasNext) {
val startData = dataOut.size()
val numRowsInBatch: Int =
BatchedPythonArrowInput.writeSizedBatch(writer.arrowWriter,
- writer.streamWriter, nextBatchStart, maxBytesPerBatch,
maxRecordsPerBatch)
+ writer.streamWriter, nextBatchStart, maxBytesPerBatch,
maxRecordsPerBatch,
+ writer.unloader, dataOut)
if (!nextBatchStart.hasNext) {
writer.streamWriter.end()
// We don't need a try catch block here as the close() method is
registered with
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]