zhengruifeng commented on code in PR #55530:
URL: https://github.com/apache/spark/pull/55530#discussion_r3222955743


##########
python/pyspark/worker.py:
##########
@@ -2940,20 +2920,9 @@ def dataframe_iter():
 
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
-        # Pre-compute expected column names/types for strict result validation.
-        # Cogrouped map has a strict contract: missing, extra, or 
type-mismatched
-        # columns must raise; no silent coercion.
-        if runner_conf.assign_cols_by_name:
-            expected_cols_and_types = {
-                col.name: to_arrow_type(col.dataType, timezone="UTC") for col 
in return_type.fields
-            }
-            reorder_names = [col.name for col in return_type.fields]
-        else:
-            expected_cols_and_types = [
-                (col.name, to_arrow_type(col.dataType, timezone="UTC"))
-                for col in return_type.fields
-            ]
-            reorder_names = None
+        arrow_return_schema = pa.schema(
+            [(col.name, to_arrow_type(col.dataType, timezone="UTC")) for col 
in return_type.fields]
+        )

Review Comment:
   The new grouped paths in this PR (`worker.py:2651-2654` and `:2714-2717`) 
thread `runner_conf.use_large_var_types` through `to_arrow_type(return_type, 
...)` for their validation schema, but cogrouped still builds the schema 
per-field without `prefers_large_types`. Under 
`spark.sql.execution.arrow.useLargeVarTypes=true` this leaves the cogrouped 
validation expecting regular `string`/`binary` while the rest of the pipeline 
(and the new grouped paths) expects `large_string`/`large_binary` — a UDF 
returning large variants is rejected, and a regular-string return that should 
be flagged is accepted. The pre-PR `verify_arrow_result` setup had the same 
omission, but since the grouped paths in this PR pick it up, aligning cogrouped 
is the consistency fix.
   
   ```suggestion
           arrow_return_type = to_arrow_type(
               return_type, timezone="UTC", 
prefers_large_types=runner_conf.use_large_var_types
           )
           arrow_return_schema = pa.schema(list(arrow_return_type))
   ```



##########
python/pyspark/sql/conversion.py:
##########
@@ -145,11 +146,26 @@ def enforce_schema(
             If False, raise an error on type mismatch instead of casting.
         safecheck : bool, default True
             If True, use safe casting (fails on overflow/truncation).
+        reorder_by_name : bool, default True
+            If True, match columns by name and reorder to the target order; any
+            missing or extra names raise ``RESULT_COLUMN_NAMES_MISMATCH``. 
Output

Review Comment:
   Heads-up: the new default `reorder_by_name=True` strictly rejects extras 
(raises `RESULT_COLUMN_NAMES_MISMATCH`), but the old `enforce_schema` silently 
dropped them — it only looked up target names via `batch.column(name)`. The 
remaining default-behavior caller — 
`ArrowStreamArrowUDTFSerializer.dump_stream` 
(`pyspark/sql/pandas/serializers.py:293`) — therefore changes contract: a 
`SQL_ARROW_UDTF` that returned target columns plus extras was a no-op before, 
now raises. Probably the right contract (the old leniency was undocumented), 
but worth surfacing in the "user-facing change" section of the PR description 
so any UDTF returning extras can be cleaned up before upgrade.



##########
python/pyspark/sql/conversion.py:
##########
@@ -145,11 +146,26 @@ def enforce_schema(
             If False, raise an error on type mismatch instead of casting.
         safecheck : bool, default True
             If True, use safe casting (fails on overflow/truncation).
+        reorder_by_name : bool, default True
+            If True, match columns by name and reorder to the target order; any
+            missing or extra names raise ``RESULT_COLUMN_NAMES_MISMATCH``. 
Output
+            columns are renamed to target names.
+            If False, match columns by position (ignore names) and preserve the
+            original column names in the output.
 
         Returns
         -------
-        pa.RecordBatch
-            RecordBatch with columns reordered and types coerced to match 
target schema.
+        pa.RecordBatch or pa.Table
+            Same container type as ``batch``, with columns matched (and 
possibly
+            reordered/cast) per the target schema.
+
+        Raises
+        ------
+        PySparkRuntimeError
+            ``RESULT_COLUMN_NAMES_MISMATCH`` when ``reorder_by_name=True`` and 
the
+            batch has missing or extra column names.
+            ``RESULT_COLUMN_TYPES_MISMATCH`` when any column's type does not 
match
+            the target (and either ``arrow_cast=False`` or the cast itself 
fails).

Review Comment:
   The `Raises` section omits `RESULT_COLUMN_SCHEMA_MISMATCH`, which the 
function also raises (positional mode, when `batch.num_columns != 
len(arrow_schema)`).
   
   ```suggestion
               the target (and either ``arrow_cast=False`` or the cast itself 
fails).
               ``RESULT_COLUMN_SCHEMA_MISMATCH`` when ``reorder_by_name=False`` 
and the
               batch has a different number of columns than the target schema.
   ```



##########
python/pyspark/sql/conversion.py:
##########
@@ -160,37 +176,68 @@ def enforce_schema(
         if batch.schema.equals(arrow_schema, check_metadata=False):
             return batch
 
-        # Check if columns are in the same order (by name) as the target 
schema.
-        # If so, use index-based access (faster than name lookup).
-        batch_names = [batch.schema.field(i).name for i in 
range(batch.num_columns)]
         target_names = [field.name for field in arrow_schema]
-        use_index = batch_names == target_names
 
-        coerced_arrays = []
-        for i, field in enumerate(arrow_schema):
-            try:
-                arr = batch.column(i) if use_index else 
batch.column(field.name)
-            except KeyError:
-                raise PySparkTypeError(
-                    f"Result column '{field.name}' does not exist in the 
output. "
-                    f"Expected schema: {arrow_schema}, got: {batch.schema}."
+        # Step 1: pick source columns from batch to align with target schema
+        if reorder_by_name:
+            batch_names = [batch.schema.field(i).name for i in 
range(batch.num_columns)]
+            missing = sorted(set(target_names) - set(batch_names))
+            extra = sorted(set(batch_names) - set(target_names))
+            if missing or extra:
+                raise PySparkRuntimeError(
+                    errorClass="RESULT_COLUMN_NAMES_MISMATCH",
+                    messageParameters={
+                        "missing": f" Missing: {', '.join(missing)}." if 
missing else "",
+                        "extra": f" Unexpected: {', '.join(extra)}." if extra 
else "",
+                    },
                 )
-            if arr.type != field.type:
-                if not arrow_cast:
-                    raise PySparkTypeError(
-                        f"Result type of column '{field.name}' does not match "
-                        f"the expected type. Expected: {field.type}, got: 
{arr.type}."
-                    )
+            source_columns = [batch.column(name) for name in target_names]
+            output_names = target_names
+        else:
+            # Positional: require exact column-count match, then take columns 
by
+            # index, preserving the batch's original column names.
+            if batch.num_columns != len(arrow_schema):

Review Comment:
   Behavior change worth noting in the PR description: under 
`assign_cols_by_name=False`, the old `verify_arrow_result` did `zip(expected, 
actual)` for the column list and silently truncated to the shorter list, so 
column-count mismatches in the positional grouped/cogrouped paths slipped 
through. The new strict count check is an improvement over the silent 
truncation, but it's a runtime behavior change for users running with 
`spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName=false` whose 
UDF/UDTF returned the wrong number of columns (and was previously getting 
silently partial validation).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to