This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new a4dcda1  [SPARK-36350][PYTHON] Move some logic related to F.nanvl to 
DataTypeOps
a4dcda1 is described below

commit a4dcda179448cc7632f5022e26be3cafb3d82505
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Fri Jul 30 11:19:49 2021 -0700

    [SPARK-36350][PYTHON] Move some logic related to F.nanvl to DataTypeOps
    
    ### What changes were proposed in this pull request?
    
    Move some logic related to `F.nanvl` to `DataTypeOps`.
    
    ### Why are the changes needed?
    
    There are several places to branch by `FloatType` or `DoubleType` to use 
`F.nanvl` but `DataTypeOps` should handle it.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #33582 from ueshin/issues/SPARK-36350/nan_to_null.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
    (cherry picked from commit 895e3f5e2aff46f5c4eed8ac9ddf4dbfb16ef5fd)
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/pandas/data_type_ops/base.py    |  3 ++
 python/pyspark/pandas/data_type_ops/num_ops.py | 11 +++++
 python/pyspark/pandas/frame.py                 | 50 +++++++------------
 python/pyspark/pandas/generic.py               | 66 ++++++++++++++++----------
 python/pyspark/pandas/groupby.py               | 66 +++++++++++++-------------
 python/pyspark/pandas/series.py                | 22 +++------
 6 files changed, 112 insertions(+), 106 deletions(-)

diff --git a/python/pyspark/pandas/data_type_ops/base.py 
b/python/pyspark/pandas/data_type_ops/base.py
index 7eb2a95..743b2c5 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -366,5 +366,8 @@ class DataTypeOps(object, metaclass=ABCMeta):
             ),
         )
 
+    def nan_to_null(self, index_ops: IndexOpsLike) -> IndexOpsLike:
+        return index_ops.copy()
+
     def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) 
-> IndexOpsLike:
         raise TypeError("astype can not be applied to %s." % self.pretty_name)
diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py 
b/python/pyspark/pandas/data_type_ops/num_ops.py
index a7987bc..f84c1af 100644
--- a/python/pyspark/pandas/data_type_ops/num_ops.py
+++ b/python/pyspark/pandas/data_type_ops/num_ops.py
@@ -326,6 +326,14 @@ class FractionalOps(NumericOps):
             ),
         )
 
+    def nan_to_null(self, index_ops: IndexOpsLike) -> IndexOpsLike:
+        # Special handle floating point types because Spark's count treats nan 
as a valid value,
+        # whereas pandas count doesn't include nan.
+        return index_ops._with_new_scol(
+            F.nanvl(index_ops.spark.column, SF.lit(None)),
+            field=index_ops._internal.data_fields[0].copy(nullable=True),
+        )
+
     def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) 
-> IndexOpsLike:
         dtype, spark_type = pandas_on_spark_type(dtype)
 
@@ -385,6 +393,9 @@ class DecimalOps(FractionalOps):
             ),
         )
 
+    def nan_to_null(self, index_ops: IndexOpsLike) -> IndexOpsLike:
+        return index_ops.copy()
+
     def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) 
