This is an automated email from the ASF dual-hosted git repository.
zhengruifeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e96995a8d430 [SPARK-56608][PYTHON] Migrate grouped/cogrouped map Arrow
UDF verify checks into enforce_schema
e96995a8d430 is described below
commit e96995a8d430bbd3e2918117a295811be3678154
Author: Yicong Huang <[email protected]>
AuthorDate: Thu May 21 08:25:15 2026 +0800
[SPARK-56608][PYTHON] Migrate grouped/cogrouped map Arrow UDF verify checks
into enforce_schema
### What changes were proposed in this pull request?
Make `ArrowBatchTransformer.enforce_schema` the single entry point for
Arrow UDF output validation, and switch `SQL_GROUPED_MAP_ARROW_UDF`,
`SQL_GROUPED_MAP_ARROW_ITER_UDF`, and `SQL_COGROUPED_MAP_ARROW_UDF` in
`worker.py` to it, replacing `verify_arrow_result` + manual reorder.
`enforce_schema` is generalized to accept both `pa.RecordBatch` and
`pa.Table`, add `reorder_by_name: bool = True` (name-based reorder vs
positional), aggregate all mismatches before raising, and raise
`PySparkRuntimeError` with the existing `errorClass`es
(`RESULT_COLUMN_NAMES_MISMATCH` / `RESULT_COLUMN_TYPES_MISMATCH` /
`RESULT_COLUMN_SCHEMA_MISMATCH`) instead of bare-string `PySparkTypeError`.
`verify_arrow_result` stays only for `SQL_ARROW_TABLE_UDF` (out of scope).
### Why are the changes needed?
Part of [SPARK-55388](https://issues.apache.org/jira/browse/SPARK-55388).
Validation is split between `verify_arrow_result` (friendly errors) and
`enforce_schema` (bare f-string errors). Consolidating gives one code path, one
error convention, and drops redundant verify+reorder work.
### Does this PR introduce _any_ user-facing change?
Yes, for `SQL_ARROW_UDTF`:
- Error format switches from bare f-strings to the `errorClass`-templated
format used by other Arrow UDFs.
- Returning target columns plus extras now raises
`RESULT_COLUMN_NAMES_MISMATCH` (with `Unexpected:` populated). This restores
the strict field-name check that existed before
[SPARK-56166](https://issues.apache.org/jira/browse/SPARK-56166), which
inadvertently relaxed it to silent-drop when switching to `enforce_schema`.
Grouped/cogrouped map error formats are unchanged.
### How was this patch tested?
- Existing `test_arrow_grouped_map.py` / `test_arrow_cogrouped_map.py`
integration tests pass unchanged.
- `test_conversion.py` extended for `reorder_by_name`, `pa.Table` input,
and count-mismatch paths.
- `test_arrow_udtf.py` regex updated for the two error tests.
- ASV (`repeat=3`) on the three affected bench classes: 0 regressions at
`-f 1.05`.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #55530 from Yicong-Huang/SPARK-56608.
Authored-by: Yicong Huang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/conversion.py | 117 +++++++++++++++-------
python/pyspark/sql/tests/arrow/test_arrow_udtf.py | 9 +-
python/pyspark/sql/tests/test_conversion.py | 66 +++++++++++-
python/pyspark/worker.py | 116 +++++----------------
4 files changed, 172 insertions(+), 136 deletions(-)
diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index a229386f3001..9110a6382725 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -21,7 +21,7 @@ import decimal
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence,
Union, overload
import pyspark
-from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
from pyspark.sql.pandas.types import (
_dedup_names,
_deduplicate_field_names,
@@ -124,19 +124,20 @@ class ArrowBatchTransformer:
@classmethod
def enforce_schema(
cls,
- batch: "pa.RecordBatch",
+ batch: Union["pa.RecordBatch", "pa.Table"],
arrow_schema: "pa.Schema",
*,
arrow_cast: bool = True,
safecheck: bool = True,
- ) -> "pa.RecordBatch":
+ reorder_by_name: bool = True,
+ ) -> Union["pa.RecordBatch", "pa.Table"]:
"""
- Enforce target schema on a RecordBatch by reordering columns and
coercing types.
+ Enforce a target schema on an Arrow RecordBatch or Table.
Parameters
----------
- batch : pa.RecordBatch
- Input RecordBatch to transform.
+ batch : pa.RecordBatch or pa.Table
+ Input to transform. Output is of the same container type.
arrow_schema : pa.Schema
Target Arrow schema. Callers should pre-compute this once via
to_arrow_schema() to avoid repeated conversion.
@@ -145,11 +146,28 @@ class ArrowBatchTransformer:
If False, raise an error on type mismatch instead of casting.
safecheck : bool, default True
If True, use safe casting (fails on overflow/truncation).
+ reorder_by_name : bool, default True
+ If True, match columns by name and reorder to the target order; any
+ missing or extra names raise ``RESULT_COLUMN_NAMES_MISMATCH``.
Output
+ columns are renamed to target names.
+ If False, match columns by position (ignore names) and preserve the
+ original column names in the output.
Returns
-------
- pa.RecordBatch
- RecordBatch with columns reordered and types coerced to match
target schema.
+ pa.RecordBatch or pa.Table
+ Same container type as ``batch``, with columns matched (and
possibly
+ reordered/cast) per the target schema.
+
+ Raises
+ ------
+ PySparkRuntimeError
+ ``RESULT_COLUMN_NAMES_MISMATCH`` when ``reorder_by_name=True`` and
the
+ batch has missing or extra column names.
+ ``RESULT_COLUMN_TYPES_MISMATCH`` when any column's type does not
match
+ the target (and either ``arrow_cast=False`` or the cast itself
fails).
+ ``RESULT_COLUMN_SCHEMA_MISMATCH`` when ``reorder_by_name=False``
and the
+ batch has a different number of columns than the target schema.
"""
import pyarrow as pa
@@ -160,37 +178,68 @@ class ArrowBatchTransformer:
if batch.schema.equals(arrow_schema, check_metadata=False):
return batch
- # Check if columns are in the same order (by name) as the target
schema.
- # If so, use index-based access (faster than name lookup).
- batch_names = [batch.schema.field(i).name for i in
range(batch.num_columns)]
target_names = [field.name for field in arrow_schema]
- use_index = batch_names == target_names
- coerced_arrays = []
- for i, field in enumerate(arrow_schema):
- try:
- arr = batch.column(i) if use_index else
batch.column(field.name)
- except KeyError:
- raise PySparkTypeError(
- f"Result column '{field.name}' does not exist in the
output. "
- f"Expected schema: {arrow_schema}, got: {batch.schema}."
+ # Step 1: pick source columns from batch to align with target schema
+ if reorder_by_name:
+ batch_names = [batch.schema.field(i).name for i in
range(batch.num_columns)]
+ missing = sorted(set(target_names) - set(batch_names))
+ extra = sorted(set(batch_names) - set(target_names))
+ if missing or extra:
+ raise PySparkRuntimeError(
+ errorClass="RESULT_COLUMN_NAMES_MISMATCH",
+ messageParameters={
+ "missing": f" Missing: {', '.join(missing)}." if
missing else "",
+ "extra": f" Unexpected: {', '.join(extra)}." if extra
else "",
+ },
)
- if arr.type != field.type:
- if not arrow_cast:
- raise PySparkTypeError(
- f"Result type of column '{field.name}' does not match "
- f"the expected type. Expected: {field.type}, got:
{arr.type}."
- )
+ source_columns = [batch.column(name) for name in target_names]
+ output_names = target_names
+ else:
+ # Positional: require exact column-count match, then take columns
by
+ # index, preserving the batch's original column names.
+ if batch.num_columns != len(arrow_schema):
+ raise PySparkRuntimeError(
+ errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
+ messageParameters={
+ "expected": str(len(arrow_schema)),
+ "actual": str(batch.num_columns),
+ },
+ )
+ source_columns = [batch.column(i) for i in
range(len(arrow_schema))]
+ output_names = [batch.schema.field(i).name for i in
range(len(arrow_schema))]
+
+ # Step 2: check types / cast, collect all mismatches
+ type_mismatches = []
+ coerced_arrays = []
+ for field, arr in zip(arrow_schema, source_columns):
+ if arr.type == field.type:
+ coerced_arrays.append(arr)
+ elif not arrow_cast:
+ type_mismatches.append((field.name, field.type, arr.type))
+ coerced_arrays.append(arr)
+ else:
try:
- arr = arr.cast(target_type=field.type, safe=safecheck)
- except (pa.ArrowInvalid, pa.ArrowTypeError) as e:
- raise PySparkTypeError(
- f"Result type of column '{field.name}' does not match "
- f"the expected type. Expected: {field.type}, got:
{arr.type}."
- ) from e
- coerced_arrays.append(arr)
+ coerced_arrays.append(arr.cast(target_type=field.type,
safe=safecheck))
+ except (pa.ArrowInvalid, pa.ArrowTypeError):
+ type_mismatches.append((field.name, field.type, arr.type))
+ coerced_arrays.append(arr)
+
+ if type_mismatches:
+ raise PySparkRuntimeError(
+ errorClass="RESULT_COLUMN_TYPES_MISMATCH",
+ messageParameters={
+ "mismatch": ", ".join(
+ f"column '{name}' (expected {expected}, actual
{actual})"
+ for name, expected, actual in type_mismatches
+ )
+ },
+ )
- return pa.RecordBatch.from_arrays(coerced_arrays, names=target_names)
+ # Preserve input container type (Table vs RecordBatch)
+ if isinstance(batch, pa.Table):
+ return pa.Table.from_arrays(coerced_arrays, names=output_names)
+ return pa.RecordBatch.from_arrays(coerced_arrays, names=output_names)
@classmethod
def to_pandas(
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
index f41b7613ec42..b82523005ac7 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
@@ -211,9 +211,8 @@ class ArrowUDTFTestsMixin:
with self.assertRaisesRegex(
PythonException,
- r"(?s)Result column 'x' does not exist in the output\. "
- r"Expected schema: x: int32\ny: string, "
- r"got: wrong_col: int32\nanother_wrong_col: double\.",
+ r"(?s)\[RESULT_COLUMN_NAMES_MISMATCH\].*"
+ r"Missing: x, y\..*Unexpected: another_wrong_col, wrong_col\.",
):
result_df = MismatchedSchemaUDTF()
result_df.collect()
@@ -375,8 +374,8 @@ class ArrowUDTFTestsMixin:
# Should fail with Arrow cast exception since string cannot be cast to
int
with self.assertRaisesRegex(
PythonException,
- "Result type of column 'id' does not match "
- "the expected type. Expected: int32, got: string.",
+ r"(?s)\[RESULT_COLUMN_TYPES_MISMATCH\].*"
+ r"column 'id' \(expected int32, actual string\)",
):
result_df = StringToIntUDTF()
result_df.collect()
diff --git a/python/pyspark/sql/tests/test_conversion.py
b/python/pyspark/sql/tests/test_conversion.py
index 304d8be740d4..dd5c7f44d281 100644
--- a/python/pyspark/sql/tests/test_conversion.py
+++ b/python/pyspark/sql/tests/test_conversion.py
@@ -18,7 +18,7 @@ import datetime
import unittest
from zoneinfo import ZoneInfo
-from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors import PySparkRuntimeError, PySparkTypeError,
PySparkValueError
from pyspark.sql.conversion import (
ArrowArrayToPandasConversion,
ArrowTableToRowsConversion,
@@ -185,8 +185,9 @@ class ArrowBatchTransformerTests(unittest.TestCase):
batch = pa.RecordBatch.from_arrays([pa.array([1], type=pa.int32())],
names=["x"])
target = pa.schema([("x", pa.int64())])
- with self.assertRaises(PySparkTypeError):
+ with self.assertRaises(PySparkRuntimeError) as cm:
ArrowBatchTransformer.enforce_schema(batch, target,
arrow_cast=False)
+ self.assertEqual(cm.exception.getCondition(),
"RESULT_COLUMN_TYPES_MISMATCH")
def test_enforce_schema_safecheck(self):
"""safecheck=True rejects overflow; safecheck=False allows it."""
@@ -194,18 +195,73 @@ class ArrowBatchTransformerTests(unittest.TestCase):
batch = pa.RecordBatch.from_arrays([pa.array([999], type=pa.int64())],
names=["x"])
target = pa.schema([("x", pa.int8())])
- with self.assertRaises(PySparkTypeError):
+ with self.assertRaises(PySparkRuntimeError) as cm:
ArrowBatchTransformer.enforce_schema(batch, target, safecheck=True)
+ self.assertEqual(cm.exception.getCondition(),
"RESULT_COLUMN_TYPES_MISMATCH")
result = ArrowBatchTransformer.enforce_schema(batch, target,
safecheck=False)
self.assertEqual(result.schema, target)
def test_enforce_schema_missing_column(self):
- """Missing column raises PySparkTypeError."""
+ """Missing column raises RESULT_COLUMN_NAMES_MISMATCH."""
import pyarrow as pa
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
- with self.assertRaises(PySparkTypeError):
+ with self.assertRaises(PySparkRuntimeError) as cm:
ArrowBatchTransformer.enforce_schema(batch, pa.schema([("missing",
pa.int64())]))
+ self.assertEqual(cm.exception.getCondition(),
"RESULT_COLUMN_NAMES_MISMATCH")
+
+ def test_enforce_schema_extra_column(self):
+ """Extra column raises RESULT_COLUMN_NAMES_MISMATCH with the extra
name listed."""
+ import pyarrow as pa
+
+ batch = pa.RecordBatch.from_arrays([pa.array([1]), pa.array([2])],
names=["a", "b"])
+ with self.assertRaises(PySparkRuntimeError) as cm:
+ ArrowBatchTransformer.enforce_schema(batch, pa.schema([("a",
pa.int64())]))
+ self.assertEqual(cm.exception.getCondition(),
"RESULT_COLUMN_NAMES_MISMATCH")
+ self.assertIn("b", str(cm.exception))
+
+ def test_enforce_schema_reorder_by_name(self):
+ """reorder_by_name=True reorders input columns to match target schema
order."""
+ import pyarrow as pa
+
+ batch = pa.RecordBatch.from_arrays([pa.array(["x"]), pa.array([1])],
names=["b", "a"])
+ target = pa.schema([("a", pa.int64()), ("b", pa.string())])
+ result = ArrowBatchTransformer.enforce_schema(batch, target)
+ self.assertEqual(result.schema.names, ["a", "b"])
+ self.assertEqual(result.column(0).to_pylist(), [1])
+ self.assertEqual(result.column(1).to_pylist(), ["x"])
+
+ def test_enforce_schema_positional(self):
+ """reorder_by_name=False matches columns by index, preserving input
names."""
+ import pyarrow as pa
+
+ batch = pa.RecordBatch.from_arrays([pa.array([1]), pa.array(["x"])],
names=["foo", "bar"])
+ target = pa.schema([("a", pa.int64()), ("b", pa.string())])
+ result = ArrowBatchTransformer.enforce_schema(batch, target,
reorder_by_name=False)
+ # Input column names are preserved
+ self.assertEqual(result.schema.names, ["foo", "bar"])
+ self.assertEqual(result.column(0).to_pylist(), [1])
+ self.assertEqual(result.column(1).to_pylist(), ["x"])
+
+ def test_enforce_schema_positional_count_mismatch(self):
+ """reorder_by_name=False with wrong column count raises
RESULT_COLUMN_SCHEMA_MISMATCH."""
+ import pyarrow as pa
+
+ batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
+ target = pa.schema([("x", pa.int64()), ("y", pa.int64())])
+ with self.assertRaises(PySparkRuntimeError) as cm:
+ ArrowBatchTransformer.enforce_schema(batch, target,
reorder_by_name=False)
+ self.assertEqual(cm.exception.getCondition(),
"RESULT_COLUMN_SCHEMA_MISMATCH")
+
+ def test_enforce_schema_table_input(self):
+ """enforce_schema accepts pa.Table and returns pa.Table."""
+ import pyarrow as pa
+
+ table = pa.table({"x": pa.array([1], type=pa.int32())})
+ target = pa.schema([("x", pa.int64())])
+ result = ArrowBatchTransformer.enforce_schema(table, target)
+ self.assertIsInstance(result, pa.Table)
+ self.assertEqual(result.schema, target)
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index daa804a718e5..306d4f80dbe5 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -2685,27 +2685,7 @@ def read_udfs(pickleSer, udf_info_list, eval_type,
runner_conf, eval_conf):
arrow_return_type = to_arrow_type(
return_type, timezone="UTC",
prefers_large_types=runner_conf.use_large_var_types
)
- if runner_conf.assign_cols_by_name:
- expected_cols_and_types = {
- col.name: to_arrow_type(
- col.dataType,
- timezone="UTC",
- prefers_large_types=runner_conf.use_large_var_types,
- )
- for col in return_type.fields
- }
- else:
- expected_cols_and_types = [
- (
- col.name,
- to_arrow_type(
- col.dataType,
- timezone="UTC",
- prefers_large_types=runner_conf.use_large_var_types,
- ),
- )
- for col in return_type.fields
- ]
+ arrow_return_schema = pa.schema(list(arrow_return_type))
key_offsets = parsed_offsets[0][0]
value_offsets = parsed_offsets[0][1]
@@ -2741,17 +2721,15 @@ def read_udfs(pickleSer, udf_info_list, eval_type,
runner_conf, eval_conf):
result = grouped_udf(key, value_table)
verify_return_type(result, pa.Table)
- verify_arrow_result(
- result, runner_conf.assign_cols_by_name,
expected_cols_and_types
+ # Verify types (and reorder by name when configured).
+ result = ArrowBatchTransformer.enforce_schema(
+ result,
+ arrow_return_schema,
+ arrow_cast=False,
+ reorder_by_name=runner_conf.assign_cols_by_name,
)
- # Reorder columns if needed and wrap into struct
for batch in result.to_batches():
- if runner_conf.assign_cols_by_name:
- batch = pa.RecordBatch.from_arrays(
- [batch.column(field.name) for field in
arrow_return_type],
- names=[field.name for field in arrow_return_type],
- )
yield ArrowBatchTransformer.wrap_struct(batch)
# profiling is not supported for UDF
@@ -2770,27 +2748,7 @@ def read_udfs(pickleSer, udf_info_list, eval_type,
runner_conf, eval_conf):
arrow_return_type = to_arrow_type(
return_type, timezone="UTC",
prefers_large_types=runner_conf.use_large_var_types
)
- if runner_conf.assign_cols_by_name:
- expected_cols_and_types = {
- col.name: to_arrow_type(
- col.dataType,
- timezone="UTC",
- prefers_large_types=runner_conf.use_large_var_types,
- )
- for col in return_type.fields
- }
- else:
- expected_cols_and_types = [
- (
- col.name,
- to_arrow_type(
- col.dataType,
- timezone="UTC",
- prefers_large_types=runner_conf.use_large_var_types,
- ),
- )
- for col in return_type.fields
- ]
+ arrow_return_schema = pa.schema(list(arrow_return_type))
key_offsets = parsed_offsets[0][0]
value_offsets = parsed_offsets[0][1]
@@ -2824,16 +2782,14 @@ def read_udfs(pickleSer, udf_info_list, eval_type,
runner_conf, eval_conf):
key = tuple(c[0] for c in keys.columns)
result = grouped_udf(key, value_batches)
- # Verify, reorder, and wrap each output batch
+ # Verify (and reorder by name when configured) each output
batch
for batch in verify_return_type(result,
Iterator[pa.RecordBatch]):
- verify_arrow_result(
- batch, runner_conf.assign_cols_by_name,
expected_cols_and_types
+ batch = ArrowBatchTransformer.enforce_schema(
+ batch,
+ arrow_return_schema,
+ arrow_cast=False,
+ reorder_by_name=runner_conf.assign_cols_by_name,
)
- if runner_conf.assign_cols_by_name:
- batch = pa.RecordBatch.from_arrays(
- [batch.column(field.name) for field in
arrow_return_type],
- names=[field.name for field in arrow_return_type],
- )
yield ArrowBatchTransformer.wrap_struct(batch)
# Drain remaining input batches to maintain stream position
@@ -2998,32 +2954,10 @@ def read_udfs(pickleSer, udf_info_list, eval_type,
runner_conf, eval_conf):
parsed_offsets = extract_key_value_indexes(arg_offsets)
- # Pre-compute expected column names/types for strict result validation.
- # Cogrouped map has a strict contract: missing, extra, or
type-mismatched
- # columns must raise; no silent coercion.
- if runner_conf.assign_cols_by_name:
- expected_cols_and_types = {
- col.name: to_arrow_type(
- col.dataType,
- timezone="UTC",
- prefers_large_types=runner_conf.use_large_var_types,
- )
- for col in return_type.fields
- }
- reorder_names = [col.name for col in return_type.fields]
- else:
- expected_cols_and_types = [
- (
- col.name,
- to_arrow_type(
- col.dataType,
- timezone="UTC",
- prefers_large_types=runner_conf.use_large_var_types,
- ),
- )
- for col in return_type.fields
- ]
- reorder_names = None
+ arrow_return_type = to_arrow_type(
+ return_type, timezone="UTC",
prefers_large_types=runner_conf.use_large_var_types
+ )
+ arrow_return_schema = pa.schema(list(arrow_return_type))
select_columns = ArrowBatchTransformer.select_columns
left_key_cols, left_val_cols = parsed_offsets[0]
@@ -3051,17 +2985,15 @@ def read_udfs(pickleSer, udf_info_list, eval_type,
runner_conf, eval_conf):
result = cogrouped_udf(key, left_values, right_values)
verify_return_type(result, pa.Table)
- verify_arrow_result(
- result, runner_conf.assign_cols_by_name,
expected_cols_and_types
+ # Verify types (and reorder by name when configured).
+ result = ArrowBatchTransformer.enforce_schema(
+ result,
+ arrow_return_schema,
+ arrow_cast=False,
+ reorder_by_name=runner_conf.assign_cols_by_name,
)
for batch in result.to_batches():
- if reorder_names is not None:
- # Names and types already validated equal; pure
reorder, no cast.
- batch = pa.RecordBatch.from_arrays(
- [batch.column(name) for name in reorder_names],
- names=reorder_names,
- )
yield ArrowBatchTransformer.wrap_struct(batch)
# profiling is not supported for UDF
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]