HeartSaVioR commented on code in PR #53122:
URL: https://github.com/apache/spark/pull/53122#discussion_r2576123807
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala:
##########
@@ -158,30 +161,52 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
)
}
- if (inputIterator.hasNext) {
- val startData = dataOut.size()
- // a new grouping key with data & init state iter
- val next = inputIterator.next()
- val dataIter = next._2
- val initIter = next._3
-
- while (dataIter.hasNext || initIter.hasNext) {
- val dataRow =
- if (dataIter.hasNext) dataIter.next()
- else InternalRow.empty
- val initRow =
- if (initIter.hasNext) initIter.next()
- else InternalRow.empty
- pandasWriter.writeRow(InternalRow(dataRow, initRow))
+
+ // If we don't have data left for the current group, move to the next
group.
+ if (currentDataIterator == null && inputIterator.hasNext) {
+ val ((_, data), isInitState) = inputIterator.next()
+ currentDataIterator = data
+ val isPrevIterFromInitState = isCurrentIterFromInitState
+ isCurrentIterFromInitState = Some(isInitState)
+ if (isPrevIterFromInitState.isDefined &&
+ isPrevIterFromInitState.get != isInitState &&
+ pandasWriter.getTotalNumRowsForBatch > 0) {
+ // So we won't have batches with mixed data and init state.
+ pandasWriter.finalizeCurrentArrowBatch()
+ return true
}
- pandasWriter.finalizeCurrentArrowBatch()
- val deltaData = dataOut.size() - startData
- pythonMetrics("pythonDataSent") += deltaData
+ }
+
+ val startData = dataOut.size()
+
+ val hasInput = if (currentDataIterator != null) {
+ var isCurrentBatchFull = false
+ // Stop writing when the current arrowBatch is finalized/full. If we
have rows left
+ while (currentDataIterator.hasNext && !isCurrentBatchFull) {
+ val dataRow = currentDataIterator.next()
+ isCurrentBatchFull = if (isCurrentIterFromInitState.get) {
Review Comment:
nit: isCurrentIterFromInitState.get can be evaluated once before the while
statement.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala:
##########
@@ -144,6 +144,9 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
private var pandasWriter: BaseStreamingArrowWriter = _
+ private var currentDataIterator: Iterator[InternalRow] = _
+ private var isCurrentIterFromInitState: Option[Boolean] = None
+
override protected def writeNextBatchToArrowStream(
Review Comment:
Thanks for making changes to return the control once one Arrow RecordBatch
has filled and flushed. Thought we dealt with it but looks like we only do in
non-initial state.
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1981,30 +1995,67 @@ def flatten_columns(cur_batch, col_name):
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and
pass into the Python
- data generator. All rows in the same batch have the same grouping
key.
+ data generator. Rows in the same batch may have different
grouping keys,
+ but each batch will have either init_data or input_data, not mix.
"""
- for batch in batches:
- flatten_state_table = flatten_columns(batch, "inputData")
- data_pandas = [
- self.arrow_to_pandas(c, i)
- for i, c in enumerate(flatten_state_table.itercolumns())
- ]
- flatten_init_table = flatten_columns(batch, "initState")
- init_data_pandas = [
- self.arrow_to_pandas(c, i)
- for i, c in enumerate(flatten_init_table.itercolumns())
- ]
- key_series = [data_pandas[o] for o in self.key_offsets]
- init_key_series = [init_data_pandas[o] for o in
self.init_key_offsets]
+ def row_stream():
+ for batch in batches:
+ self._update_batch_size_stats(batch)
- if any(s.empty for s in key_series):
- # If any row is empty, assign batch_key using
init_key_series
- batch_key = tuple(s[0] for s in init_key_series)
- else:
- # If all rows are non-empty, create batch_key from
key_series
- batch_key = tuple(s[0] for s in key_series)
- yield (batch_key, data_pandas, init_data_pandas)
+ flatten_state_table = flatten_columns(batch, "inputData")
+ data_pandas = [
+ self.arrow_to_pandas(c, i)
+ for i, c in
enumerate(flatten_state_table.itercolumns())
+ ]
+
+ if bool(data_pandas):
+ for row in pd.concat(data_pandas,
axis=1).itertuples(index=False):
+ batch_key = tuple(row[s] for s in self.key_offsets)
+ yield (batch_key, row, None)
+ else:
+ flatten_init_table = flatten_columns(batch,
"initState")
+ init_data_pandas = [
+ self.arrow_to_pandas(c, i)
+ for i, c in
enumerate(flatten_init_table.itercolumns())
+ ]
+ if bool(init_data_pandas):
Review Comment:
Do we have a case of both `data_pandas` and `init_data_pandas` are not
available? If that shouldn't be a case, assert would be probably better to make
it fail fast.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala:
##########
@@ -95,4 +95,43 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite
with BeforeAndAfterEac
verify(writer, times(2)).writeBatch()
verify(arrowWriter, times(2)).reset()
}
+
+ test("test negative or zero arrowMaxRecordsPerBatch is unlimited") {
+ val root: VectorSchemaRoot = mock(classOf[VectorSchemaRoot])
+ val dataRow = mock(classOf[InternalRow])
+
+ // Test with negative value
+ transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(
+ root, writer, -1, arrowMaxBytesPerBatch, arrowWriter)
+
+ // Write many rows (more than typical batch size)
Review Comment:
I wonder how do we test this practically - the default value of Arrow
RecordBatch size is actually 10K, and 10 is way too low than the default. It's
also questionable whether we have to go with 10K writes (maybe OK to just do it
if it doesn't incur second level of test time addition), so probably a matter
of balance.
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -2153,68 +2213,62 @@ def extract_rows(cur_batch, col_name, key_offsets):
table = pa.Table.from_arrays(data_field_arrays,
names=data_field_names)
if table.num_rows == 0:
- return (None, iter([]))
- else:
- batch_key = tuple(table.column(o)[0].as_py() for o in
key_offsets)
+ return None
- rows = []
+ def row_iterator():
for row_idx in range(table.num_rows):
+ key = tuple(table.column(o)[row_idx].as_py() for o in
key_offsets)
row = DataRow(
*(table.column(i)[row_idx].as_py() for i in
range(table.num_columns))
)
- rows.append(row)
+ yield (key, row)
- return (batch_key, iter(rows))
+ return row_iterator()
"""
The arrow batch is written in the schema:
schema: StructType = new StructType()
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and
pass into the Python
- data generator. All rows in the same batch have the same grouping
key.
+ data generator. Each batch will have either init_data or
input_data, not mix.
"""
for batch in batches:
- (input_batch_key, input_data_iter) = extract_rows(
- batch, "inputData", self.key_offsets
- )
- (init_batch_key, init_state_iter) = extract_rows(
- batch, "initState", self.init_key_offsets
- )
+ # Detect which column has data - each batch contains only one
type
+ input_result = extract_rows(batch, "inputData",
self.key_offsets)
- if input_batch_key is None:
- batch_key = init_batch_key
+ if input_result is not None:
+ for key, input_data_row in input_result:
+ yield (key, input_data_row, None)
else:
- batch_key = input_batch_key
-
- for init_state_row in init_state_iter:
- yield (batch_key, None, init_state_row)
-
- for input_data_row in input_data_iter:
- yield (batch_key, input_data_row, None)
+ init_result = extract_rows(batch, "initState",
self.init_key_offsets)
+ if init_result is not None:
Review Comment:
same here; do we want to assert here?
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1981,30 +1995,67 @@ def flatten_columns(cur_batch, col_name):
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and
pass into the Python
- data generator. All rows in the same batch have the same grouping
key.
+ data generator. Rows in the same batch may have different
grouping keys,
+ but each batch will have either init_data or input_data, not mix.
"""
- for batch in batches:
- flatten_state_table = flatten_columns(batch, "inputData")
- data_pandas = [
- self.arrow_to_pandas(c, i)
- for i, c in enumerate(flatten_state_table.itercolumns())
- ]
- flatten_init_table = flatten_columns(batch, "initState")
- init_data_pandas = [
- self.arrow_to_pandas(c, i)
- for i, c in enumerate(flatten_init_table.itercolumns())
- ]
- key_series = [data_pandas[o] for o in self.key_offsets]
- init_key_series = [init_data_pandas[o] for o in
self.init_key_offsets]
+ def row_stream():
+ for batch in batches:
+ self._update_batch_size_stats(batch)
- if any(s.empty for s in key_series):
- # If any row is empty, assign batch_key using
init_key_series
- batch_key = tuple(s[0] for s in init_key_series)
- else:
- # If all rows are non-empty, create batch_key from
key_series
- batch_key = tuple(s[0] for s in key_series)
- yield (batch_key, data_pandas, init_data_pandas)
+ flatten_state_table = flatten_columns(batch, "inputData")
+ data_pandas = [
+ self.arrow_to_pandas(c, i)
+ for i, c in
enumerate(flatten_state_table.itercolumns())
+ ]
+
+ if bool(data_pandas):
+ for row in pd.concat(data_pandas,
axis=1).itertuples(index=False):
+ batch_key = tuple(row[s] for s in self.key_offsets)
+ yield (batch_key, row, None)
+ else:
+ flatten_init_table = flatten_columns(batch,
"initState")
+ init_data_pandas = [
+ self.arrow_to_pandas(c, i)
+ for i, c in
enumerate(flatten_init_table.itercolumns())
+ ]
+ if bool(init_data_pandas):
+ for row in pd.concat(init_data_pandas,
axis=1).itertuples(index=False):
+ batch_key = tuple(row[s] for s in
self.init_key_offsets)
+ yield (batch_key, None, row)
+
+ EMPTY_DATAFRAME = pd.DataFrame()
+ for batch_key, group_rows in groupby(row_stream(), key=lambda x:
x[0]):
+ rows = []
+ init_state_rows = []
+ for _, row, init_state_row in group_rows:
+ if row is not None:
+ rows.append(row)
+ if init_state_row is not None:
+ init_state_rows.append(init_state_row)
+
+ total_len = len(rows) + len(init_state_rows)
+ if (
+ total_len >= self.arrow_max_records_per_batch
+ or total_len * self.average_arrow_row_size >=
self.arrow_max_bytes_per_batch
+ ):
+ yield (
Review Comment:
Just to confirm since we separate data and init state out from composing
Arrow RecordBatch in task thread -> Python worker.
We could produce the tuple which has both data and init state for the same
key, and we will separate this out again in later phase. Do I understand
correctly?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]