This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8bf664071856 [SPARK-53638][SS][PYTHON] Limit the byte size of arrow
batch for TWS to avoid OOM
8bf664071856 is described below
commit 8bf664071856ace125c6ce1c2eff388a93536ca7
Author: zeruibao <[email protected]>
AuthorDate: Wed Oct 8 10:14:53 2025 +0800
[SPARK-53638][SS][PYTHON] Limit the byte size of arrow batch for TWS to
avoid OOM
### What changes were proposed in this pull request?
Limit the byte size of Arrow batch for TWS to avoid OOM.
### Why are the changes needed?
On the Python worker side, when using the Pandas execution path, Arrow
batches must be converted into Pandas DataFrames in memory. If an Arrow batch
is too large, this conversion can lead to OOM errors in the Python worker. To
mitigate this risk, we need to enforce a limit on the byte size of each Arrow
batch. Similarly, processing the Pandas DataFrame inside `handleInputRows` also
occurs entirely in memory, so applying a size limit to the DataFrame itself
further helps prevent OOM issues.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52391 from
zeruibao/zeruibao/SPARK-53638-limit-the-byte-size-of-arrow-batch.
Lead-authored-by: zeruibao <[email protected]>
Co-authored-by: Zerui Bao <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/pandas/serializers.py | 32 +++++++-
.../helper/helper_pandas_transform_with_state.py | 43 ++++++++++
.../pandas/test_pandas_transform_with_state.py | 95 ++++++++++++++++++++++
python/pyspark/worker.py | 12 +++
.../ApplyInPandasWithStatePythonRunner.scala | 10 ++-
.../streaming/ApplyInPandasWithStateWriter.scala | 7 +-
.../streaming/BaseStreamingArrowWriter.scala | 12 ++-
.../TransformWithStateInPySparkPythonRunner.scala | 18 +++-
.../streaming/BaseStreamingArrowWriterSuite.scala | 35 +++++++-
9 files changed, 251 insertions(+), 13 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 35eeb11861a6..bff7f337314b 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1672,6 +1672,7 @@ class
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
safecheck,
assign_cols_by_name,
arrow_max_records_per_batch,
+ arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled,
):
super(TransformWithStateInPandasSerializer, self).__init__(
@@ -1682,7 +1683,11 @@ class
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
arrow_cast=True,
)
self.arrow_max_records_per_batch = arrow_max_records_per_batch
+ self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch
self.key_offsets = None
+ self.average_arrow_row_size = 0
+ self.total_bytes = 0
+ self.total_rows = 0
def load_stream(self, stream):
"""
@@ -1711,6 +1716,18 @@ class
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
def row_stream():
for batch in batches:
+ # Short circuit batch size calculation if the batch size is
+ # unlimited as computing batch size is computationally
expensive.
+ if self.arrow_max_bytes_per_batch != 2**31 - 1 and
batch.num_rows > 0:
+ batch_bytes = sum(
+ buf.size
+ for col in batch.columns
+ for buf in col.buffers()
+ if buf is not None
+ )
+ self.total_bytes += batch_bytes
+ self.total_rows += batch.num_rows
+ self.average_arrow_row_size = self.total_bytes /
self.total_rows
data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in
enumerate(pa.Table.from_batches([batch]).itercolumns())
@@ -1720,8 +1737,17 @@ class
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
yield (batch_key, row)
for batch_key, group_rows in groupby(row_stream(), key=lambda x:
x[0]):
- df = pd.DataFrame([row for _, row in group_rows])
- yield (batch_key, df)
+ rows = []
+ for _, row in group_rows:
+ rows.append(row)
+ if (
+ len(rows) >= self.arrow_max_records_per_batch
+ or len(rows) * self.average_arrow_row_size >=
self.arrow_max_bytes_per_batch
+ ):
+ yield (batch_key, pd.DataFrame(rows))
+ rows = []
+ if rows:
+ yield (batch_key, pd.DataFrame(rows))
_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)
@@ -1766,6 +1792,7 @@ class
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
safecheck,
assign_cols_by_name,
arrow_max_records_per_batch,
+ arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled,
):
super(TransformWithStateInPandasInitStateSerializer, self).__init__(
@@ -1773,6 +1800,7 @@ class
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
safecheck,
assign_cols_by_name,
arrow_max_records_per_batch,
+ arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled,
)
self.init_key_offsets = None
diff --git
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
index a35bae88bedb..38119390940f 100644
---
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
+++
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
@@ -237,6 +237,16 @@ class
StatefulProcessorCompositeTypeFactory(StatefulProcessorFactory):
return RowStatefulProcessorCompositeType()
+class ChunkCountProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasChunkCountProcessor()
+
+
+class ChunkCountProcessorWithInitialStateFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasChunkCountWithInitialStateProcessor()
+
+
# StatefulProcessor implementations
@@ -1830,3 +1840,36 @@ class
RowStatefulProcessorCompositeType(StatefulProcessor):
def close(self) -> None:
pass
+
+
+class PandasChunkCountProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ pass
+
+ def handleInputRows(self, key, rows, timerValues) ->
Iterator[pd.DataFrame]:
+ chunk_count = 0
+ for _ in rows:
+ chunk_count += 1
+ yield pd.DataFrame({"id": [key[0]], "chunkCount": [chunk_count]})
+
+ def close(self) -> None:
+ pass
+
+
+class PandasChunkCountWithInitialStateProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ self.value_state = handle.getValueState("value_state", state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) ->
Iterator[pd.DataFrame]:
+ chunk_count = 0
+ for _ in rows:
+ chunk_count += 1
+ yield pd.DataFrame({"id": [key[0]], "chunkCount": [chunk_count]})
+
+ def handleInitialState(self, key, initialState, timerValues) -> None:
+ init_val = initialState.at[0, "initVal"]
+ self.value_state.update((init_val,))
+
+ def close(self) -> None:
+ pass
diff --git
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index af44093c512d..576c0cf6e6e1 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -70,6 +70,8 @@ from
pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import (
UpcastProcessorFactory,
MinEventTimeStatefulProcessorFactory,
StatefulProcessorCompositeTypeFactory,
+ ChunkCountProcessorFactory,
+ ChunkCountProcessorWithInitialStateFactory,
)
@@ -1864,6 +1866,99 @@ class TransformWithStateTestsMixin:
.collect()
)
+ def test_transform_with_state_with_bytes_limit(self):
+ if not self.use_pandas():
+ return
+
+ def make_check_results(expected_per_batch):
+ def check_results(batch_df, batch_id):
+ batch_df.collect()
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) ==
expected_per_batch[0]
+ else:
+ assert set(batch_df.sort("id").collect()) ==
expected_per_batch[1]
+
+ return check_results
+
+ result_with_small_limit = [
+ {
+ Row(id="0", chunkCount=2),
+ Row(id="1", chunkCount=2),
+ },
+ {
+ Row(id="0", chunkCount=3),
+ Row(id="1", chunkCount=2),
+ },
+ ]
+
+ result_with_large_limit = [
+ {
+ Row(id="0", chunkCount=1),
+ Row(id="1", chunkCount=1),
+ },
+ {
+ Row(id="0", chunkCount=1),
+ Row(id="1", chunkCount=1),
+ },
+ ]
+
+ data = [("0", 789), ("3", 987)]
+ initial_state = self.spark.createDataFrame(data, "id string, initVal
int").groupBy("id")
+
+ with self.sql_conf(
+ # Set it to a very small number so that every row would be a
separate pandas df
+ {"spark.sql.execution.arrow.maxBytesPerBatch": "2"}
+ ):
+ self._test_transform_with_state_basic(
+ ChunkCountProcessorFactory(),
+ make_check_results(result_with_small_limit),
+ output_schema=StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("chunkCount", IntegerType(), True),
+ ]
+ ),
+ )
+
+ self._test_transform_with_state_basic(
+ ChunkCountProcessorWithInitialStateFactory(),
+ make_check_results(result_with_small_limit),
+ initial_state=initial_state,
+ output_schema=StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("chunkCount", IntegerType(), True),
+ ]
+ ),
+ )
+
+ with self.sql_conf(
+ # Set it to a very large number so that every row would be in the
same pandas df
+ {"spark.sql.execution.arrow.maxBytesPerBatch": "100000"}
+ ):
+ self._test_transform_with_state_basic(
+ ChunkCountProcessorFactory(),
+ make_check_results(result_with_large_limit),
+ output_schema=StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("chunkCount", IntegerType(), True),
+ ]
+ ),
+ )
+
+ self._test_transform_with_state_basic(
+ ChunkCountProcessorWithInitialStateFactory(),
+ make_check_results(result_with_large_limit),
+ initial_state=initial_state,
+ output_schema=StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("chunkCount", IntegerType(), True),
+ ]
+ ),
+ )
+
@unittest.skipIf(
not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0",
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index a15d59f04e1e..e06034bada9f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -2646,11 +2646,17 @@ def read_udfs(pickleSer, infile, eval_type):
)
arrow_max_records_per_batch = int(arrow_max_records_per_batch)
+ arrow_max_bytes_per_batch = runner_conf.get(
+ "spark.sql.execution.arrow.maxBytesPerBatch", 2**31 - 1
+ )
+ arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch)
+
ser = TransformWithStateInPandasSerializer(
timezone,
safecheck,
_assign_cols_by_name,
arrow_max_records_per_batch,
+ arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
elif eval_type ==
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF:
@@ -2659,11 +2665,17 @@ def read_udfs(pickleSer, infile, eval_type):
)
arrow_max_records_per_batch = int(arrow_max_records_per_batch)
+ arrow_max_bytes_per_batch = runner_conf.get(
+ "spark.sql.execution.arrow.maxBytesPerBatch", 2**31 - 1
+ )
+ arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch)
+
ser = TransformWithStateInPandasInitStateSerializer(
timezone,
safecheck,
_assign_cols_by_name,
arrow_max_records_per_batch,
+ arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
elif eval_type ==
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF:
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
index b6f6a4cbc30b..51d9f6f523a2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
@@ -106,12 +106,14 @@ class ApplyInPandasWithStatePythonRunner(
}
private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+ private val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch
// applyInPandasWithState has its own mechanism to construct the Arrow
RecordBatch instance.
// Configurations are both applied to executor and Python worker, set them
to the worker conf
// to let Python worker read the config properly.
override protected val workerConf: Map[String, String] = initialWorkerConf +
- (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key ->
arrowMaxRecordsPerBatch.toString)
+ (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key ->
arrowMaxRecordsPerBatch.toString) +
+ (SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key ->
arrowMaxBytesPerBatch.toString)
private val stateRowDeserializer = stateEncoder.createDeserializer()
@@ -142,7 +144,11 @@ class ApplyInPandasWithStatePythonRunner(
dataOut: DataOutputStream,
inputIterator: Iterator[InType]): Boolean = {
if (pandasWriter == null) {
- pandasWriter = new ApplyInPandasWithStateWriter(root, writer,
arrowMaxRecordsPerBatch)
+ pandasWriter = new ApplyInPandasWithStateWriter(
+ root,
+ writer,
+ arrowMaxRecordsPerBatch,
+ arrowMaxBytesPerBatch)
}
if (inputIterator.hasNext) {
val startData = dataOut.size()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala
index f55ca749112f..cd83270bb4c4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala
@@ -50,8 +50,9 @@ import org.apache.spark.unsafe.types.UTF8String
class ApplyInPandasWithStateWriter(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
- arrowMaxRecordsPerBatch: Int)
- extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) {
+ arrowMaxRecordsPerBatch: Int,
+ arrowMaxBytesPerBatch: Long)
+ extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch,
arrowMaxBytesPerBatch) {
import ApplyInPandasWithStateWriter._
@@ -144,7 +145,7 @@ class ApplyInPandasWithStateWriter(
// If it exceeds the condition of batch (number of records) once the all
data is received for
// same group, finalize and construct a new batch.
- if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
+ if (isBatchSizeLimitReached) {
finalizeCurrentArrowBatch()
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
index ba8b2c3ac7da..f0371cafb72a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
@@ -32,6 +32,7 @@ class BaseStreamingArrowWriter(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
arrowMaxRecordsPerBatch: Int,
+ arrowMaxBytesPerBatch: Long,
arrowWriterForTest: ArrowWriter = null) {
protected val arrowWriterForData: ArrowWriter = if (arrowWriterForTest ==
null) {
ArrowWriter.create(root)
@@ -54,7 +55,7 @@ class BaseStreamingArrowWriter(
// If it exceeds the condition of batch (number of records) and there is
more data for the
// same group, finalize and construct a new batch.
- val isCurrentBatchFull = totalNumRowsForBatch >= arrowMaxRecordsPerBatch
+ val isCurrentBatchFull = isBatchSizeLimitReached
if (isCurrentBatchFull) {
finalizeCurrentChunk(isLastChunkForGroup = false)
finalizeCurrentArrowBatch()
@@ -84,4 +85,13 @@ class BaseStreamingArrowWriter(
protected def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = {
numRowsForCurrentChunk = 0
}
+
+ protected def isBatchSizeLimitReached: Boolean = {
+ // If we have either reached the records or bytes limit
+ totalNumRowsForBatch >= arrowMaxRecordsPerBatch ||
+ // Short circuit batch size calculation if the batch size is unlimited
as computing batch
+ // size is computationally expensive.
+ ((arrowMaxBytesPerBatch != Int.MaxValue)
+ && (arrowWriterForData.sizeInBytes() >= arrowMaxBytesPerBatch))
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index f0df3e1f7d15..42d4ad68c29a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -75,7 +75,12 @@ class TransformWithStateInPySparkPythonRunner(
dataOut: DataOutputStream,
inputIterator: Iterator[InType]): Boolean = {
if (pandasWriter == null) {
- pandasWriter = new BaseStreamingArrowWriter(root, writer,
arrowMaxRecordsPerBatch)
+ pandasWriter = new BaseStreamingArrowWriter(
+ root,
+ writer,
+ arrowMaxRecordsPerBatch,
+ arrowMaxBytesPerBatch
+ )
}
// If we don't have data left for the current group, move to the next
group.
@@ -145,7 +150,12 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
dataOut: DataOutputStream,
inputIterator: Iterator[GroupedInType]): Boolean = {
if (pandasWriter == null) {
- pandasWriter = new BaseStreamingArrowWriter(root, writer,
arrowMaxRecordsPerBatch)
+ pandasWriter = new BaseStreamingArrowWriter(
+ root,
+ writer,
+ arrowMaxRecordsPerBatch,
+ arrowMaxBytesPerBatch
+ )
}
if (inputIterator.hasNext) {
@@ -200,9 +210,11 @@ abstract class
TransformWithStateInPySparkPythonBaseRunner[I](
protected val sqlConf = SQLConf.get
protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+ protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch
override protected val workerConf: Map[String, String] = initialWorkerConf +
- (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key ->
arrowMaxRecordsPerBatch.toString)
+ (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key ->
arrowMaxRecordsPerBatch.toString) +
+ (SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key ->
arrowMaxBytesPerBatch.toString)
// Use lazy val to initialize the fields before these are accessed in
[[PythonArrowInput]]'s
// constructor.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
index f0fee2b9b0d9..fc10a102b4f5 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
@@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.python.streaming
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
-import org.mockito.Mockito.{mock, never, times, verify}
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito.{mock, never, times, verify, when}
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.SparkFunSuite
@@ -28,6 +29,7 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter
class BaseStreamingArrowWriterSuite extends SparkFunSuite with
BeforeAndAfterEach {
// Setting the maximum number of records per batch to 2 to make test easier.
val arrowMaxRecordsPerBatch = 2
+ val arrowMaxBytesPerBatch = Int.MaxValue
var transformWithStateInPySparkWriter: BaseStreamingArrowWriter = _
var arrowWriter: ArrowWriter = _
var writer: ArrowStreamWriter = _
@@ -37,7 +39,7 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite
with BeforeAndAfterEac
writer = mock(classOf[ArrowStreamWriter])
arrowWriter = mock(classOf[ArrowWriter])
transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(
- root, writer, arrowMaxRecordsPerBatch, arrowWriter)
+ root, writer, arrowMaxRecordsPerBatch, arrowMaxBytesPerBatch,
arrowWriter)
}
test("test writeRow") {
@@ -64,4 +66,33 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite
with BeforeAndAfterEac
verify(writer).writeBatch()
verify(arrowWriter).reset()
}
+
+ test("test maxBytesPerBatch can work") {
+ val root: VectorSchemaRoot = mock(classOf[VectorSchemaRoot])
+
+ var sizeCounter = 0
+ when(arrowWriter.write(any[InternalRow])).thenAnswer { _ =>
+ sizeCounter += 1
+ ()
+ }
+
+ when(arrowWriter.sizeInBytes()).thenAnswer { _ => sizeCounter }
+
+ // Set arrowMaxBytesPerBatch to 1
+ transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(
+ root, writer, arrowMaxRecordsPerBatch, 1, arrowWriter)
+ val dataRow = mock(classOf[InternalRow])
+ transformWithStateInPySparkWriter.writeRow(dataRow)
+ verify(arrowWriter).write(dataRow)
+ verify(writer, never()).writeBatch()
+ transformWithStateInPySparkWriter.writeRow(dataRow)
+ verify(arrowWriter, times(2)).write(dataRow)
+ // Write batch is called since we reach arrowMaxBytesPerBatch
+ verify(writer).writeBatch()
+ transformWithStateInPySparkWriter.finalizeCurrentArrowBatch()
+ verify(arrowWriter, times(2)).finish()
+ // The second record would be written
+ verify(writer, times(2)).writeBatch()
+ verify(arrowWriter, times(2)).reset()
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]