This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8bf664071856 [SPARK-53638][SS][PYTHON] Limit the byte size of arrow 
batch for TWS to avoid OOM
8bf664071856 is described below

commit 8bf664071856ace125c6ce1c2eff388a93536ca7
Author: zeruibao <[email protected]>
AuthorDate: Wed Oct 8 10:14:53 2025 +0800

    [SPARK-53638][SS][PYTHON] Limit the byte size of arrow batch for TWS to 
avoid OOM
    
    ### What changes were proposed in this pull request?
    Limit the byte size of Arrow batch for TWS to avoid OOM.
    
    ### Why are the changes needed?
    On the Python worker side, when using the Pandas execution path, Arrow 
batches must be converted into Pandas DataFrames in memory. If an Arrow batch 
is too large, this conversion can lead to OOM errors in the Python worker. To 
mitigate this risk, we need to enforce a limit on the byte size of each Arrow 
batch. Similarly, processing the Pandas DataFrame inside `handleInputRows` also 
occurs entirely in memory, so applying a size limit to the DataFrame itself 
further helps prevent OOM issues.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #52391 from 
zeruibao/zeruibao/SPARK-53638-limit-the-byte-size-of-arrow-batch.
    
    Lead-authored-by: zeruibao <[email protected]>
    Co-authored-by: Zerui Bao <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/pandas/serializers.py           | 32 +++++++-
 .../helper/helper_pandas_transform_with_state.py   | 43 ++++++++++
 .../pandas/test_pandas_transform_with_state.py     | 95 ++++++++++++++++++++++
 python/pyspark/worker.py                           | 12 +++
 .../ApplyInPandasWithStatePythonRunner.scala       | 10 ++-
 .../streaming/ApplyInPandasWithStateWriter.scala   |  7 +-
 .../streaming/BaseStreamingArrowWriter.scala       | 12 ++-
 .../TransformWithStateInPySparkPythonRunner.scala  | 18 +++-
 .../streaming/BaseStreamingArrowWriterSuite.scala  | 35 +++++++-
 9 files changed, 251 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 35eeb11861a6..bff7f337314b 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1672,6 +1672,7 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
         safecheck,
         assign_cols_by_name,
         arrow_max_records_per_batch,
+        arrow_max_bytes_per_batch,
         int_to_decimal_coercion_enabled,
     ):
         super(TransformWithStateInPandasSerializer, self).__init__(
@@ -1682,7 +1683,11 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
             arrow_cast=True,
         )
         self.arrow_max_records_per_batch = arrow_max_records_per_batch
+        self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch
         self.key_offsets = None
