nyaapa commented on code in PR #53122:
URL: https://github.com/apache/spark/pull/53122#discussion_r2551042509
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1981,30 +1988,71 @@ def flatten_columns(cur_batch, col_name):
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and
pass into the Python
- data generator. All rows in the same batch have the same grouping
key.
+ data generator. Rows in the same batch may have different
grouping keys,
+ but each batch will have either init_data or input_data, not mix.
"""
- for batch in batches:
- flatten_state_table = flatten_columns(batch, "inputData")
- data_pandas = [
- self.arrow_to_pandas(c, i)
- for i, c in enumerate(flatten_state_table.itercolumns())
- ]
+ def row_stream():
+ for batch in batches:
+ if self.arrow_max_bytes_per_batch != 2**31 - 1 and
batch.num_rows > 0:
+ batch_bytes = sum(
+ buf.size
+ for col in batch.columns
+ for buf in col.buffers()
+ if buf is not None
+ )
+ self.total_bytes += batch_bytes
+ self.total_rows += batch.num_rows
+ self.average_arrow_row_size = self.total_bytes /
self.total_rows
Review Comment:
Done
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -2141,6 +2188,11 @@ def generate_data_batches(batches):
def extract_rows(cur_batch, col_name, key_offsets):
data_column =
cur_batch.column(cur_batch.schema.get_field_index(col_name))
+
+ # Check if the entire column is null
+ if data_column.null_count == len(data_column):
+ return None
Review Comment:
Done
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -2153,68 +2205,62 @@ def extract_rows(cur_batch, col_name, key_offsets):
table = pa.Table.from_arrays(data_field_arrays,
names=data_field_names)
if table.num_rows == 0:
- return (None, iter([]))
- else:
- batch_key = tuple(table.column(o)[0].as_py() for o in
key_offsets)
+ return None
- rows = []
+ def row_iterator():
for row_idx in range(table.num_rows):
+ key = tuple(table.column(o)[row_idx].as_py() for o in
key_offsets)
row = DataRow(
*(table.column(i)[row_idx].as_py() for i in
range(table.num_columns))
)
- rows.append(row)
+ yield (key, row)
- return (batch_key, iter(rows))
+ return row_iterator()
"""
The arrow batch is written in the schema:
schema: StructType = new StructType()
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and
pass into the Python
- data generator. All rows in the same batch have the same grouping
key.
+ data generator. Each batch will have either init_data or
input_data, not mix.
"""
for batch in batches:
- (input_batch_key, input_data_iter) = extract_rows(
- batch, "inputData", self.key_offsets
- )
- (init_batch_key, init_state_iter) = extract_rows(
- batch, "initState", self.init_key_offsets
- )
+ # Detect which column has data - each batch contains only one
type
+ input_result = extract_rows(batch, "inputData",
self.key_offsets)
- if input_batch_key is None:
- batch_key = init_batch_key
+ if input_result is not None:
+ for key, input_data_row in input_result:
+ yield (key, input_data_row, None)
else:
- batch_key = input_batch_key
-
- for init_state_row in init_state_iter:
- yield (batch_key, None, init_state_row)
-
- for input_data_row in input_data_iter:
- yield (batch_key, input_data_row, None)
+ init_result = extract_rows(batch, "initState",
self.init_key_offsets)
+ if init_result is not None:
+ for key, init_state_row in init_result:
+ yield (key, None, init_state_row)
_batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)
for k, g in groupby(data_batches, key=lambda x: x[0]):
- # g: list(batch_key, input_data_iter, init_state_iter)
-
- # they are sharing the iterator, hence need to copy
- input_values_iter, init_state_iter = itertools.tee(g, 2)
-
- chained_input_values = itertools.chain(map(lambda x: x[1],
input_values_iter))
- chained_init_state_values = itertools.chain(map(lambda x: x[2],
init_state_iter))
-
- chained_input_values_without_none = filter(
- lambda x: x is not None, chained_input_values
- )
- chained_init_state_values_without_none = filter(
- lambda x: x is not None, chained_init_state_values
- )
-
- ret_tuple = (chained_input_values_without_none,
chained_init_state_values_without_none)
-
- yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k,
ret_tuple)
+ input_rows = []
+ init_rows = []
+
+ for batch_key, input_row, init_row in g:
+ if input_row is not None:
+ input_rows.append(input_row)
+ if init_row is not None:
+ init_rows.append(init_row)
+
+ total_len = len(input_rows) + len(init_rows)
+ if total_len >= self.arrow_max_records_per_batch:
Review Comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]