This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new def03c2b9583 [SPARK-54226][SQL] Extend Arrow compression to Pandas UDF
def03c2b9583 is described below

commit def03c2b9583c71d483d305b5bbba9ade39ed29a
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Nov 7 06:46:16 2025 -0800

    [SPARK-54226][SQL] Extend Arrow compression to Pandas UDF
    
    ### What changes were proposed in this pull request?
    
    This is an extension to https://github.com/apache/spark/pull/52747. In 
https://github.com/apache/spark/pull/52747, we add the support of Arrow 
compression to `toArrow` and `toPandas` to reduce memory usage. We would like 
to extend the memory optimization feature to Pandas UDF case.
    
    ### Why are the changes needed?
    
    To optimize memory usage for Pandas UDF case.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code v2.0.14
    
    Closes #52925 from viirya/arrow_compress_udf.
    
    Authored-by: Liang-Chi Hsieh <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
    (cherry picked from commit 96ed48db3879fc9e2d250c4548eec4409061d2de)
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../sql/tests/pandas/test_pandas_grouped_map.py    | 60 ++++++++++++++++
 .../tests/pandas/test_pandas_udf_grouped_agg.py    | 36 ++++++++++
 .../sql/tests/pandas/test_pandas_udf_scalar.py     | 56 +++++++++++++++
 .../sql/execution/arrow/ArrowWriterWrapper.scala   |  8 ++-
 .../org/apache/spark/sql/internal/SQLConf.scala    | 15 ++++
 .../sql/execution/arrow/ArrowConverters.scala      |  6 +-
 .../python/CoGroupedArrowPythonRunner.scala        | 39 ++++++++++-
 .../sql/execution/python/PythonArrowInput.scala    | 79 ++++++++++++++++++++--
 8 files changed, 285 insertions(+), 14 deletions(-)

diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index ef84673179dc..b60c5a187fbf 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -1406,6 +1406,66 @@ class ApplyInPandasTestsMixin:
             actual = grouped_df.applyInPandas(func, "value long").collect()
             self.assertEqual(actual, expected)
 
+    def test_grouped_map_pandas_udf_with_compression_codec(self):
+        # Test grouped map Pandas UDF with different compression codec settings
+        @pandas_udf("id long, v int, v1 double", PandasUDFType.GROUPED_MAP)
+        def foo(pdf):
+            return pdf.assign(v1=pdf.v * pdf.id * 1.0)
+
+        df = self.data
+        pdf = df.toPandas()
+        expected = pdf.groupby("id", 
as_index=False).apply(foo.func).reset_index(drop=True)
+
+        for codec in ["none", "zstd", "lz4"]:
+            with self.subTest(compressionCodec=codec):
+                with 
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+                    result = df.groupby("id").apply(foo).sort("id").toPandas()
+                    assert_frame_equal(expected, result)
+
+    def test_apply_in_pandas_with_compression_codec(self):
+        # Test applyInPandas with different compression codec settings
+        def stats(key, pdf):
+            return pd.DataFrame([(key[0], pdf.v.mean())], columns=["id", 
"mean"])
+
+        df = self.data
+        expected = df.select("id").distinct().withColumn("mean", 
sf.lit(24.5)).toPandas()
+
+        for codec in ["none", "zstd", "lz4"]:
+            with self.subTest(compressionCodec=codec):
+                with 
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+                    result = (
+                        df.groupby("id")
+                        .applyInPandas(stats, schema="id long, mean double")
+                        .sort("id")
+                        .toPandas()
+                    )
+                    assert_frame_equal(expected, result)
+
+    def test_apply_in_pandas_iterator_with_compression_codec(self):
+        # Test applyInPandas with iterator and compression
+        df = self.spark.createDataFrame(
+            [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+        )
+
+        def sum_func(batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+            total = 0
+            for batch in batches:
+                total += batch["v"].sum()
+            yield pd.DataFrame({"v": [total]})
+
+        expected = [Row(v=3.0), Row(v=18.0)]
+
+        for codec in ["none", "zstd", "lz4"]:
+            with self.subTest(compressionCodec=codec):
+                with 
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+                    result = (
+                        df.groupby("id")
+                        .applyInPandas(sum_func, schema="v double")
+                        .orderBy("v")
+                        .collect()
+                    )
+                    self.assertEqual(result, expected)
+
 
 class ApplyInPandasTests(ApplyInPandasTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index 2b3e42312df9..2958d0e67f1e 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -861,6 +861,42 @@ class GroupedAggPandasUDFTestsMixin:
             ],
         )
 
+    def test_grouped_agg_pandas_udf_with_compression_codec(self):
+        # Test grouped agg Pandas UDF with different compression codec settings
+        @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+        def sum_udf(v):
+            return v.sum()
+
+        df = self.data
+        expected = df.groupby("id").agg(sum_udf(df.v)).sort("id").toPandas()
+
+        for codec in ["none", "zstd", "lz4"]:
+            with self.subTest(compressionCodec=codec):
+                with 
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+                    result = 
df.groupby("id").agg(sum_udf(df.v)).sort("id").toPandas()
+                    assert_frame_equal(expected, result)
+
+    def test_grouped_agg_pandas_udf_with_compression_codec_complex(self):
+        # Test grouped agg with multiple UDFs and compression
+        @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+        def mean_udf(v):
+            return v.mean()
+
+        @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+        def sum_udf(v):
+            return v.sum()
+
+        df = self.data
+        expected = df.groupby("id").agg(mean_udf(df.v), 
sum_udf(df.v)).sort("id").toPandas()
+
+        for codec in ["none", "zstd", "lz4"]:
+            with self.subTest(compressionCodec=codec):
+                with 
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+                    result = (
+                        df.groupby("id").agg(mean_udf(df.v), 
sum_udf(df.v)).sort("id").toPandas()
+                    )
+                    assert_frame_equal(expected, result)
+
 
 class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, 
ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index fbfe1a226b5e..554c994afc1e 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -1988,6 +1988,62 @@ class ScalarPandasUDFTestsMixin:
             ],
         )
 