+        self.average_arrow_row_size = 0
+        self.total_bytes = 0
+        self.total_rows = 0
 
     def load_stream(self, stream):
         """
@@ -1711,6 +1716,18 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
 
             def row_stream():
                 for batch in batches:
+                    # Short circuit batch size calculation if the batch size is
+                    # unlimited as computing batch size is computationally 
expensive.
+                    if self.arrow_max_bytes_per_batch != 2**31 - 1 and 
batch.num_rows > 0:
+                        batch_bytes = sum(
+                            buf.size
+                            for col in batch.columns
+                            for buf in col.buffers()
+                            if buf is not None
+                        )
+                        self.total_bytes += batch_bytes
+                        self.total_rows += batch.num_rows
+                        self.average_arrow_row_size = self.total_bytes / 
self.total_rows
                     data_pandas = [
                         self.arrow_to_pandas(c, i)
                         for i, c in 
enumerate(pa.Table.from_batches([batch]).itercolumns())
@@ -1720,8 +1737,17 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
                         yield (batch_key, row)
 
             for batch_key, group_rows in groupby(row_stream(), key=lambda x: 
x[0]):
-                df = pd.DataFrame([row for _, row in group_rows])
-                yield (batch_key, df)
+                rows = []
+                for _, row in group_rows:
+                    rows.append(row)
+                    if (
+                        len(rows) >= self.arrow_max_records_per_batch
+                        or len(rows) * self.average_arrow_row_size >= 
self.arrow_max_bytes_per_batch
+                    ):
+                        yield (batch_key, pd.DataFrame(rows))
+                        rows = []
+                if rows:
+                    yield (batch_key, pd.DataFrame(rows))
 
         _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
         data_batches = generate_data_batches(_batches)
@@ -1766,6 +1792,7 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
         safecheck,
         assign_cols_by_name,
         arrow_max_records_per_batch,
+        arrow_max_bytes_per_batch,
         int_to_decimal_coercion_enabled,
     ):
         super(TransformWithStateInPandasInitStateSerializer, self).__init__(
@@ -1773,6 +1800,7 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
             safecheck,
             assign_cols_by_name,
             arrow_max_records_per_batch,
+            arrow_max_bytes_per_batch,
             int_to_decimal_coercion_enabled,
         )
         self.init_key_offsets = None
diff --git 
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
index a35bae88bedb..38119390940f 100644
--- 
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
+++ 
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
@@ -237,6 +237,16 @@ class 
StatefulProcessorCompositeTypeFactory(StatefulProcessorFactory):
         return RowStatefulProcessorCompositeType()
 
 
+class ChunkCountProcessorFactory(StatefulProcessorFactory):
+    def pandas(self):
+        return PandasChunkCountProcessor()
+
+
+class ChunkCountProcessorWithInitialStateFactory(StatefulProcessorFactory):
+    def pandas(self):
+        return PandasChunkCountWithInitialStateProcessor()
+
+
 # StatefulProcessor implementations
 
 
@@ -1830,3 +1840,36 @@ class 
RowStatefulProcessorCompositeType(StatefulProcessor):
 
     def close(self) -> None:
         pass
+
+
+class PandasChunkCountProcessor(StatefulProcessor):
+    def init(self, handle: StatefulProcessorHandle) -> None:
+        pass
+
+    def handleInputRows(self, key, rows, timerValues) -> 
Iterator[pd.DataFrame]:
+        chunk_count = 0
+        for _ in rows:
+            chunk_count += 1
+        yield pd.DataFrame({"id": [key[0]], "chunkCount": [chunk_count]})
+
+    def close(self) -> None:
+        pass
+
+
+class PandasChunkCountWithInitialStateProcessor(StatefulProcessor):
+    def init(self, handle: StatefulProcessorHandle) -> None:
+        state_schema = StructType([StructField("value", IntegerType(), True)])
+        self.value_state = handle.getValueState("value_state", state_schema)
+
+    def handleInputRows(self, key, rows, timerValues) -> 
Iterator[pd.DataFrame]:
+        chunk_count = 0
+        for _ in rows:
+            chunk_count += 1
+        yield pd.DataFrame({"id": [key[0]], "chunkCount": [chunk_count]})
+
+    def handleInitialState(self, key, initialState, timerValues) -> None:
+        init_val = initialState.at[0, "initVal"]
+        self.value_state.update((init_val,))
+
+    def close(self) -> None:
+        pass
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index af44093c512d..576c0cf6e6e1 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -70,6 +70,8 @@ from 
pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import (
     UpcastProcessorFactory,
     MinEventTimeStatefulProcessorFactory,
     StatefulProcessorCompositeTypeFactory,
+    ChunkCountProcessorFactory,
+    ChunkCountProcessorWithInitialStateFactory,
 )
 
 
@@ -1864,6 +1866,99 @@ class TransformWithStateTestsMixin:
                     .collect()
                 )
 
+    def test_transform_with_state_with_bytes_limit(self):
+        if not self.use_pandas():
+            return
+
+        def make_check_results(expected_per_batch):
+            def check_results(batch_df, batch_id):
+                batch_df.collect()
+                if batch_id == 0:
+                    assert set(batch_df.sort("id").collect()) == 
expected_per_batch[0]
+                else:
+                    assert set(batch_df.sort("id").collect()) == 
expected_per_batch[1]
+
+            return check_results
+
+        result_with_small_limit = [
+            {
+                Row(id="0", chunkCount=2),
+                Row(id="1", chunkCount=2),
+            },
+            {
+                Row(id="0", chunkCount=3),
+                Row(id="1", chunkCount=2),
+            },
+        ]
+
+        result_with_large_limit = [
+            {
+                Row(id="0", chunkCount=1),
+                Row(id="1", chunkCount=1),
+            },
+            {
+                Row(id="0", chunkCount=1),
+                Row(id="1", chunkCount=1),
+            },
+        ]
+
+        data = [("0", 789), ("3", 987)]
+        initial_state = self.spark.createDataFrame(data, "id string, initVal 
int").groupBy("id")
+
+        with self.sql_conf(
+            # Set it to a very small number so that every row would be a 
separate pandas df
+            {"spark.sql.execution.arrow.maxBytesPerBatch": "2"}
+        ):
+            self._test_transform_with_state_basic(
+                ChunkCountProcessorFactory(),
+                make_check_results(result_with_small_limit),
+                output_schema=StructType(
+                    [
+                        StructField("id", StringType(), True),
+                        StructField("chunkCount", IntegerType(), True),
+                    ]
+                ),
+            )
+
+            self._test_transform_with_state_basic(
+                ChunkCountProcessorWithInitialStateFactory(),
+                make_check_results(result_with_small_limit),
+                initial_state=initial_state,
+                output_schema=StructType(
+                    [
+                        StructField("id", StringType(), True),
+                        StructField("chunkCount", IntegerType(), True),
+                    ]
+                ),
+            )
+
+        with self.sql_conf(
+            # Set it to a very large number so that every row would be in the 
same pandas df
+            {"spark.sql.execution.arrow.maxBytesPerBatch": "100000"}
+        ):
+            self._test_transform_with_state_basic(
+                ChunkCountProcessorFactory(),
+                make_check_results(result_with_large_limit),
+                output_schema=StructType(
+                    [
+                        StructField("id", StringType(), True),
+                        StructField("chunkCount", IntegerType(), True),
+                    ]
+                ),
+            )
+
+            self._test_transform_with_state_basic(
+                ChunkCountProcessorWithInitialStateFactory(),
+                make_check_results(result_with_large_limit),
+                initial_state=initial_state,
+                output_schema=StructType(
+                    [
+                        StructField("id", StringType(), True),
+                        StructField("chunkCount", IntegerType(), True),
+                    ]
+                ),
+            )
+
 
 @unittest.skipIf(
     not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0",
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index a15d59f04e1e..e06034bada9f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -2646,11 +2646,17 @@ def read_udfs(pickleSer, infile, eval_type):
             )
             arrow_max_records_per_batch = int(arrow_max_records_per_batch)
 
+            arrow_max_bytes_per_batch = runner_conf.get(
+                "spark.sql.execution.arrow.maxBytesPerBatch", 2**31 - 1
+            )
+            arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch)
+
             ser = TransformWithStateInPandasSerializer(
                 timezone,
                 safecheck,
                 _assign_cols_by_name,
                 arrow_max_records_per_batch,
+                arrow_max_bytes_per_batch,
                 
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
             )
         elif eval_type == 
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF:
@@ -2659,11 +2665,17 @@ def read_udfs(pickleSer, infile, eval_type):
             )
             arrow_max_records_per_batch = int(arrow_max_records_per_batch)
 
+            arrow_max_bytes_per_batch = runner_conf.get(
+                "spark.sql.execution.arrow.maxBytesPerBatch", 2**31 - 1
+            )
+            arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch)
+
             ser = TransformWithStateInPandasInitStateSerializer(
                 timezone,
                 safecheck,
                 _assign_cols_by_name,
                 arrow_max_records_per_batch,
+                arrow_max_bytes_per_batch,
                 
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
             )
         elif eval_type == 
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF:
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
index b6f6a4cbc30b..51d9f6f523a2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
@@ -106,12 +106,14 @@ class ApplyInPandasWithStatePythonRunner(
   }
 
   private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+  private val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch
 
   // applyInPandasWithState has its own mechanism to construct the Arrow 
RecordBatch instance.
   // Configurations are both applied to executor and Python worker, set them 
to the worker conf
   // to let Python worker read the config properly.
   override protected val workerConf: Map[String, String] = initialWorkerConf +
-    (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> 
arrowMaxRecordsPerBatch.toString)
+    (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> 
arrowMaxRecordsPerBatch.toString) +
+    (SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> 
arrowMaxBytesPerBatch.toString)
 
   private val stateRowDeserializer = stateEncoder.createDeserializer()
 
@@ -142,7 +144,11 @@ class ApplyInPandasWithStatePythonRunner(
       dataOut: DataOutputStream,
       inputIterator: Iterator[InType]): Boolean = {
     if (pandasWriter == null) {
-      pandasWriter = new ApplyInPandasWithStateWriter(root, writer, 
arrowMaxRecordsPerBatch)
+      pandasWriter = new ApplyInPandasWithStateWriter(
+        root,
+        writer,
+        arrowMaxRecordsPerBatch,
+        arrowMaxBytesPerBatch)
     }
     if (inputIterator.hasNext) {
       val startData = dataOut.size()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala
index f55ca749112f..cd83270bb4c4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala
@@ -50,8 +50,9 @@ import org.apache.spark.unsafe.types.UTF8String
 class ApplyInPandasWithStateWriter(
     root: VectorSchemaRoot,
     writer: ArrowStreamWriter,
-    arrowMaxRecordsPerBatch: Int)
-  extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) {
+    arrowMaxRecordsPerBatch: Int,
+    arrowMaxBytesPerBatch: Long)
+  extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch, 
arrowMaxBytesPerBatch) {
 
   import ApplyInPandasWithStateWriter._
 
@@ -144,7 +145,7 @@ class ApplyInPandasWithStateWriter(
 
     // If it exceeds the condition of batch (number of records) once the all 
data is received for
     // same group, finalize and construct a new batch.
-    if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
+    if (isBatchSizeLimitReached) {
       finalizeCurrentArrowBatch()
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
index ba8b2c3ac7da..f0371cafb72a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
@@ -32,6 +32,7 @@ class BaseStreamingArrowWriter(
     root: VectorSchemaRoot,
     writer: ArrowStreamWriter,
     arrowMaxRecordsPerBatch: Int,
+    arrowMaxBytesPerBatch: Long,
     arrowWriterForTest: ArrowWriter = null) {
   protected val arrowWriterForData: ArrowWriter = if (arrowWriterForTest == 
null) {
     ArrowWriter.create(root)
@@ -54,7 +55,7 @@ class BaseStreamingArrowWriter(
     // If it exceeds the condition of batch (number of records) and there is 
more data for the
     // same group, finalize and construct a new batch.
 
-    val isCurrentBatchFull = totalNumRowsForBatch >= arrowMaxRecordsPerBatch
+    val isCurrentBatchFull = isBatchSizeLimitReached
     if (isCurrentBatchFull) {
       finalizeCurrentChunk(isLastChunkForGroup = false)
       finalizeCurrentArrowBatch()
@@ -84,4 +85,13 @@ class BaseStreamingArrowWriter(
   protected def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = {
     numRowsForCurrentChunk = 0
   }
+
+  protected def isBatchSizeLimitReached: Boolean = {
+    // If we have either reached the records or bytes limit
+    totalNumRowsForBatch >= arrowMaxRecordsPerBatch ||
+      // Short circuit batch size calculation if the batch size is unlimited 
as computing batch
+      // size is computationally expensive.
+      ((arrowMaxBytesPerBatch != Int.MaxValue)
+        && (arrowWriterForData.sizeInBytes() >= arrowMaxBytesPerBatch))
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index f0df3e1f7d15..42d4ad68c29a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -75,7 +75,12 @@ class TransformWithStateInPySparkPythonRunner(
       dataOut: DataOutputStream,
       inputIterator: Iterator[InType]): Boolean = {
     if (pandasWriter == null) {
-      pandasWriter = new BaseStreamingArrowWriter(root, writer, 
arrowMaxRecordsPerBatch)
+      pandasWriter = new BaseStreamingArrowWriter(
+        root,
+        writer,
+        arrowMaxRecordsPerBatch,
+        arrowMaxBytesPerBatch
+      )
     }
 
     // If we don't have data left for the current group, move to the next 
group.
@@ -145,7 +150,12 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
       dataOut: DataOutputStream,
       inputIterator: Iterator[GroupedInType]): Boolean = {
     if (pandasWriter == null) {
-      pandasWriter = new BaseStreamingArrowWriter(root, writer, 
arrowMaxRecordsPerBatch)
+      pandasWriter = new BaseStreamingArrowWriter(
+        root,
+        writer,
+        arrowMaxRecordsPerBatch,
+        arrowMaxBytesPerBatch
+      )
     }
 
     if (inputIterator.hasNext) {
@@ -200,9 +210,11 @@ abstract class 
TransformWithStateInPySparkPythonBaseRunner[I](
 
   protected val sqlConf = SQLConf.get
   protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+  protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch
 
   override protected val workerConf: Map[String, String] = initialWorkerConf +
-    (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> 
arrowMaxRecordsPerBatch.toString)
+    (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> 
arrowMaxRecordsPerBatch.toString) +
+    (SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> 
arrowMaxBytesPerBatch.toString)
 
   // Use lazy val to initialize the fields before these are accessed in 
[[PythonArrowInput]]'s
   // constructor.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
index f0fee2b9b0d9..fc10a102b4f5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
@@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.python.streaming
 
 import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.ipc.ArrowStreamWriter
-import org.mockito.Mockito.{mock, never, times, verify}
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito.{mock, never, times, verify, when}
 import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark.SparkFunSuite
@@ -28,6 +29,7 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter
 class BaseStreamingArrowWriterSuite extends SparkFunSuite with 
BeforeAndAfterEach {
   // Setting the maximum number of records per batch to 2 to make test easier.
   val arrowMaxRecordsPerBatch = 2
+  val arrowMaxBytesPerBatch = Int.MaxValue
   var transformWithStateInPySparkWriter: BaseStreamingArrowWriter = _
   var arrowWriter: ArrowWriter = _
   var writer: ArrowStreamWriter = _
@@ -37,7 +39,7 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite 
with BeforeAndAfterEac
     writer = mock(classOf[ArrowStreamWriter])
     arrowWriter = mock(classOf[ArrowWriter])
     transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(
-      root, writer, arrowMaxRecordsPerBatch, arrowWriter)
+      root, writer, arrowMaxRecordsPerBatch, arrowMaxBytesPerBatch, 
arrowWriter)
   }
 
   test("test writeRow") {
@@ -64,4 +66,33 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite 
with BeforeAndAfterEac
     verify(writer).writeBatch()
     verify(arrowWriter).reset()
   }
+
+  test("test maxBytesPerBatch can work") {
+    val root: VectorSchemaRoot = mock(classOf[VectorSchemaRoot])
+
+    var sizeCounter = 0
+    when(arrowWriter.write(any[InternalRow])).thenAnswer { _ =>
+      sizeCounter += 1
+      ()
+    }
+
+    when(arrowWriter.sizeInBytes()).thenAnswer { _ => sizeCounter }
+
+    // Set arrowMaxBytesPerBatch to 1
+    transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(
+      root, writer, arrowMaxRecordsPerBatch, 1, arrowWriter)
+    val dataRow = mock(classOf[InternalRow])
+    transformWithStateInPySparkWriter.writeRow(dataRow)
+    verify(arrowWriter).write(dataRow)
+    verify(writer, never()).writeBatch()
+    transformWithStateInPySparkWriter.writeRow(dataRow)
+    verify(arrowWriter, times(2)).write(dataRow)
+    // Write batch is called since we reach arrowMaxBytesPerBatch
+    verify(writer).writeBatch()
+    transformWithStateInPySparkWriter.finalizeCurrentArrowBatch()
+    verify(arrowWriter, times(2)).finish()
+    // The second record would be written
+    verify(writer, times(2)).writeBatch()
+    verify(arrowWriter, times(2)).reset()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to