nyaapa commented on code in PR #53122:
URL: https://github.com/apache/spark/pull/53122#discussion_r2591712193
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1981,30 +1995,67 @@ def flatten_columns(cur_batch, col_name):
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and
pass into the Python
- data generator. All rows in the same batch have the same grouping
key.
+ data generator. Rows in the same batch may have different
grouping keys,
+ but each batch will have either init_data or input_data, not mix.
"""
- for batch in batches:
- flatten_state_table = flatten_columns(batch, "inputData")
- data_pandas = [
- self.arrow_to_pandas(c, i)
- for i, c in enumerate(flatten_state_table.itercolumns())
- ]
- flatten_init_table = flatten_columns(batch, "initState")
- init_data_pandas = [
- self.arrow_to_pandas(c, i)
- for i, c in enumerate(flatten_init_table.itercolumns())
- ]
- key_series = [data_pandas[o] for o in self.key_offsets]
- init_key_series = [init_data_pandas[o] for o in
self.init_key_offsets]
+ def row_stream():
+ for batch in batches:
+ self._update_batch_size_stats(batch)
- if any(s.empty for s in key_series):
- # If any row is empty, assign batch_key using
init_key_series
- batch_key = tuple(s[0] for s in init_key_series)
- else:
- # If all rows are non-empty, create batch_key from
key_series
- batch_key = tuple(s[0] for s in key_series)
- yield (batch_key, data_pandas, init_data_pandas)
+ flatten_state_table = flatten_columns(batch, "inputData")
+ data_pandas = [
+ self.arrow_to_pandas(c, i)
+ for i, c in
enumerate(flatten_state_table.itercolumns())
+ ]
+
+ if bool(data_pandas):
+ for row in pd.concat(data_pandas,
axis=1).itertuples(index=False):
+ batch_key = tuple(row[s] for s in self.key_offsets)
+ yield (batch_key, row, None)
+ else:
+ flatten_init_table = flatten_columns(batch,
"initState")
+ init_data_pandas = [
+ self.arrow_to_pandas(c, i)
+ for i, c in
enumerate(flatten_init_table.itercolumns())
+ ]
+ if bool(init_data_pandas):
+ for row in pd.concat(init_data_pandas,
axis=1).itertuples(index=False):
+ batch_key = tuple(row[s] for s in
self.init_key_offsets)
+ yield (batch_key, None, row)
+
+ EMPTY_DATAFRAME = pd.DataFrame()
+ for batch_key, group_rows in groupby(row_stream(), key=lambda x:
x[0]):
+ rows = []
+ init_state_rows = []
+ for _, row, init_state_row in group_rows:
+ if row is not None:
+ rows.append(row)
+ if init_state_row is not None:
+ init_state_rows.append(init_state_row)
+
+ total_len = len(rows) + len(init_state_rows)
+ if (
+ total_len >= self.arrow_max_records_per_batch
+ or total_len * self.average_arrow_row_size >=
self.arrow_max_bytes_per_batch
+ ):
+ yield (
Review Comment:
yep, if, for example, we have only one key, with both init state and data
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]