This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e05959e1900 [SPARK-44717][PYTHON][PS] Respect TimestampNTZ in 
resampling
e05959e1900 is described below

commit e05959e1900cc687f61b794da47a1516d9baf66b
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Wed Aug 9 11:03:40 2023 +0900

    [SPARK-44717][PYTHON][PS] Respect TimestampNTZ in resampling
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to respect `TimestampNTZ` type in resampling at pandas API 
on Spark.
    
    ### Why are the changes needed?
    
    It still operates as if the timestamps are `TIMESTAMP_LTZ` even when 
`spark.sql.timestampType` is set to `TIMESTAMP_NTZ`, which is unexpected.
    
    ### Does this PR introduce _any_ user-facing change?
    
    This fixes a bug so end users can use exactly same behaviour with pandas 
with `TimestampNTZType` - pandas does not respect the local timezone with DST. 
While we might need to follow this even for `TimestampType`, this PR does not 
address the case as it might be controversial.
    
    ### How was this patch tested?
    
    Unittest was added.
    
    Closes #42392 from HyukjinKwon/SPARK-44717.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/pandas/frame.py                     |  4 +-
 python/pyspark/pandas/resample.py                  | 43 +++++++++++++-------
 .../pandas/tests/connect/test_parity_resample.py   | 12 +++++-
 python/pyspark/pandas/tests/test_resample.py       | 47 ++++++++++++++++++++++
 4 files changed, 88 insertions(+), 18 deletions(-)

diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 72d4a88b692..65c43eb7cf4 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -13155,7 +13155,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
 
         if on is None and not isinstance(self.index, DatetimeIndex):
             raise NotImplementedError("resample currently works only for 
DatetimeIndex")
-        if on is not None and not isinstance(as_spark_type(on.dtype), 
TimestampType):
+        if on is not None and not isinstance(
+            as_spark_type(on.dtype), (TimestampType, TimestampNTZType)
+        ):
             raise NotImplementedError("`on` currently works only for 
TimestampType")
 
         agg_columns: List[ps.Series] = []
diff --git a/python/pyspark/pandas/resample.py 
b/python/pyspark/pandas/resample.py
index c6c6019c07e..30f8c9d3169 100644
--- a/python/pyspark/pandas/resample.py
+++ b/python/pyspark/pandas/resample.py
@@ -46,7 +46,8 @@ from pyspark.sql import Column, functions as F
 from pyspark.sql.types import (
     NumericType,
     StructField,
-    TimestampType,
+    TimestampNTZType,
+    DataType,
 )
 
 from pyspark import pandas as ps  # For running doctests and reference 
resolution in PyCharm.
@@ -130,6 +131,13 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
         else:
             return self._resamplekey.spark.column
 
+    @property
+    def _resamplekey_type(self) -> DataType:
+        if self._resamplekey is None:
+            return self._psdf.index.spark.data_type
+        else:
+            return self._resamplekey.spark.data_type
+
     @property
     def _agg_columns_scols(self) -> List[Column]:
         return [s.spark.column for s in self._agg_columns]
@@ -154,7 +162,8 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
             col = col._jc if isinstance(col, Column) else F.lit(col)._jc
             return sql_utils.makeInterval(unit, col)
 
-    def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
+    def _bin_timestamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
+        key_type = self._resamplekey_type
         origin_scol = F.lit(origin)
         (rule_code, n) = (self._offset.rule_code, getattr(self._offset, "n"))
         left_closed, right_closed = (self._closed == "left", self._closed == 
"right")
@@ -188,7 +197,7 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
                     F.year(ts_scol) - (mod - n)
                 )
 
-            return F.to_timestamp(
+            ret = F.to_timestamp(
                 F.make_date(
                     F.when(edge_cond, edge_label).otherwise(non_edge_label), 
F.lit(12), F.lit(31)
                 )
@@ -227,7 +236,7 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
                     truncated_ts_scol - self.get_make_interval("MONTH", mod - 
n)
                 )
 
-            return F.to_timestamp(
+            ret = F.to_timestamp(
                 F.last_day(F.when(edge_cond, 
edge_label).otherwise(non_edge_label))
             )
 
@@ -242,15 +251,15 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
                 )
 
                 if left_closed and left_labeled:
-                    return F.date_trunc("DAY", ts_scol)
+                    ret = F.date_trunc("DAY", ts_scol)
                 elif left_closed and right_labeled:
-                    return F.date_trunc("DAY", F.date_add(ts_scol, 1))
+                    ret = F.date_trunc("DAY", F.date_add(ts_scol, 1))
                 elif right_closed and left_labeled:
-                    return F.when(edge_cond, F.date_trunc("DAY", 
F.date_sub(ts_scol, 1))).otherwise(
+                    ret = F.when(edge_cond, F.date_trunc("DAY", 
F.date_sub(ts_scol, 1))).otherwise(
                         F.date_trunc("DAY", ts_scol)
                     )
                 else:
-                    return F.when(edge_cond, F.date_trunc("DAY", 
ts_scol)).otherwise(
+                    ret = F.when(edge_cond, F.date_trunc("DAY", 
ts_scol)).otherwise(
                         F.date_trunc("DAY", F.date_add(ts_scol, 1))
                     )
 
@@ -272,13 +281,15 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
                 else:
                     non_edge_label = F.date_sub(truncated_ts_scol, mod - n)
 
-                return F.when(edge_cond, edge_label).otherwise(non_edge_label)
+                ret = F.when(edge_cond, edge_label).otherwise(non_edge_label)
 
         elif rule_code in ["H", "T", "S"]:
             unit_mapping = {"H": "HOUR", "T": "MINUTE", "S": "SECOND"}
             unit_str = unit_mapping[rule_code]
 
             truncated_ts_scol = F.date_trunc(unit_str, ts_scol)
+            if isinstance(key_type, TimestampNTZType):
+                truncated_ts_scol = F.to_timestamp_ntz(truncated_ts_scol)
             diff = timestampdiff(unit_str, origin_scol, truncated_ts_scol)
             mod = F.lit(0) if n == 1 else (diff % F.lit(n))
 
@@ -307,11 +318,16 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
                     truncated_ts_scol + self.get_make_interval(unit_str, n),
                 ).otherwise(truncated_ts_scol - 
self.get_make_interval(unit_str, mod - n))
 
-            return F.when(edge_cond, edge_label).otherwise(non_edge_label)
+            ret = F.when(edge_cond, edge_label).otherwise(non_edge_label)
 
         else:
             raise ValueError("Got the unexpected unit {}".format(rule_code))
 
+        if isinstance(key_type, TimestampNTZType):
+            return F.to_timestamp_ntz(ret)
+        else:
+            return ret
+
     def _downsample(self, f: str) -> DataFrame:
         """
         Downsample the defined function.
@@ -374,12 +390,9 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
         bin_col_label = verify_temp_column_name(self._psdf, bin_col_name)
         bin_col_field = InternalField(
             dtype=np.dtype("datetime64[ns]"),
-            struct_field=StructField(bin_col_name, TimestampType(), True),
-        )
-        bin_scol = self._bin_time_stamp(
-            ts_origin,
-            self._resamplekey_scol,
+            struct_field=StructField(bin_col_name, self._resamplekey_type, 
True),
         )
+        bin_scol = self._bin_timestamp(ts_origin, self._resamplekey_scol)
 
         agg_columns = [
             psser for psser in self._agg_columns if 
(isinstance(psser.spark.data_type, NumericType))
diff --git a/python/pyspark/pandas/tests/connect/test_parity_resample.py 
b/python/pyspark/pandas/tests/connect/test_parity_resample.py
index e5957cc9b4a..d5c901f113a 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_resample.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_resample.py
@@ -16,17 +16,25 @@
 #
 import unittest
 
-from pyspark.pandas.tests.test_resample import ResampleTestsMixin
+from pyspark.pandas.tests.test_resample import ResampleTestsMixin, 
ResampleWithTimezoneMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
 
 
-class ResampleTestsParityMixin(
+class ResampleParityTests(
     ResampleTestsMixin, PandasOnSparkTestUtils, TestUtils, 
ReusedConnectTestCase
 ):
     pass
 
 
+class ResampleWithTimezoneTests(
+    ResampleWithTimezoneMixin, PandasOnSparkTestUtils, TestUtils, 
ReusedConnectTestCase
+):
+    @unittest.skip("SPARK-44731: Support 'spark.sql.timestampType' in Python 
Spark Connect client")
+    def test_series_resample_with_timezone(self):
+        super().test_series_resample_with_timezone()
+
+
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.test_parity_resample import *  # noqa: 
F401
 
diff --git a/python/pyspark/pandas/tests/test_resample.py 
b/python/pyspark/pandas/tests/test_resample.py
index 0650fc40448..40614025907 100644
--- a/python/pyspark/pandas/tests/test_resample.py
+++ b/python/pyspark/pandas/tests/test_resample.py
@@ -19,6 +19,8 @@
 import unittest
 import inspect
 import datetime
+import os
+
 import numpy as np
 import pandas as pd
 
@@ -283,10 +285,55 @@ class ResampleTestsMixin:
         )
 
 
+class ResampleWithTimezoneMixin:
+    timezone = None
+
+    @classmethod
+    def setUpClass(cls):
+        cls.timezone = os.environ.get("TZ", None)
+        os.environ["TZ"] = "America/New_York"
+        super(ResampleWithTimezoneMixin, cls).setUpClass()
+
+    @classmethod
+    def tearDownClass(cls):
+        super(ResampleWithTimezoneMixin, cls).tearDownClass()
+        if cls.timezone is not None:
+            os.environ["TZ"] = cls.timezone
+
+    @property
+    def pdf(self):
+        np.random.seed(22)
+        index = pd.date_range(start="2011-01-02", end="2022-05-01", freq="1D")
+        return pd.DataFrame(np.random.rand(len(index), 2), index=index, 
columns=list("AB"))
+
+    @property
+    def psdf(self):
+        return ps.from_pandas(self.pdf)
+
+    def test_series_resample_with_timezone(self):
+        with self.sql_conf(
+            {
+                "spark.sql.session.timeZone": "Asia/Seoul",
+                "spark.sql.timestampType": "TIMESTAMP_NTZ",
+            }
+        ):
+            p_resample = self.pdf.resample(rule="1001H", closed="right", 
label="right")
+            ps_resample = self.psdf.resample(rule="1001H", closed="right", 
label="right")
+            self.assert_eq(
+                p_resample.sum().sort_index(),
+                ps_resample.sum().sort_index(),
+                almost=True,
+            )
+
+
 class ResampleTests(ResampleTestsMixin, PandasOnSparkTestCase, TestUtils):
     pass
 
 
+class ResampleWithTimezoneTests(ResampleWithTimezoneMixin, 
PandasOnSparkTestCase, TestUtils):
+    pass
+
+
 if __name__ == "__main__":
     from pyspark.pandas.tests.test_resample import *  # noqa: F401
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to