This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 07ed82b  [SPARK-36333][PYTHON] Reuse isnull where the null check is 
needed
07ed82b is described below

commit 07ed82be0bf7d24a41516e831a066c4c99da4efc
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Thu Jul 29 15:33:11 2021 -0700

    [SPARK-36333][PYTHON] Reuse isnull where the null check is needed
    
    ### What changes were proposed in this pull request?
    
    Reuse `IndexOpsMixin.isnull()` where the null check is needed.
    
    ### Why are the changes needed?
    
    There are some places where we can reuse `IndexOpsMixin.isnull()` instead 
of directly using Spark `Column`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #33562 from ueshin/issues/SPARK-36333/reuse_isnull.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/pandas/base.py    | 16 +++-------------
 python/pyspark/pandas/frame.py   | 12 ++++++------
 python/pyspark/pandas/generic.py |  1 +
 python/pyspark/pandas/groupby.py |  4 +++-
 python/pyspark/pandas/series.py  | 12 ++++++------
 5 files changed, 19 insertions(+), 26 deletions(-)

diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py
index 832d7e8..58f6c19 100644
--- a/python/pyspark/pandas/base.py
+++ b/python/pyspark/pandas/base.py
@@ -27,11 +27,7 @@ import numpy as np
 import pandas as pd  # noqa: F401
 from pandas.api.types import is_list_like, CategoricalDtype
 from pyspark.sql import functions as F, Column, Window
-from pyspark.sql.types import (
-    DoubleType,
-    FloatType,
-    LongType,
-)
+from pyspark.sql.types import LongType
 
 from pyspark import pandas as ps  # For running doctests and reference 
resolution in PyCharm.
 from pyspark.pandas._typing import Axis, Dtype, IndexOpsLike, Label, 
SeriesOrIndex
@@ -1622,15 +1618,9 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
         if len(kvs) == 0:  # uniques are all missing values
             new_scol = SF.lit(na_sentinel_code)
         else:
-            scol = self.spark.column
-            if isinstance(self.spark.data_type, (FloatType, DoubleType)):
-                cond = scol.isNull() | F.isnan(scol)
-            else:
-                cond = scol.isNull()
             map_scol = F.create_map(*kvs)
-
-            null_scol = F.when(cond, SF.lit(na_sentinel_code))
-            new_scol = null_scol.otherwise(map_scol[scol])
+            null_scol = F.when(self.isnull().spark.column, 
SF.lit(na_sentinel_code))
+            new_scol = null_scol.otherwise(map_scol[self.spark.column])
 
         codes = 
self._with_new_scol(new_scol.alias(self._internal.data_spark_column_names[0]))
 
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index faacd7a..af8b5ad 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -5313,13 +5313,12 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
 
                 internal = internal.with_filter(cond)
 
+            psdf = DataFrame(internal)
+
             null_counts = []
             for label in internal.column_labels:
-                scol = internal.spark_column_for(label)
-                if isinstance(internal.spark_type_for(label), (FloatType, 
DoubleType)):
-                    cond = scol.isNull() | F.isnan(scol)
-                else:
-                    cond = scol.isNull()
+                psser = psdf._psser_for(label)
+                cond = psser.isnull().spark.column
                 null_counts.append(
                     F.sum(F.when(~cond, 
1).otherwise(0)).alias(name_like_string(label))
                 )
@@ -8477,7 +8476,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         for label in self._internal.column_labels:
             scol = self._internal.spark_column_for(label)
             spark_type = self._internal.spark_type_for(label)
-            if isinstance(spark_type, DoubleType) or isinstance(spark_type, 
FloatType):
+            # TODO(SPARK-36350): Make this work with DataTypeOps.
+            if isinstance(spark_type, (FloatType, DoubleType)):
                 exprs.append(
                     F.nanvl(scol, 
SF.lit(None)).alias(self._internal.spark_column_name_for(label))
                 )
diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index 30fc008..6ed83d0 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -3180,6 +3180,7 @@ class Frame(object, metaclass=ABCMeta):
     def _count_expr(spark_column: Column, spark_type: DataType) -> Column:
         # Special handle floating point types because Spark's count treats nan 
as a valid value,
         # whereas pandas count doesn't include nan.
+        # TODO(SPARK-36350): Make this work with DataTypeOps.
         if isinstance(spark_type, (FloatType, DoubleType)):
             return F.count(F.nanvl(spark_column, SF.lit(None)))
         else:
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index c91fcd7..9356be8 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -2545,7 +2545,9 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
                 # types. Converting the NaNs is used in a few places, it 
should be in utils.
                 # Special handle floating point types because Spark's count 
treats nan as a valid
                 # value, whereas pandas count doesn't include nan.
-                if isinstance(spark_type, DoubleType) or 
isinstance(spark_type, FloatType):
+
+                # TODO(SPARK-36350): Make this work with DataTypeOps.
+                if isinstance(spark_type, (FloatType, DoubleType)):
                     stat_exprs.append(sfun(F.nanvl(scol, 
SF.lit(None))).alias(name))
                     data_columns.append(name)
                     column_labels.append(label)
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 442fbae..70f6a96 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -1915,12 +1915,12 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
 
         scol = self.spark.column
 
-        if isinstance(self.spark.data_type, (FloatType, DoubleType)):
-            cond = scol.isNull() | F.isnan(scol)
-        else:
-            if not self.spark.nullable:
-                return self._psdf.copy()._psser_for(self._column_label)
-            cond = scol.isNull()
+        if not self.spark.nullable and not isinstance(
+            self.spark.data_type, (FloatType, DoubleType)
+        ):
+            return self._psdf.copy()._psser_for(self._column_label)
+
+        cond = self.isnull().spark.column
 
         if value is not None:
             if not isinstance(value, (float, int, str, bool)):

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to