This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 68c0f64dddc9 [SPARK-45523][PYTHON] Refactor the null-checking to have shortcuts 68c0f64dddc9 is described below commit 68c0f64dddc917be7d489f67fab06fcbfe500f0d Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Tue Oct 24 13:12:11 2023 -0700 [SPARK-45523][PYTHON] Refactor the null-checking to have shortcuts ### What changes were proposed in this pull request? This is a follow-up of apache/spark#43356. Refactor the null-checking to have shortcuts. ### Why are the changes needed? The null-check can have shortcuts for some cases. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43492 from ueshin/issues/SPARK-45523/nullcheck. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/worker.py | 211 +++++++++++++++++++++++++++++------------------ 1 file changed, 129 insertions(+), 82 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b1f59e1619fe..f6208032d9ac 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -24,7 +24,7 @@ import dataclasses import time from inspect import getfullargspec import json -from typing import Any, Callable, Iterable, Iterator +from typing import Any, Callable, Iterable, Iterator, Optional import faulthandler from pyspark.accumulators import _accumulatorRegistry @@ -58,7 +58,6 @@ from pyspark.sql.types import ( MapType, Row, StringType, - StructField, StructType, _create_row, _parse_datatype_json_string, @@ -700,7 +699,7 @@ def read_udtf(pickleSer, infile, eval_type): ) return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) - if not type(return_type) == StructType: + if not isinstance(return_type, StructType): raise PySparkRuntimeError( f"The return type of a UDTF must be a struct type, but got {type(return_type)}." ) @@ -845,70 +844,112 @@ def read_udtf(pickleSer, infile, eval_type): "the query again." ) - # This determines which result columns have nullable types. - def check_nullable_column(i: int, data_type: DataType, nullable: bool) -> None: - if not nullable: - nullable_columns.add(i) - elif isinstance(data_type, ArrayType): - check_nullable_column(i, data_type.elementType, data_type.containsNull) - elif isinstance(data_type, StructType): - for subfield in data_type.fields: - check_nullable_column(i, subfield.dataType, subfield.nullable) - elif isinstance(data_type, MapType): - check_nullable_column(i, data_type.valueType, data_type.valueContainsNull) - - nullable_columns: set[int] = set() - for i, field in enumerate(return_type.fields): - check_nullable_column(i, field.dataType, field.nullable) - - # Compares each UDTF output row against the output schema for this particular UDTF call, - # raising an error if the two are incompatible. - def check_output_row_against_schema(row: Any, expected_schema: StructType) -> None: - for result_column_index in nullable_columns: - - def check_for_none_in_non_nullable_column( - value: Any, data_type: DataType, nullable: bool - ) -> None: - if value is None and not nullable: - raise PySparkRuntimeError( - error_class="UDTF_EXEC_ERROR", - message_parameters={ - "method_name": "eval' or 'terminate", - "error": f"Column {result_column_index} within a returned row had a " - + "value of None, either directly or within array/struct/map " - + "subfields, but the corresponding column type was declared as " - + "non-nullable; please update the UDTF to return a non-None value at " - + "this location or otherwise declare the column type as nullable.", - }, - ) - elif ( - isinstance(data_type, ArrayType) - and isinstance(value, list) - and not data_type.containsNull - ): - for sub_value in value: - check_for_none_in_non_nullable_column( - sub_value, data_type.elementType, data_type.containsNull - ) - elif isinstance(data_type, StructType) and isinstance(value, Row): - for i in range(len(value)): - check_for_none_in_non_nullable_column( - value[i], data_type[i].dataType, data_type[i].nullable - ) - elif isinstance(data_type, MapType) and isinstance(value, dict): - for map_key, map_value in value.items(): - check_for_none_in_non_nullable_column( - map_key, data_type.keyType, nullable=False - ) - check_for_none_in_non_nullable_column( - map_value, data_type.valueType, data_type.valueContainsNull - ) + def build_null_checker(return_type: StructType) -> Optional[Callable[[Any], None]]: + def raise_(result_column_index): + raise PySparkRuntimeError( + error_class="UDTF_EXEC_ERROR", + message_parameters={ + "method_name": "eval' or 'terminate", + "error": f"Column {result_column_index} within a returned row had a " + + "value of None, either directly or within array/struct/map " + + "subfields, but the corresponding column type was declared as " + + "non-nullable; please update the UDTF to return a non-None value at " + + "this location or otherwise declare the column type as nullable.", + }, + ) - field: StructField = expected_schema[result_column_index] - if row is not None: - check_for_none_in_non_nullable_column( - list(row)[result_column_index], field.dataType, field.nullable - ) + def checker(data_type: DataType, result_column_index: int): + if isinstance(data_type, ArrayType): + element_checker = checker(data_type.elementType, result_column_index) + contains_null = data_type.containsNull + + if element_checker is None and contains_null: + return None + + def check_array(arr): + if isinstance(arr, list): + for e in arr: + if e is None: + if not contains_null: + raise_(result_column_index) + elif element_checker is not None: + element_checker(e) + + return check_array + + elif isinstance(data_type, MapType): + key_checker = checker(data_type.keyType, result_column_index) + value_checker = checker(data_type.valueType, result_column_index) + value_contains_null = data_type.valueContainsNull + + if value_checker is None and value_contains_null: + + def check_map(map): + if isinstance(map, dict): + for k, v in map.items(): + if k is None: + raise_(result_column_index) + elif key_checker is not None: + key_checker(k) + + else: + + def check_map(map): + if isinstance(map, dict): + for k, v in map.items(): + if k is None: + raise_(result_column_index) + elif key_checker is not None: + key_checker(k) + if v is None: + if not value_contains_null: + raise_(result_column_index) + elif value_checker is not None: + value_checker(v) + + return check_map + + elif isinstance(data_type, StructType): + field_checkers = [checker(f.dataType, result_column_index) for f in data_type] + nullables = [f.nullable for f in data_type] + + if all(c is None for c in field_checkers) and all(nullables): + return None + + def check_struct(struct): + if isinstance(struct, tuple): + for value, checker, nullable in zip(struct, field_checkers, nullables): + if value is None: + if not nullable: + raise_(result_column_index) + elif checker is not None: + checker(value) + + return check_struct + + else: + return None + + field_checkers = [ + checker(f.dataType, result_column_index=i) for i, f in enumerate(return_type) + ] + nullables = [f.nullable for f in return_type] + + if all(c is None for c in field_checkers) and all(nullables): + return None + + def check(row): + if isinstance(row, tuple): + for i, (value, checker, nullable) in enumerate(zip(row, field_checkers, nullables)): + if value is None: + if not nullable: + raise_(i) + elif checker is not None: + checker(value) + + return check + + check_output_row_against_schema = build_null_checker(return_type) if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: @@ -948,8 +989,6 @@ def read_udtf(pickleSer, infile, eval_type): verify_pandas_result( result, return_type, assign_cols_by_name=False, truncate_return_schema=False ) - for result_tuple in result.itertuples(): - check_output_row_against_schema(list(result_tuple), return_type) return result # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. @@ -965,28 +1004,36 @@ def read_udtf(pickleSer, infile, eval_type): def check_return_value(res): # Check whether the result of an arrow UDTF is iterable before # using it to construct a pandas DataFrame. - if res is not None and not isinstance(res, Iterable): - raise PySparkRuntimeError( - error_class="UDTF_RETURN_NOT_ITERABLE", - message_parameters={ - "type": type(res).__name__, - "func": f.__name__, - }, - ) + if res is not None: + if not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={ + "type": type(res).__name__, + "func": f.__name__, + }, + ) + if check_output_row_against_schema is not None: + for row in res: + if row is not None: + check_output_row_against_schema(row) + yield row + else: + yield from res def evaluate(*args: pd.Series): if len(args) == 0: res = func() - check_return_value(res) - yield verify_result(pd.DataFrame(res)), arrow_return_type + yield verify_result(pd.DataFrame(check_return_value(res))), arrow_return_type else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. row_tuples = zip(*args) for row in row_tuples: res = func(*row) - check_return_value(res) - yield verify_result(pd.DataFrame(res)), arrow_return_type + yield verify_result( + pd.DataFrame(check_return_value(res)) + ), arrow_return_type return evaluate @@ -1043,8 +1090,8 @@ def read_udtf(pickleSer, infile, eval_type): "func": f.__name__, }, ) - - check_output_row_against_schema(result, return_type) + if check_output_row_against_schema is not None: + check_output_row_against_schema(result) return toInternal(result) # Evaluate the function and return a tuple back to the executor. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org