+    def test_scalar_pandas_udf_with_compression_codec(self):
+        # Test scalar Pandas UDF with different compression codec settings
+        @pandas_udf("long")
+        def plus_one(v):
+            return v + 1
+
+        df = self.spark.range(100)
+        expected = [Row(result=i + 1) for i in range(100)]
+
+        for codec in ["none", "zstd", "lz4"]:
+            with self.subTest(compressionCodec=codec):
+                with 
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+                    result = 
df.select(plus_one("id").alias("result")).collect()
+                    self.assertEqual(expected, result)
+
+    def test_scalar_pandas_udf_with_compression_codec_complex_types(self):
+        # Test scalar Pandas UDF with compression for complex types (strings, 
arrays)
+        @pandas_udf("string")
+        def concat_string(v):
+            return v.apply(lambda x: "value_" + str(x))
+
+        @pandas_udf(ArrayType(IntegerType()))
+        def create_array(v):
+            return v.apply(lambda x: [x, x * 2, x * 3])
+
+        df = self.spark.range(50)
+
+        for codec in ["none", "zstd", "lz4"]:
+            with self.subTest(compressionCodec=codec):
+                with 
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+                    # Test string UDF
+                    result = 
df.select(concat_string("id").alias("result")).collect()
+                    expected = [Row(result=f"value_{i}") for i in range(50)]
+                    self.assertEqual(expected, result)
+
+                    # Test array UDF
+                    result = 
df.select(create_array("id").alias("result")).collect()
+                    expected = [Row(result=[i, i * 2, i * 3]) for i in 
range(50)]
+                    self.assertEqual(expected, result)
+
+    def test_scalar_iter_pandas_udf_with_compression_codec(self):
+        # Test scalar iterator Pandas UDF with compression
+        @pandas_udf("long", PandasUDFType.SCALAR_ITER)
+        def plus_two(iterator):
+            for s in iterator:
+                yield s + 2
+
+        df = self.spark.range(100)
+        expected = [Row(result=i + 2) for i in range(100)]
+
+        for codec in ["none", "zstd", "lz4"]:
+            with self.subTest(compressionCodec=codec):
+                with 
self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
+                    result = 
df.select(plus_two("id").alias("result")).collect()
+                    self.assertEqual(expected, result)
+
 
 class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala
index 6c5799bd241b..c04bae07f67d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow
 import java.io.DataOutputStream
 
 import org.apache.arrow.memory.BufferAllocator
-import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
 import org.apache.arrow.vector.ipc.ArrowStreamWriter
 
 import org.apache.spark.TaskContext
@@ -34,6 +34,7 @@ case class ArrowWriterWrapper(
     var arrowWriter: SparkArrowWriter,
     var root: VectorSchemaRoot,
     var allocator: BufferAllocator,
+    var unloader: VectorUnloader,
     context: TaskContext) {
   @volatile var isClosed = false
 
@@ -58,6 +59,7 @@ case class ArrowWriterWrapper(
       arrowWriter = null
       root = null
       allocator = null
+      unloader = null
     }
   }
 }
