This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 68c0f64dddc9 [SPARK-45523][PYTHON] Refactor the null-checking to have 
shortcuts
68c0f64dddc9 is described below

commit 68c0f64dddc917be7d489f67fab06fcbfe500f0d
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Tue Oct 24 13:12:11 2023 -0700

    [SPARK-45523][PYTHON] Refactor the null-checking to have shortcuts
    
    ### What changes were proposed in this pull request?
    
    This is a follow-up of apache/spark#43356.
    
    Refactor the null-checking to have shortcuts.
    
    ### Why are the changes needed?
    
    The null-check can have shortcuts for some cases.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    The existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43492 from ueshin/issues/SPARK-45523/nullcheck.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/worker.py | 211 +++++++++++++++++++++++++++++------------------
 1 file changed, 129 insertions(+), 82 deletions(-)

diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index b1f59e1619fe..f6208032d9ac 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -24,7 +24,7 @@ import dataclasses
 import time
 from inspect import getfullargspec
 import json
-from typing import Any, Callable, Iterable, Iterator
+from typing import Any, Callable, Iterable, Iterator, Optional
 import faulthandler
 
 from pyspark.accumulators import _accumulatorRegistry
@@ -58,7 +58,6 @@ from pyspark.sql.types import (
     MapType,
     Row,
     StringType,
-    StructField,
     StructType,
     _create_row,
     _parse_datatype_json_string,
@@ -700,7 +699,7 @@ def read_udtf(pickleSer, infile, eval_type):
         )
 
     return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile))
-    if not type(return_type) == StructType:
+    if not isinstance(return_type, StructType):
         raise PySparkRuntimeError(
             f"The return type of a UDTF must be a struct type, but got 
{type(return_type)}."
         )
@@ -845,70 +844,112 @@ def read_udtf(pickleSer, infile, eval_type):
             "the query again."
         )
 
-    # This determines which result columns have nullable types.
-    def check_nullable_column(i: int, data_type: DataType, nullable: bool) -> 
None:
-        if not nullable:
-            nullable_columns.add(i)
-        elif isinstance(data_type, ArrayType):
-            check_nullable_column(i, data_type.elementType, 
data_type.containsNull)
-        elif isinstance(data_type, StructType):
-            for subfield in data_type.fields:
-                check_nullable_column(i, subfield.dataType, subfield.nullable)
-        elif isinstance(data_type, MapType):
-            check_nullable_column(i, data_type.valueType, 
data_type.valueContainsNull)
-
-    nullable_columns: set[int] = set()
-    for i, field in enumerate(return_type.fields):
-        check_nullable_column(i, field.dataType, field.nullable)
-
-    # Compares each UDTF output row against the output schema for this 
particular UDTF call,
-    # raising an error if the two are incompatible.
-    def check_output_row_against_schema(row: Any, expected_schema: StructType) 
-> None:
-        for result_column_index in nullable_columns:
-
-            def check_for_none_in_non_nullable_column(
-                value: Any, data_type: DataType, nullable: bool
-            ) -> None:
-                if value is None and not nullable:
-                    raise PySparkRuntimeError(
-                        error_class="UDTF_EXEC_ERROR",
-                        message_parameters={
-                            "method_name": "eval' or 'terminate",
-                            "error": f"Column {result_column_index} within a 
returned row had a "
-                            + "value of None, either directly or within 
array/struct/map "
-                            + "subfields, but the corresponding column type 
was declared as "
-                            + "non-nullable; please update the UDTF to return 
a non-None value at "
-                            + "this location or otherwise declare the column 
type as nullable.",
-                        },
-                    )
-                elif (
-                    isinstance(data_type, ArrayType)
-                    and isinstance(value, list)
-                    and not data_type.containsNull
-                ):
-                    for sub_value in value:
-                        check_for_none_in_non_nullable_column(
-                            sub_value, data_type.elementType, 
data_type.containsNull
-                        )
-                elif isinstance(data_type, StructType) and isinstance(value, 
Row):
-                    for i in range(len(value)):
-                        check_for_none_in_non_nullable_column(
-                            value[i], data_type[i].dataType, 
data_type[i].nullable
-                        )
-                elif isinstance(data_type, MapType) and isinstance(value, 
dict):
-                    for map_key, map_value in value.items():
-                        check_for_none_in_non_nullable_column(
-                            map_key, data_type.keyType, nullable=False
-                        )
-                        check_for_none_in_non_nullable_column(
-                            map_value, data_type.valueType, 
data_type.valueContainsNull
-                        )
+    def build_null_checker(return_type: StructType) -> 
Optional[Callable[[Any], None]]:
+        def raise_(result_column_index):
+            raise PySparkRuntimeError(
+                error_class="UDTF_EXEC_ERROR",
+                message_parameters={
+                    "method_name": "eval' or 'terminate",
+                    "error": f"Column {result_column_index} within a returned 
row had a "
+                    + "value of None, either directly or within 
array/struct/map "
+                    + "subfields, but the corresponding column type was 
declared as "
+                    + "non-nullable; please update the UDTF to return a 
non-None value at "
+                    + "this location or otherwise declare the column type as 
nullable.",
+                },
+            )
 
