This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5046b8c04ca [SPARK-38937][PYTHON] interpolate support param 
`limit_direction`
5046b8c04ca is described below

commit 5046b8c04cadca6605dd34b98b31850b643dfe45
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Apr 25 11:46:34 2022 +0900

    [SPARK-38937][PYTHON] interpolate support param `limit_direction`
    
    ### What changes were proposed in this pull request?
    interpolate support param `limit_direction`
    
    ### Why are the changes needed?
     `limit_direction` is supported in the pandas side
    
    ### Does this PR introduce _any_ user-facing change?
    yes, a new param is supported
    
    ### How was this patch tested?
    added ut
    
    Closes #36246 from zhengruifeng/linear_interpolate_support_limit_direction.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../pandas_on_spark/supported_pandas_api.rst       |  7 +--
 python/pyspark/pandas/frame.py                     | 18 +++++--
 python/pyspark/pandas/generic.py                   | 13 +++--
 python/pyspark/pandas/series.py                    | 63 ++++++++++++++++++----
 .../pyspark/pandas/tests/test_generic_functions.py | 33 ++++++------
 5 files changed, 97 insertions(+), 37 deletions(-)

diff --git 
a/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst 
b/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
index a975d4ec8cc..937f7a3f179 100644
--- a/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
+++ b/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
@@ -221,9 +221,7 @@ Supported DataFrame APIs
 | :func:`insert`                             | Y           |                   
                   |
 
+--------------------------------------------+-------------+--------------------------------------+
 | :func:`interpolate`                        | P           | ``axis``, 
``inplace``,               |
-|                                            |             | 
``limit_direction``, ``limit_area``, |
-|                                            |             | ``downcast``      
                   |
-|                                            |             |                   
                   |
+|                                            |             | ``limit_area``, 
``downcast``         |
 
+--------------------------------------------+-------------+--------------------------------------+
 | :func:`isin`                               | Y           |                   
                   |
 
+--------------------------------------------+-------------+--------------------------------------+
@@ -877,8 +875,7 @@ Supported Series APIs
 | infer_objects                   | N                 |                        
                   |
 
+---------------------------------+-------------------+-------------------------------------------+
 | :func:`interpolate`             | P                 | ``axis``, ``inplace``, 
                   |
-|                                 |                   | ``limit_direction``, 
``limit_area``,      |
-|                                 |                   | ``downcast``           
                   |
+|                                 |                   | ``limit_area``, 
``downcast``              |
 
+---------------------------------+-------------------+-------------------------------------------+
 | :func:`is_monotonic`            | Y                 |                        
                   |
 
+---------------------------------+-------------------+-------------------------------------------+
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index a78aaa66f08..c09fe029bd6 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -5500,11 +5500,20 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         else:
             return psdf
 