@@ -77,8 +79,10 @@ object ArrowWriterWrapper {
       s"stdout writer for $allocatorOwner", 0, Long.MaxValue)
     val root = VectorSchemaRoot.create(arrowSchema, allocator)
     val arrowWriter = SparkArrowWriter.create(root)
+
     val streamWriter = new ArrowStreamWriter(root, null, dataOut)
     streamWriter.start()
-    ArrowWriterWrapper(streamWriter, arrowWriter, root, allocator, context)
+    // Unloader will be set by the caller after creation
+    ArrowWriterWrapper(streamWriter, arrowWriter, root, allocator, null, 
context)
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b8907629ad37..a367971a2fc8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4001,6 +4001,19 @@ object SQLConf {
       .checkValues(Set("none", "zstd", "lz4"))
       .createWithDefault("none")
 
+  val ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL =
+    buildConf("spark.sql.execution.arrow.zstd.compressionLevel")
+      .doc("Compression level for Zstandard (zstd) codec when compressing 
Arrow IPC data. " +
+        "This config is only used when 
spark.sql.execution.arrow.compressionCodec is set to " +
+        "'zstd'. Valid values are integers from 1 (fastest, lowest 
compression) to 22 " +
+        "(slowest, highest compression). The default value 3 provides a good 
balance between " +
+        "compression speed and compression ratio.")
+      .version("4.1.0")
+      .intConf
+      .checkValue(level => level >= 1 && level <= 22,
+        "Zstd compression level must be between 1 and 22")
+      .createWithDefault(3)
+
   val ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH =
     
buildConf("spark.sql.execution.arrow.transformWithStateInPySpark.maxStateRecordsPerBatch")
       .doc("When using TransformWithState in PySpark (both Python Row and 
Pandas), limit " +
@@ -7348,6 +7361,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def arrowCompressionCodec: String = 
getConf(ARROW_EXECUTION_COMPRESSION_CODEC)
 
+  def arrowZstdCompressionLevel: Int = 
getConf(ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL)
+
   def arrowTransformWithStateInPySparkMaxStateRecordsPerBatch: Int =
     getConf(ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 7f260bd2efd0..8b031af14e8b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -31,6 +31,7 @@ import org.apache.arrow.vector.compression.{CompressionCodec, 
NoCompressionCodec
 import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, 
ReadChannel, WriteChannel}
 import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, 
MessageSerializer}
 