-            field: StructField = expected_schema[result_column_index]
-            if row is not None:
-                check_for_none_in_non_nullable_column(
-                    list(row)[result_column_index], field.dataType, 
field.nullable
-                )
+        def checker(data_type: DataType, result_column_index: int):
+            if isinstance(data_type, ArrayType):
+                element_checker = checker(data_type.elementType, 
result_column_index)
+                contains_null = data_type.containsNull
+
+                if element_checker is None and contains_null:
+                    return None
+
+                def check_array(arr):
+                    if isinstance(arr, list):
+                        for e in arr:
+                            if e is None:
+                                if not contains_null:
+                                    raise_(result_column_index)
+                            elif element_checker is not None:
+                                element_checker(e)
+
+                return check_array
+
+            elif isinstance(data_type, MapType):
+                key_checker = checker(data_type.keyType, result_column_index)
+                value_checker = checker(data_type.valueType, 
result_column_index)
+                value_contains_null = data_type.valueContainsNull
+
+                if value_checker is None and value_contains_null:
+
+                    def check_map(map):
+                        if isinstance(map, dict):
+                            for k, v in map.items():
+                                if k is None:
+                                    raise_(result_column_index)
+                                elif key_checker is not None:
+                                    key_checker(k)
+
+                else:
+
+                    def check_map(map):
+                        if isinstance(map, dict):
+                            for k, v in map.items():
+                                if k is None:
+                                    raise_(result_column_index)
+                                elif key_checker is not None:
+                                    key_checker(k)
+                                if v is None:
+                                    if not value_contains_null:
+                                        raise_(result_column_index)
+                                elif value_checker is not None:
+                                    value_checker(v)
+
+                return check_map
+
+            elif isinstance(data_type, StructType):
+                field_checkers = [checker(f.dataType, result_column_index) for 
f in data_type]
+                nullables = [f.nullable for f in data_type]
+
+                if all(c is None for c in field_checkers) and all(nullables):
+                    return None
+
+                def check_struct(struct):
+                    if isinstance(struct, tuple):
+                        for value, checker, nullable in zip(struct, 
field_checkers, nullables):
+                            if value is None:
+                                if not nullable:
+                                    raise_(result_column_index)
+                            elif checker is not None:
+                                checker(value)
+
+                return check_struct
+
+            else:
+                return None
+
+        field_checkers = [
+            checker(f.dataType, result_column_index=i) for i, f in 
enumerate(return_type)
+        ]
+        nullables = [f.nullable for f in return_type]
+
+        if all(c is None for c in field_checkers) and all(nullables):
+            return None
+
+        def check(row):
+            if isinstance(row, tuple):
+                for i, (value, checker, nullable) in enumerate(zip(row, 
field_checkers, nullables)):
+                    if value is None:
+                        if not nullable:
+                            raise_(i)
+                    elif checker is not None:
+                        checker(value)
+
+        return check
+
+    check_output_row_against_schema = build_null_checker(return_type)
 
     if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
 
@@ -948,8 +989,6 @@ def read_udtf(pickleSer, infile, eval_type):
                 verify_pandas_result(
                     result, return_type, assign_cols_by_name=False, 
truncate_return_schema=False
                 )
-                for result_tuple in result.itertuples():
-                    check_output_row_against_schema(list(result_tuple), 
return_type)
                 return result
 
             # Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
@@ -965,28 +1004,36 @@ def read_udtf(pickleSer, infile, eval_type):
             def check_return_value(res):
                 # Check whether the result of an arrow UDTF is iterable before
                 # using it to construct a pandas DataFrame.
-                if res is not None and not isinstance(res, Iterable):
-                    raise PySparkRuntimeError(
-                        error_class="UDTF_RETURN_NOT_ITERABLE",
-                        message_parameters={
-                            "type": type(res).__name__,
-                            "func": f.__name__,
-                        },
-                    )
+                if res is not None:
+                    if not isinstance(res, Iterable):
+                        raise PySparkRuntimeError(
+                            error_class="UDTF_RETURN_NOT_ITERABLE",
+                            message_parameters={
+                                "type": type(res).__name__,
+                                "func": f.__name__,
+                            },
+                        )
+                    if check_output_row_against_schema is not None:
+                        for row in res:
+                            if row is not None:
+                                check_output_row_against_schema(row)
+                            yield row
+                    else:
+                        yield from res
 
             def evaluate(*args: pd.Series):
                 if len(args) == 0:
                     res = func()
-                    check_return_value(res)
-                    yield verify_result(pd.DataFrame(res)), arrow_return_type
+                    yield 
verify_result(pd.DataFrame(check_return_value(res))), arrow_return_type
                 else:
                     # Create tuples from the input pandas Series, each tuple
                     # represents a row across all Series.
                     row_tuples = zip(*args)
                     for row in row_tuples:
                         res = func(*row)
-                        check_return_value(res)
-                        yield verify_result(pd.DataFrame(res)), 
arrow_return_type
+                        yield verify_result(
+                            pd.DataFrame(check_return_value(res))
+                        ), arrow_return_type
 
             return evaluate
 
@@ -1043,8 +1090,8 @@ def read_udtf(pickleSer, infile, eval_type):
                                 "func": f.__name__,
                             },
                         )
-
-                check_output_row_against_schema(result, return_type)
+                    if check_output_row_against_schema is not None:
+                        check_output_row_against_schema(result)
                 return toInternal(result)
 
             # Evaluate the function and return a tuple back to the executor.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to