-    def interpolate(self, method: Optional[str] = None, limit: Optional[int] = 
None) -> "DataFrame":
-        if (method is not None) and (method not in ["linear"]):
+    def interpolate(
+        self,
+        method: str = "linear",
+        limit: Optional[int] = None,
+        limit_direction: Optional[str] = None,
+    ) -> "DataFrame":
+        if method not in ["linear"]:
             raise NotImplementedError("interpolate currently works only for 
method='linear'")
         if (limit is not None) and (not limit > 0):
             raise ValueError("limit must be > 0.")
+        if (limit_direction is not None) and (
+            limit_direction not in ["forward", "backward", "both"]
+        ):
+            raise ValueError("invalid limit_direction: 
'{}'".format(limit_direction))
 
         numeric_col_names = []
         for label in self._internal.column_labels:
@@ -5514,7 +5523,10 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
 
         psdf = self[numeric_col_names]
         return psdf._apply_series_op(
-            lambda psser: psser._interpolate(method=method, limit=limit), 
should_resolve=True
+            lambda psser: psser._interpolate(
+                method=method, limit=limit, limit_direction=limit_direction
+            ),
+            should_resolve=True,
         )
 
     def replace(
diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index 21c880373ad..1ce4671d696 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -3253,16 +3253,17 @@ class Frame(object, metaclass=ABCMeta):
 
     pad = ffill
 
-    # TODO: add 'axis', 'inplace', 'limit_direction', 'limit_area', 'downcast'
+    # TODO: add 'axis', 'inplace', 'limit_area', 'downcast'
     def interpolate(
         self: FrameLike,
-        method: Optional[str] = None,
+        method: str = "linear",
         limit: Optional[int] = None,
+        limit_direction: Optional[str] = None,
     ) -> FrameLike:
         """
         Fill NaN values using an interpolation method.
 
-        .. note:: the current implementation of rank uses Spark's Window 
without
+        .. note:: the current implementation of interpolate uses Spark's 
Window without
             specifying partition specification. This leads to move all data 
into
             single partition in single machine and could cause serious
             performance degradation. Avoid this method against very large 
dataset.
@@ -3281,6 +3282,10 @@ class Frame(object, metaclass=ABCMeta):
             Maximum number of consecutive NaNs to fill. Must be greater than
             0.
 
+        limit_direction : str, default None
+            Consecutive NaNs will be filled in this direction.
+            One of {{'forward', 'backward', 'both'}}.
+
         Returns
         -------
         Series or DataFrame or None
@@ -3335,7 +3340,7 @@ class Frame(object, metaclass=ABCMeta):
         2  2.0  3.0 -3.0   9.0
         3  2.0  4.0 -4.0  16.0
         """
-        return self.interpolate(method=method, limit=limit)
+        return self.interpolate(method=method, limit=limit, 
limit_direction=limit_direction)
 
     @property
     def at(self) -> AtIndexer:
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index ea3426d5a54..ef0208b3bbd 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -2169,18 +2169,28 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
             )
         )._psser_for(self._column_label)
 
-    def interpolate(self, method: Optional[str] = None, limit: Optional[int] = 
None) -> "Series":
-        return self._interpolate(method=method, limit=limit)
+    def interpolate(
+        self,
+        method: str = "linear",
+        limit: Optional[int] = None,
+        limit_direction: Optional[str] = None,
+    ) -> "Series":
+        return self._interpolate(method=method, limit=limit, 
limit_direction=limit_direction)
 
     def _interpolate(
         self,
-        method: Optional[str] = None,
+        method: str = "linear",
         limit: Optional[int] = None,
+        limit_direction: Optional[str] = None,
     ) -> "Series":
-        if (method is not None) and (method not in ["linear"]):
+        if method not in ["linear"]:
             raise NotImplementedError("interpolate currently works only for 
method='linear'")
         if (limit is not None) and (not limit > 0):
             raise ValueError("limit must be > 0.")
+        if (limit_direction is not None) and (
+            limit_direction not in ["forward", "backward", "both"]
+        ):
+            raise ValueError("invalid limit_direction: 
'{}'".format(limit_direction))
 
         if not self.spark.nullable and not isinstance(
             self.spark.data_type, (FloatType, DoubleType)
@@ -2209,15 +2219,50 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
         ) * null_index_forward + last_non_null_forward
 
         fill_cond = ~F.isnull(last_non_null_backward) & 
~F.isnull(last_non_null_forward)
-        pad_cond = F.isnull(last_non_null_backward) & 
~F.isnull(last_non_null_forward)
-        if limit is not None:
-            fill_cond = fill_cond & (null_index_forward <= F.lit(limit))
-            pad_cond = pad_cond & (null_index_forward <= F.lit(limit))
+
+        pad_head = SF.lit(None)
+        pad_head_cond = SF.lit(False)
+        pad_tail = SF.lit(None)
+        pad_tail_cond = SF.lit(False)
+
+        # inputs  -> NaN, NaN, 1.0, NaN, NaN, NaN, 5.0, NaN, NaN
+        if limit_direction is None or limit_direction == "forward":
+            # outputs -> NaN, NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0
+            pad_tail = last_non_null_forward
+            pad_tail_cond = F.isnull(last_non_null_backward) & 
~F.isnull(last_non_null_forward)
+            if limit is not None:
+                # outputs (limit=1) -> NaN, NaN, 1.0, 2.0, NaN, NaN, 5.0, 5.0, 
NaN
+                fill_cond = fill_cond & (null_index_forward <= F.lit(limit))
+                pad_tail_cond = pad_tail_cond & (null_index_forward <= 
F.lit(limit))
+
+        elif limit_direction == "backward":
+            # outputs -> 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, NaN, NaN
+            pad_head = last_non_null_backward
+            pad_head_cond = ~F.isnull(last_non_null_backward) & 
F.isnull(last_non_null_forward)
+            if limit is not None:
+                # outputs (limit=1) -> NaN, 1.0, 1.0, NaN, NaN, 4.0, 5.0, NaN, 
NaN
+                fill_cond = fill_cond & (null_index_backward <= F.lit(limit))
+                pad_head_cond = pad_head_cond & (null_index_backward <= 
F.lit(limit))
+
+        else:
+            # outputs -> 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0
+            pad_head = last_non_null_backward
+            pad_head_cond = ~F.isnull(last_non_null_backward) & 
F.isnull(last_non_null_forward)
+            pad_tail = last_non_null_forward
+            pad_tail_cond = F.isnull(last_non_null_backward) & 
~F.isnull(last_non_null_forward)
+            if limit is not None:
+                # outputs (limit=1) -> NaN, 1.0, 1.0, 2.0, NaN, 4.0, 5.0, 5.0, 
NaN
+                fill_cond = fill_cond & (
+                    (null_index_forward <= F.lit(limit)) | 
(null_index_backward <= F.lit(limit))
+                )
+                pad_head_cond = pad_head_cond & (null_index_backward <= 
F.lit(limit))
+                pad_tail_cond = pad_tail_cond & (null_index_forward <= 
F.lit(limit))
 
         cond = self.isnull().spark.column
         scol = (
             F.when(cond & fill_cond, fill)
-            .when(cond & pad_cond, last_non_null_forward)
+            .when(cond & pad_head_cond, pad_head)
+            .when(cond & pad_tail_cond, pad_tail)
             .otherwise(scol)
         )
 
diff --git a/python/pyspark/pandas/tests/test_generic_functions.py 
b/python/pyspark/pandas/tests/test_generic_functions.py
index e1c804e0550..3e4db6c86bc 100644
--- a/python/pyspark/pandas/tests/test_generic_functions.py
+++ b/python/pyspark/pandas/tests/test_generic_functions.py
@@ -33,17 +33,18 @@ class GenericFunctionsTest(PandasOnSparkTestCase, 
TestUtils):
         with self.assertRaisesRegex(ValueError, "limit must be > 0"):
             psdf.interpolate(limit=0)
 
-    def _test_series_interpolate(self, pser):
-        psser = ps.from_pandas(pser)
-        self.assert_eq(psser.interpolate(), pser.interpolate())
-        for l1 in range(1, 5):
-            self.assert_eq(psser.interpolate(limit=l1), 
pser.interpolate(limit=l1))
-
-    def _test_dataframe_interpolate(self, pdf):
-        psdf = ps.from_pandas(pdf)
-        self.assert_eq(psdf.interpolate(), pdf.interpolate())
-        for l2 in range(1, 5):
-            self.assert_eq(psdf.interpolate(limit=l2), 
pdf.interpolate(limit=l2))
+        with self.assertRaisesRegex(ValueError, "invalid limit_direction"):
+            psdf.interpolate(limit_direction="jump")
+
+    def _test_interpolate(self, pobj):
+        psobj = ps.from_pandas(pobj)
+        self.assert_eq(psobj.interpolate(), pobj.interpolate())
+        for limit in range(1, 5):
+            for limit_direction in [None, "forward", "backward", "both"]:
+                self.assert_eq(
+                    psobj.interpolate(limit=limit, 
limit_direction=limit_direction),
+                    pobj.interpolate(limit=limit, 
limit_direction=limit_direction),
+                )
 
     def test_interpolate(self):
         pser = pd.Series(
@@ -54,7 +55,7 @@ class GenericFunctionsTest(PandasOnSparkTestCase, TestUtils):
             ],
             name="a",
         )
-        self._test_series_interpolate(pser)
+        self._test_interpolate(pser)
 
         pser = pd.Series(
             [
@@ -64,7 +65,7 @@ class GenericFunctionsTest(PandasOnSparkTestCase, TestUtils):
             ],
             name="a",
         )
-        self._test_series_interpolate(pser)
+        self._test_interpolate(pser)
 
         pser = pd.Series(
             [
@@ -84,7 +85,7 @@ class GenericFunctionsTest(PandasOnSparkTestCase, TestUtils):
             ],
             name="a",
         )
-        self._test_series_interpolate(pser)
+        self._test_interpolate(pser)
 
         pdf = pd.DataFrame(
             [
@@ -96,7 +97,7 @@ class GenericFunctionsTest(PandasOnSparkTestCase, TestUtils):
             ],
             columns=list("abc"),
         )
-        self._test_dataframe_interpolate(pdf)
+        self._test_interpolate(pdf)
 
         pdf = pd.DataFrame(
             [
@@ -108,7 +109,7 @@ class GenericFunctionsTest(PandasOnSparkTestCase, 
TestUtils):
             ],
             columns=list("abcde"),
         )
-        self._test_dataframe_interpolate(pdf)
+        self._test_interpolate(pdf)
 
 
 if __name__ == "__main__":


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to