+import org.apache.spark.SparkException
 import org.apache.spark.TaskContext
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.util.JavaUtils
@@ -102,15 +103,16 @@ private[sql] object ArrowConverters extends Logging {
     private val codec = compressionCodecName match {
       case "none" => NoCompressionCodec.INSTANCE
       case "zstd" =>
+        val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
         val factory = CompressionCodec.Factory.INSTANCE
-        val codecType = new ZstdCompressionCodec().getCodecType()
+        val codecType = new 
ZstdCompressionCodec(compressionLevel).getCodecType()
         factory.createCodec(codecType)
       case "lz4" =>
         val factory = CompressionCodec.Factory.INSTANCE
         val codecType = new Lz4CompressionCodec().getCodecType()
         factory.createCodec(codecType)
       case other =>
-        throw new IllegalArgumentException(
+        throw SparkException.internalError(
           s"Unsupported Arrow compression codec: $other. Supported values: 
none, zstd, lz4")
     }
     protected val unloader = new VectorUnloader(root, true, codec, true)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 00eb9039d05c..50013e533819 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -20,7 +20,11 @@ package org.apache.spark.sql.execution.python
 import java.io.DataOutputStream
 import java.util
 
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
+import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector.compression.{CompressionCodec, 
NoCompressionCodec}
+
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
 import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD, PythonWorker}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.arrow.ArrowWriterWrapper
@@ -76,6 +80,27 @@ class CoGroupedArrowPythonRunner(
     if (v > 0) v else Int.MaxValue
   }
   private val maxBytesPerBatch: Long = SQLConf.get.arrowMaxBytesPerBatch
+  private val compressionCodecName: String = SQLConf.get.arrowCompressionCodec
+
+  // Helper method to create VectorUnloader with compression
+  private def createUnloader(root: VectorSchemaRoot): VectorUnloader = {
+    val codec = compressionCodecName match {
+      case "none" => NoCompressionCodec.INSTANCE
+      case "zstd" =>
+        val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
+        val factory = CompressionCodec.Factory.INSTANCE
+        val codecType = new 
ZstdCompressionCodec(compressionLevel).getCodecType()
+        factory.createCodec(codecType)
+      case "lz4" =>
+        val factory = CompressionCodec.Factory.INSTANCE
+        val codecType = new Lz4CompressionCodec().getCodecType()
+        factory.createCodec(codecType)
+      case other =>
+        throw SparkException.internalError(
+          s"Unsupported Arrow compression codec: $other. Supported values: 
none, zstd, lz4")
+    }
+    new VectorUnloader(root, true, codec, true)
+  }
 
   protected def newWriter(
       env: SparkEnv,
@@ -136,13 +161,17 @@ class CoGroupedArrowPythonRunner(
             leftGroupArrowWriter = 
ArrowWriterWrapper.createAndStartArrowWriter(leftSchema,
               timeZoneId, pythonExec + " (left)", errorOnDuplicatedFieldNames 
= true,
               largeVarTypes, dataOut, context)
+            // Set the unloader with compression after creating the writer
+            leftGroupArrowWriter.unloader = 
createUnloader(leftGroupArrowWriter.root)
           }
           numRowsInBatch = BatchedPythonArrowInput.writeSizedBatch(
             leftGroupArrowWriter.arrowWriter,
             leftGroupArrowWriter.streamWriter,
             nextBatchInLeftGroup,
             maxBytesPerBatch,
-            maxRecordsPerBatch)
+            maxRecordsPerBatch,
+            leftGroupArrowWriter.unloader,
+            dataOut)
 
           if (!nextBatchInLeftGroup.hasNext) {
             leftGroupArrowWriter.streamWriter.end()
@@ -155,13 +184,17 @@ class CoGroupedArrowPythonRunner(
             rightGroupArrowWriter = 
ArrowWriterWrapper.createAndStartArrowWriter(rightSchema,
               timeZoneId, pythonExec + " (right)", errorOnDuplicatedFieldNames 
= true,
               largeVarTypes, dataOut, context)
+            // Set the unloader with compression after creating the writer
+            rightGroupArrowWriter.unloader = 
createUnloader(rightGroupArrowWriter.root)
           }
           numRowsInBatch = BatchedPythonArrowInput.writeSizedBatch(
             rightGroupArrowWriter.arrowWriter,
             rightGroupArrowWriter.streamWriter,
             nextBatchInRightGroup,
             maxBytesPerBatch,
-            maxRecordsPerBatch)
+            maxRecordsPerBatch,
+            rightGroupArrowWriter.unloader,
+            dataOut)
           if (!nextBatchInRightGroup.hasNext) {
             rightGroupArrowWriter.streamWriter.end()
             rightGroupArrowWriter.close()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index b2ec96c5b29f..f77b0a9342b0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -17,11 +17,16 @@
 package org.apache.spark.sql.execution.python
 
 import java.io.DataOutputStream
+import java.nio.channels.Channels
 
-import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
+import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector.compression.{CompressionCodec, 
NoCompressionCodec}
 import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.apache.arrow.vector.ipc.WriteChannel
+import org.apache.arrow.vector.ipc.message.MessageSerializer
 
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
 import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.arrow
@@ -70,6 +75,26 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
   protected val allocator =
     ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for 
$pythonExec", 0, Long.MaxValue)
   protected val root = VectorSchemaRoot.create(arrowSchema, allocator)
+
+  // Create compression codec based on config
+  private val compressionCodecName = SQLConf.get.arrowCompressionCodec
+  private val codec = compressionCodecName match {
+    case "none" => NoCompressionCodec.INSTANCE
+    case "zstd" =>
+      val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
+      val factory = CompressionCodec.Factory.INSTANCE
+      val codecType = new ZstdCompressionCodec(compressionLevel).getCodecType()
+      factory.createCodec(codecType)
+    case "lz4" =>
+      val factory = CompressionCodec.Factory.INSTANCE
+      val codecType = new Lz4CompressionCodec().getCodecType()
+      factory.createCodec(codecType)
+    case other =>
+      throw SparkException.internalError(
+        s"Unsupported Arrow compression codec: $other. Supported values: none, 
zstd, lz4")
+  }
+  protected val unloader = new VectorUnloader(root, true, codec, true)
+
   protected var writer: ArrowStreamWriter = _
 
   protected def close(): Unit = {
@@ -137,7 +162,14 @@ private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[In
       }
 
       arrowWriter.finish()
-      writer.writeBatch()
+      // Use unloader to get compressed batch and write it manually
+      val batch = unloader.getRecordBatch()
+      try {
+        val writeChannel = new WriteChannel(Channels.newChannel(dataOut))
+        MessageSerializer.serialize(writeChannel, batch)
+      } finally {
+        batch.close()
+      }
       arrowWriter.reset()
       val deltaData = dataOut.size() - startData
       pythonMetrics("pythonDataSent") += deltaData
@@ -169,7 +201,8 @@ private[python] trait BatchedPythonArrowInput extends 
BasicPythonArrowInput {
       val startData = dataOut.size()
 
       val numRowsInBatch = BatchedPythonArrowInput.writeSizedBatch(
-        arrowWriter, writer, nextBatchStart, maxBytesPerBatch, 
maxRecordsPerBatch)
+        arrowWriter, writer, nextBatchStart, maxBytesPerBatch, 
maxRecordsPerBatch, unloader,
+        dataOut)
       assert(0 < numRowsInBatch && numRowsInBatch <= maxRecordsPerBatch, 
numRowsInBatch)
 
       val deltaData = dataOut.size() - startData
@@ -209,7 +242,9 @@ private[python] object BatchedPythonArrowInput {
       writer: ArrowStreamWriter,
       rowIter: Iterator[InternalRow],
       maxBytesPerBatch: Long,
-      maxRecordsPerBatch: Int): Int = {
+      maxRecordsPerBatch: Int,
+      unloader: VectorUnloader,
+      dataOut: DataOutputStream): Int = {
     var numRowsInBatch: Int = 0
 
     def underBatchSizeLimit: Boolean =
@@ -221,7 +256,14 @@ private[python] object BatchedPythonArrowInput {
       numRowsInBatch += 1
     }
     arrowWriter.finish()
-    writer.writeBatch()
+    // Use unloader to get compressed batch and write it manually
+    val batch = unloader.getRecordBatch()
+    try {
+      val writeChannel = new WriteChannel(Channels.newChannel(dataOut))
+      MessageSerializer.serialize(writeChannel, batch)
+    } finally {
+      batch.close()
+    }
     arrowWriter.reset()
     numRowsInBatch
   }
@@ -231,6 +273,26 @@ private[python] object BatchedPythonArrowInput {
  * Enables an optimization that splits each group into the sized batches.
  */
 private[python] trait GroupedPythonArrowInput { self: 
RowInputArrowPythonRunner =>
+
+  // Helper method to create VectorUnloader with compression for grouped 
operations
+  private def createUnloaderForGroup(root: VectorSchemaRoot): VectorUnloader = 
{
+    val codec = SQLConf.get.arrowCompressionCodec match {
+      case "none" => NoCompressionCodec.INSTANCE
+      case "zstd" =>
+        val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
+        val factory = CompressionCodec.Factory.INSTANCE
+        val codecType = new 
ZstdCompressionCodec(compressionLevel).getCodecType()
+        factory.createCodec(codecType)
+      case "lz4" =>
+        val factory = CompressionCodec.Factory.INSTANCE
+        val codecType = new Lz4CompressionCodec().getCodecType()
+        factory.createCodec(codecType)
+      case other =>
+        throw SparkException.internalError(
+          s"Unsupported Arrow compression codec: $other. Supported values: 
none, zstd, lz4")
+    }
+    new VectorUnloader(root, true, codec, true)
+  }
   protected override def newWriter(
       env: SparkEnv,
       worker: PythonWorker,
@@ -255,13 +317,16 @@ private[python] trait GroupedPythonArrowInput { self: 
RowInputArrowPythonRunner
             writer = ArrowWriterWrapper.createAndStartArrowWriter(
               schema, timeZoneId, pythonExec,
               errorOnDuplicatedFieldNames, largeVarTypes, dataOut, context)
+            // Set the unloader with compression after creating the writer
+            writer.unloader = createUnloaderForGroup(writer.root)
             nextBatchStart = inputIterator.next()
           }
         }
         if (nextBatchStart.hasNext) {
           val startData = dataOut.size()
           val numRowsInBatch: Int = 
BatchedPythonArrowInput.writeSizedBatch(writer.arrowWriter,
-            writer.streamWriter, nextBatchStart, maxBytesPerBatch, 
maxRecordsPerBatch)
+            writer.streamWriter, nextBatchStart, maxBytesPerBatch, 
maxRecordsPerBatch,
+            writer.unloader, dataOut)
           if (!nextBatchStart.hasNext) {
             writer.streamWriter.end()
             // We don't need a try catch block here as the close() method is 
registered with


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to