-> IndexOpsLike:
         # TODO(SPARK-36230): check index_ops.hasnans after fixing SPARK-36230
         dtype, spark_type = pandas_on_spark_type(dtype)
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index de675f1f..4a737b6 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -642,7 +642,7 @@ class DataFrame(Frame, Generic[T]):
 
     def _reduce_for_stat_function(
         self,
-        sfun: Union[Callable[[Column], Column], Callable[[Column, DataType], 
Column]],
+        sfun: Callable[["Series"], Column],
         name: str,
         axis: Optional[Axis] = None,
         numeric_only: bool = True,
@@ -664,7 +664,6 @@ class DataFrame(Frame, Generic[T]):
             is mainly for pandas compatibility. Only 'DataFrame.count' uses 
this parameter
             currently.
         """
-        from inspect import signature
         from pyspark.pandas.series import Series, first_series
 
         axis = validate_axis(axis)
@@ -673,29 +672,19 @@ class DataFrame(Frame, Generic[T]):
 
             exprs = 
[SF.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)]
             new_column_labels = []
-            num_args = len(signature(sfun).parameters)
             for label in self._internal.column_labels:
-                spark_column = self._internal.spark_column_for(label)
-                spark_type = self._internal.spark_type_for(label)
+                psser = self._psser_for(label)
 
-                is_numeric_or_boolean = isinstance(spark_type, (NumericType, 
BooleanType))
+                is_numeric_or_boolean = isinstance(
+                    psser.spark.data_type, (NumericType, BooleanType)
+                )
                 keep_column = not numeric_only or is_numeric_or_boolean
 
                 if keep_column:
-                    if num_args == 1:
-                        # Only pass in the column if sfun accepts only one arg
-                        scol = cast(Callable[[Column], Column], 
sfun)(spark_column)
-                    else:  # must be 2
-                        assert num_args == 2
-                        # Pass in both the column and its data type if sfun 
accepts two args
-                        scol = cast(Callable[[Column, DataType], Column], 
sfun)(
-                            spark_column, spark_type
-                        )
+                    scol = sfun(psser)
 
                     if min_count > 0:
-                        scol = F.when(
-                            Frame._count_expr(spark_column, spark_type) >= 
min_count, scol
-                        )
+                        scol = F.when(Frame._count_expr(psser) >= min_count, 
scol)
 
                     exprs.append(scol.alias(name_like_string(label)))
                     new_column_labels.append(label)
@@ -8485,16 +8474,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         exprs = []
         column_labels = []
         for label in self._internal.column_labels:
-            scol = self._internal.spark_column_for(label)
-            spark_type = self._internal.spark_type_for(label)
-            # TODO(SPARK-36350): Make this work with DataTypeOps.
-            if isinstance(spark_type, (FloatType, DoubleType)):
-                exprs.append(
-                    F.nanvl(scol, 
SF.lit(None)).alias(self._internal.spark_column_name_for(label))
-                )
-                column_labels.append(label)
-            elif isinstance(spark_type, NumericType):
-                exprs.append(scol)
+            psser = self._psser_for(label)
+            if isinstance(psser.spark.data_type, NumericType):
+                exprs.append(psser._dtype_op.nan_to_null(psser).spark.column)
                 column_labels.append(label)
 
         if len(exprs) == 0:
@@ -10813,7 +10795,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
             if v < 0.0 or v > 1.0:
                 raise ValueError("percentiles should all be in the interval 
[0, 1].")
 
-        def quantile(spark_column: Column, spark_type: DataType) -> Column:
+        def quantile(psser: "Series") -> Column:
+            spark_type = psser.spark.data_type
+            spark_column = psser.spark.column
             if isinstance(spark_type, (BooleanType, NumericType)):
                 return F.percentile_approx(spark_column.cast(DoubleType()), 
qq, accuracy)
             else:
@@ -10839,13 +10823,15 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
             for label, column in zip(
                 self._internal.column_labels, 
self._internal.data_spark_column_names
             ):
-                spark_type = self._internal.spark_type_for(label)
+                psser = self._psser_for(label)
 
-                is_numeric_or_boolean = isinstance(spark_type, (NumericType, 
BooleanType))
+                is_numeric_or_boolean = isinstance(
+                    psser.spark.data_type, (NumericType, BooleanType)
+                )
                 keep_column = not numeric_only or is_numeric_or_boolean
 
                 if keep_column:
-                    percentile_col = 
quantile(self._internal.spark_column_for(label), spark_type)
+                    percentile_col = quantile(psser)
                     percentile_cols.append(percentile_col.alias(column))
                     percentile_col_names.append(column)
                     column_labels.append(label)
diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index 7ec9b05..cdd8f67 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -44,9 +44,7 @@ from pandas.api.types import is_list_like
 from pyspark.sql import Column, functions as F
 from pyspark.sql.types import (
     BooleanType,
-    DataType,
     DoubleType,
-    FloatType,
     IntegralType,
     LongType,
     NumericType,
@@ -114,7 +112,7 @@ class Frame(object, metaclass=ABCMeta):
     @abstractmethod
     def _reduce_for_stat_function(
         self,
-        sfun: Union[Callable[[Column], Column], Callable[[Column, DataType], 
Column]],
+        sfun: Callable[["Series"], Column],
         name: str,
         axis: Optional[Axis] = None,
         numeric_only: bool = True,
@@ -1204,7 +1202,9 @@ class Frame(object, metaclass=ABCMeta):
         if numeric_only is None and axis == 0:
             numeric_only = True
 
-        def mean(spark_column: Column, spark_type: DataType) -> Column:
+        def mean(psser: "Series") -> Column:
+            spark_type = psser.spark.data_type
+            spark_column = psser.spark.column
             if isinstance(spark_type, BooleanType):
                 spark_column = spark_column.cast(LongType())
             elif not isinstance(spark_type, NumericType):
@@ -1289,7 +1289,9 @@ class Frame(object, metaclass=ABCMeta):
         elif numeric_only is True and axis == 1:
             numeric_only = None
 
-        def sum(spark_column: Column, spark_type: DataType) -> Column:
+        def sum(psser: "Series") -> Column:
+            spark_type = psser.spark.data_type
+            spark_column = psser.spark.column
             if isinstance(spark_type, BooleanType):
                 spark_column = spark_column.cast(LongType())
             elif not isinstance(spark_type, NumericType):
@@ -1373,7 +1375,9 @@ class Frame(object, metaclass=ABCMeta):
         elif numeric_only is True and axis == 1:
             numeric_only = None
 
-        def prod(spark_column: Column, spark_type: DataType) -> Column:
+        def prod(psser: "Series") -> Column:
+            spark_type = psser.spark.data_type
+            spark_column = psser.spark.column
             if isinstance(spark_type, BooleanType):
                 scol = F.min(F.coalesce(spark_column, 
SF.lit(True))).cast(LongType())
             elif isinstance(spark_type, NumericType):
@@ -1444,7 +1448,9 @@ class Frame(object, metaclass=ABCMeta):
         if numeric_only is None and axis == 0:
             numeric_only = True
 
-        def skew(spark_column: Column, spark_type: DataType) -> Column:
+        def skew(psser: "Series") -> Column:
+            spark_type = psser.spark.data_type
+            spark_column = psser.spark.column
             if isinstance(spark_type, BooleanType):
                 spark_column = spark_column.cast(LongType())
             elif not isinstance(spark_type, NumericType):
@@ -1501,7 +1507,9 @@ class Frame(object, metaclass=ABCMeta):
         if numeric_only is None and axis == 0:
             numeric_only = True
 
-        def kurtosis(spark_column: Column, spark_type: DataType) -> Column:
+        def kurtosis(psser: "Series") -> Column:
+            spark_type = psser.spark.data_type
+            spark_column = psser.spark.column
             if isinstance(spark_type, BooleanType):
                 spark_column = spark_column.cast(LongType())
             elif not isinstance(spark_type, NumericType):
@@ -1570,7 +1578,10 @@ class Frame(object, metaclass=ABCMeta):
             numeric_only = None
 
         return self._reduce_for_stat_function(
-            F.min, name="min", axis=axis, numeric_only=numeric_only
+            lambda psser: F.min(psser.spark.column),
+            name="min",
+            axis=axis,
+            numeric_only=numeric_only,
         )
 
     def max(
@@ -1625,7 +1636,10 @@ class Frame(object, metaclass=ABCMeta):
             numeric_only = None
 
         return self._reduce_for_stat_function(
-            F.max, name="max", axis=axis, numeric_only=numeric_only
+            lambda psser: F.max(psser.spark.column),
+            name="max",
+            axis=axis,
+            numeric_only=numeric_only,
         )
 
     def count(
@@ -1763,7 +1777,9 @@ class Frame(object, metaclass=ABCMeta):
         if numeric_only is None and axis == 0:
             numeric_only = True
 
-        def std(spark_column: Column, spark_type: DataType) -> Column:
+        def std(psser: "Series") -> Column:
+            spark_type = psser.spark.data_type
+            spark_column = psser.spark.column
             if isinstance(spark_type, BooleanType):
                 spark_column = spark_column.cast(LongType())
             elif not isinstance(spark_type, NumericType):
@@ -1842,7 +1858,9 @@ class Frame(object, metaclass=ABCMeta):
         if numeric_only is None and axis == 0:
             numeric_only = True
 
-        def var(spark_column: Column, spark_type: DataType) -> Column:
+        def var(psser: "Series") -> Column:
+            spark_type = psser.spark.data_type
+            spark_column = psser.spark.column
             if isinstance(spark_type, BooleanType):
                 spark_column = spark_column.cast(LongType())
             elif not isinstance(spark_type, NumericType):
@@ -1955,7 +1973,9 @@ class Frame(object, metaclass=ABCMeta):
                 "accuracy must be an integer; however, got [%s]" % 
type(accuracy).__name__
             )
 
-        def median(spark_column: Column, spark_type: DataType) -> Column:
+        def median(psser: "Series") -> Column:
+            spark_type = psser.spark.data_type
+            spark_column = psser.spark.column
             if isinstance(spark_type, (BooleanType, NumericType)):
                 return F.percentile_approx(spark_column.cast(DoubleType()), 
0.5, accuracy)
             else:
@@ -2037,7 +2057,9 @@ class Frame(object, metaclass=ABCMeta):
         if numeric_only is None and axis == 0:
             numeric_only = True
 
-        def std(spark_column: Column, spark_type: DataType) -> Column:
+        def std(psser: "Series") -> Column:
+            spark_type = psser.spark.data_type
+            spark_column = psser.spark.column
             if isinstance(spark_type, BooleanType):
                 spark_column = spark_column.cast(LongType())
             elif not isinstance(spark_type, NumericType):
@@ -2051,10 +2073,8 @@ class Frame(object, metaclass=ABCMeta):
             else:
                 return F.stddev_samp(spark_column)
 
-        def sem(spark_column: Column, spark_type: DataType) -> Column:
-            return std(spark_column, spark_type) / pow(
-                Frame._count_expr(spark_column, spark_type), 0.5
-            )
+        def sem(psser: "Series") -> Column:
+            return std(psser) / pow(Frame._count_expr(psser), 0.5)
 
         return self._reduce_for_stat_function(
             sem, name="sem", numeric_only=numeric_only, axis=axis, ddof=ddof
@@ -3180,14 +3200,8 @@ class Frame(object, metaclass=ABCMeta):
         )
 
     @staticmethod
-    def _count_expr(spark_column: Column, spark_type: DataType) -> Column:
-        # Special handle floating point types because Spark's count treats nan 
as a valid value,
-        # whereas pandas count doesn't include nan.
-        # TODO(SPARK-36350): Make this work with DataTypeOps.
-        if isinstance(spark_type, (FloatType, DoubleType)):
-            return F.count(F.nanvl(spark_column, SF.lit(None)))
-        else:
-            return F.count(spark_column)
+    def _count_expr(psser: "Series") -> Column:
+        return F.count(psser._dtype_op.nan_to_null(psser).spark.column)
 
 
 def _test() -> None:
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 9356be8..7798e9c 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -2524,40 +2524,18 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
     def _reduce_for_stat_function(
         self, sfun: Callable[[Column], Column], only_numeric: bool
     ) -> FrameLike:
-        agg_columns = self._agg_columns
-        agg_columns_scols = self._agg_columns_scols
-
         groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in 
range(len(self._groupkeys))]
         groupkey_scols = [s.alias(name) for s, name in 
zip(self._groupkeys_scols, groupkey_names)]
 
-        sdf = self._psdf._internal.spark_frame.select(groupkey_scols + 
agg_columns_scols)
+        agg_columns = [
+            psser
+            for psser in self._agg_columns
+            if isinstance(psser.spark.data_type, NumericType) or not 
only_numeric
+        ]
 
-        data_columns = []
-        column_labels = []
-        if len(agg_columns) > 0:
-            stat_exprs = []
-            for psser in agg_columns:
-                spark_type = psser.spark.data_type
-                name = psser._internal.data_spark_column_names[0]
-                label = psser._column_label
-                scol = scol_for(sdf, name)
-                # TODO: we should have a function that takes dataframes and 
converts the numeric
-                # types. Converting the NaNs is used in a few places, it 
should be in utils.
-                # Special handle floating point types because Spark's count 
treats nan as a valid
-                # value, whereas pandas count doesn't include nan.
-
-                # TODO(SPARK-36350): Make this work with DataTypeOps.
-                if isinstance(spark_type, (FloatType, DoubleType)):
-                    stat_exprs.append(sfun(F.nanvl(scol, 
SF.lit(None))).alias(name))
-                    data_columns.append(name)
-                    column_labels.append(label)
-                elif isinstance(spark_type, NumericType) or not only_numeric:
-                    stat_exprs.append(sfun(scol).alias(name))
-                    data_columns.append(name)
-                    column_labels.append(label)
-            sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
-        else:
-            sdf = sdf.select(*groupkey_names).distinct()
+        sdf = self._psdf._internal.spark_frame.select(
+            *groupkey_scols, *[psser.spark.column for psser in agg_columns]
+        )
 
         internal = InternalFrame(
             spark_frame=sdf,
@@ -2567,12 +2545,36 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
                 psser._internal.data_fields[0].copy(name=name)
                 for psser, name in zip(self._groupkeys, groupkey_names)
             ],
-            column_labels=column_labels,
-            data_spark_columns=[scol_for(sdf, col) for col in data_columns],
+            data_spark_columns=[
+                scol_for(sdf, psser._internal.data_spark_column_names[0]) for 
psser in agg_columns
+            ],
+            column_labels=[psser._column_label for psser in agg_columns],
+            data_fields=[psser._internal.data_fields[0] for psser in 
agg_columns],
             column_label_names=self._psdf._internal.column_label_names,
         )
         psdf = DataFrame(internal)  # type: DataFrame
 
+        if len(psdf._internal.column_labels) > 0:
+            stat_exprs = []
+            for label in psdf._internal.column_labels:
+                psser = psdf._psser_for(label)
+                stat_exprs.append(
+                    
sfun(psser._dtype_op.nan_to_null(psser).spark.column).alias(
+                        psser._internal.data_spark_column_names[0]
+                    )
+                )
+            sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
+        else:
+            sdf = sdf.select(*groupkey_names).distinct()
+
+        internal = internal.copy(
+            spark_frame=sdf,
+            index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
+            data_spark_columns=[scol_for(sdf, col) for col in 
internal.data_spark_column_names],
+            data_fields=None,
+        )
+        psdf = DataFrame(internal)
+
         if self._dropna:
             psdf = DataFrame(
                 psdf._internal.with_new_sdf(
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 2d309b6..bf1727d 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -54,7 +54,6 @@ from pyspark.sql import functions as F, Column, DataFrame as 
SparkDataFrame
 from pyspark.sql.types import (
     ArrayType,
     BooleanType,
-    DataType,
     DecimalType,
     DoubleType,
     FloatType,
@@ -3453,7 +3452,9 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
             if q_float < 0.0 or q_float > 1.0:
                 raise ValueError("percentiles should all be in the interval 
[0, 1].")
 
-            def quantile(spark_column: Column, spark_type: DataType) -> Column:
+            def quantile(psser: Series) -> Column:
+                spark_type = psser.spark.data_type
+                spark_column = psser.spark.column
                 if isinstance(spark_type, (BooleanType, NumericType)):
                     return 
F.percentile_approx(spark_column.cast(DoubleType()), q_float, accuracy)
                 else:
@@ -6186,7 +6187,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
 
     def _reduce_for_stat_function(
         self,
-        sfun: Union[Callable[[Column], Column], Callable[[Column, DataType], 
Column]],
+        sfun: Callable[["Series"], Column],
         name: str_type,
         axis: Optional[Axis] = None,
         numeric_only: bool = True,
@@ -6202,26 +6203,15 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
         axis : used only for sanity check because series only support index 
axis.
         numeric_only : not used by this implementation, but passed down by 
stats functions
         """
-        from inspect import signature
-
         axis = validate_axis(axis)
         if axis == 1:
             raise ValueError("Series does not support columns axis.")
-        num_args = len(signature(sfun).parameters)
-        spark_column = self.spark.column
-        spark_type = self.spark.data_type
 
-        if num_args == 1:
-            # Only pass in the column if sfun accepts only one arg
-            scol = cast(Callable[[Column], Column], sfun)(spark_column)
-        else:  # must be 2
-            assert num_args == 2
-            # Pass in both the column and its data type if sfun accepts two 
args
-            scol = cast(Callable[[Column, DataType], Column], 
sfun)(spark_column, spark_type)
+        scol = sfun(self)
 
         min_count = kwargs.get("min_count", 0)
         if min_count > 0:
-            scol = F.when(Frame._count_expr(spark_column, spark_type) >= 
min_count, scol)
+            scol = F.when(Frame._count_expr(self) >= min_count, scol)
 
         result = unpack_scalar(self._internal.spark_frame.select(scol))
         return result if result is not None else np.nan

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to