This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 68c0f64dddc9 [SPARK-45523][PYTHON] Refactor the null-checking to have
shortcuts
68c0f64dddc9 is described below
commit 68c0f64dddc917be7d489f67fab06fcbfe500f0d
Author: Takuya UESHIN <[email protected]>
AuthorDate: Tue Oct 24 13:12:11 2023 -0700
[SPARK-45523][PYTHON] Refactor the null-checking to have shortcuts
### What changes were proposed in this pull request?
This is a follow-up of apache/spark#43356.
Refactor the null-checking to have shortcuts.
### Why are the changes needed?
The null-check can have shortcuts for some cases.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
The existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43492 from ueshin/issues/SPARK-45523/nullcheck.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
---
python/pyspark/worker.py | 211 +++++++++++++++++++++++++++++------------------
1 file changed, 129 insertions(+), 82 deletions(-)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index b1f59e1619fe..f6208032d9ac 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -24,7 +24,7 @@ import dataclasses
import time
from inspect import getfullargspec
import json
-from typing import Any, Callable, Iterable, Iterator
+from typing import Any, Callable, Iterable, Iterator, Optional
import faulthandler
from pyspark.accumulators import _accumulatorRegistry
@@ -58,7 +58,6 @@ from pyspark.sql.types import (
MapType,
Row,
StringType,
- StructField,
StructType,
_create_row,
_parse_datatype_json_string,
@@ -700,7 +699,7 @@ def read_udtf(pickleSer, infile, eval_type):
)
return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile))
- if not type(return_type) == StructType:
+ if not isinstance(return_type, StructType):
raise PySparkRuntimeError(
f"The return type of a UDTF must be a struct type, but got
{type(return_type)}."
)
@@ -845,70 +844,112 @@ def read_udtf(pickleSer, infile, eval_type):
"the query again."
)
- # This determines which result columns have nullable types.
- def check_nullable_column(i: int, data_type: DataType, nullable: bool) ->
None:
- if not nullable:
- nullable_columns.add(i)
- elif isinstance(data_type, ArrayType):
- check_nullable_column(i, data_type.elementType,
data_type.containsNull)
- elif isinstance(data_type, StructType):
- for subfield in data_type.fields:
- check_nullable_column(i, subfield.dataType, subfield.nullable)
- elif isinstance(data_type, MapType):
- check_nullable_column(i, data_type.valueType,
data_type.valueContainsNull)
-
- nullable_columns: set[int] = set()
- for i, field in enumerate(return_type.fields):
- check_nullable_column(i, field.dataType, field.nullable)
-
- # Compares each UDTF output row against the output schema for this
particular UDTF call,
- # raising an error if the two are incompatible.
- def check_output_row_against_schema(row: Any, expected_schema: StructType)
-> None:
- for result_column_index in nullable_columns:
-
- def check_for_none_in_non_nullable_column(
- value: Any, data_type: DataType, nullable: bool
- ) -> None:
- if value is None and not nullable:
- raise PySparkRuntimeError(
- error_class="UDTF_EXEC_ERROR",
- message_parameters={
- "method_name": "eval' or 'terminate",
- "error": f"Column {result_column_index} within a
returned row had a "
- + "value of None, either directly or within
array/struct/map "
- + "subfields, but the corresponding column type
was declared as "
- + "non-nullable; please update the UDTF to return
a non-None value at "
- + "this location or otherwise declare the column
type as nullable.",
- },
- )
- elif (
- isinstance(data_type, ArrayType)
- and isinstance(value, list)
- and not data_type.containsNull
- ):
- for sub_value in value:
- check_for_none_in_non_nullable_column(
- sub_value, data_type.elementType,
data_type.containsNull
- )
- elif isinstance(data_type, StructType) and isinstance(value,
Row):
- for i in range(len(value)):
- check_for_none_in_non_nullable_column(
- value[i], data_type[i].dataType,
data_type[i].nullable
- )
- elif isinstance(data_type, MapType) and isinstance(value,
dict):
- for map_key, map_value in value.items():
- check_for_none_in_non_nullable_column(
- map_key, data_type.keyType, nullable=False
- )
- check_for_none_in_non_nullable_column(
- map_value, data_type.valueType,
data_type.valueContainsNull
- )
+ def build_null_checker(return_type: StructType) ->
Optional[Callable[[Any], None]]:
+ def raise_(result_column_index):
+ raise PySparkRuntimeError(
+ error_class="UDTF_EXEC_ERROR",
+ message_parameters={
+ "method_name": "eval' or 'terminate",
+ "error": f"Column {result_column_index} within a returned
row had a "
+ + "value of None, either directly or within
array/struct/map "
+ + "subfields, but the corresponding column type was
declared as "
+ + "non-nullable; please update the UDTF to return a
non-None value at "
+ + "this location or otherwise declare the column type as
nullable.",
+ },
+ )
- field: StructField = expected_schema[result_column_index]
- if row is not None:
- check_for_none_in_non_nullable_column(
- list(row)[result_column_index], field.dataType,
field.nullable
- )
+ def checker(data_type: DataType, result_column_index: int):
+ if isinstance(data_type, ArrayType):
+ element_checker = checker(data_type.elementType,
result_column_index)
+ contains_null = data_type.containsNull
+
+ if element_checker is None and contains_null:
+ return None
+
+ def check_array(arr):
+ if isinstance(arr, list):
+ for e in arr:
+ if e is None:
+ if not contains_null:
+ raise_(result_column_index)
+ elif element_checker is not None:
+ element_checker(e)
+
+ return check_array
+
+ elif isinstance(data_type, MapType):
+ key_checker = checker(data_type.keyType, result_column_index)
+ value_checker = checker(data_type.valueType,
result_column_index)
+ value_contains_null = data_type.valueContainsNull
+
+ if value_checker is None and value_contains_null:
+
+ def check_map(map):
+ if isinstance(map, dict):
+ for k, v in map.items():
+ if k is None:
+ raise_(result_column_index)
+ elif key_checker is not None:
+ key_checker(k)
+
+ else:
+
+ def check_map(map):
+ if isinstance(map, dict):
+ for k, v in map.items():
+ if k is None:
+ raise_(result_column_index)
+ elif key_checker is not None:
+ key_checker(k)
+ if v is None:
+ if not value_contains_null:
+ raise_(result_column_index)
+ elif value_checker is not None:
+ value_checker(v)
+
+ return check_map
+
+ elif isinstance(data_type, StructType):
+ field_checkers = [checker(f.dataType, result_column_index) for
f in data_type]
+ nullables = [f.nullable for f in data_type]
+
+ if all(c is None for c in field_checkers) and all(nullables):
+ return None
+
+ def check_struct(struct):
+ if isinstance(struct, tuple):
+ for value, checker, nullable in zip(struct,
field_checkers, nullables):
+ if value is None:
+ if not nullable:
+ raise_(result_column_index)
+ elif checker is not None:
+ checker(value)
+
+ return check_struct
+
+ else:
+ return None
+
+ field_checkers = [
+ checker(f.dataType, result_column_index=i) for i, f in
enumerate(return_type)
+ ]
+ nullables = [f.nullable for f in return_type]
+
+ if all(c is None for c in field_checkers) and all(nullables):
+ return None
+
+ def check(row):
+ if isinstance(row, tuple):
+ for i, (value, checker, nullable) in enumerate(zip(row,
field_checkers, nullables)):
+ if value is None:
+ if not nullable:
+ raise_(i)
+ elif checker is not None:
+ checker(value)
+
+ return check
+
+ check_output_row_against_schema = build_null_checker(return_type)
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
@@ -948,8 +989,6 @@ def read_udtf(pickleSer, infile, eval_type):
verify_pandas_result(
result, return_type, assign_cols_by_name=False,
truncate_return_schema=False
)
- for result_tuple in result.itertuples():
- check_output_row_against_schema(list(result_tuple),
return_type)
return result
# Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
@@ -965,28 +1004,36 @@ def read_udtf(pickleSer, infile, eval_type):
def check_return_value(res):
# Check whether the result of an arrow UDTF is iterable before
# using it to construct a pandas DataFrame.
- if res is not None and not isinstance(res, Iterable):
- raise PySparkRuntimeError(
- error_class="UDTF_RETURN_NOT_ITERABLE",
- message_parameters={
- "type": type(res).__name__,
- "func": f.__name__,
- },
- )
+ if res is not None:
+ if not isinstance(res, Iterable):
+ raise PySparkRuntimeError(
+ error_class="UDTF_RETURN_NOT_ITERABLE",
+ message_parameters={
+ "type": type(res).__name__,
+ "func": f.__name__,
+ },
+ )
+ if check_output_row_against_schema is not None:
+ for row in res:
+ if row is not None:
+ check_output_row_against_schema(row)
+ yield row
+ else:
+ yield from res
def evaluate(*args: pd.Series):
if len(args) == 0:
res = func()
- check_return_value(res)
- yield verify_result(pd.DataFrame(res)), arrow_return_type
+ yield
verify_result(pd.DataFrame(check_return_value(res))), arrow_return_type
else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
row_tuples = zip(*args)
for row in row_tuples:
res = func(*row)
- check_return_value(res)
- yield verify_result(pd.DataFrame(res)),
arrow_return_type
+ yield verify_result(
+ pd.DataFrame(check_return_value(res))
+ ), arrow_return_type
return evaluate
@@ -1043,8 +1090,8 @@ def read_udtf(pickleSer, infile, eval_type):
"func": f.__name__,
},
)
-
- check_output_row_against_schema(result, return_type)
+ if check_output_row_against_schema is not None:
+ check_output_row_against_schema(result)
return toInternal(result)
# Evaluate the function and return a tuple back to the executor.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]