This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e500c78dea16 [SPARK-55843][PS] Handle the unit of datetime64 and
timedelta64 dtypes
e500c78dea16 is described below
commit e500c78dea1603a3df7f6aa7b522980ccff3c0c4
Author: Takuya Ueshin <[email protected]>
AuthorDate: Wed Mar 4 17:50:40 2026 -0800
[SPARK-55843][PS] Handle the unit of datetime64 and timedelta64 dtypes
### What changes were proposed in this pull request?
Handles the unit of `datetime64` and `timedelta64` dtypes.
### Why are the changes needed?
In pandas 3, the unit of `datetime64` and `timedelta64` dtypes are handled
more strictly.
### Does this PR introduce _any_ user-facing change?
Yes, it will behave more like pandas 3.
### How was this patch tested?
The existing tests should pass.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54629 from ueshin/issues/SPARK-55843/datetime.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
---
python/pyspark/pandas/datetimes.py | 32 ++++++++--
python/pyspark/pandas/indexes/datetimes.py | 5 +-
python/pyspark/pandas/indexes/timedelta.py | 5 +-
python/pyspark/pandas/namespace.py | 32 +++++++---
.../pyspark/pandas/tests/indexes/test_datetime.py | 73 ++++++++++++----------
python/pyspark/pandas/tests/series/test_stat.py | 2 +-
python/pyspark/pandas/tests/test_namespace.py | 8 +--
python/pyspark/pandas/tests/test_typedef.py | 7 ++-
python/pyspark/pandas/typedef/typehints.py | 45 +++++++++----
python/pyspark/sql/pandas/types.py | 17 ++++-
10 files changed, 153 insertions(+), 73 deletions(-)
diff --git a/python/pyspark/pandas/datetimes.py
b/python/pyspark/pandas/datetimes.py
index 2e0f1c1c953b..cc07ed278233 100644
--- a/python/pyspark/pandas/datetimes.py
+++ b/python/pyspark/pandas/datetimes.py
@@ -21,14 +21,16 @@ Date/Time related functions on pandas-on-Spark Series
from typing import Any, Optional, Union, no_type_check
import numpy as np
-import pandas as pd # noqa: F401
+import pandas as pd
from pandas.tseries.offsets import DateOffset
import pyspark.pandas as ps
+from pyspark.loose_version import LooseVersion
import pyspark.sql.functions as F
from pyspark.sql.types import DateType, TimestampType, TimestampNTZType,
IntegerType
from pyspark.pandas import DataFrame
from pyspark.pandas.config import option_context
+from pyspark.pandas._typing import Dtype
class DatetimeMethods:
@@ -603,8 +605,13 @@ class DatetimeMethods:
2 2012-03-31
dtype: datetime64[ns]
"""
+ ret_dtype: Union[type, Dtype]
+ if LooseVersion(pd.__version__) < "3.0.0":
+ ret_dtype = np.datetime64
+ else:
+ ret_dtype = self._data.dtype
- def pandas_normalize(s) -> ps.Series[np.datetime64]: # type:
ignore[no-untyped-def]
+ def pandas_normalize(s) -> ps.Series[ret_dtype]: # type:
ignore[no-untyped-def, valid-type]
return s.dt.normalize()
return self._data.pandas_on_spark.transform_batch(pandas_normalize)
@@ -706,8 +713,13 @@ class DatetimeMethods:
2 2018-01-01 12:00:00
dtype: datetime64[ns]
"""
+ ret_dtype: Union[type, Dtype]
+ if LooseVersion(pd.__version__) < "3.0.0":
+ ret_dtype = np.datetime64
+ else:
+ ret_dtype = self._data.dtype
- def pandas_round(s) -> ps.Series[np.datetime64]: # type:
ignore[no-untyped-def]
+ def pandas_round(s) -> ps.Series[ret_dtype]: # type:
ignore[no-untyped-def, valid-type]
return s.dt.round(freq, *args, **kwargs)
return self._data.pandas_on_spark.transform_batch(pandas_round)
@@ -761,8 +773,13 @@ class DatetimeMethods:
2 2018-01-01 12:00:00
dtype: datetime64[ns]
"""
+ ret_dtype: Union[type, Dtype]
+ if LooseVersion(pd.__version__) < "3.0.0":
+ ret_dtype = np.datetime64
+ else:
+ ret_dtype = self._data.dtype
- def pandas_floor(s) -> ps.Series[np.datetime64]: # type:
ignore[no-untyped-def]
+ def pandas_floor(s) -> ps.Series[ret_dtype]: # type:
ignore[no-untyped-def, valid-type]
return s.dt.floor(freq, *args, **kwargs)
return self._data.pandas_on_spark.transform_batch(pandas_floor)
@@ -816,8 +833,13 @@ class DatetimeMethods:
2 2018-01-01 13:00:00
dtype: datetime64[ns]
"""
+ ret_dtype: Union[type, Dtype]
+ if LooseVersion(pd.__version__) < "3.0.0":
+ ret_dtype = np.datetime64
+ else:
+ ret_dtype = self._data.dtype
- def pandas_ceil(s) -> ps.Series[np.datetime64]: # type:
ignore[no-untyped-def]
+ def pandas_ceil(s) -> ps.Series[ret_dtype]: # type:
ignore[no-untyped-def, valid-type]
return s.dt.ceil(freq, *args, **kwargs)
return self._data.pandas_on_spark.transform_batch(pandas_ceil)
diff --git a/python/pyspark/pandas/indexes/datetimes.py
b/python/pyspark/pandas/indexes/datetimes.py
index 6d7b723d0f11..51fec78e89dc 100644
--- a/python/pyspark/pandas/indexes/datetimes.py
+++ b/python/pyspark/pandas/indexes/datetimes.py
@@ -160,8 +160,9 @@ class DatetimeIndex(Index):
raise TypeError("Index.name must be a hashable type")
if isinstance(data, (Series, Index)):
- if dtype is None:
- dtype = "datetime64[ns]"
+ if LooseVersion(pd.__version__) < "3.0.0":
+ if dtype is None:
+ dtype = "datetime64[ns]"
return cast(DatetimeIndex, Index(data, dtype=dtype, copy=copy,
name=name))
return cast(DatetimeIndex, ps.from_pandas(pd.DatetimeIndex(**kwargs)))
diff --git a/python/pyspark/pandas/indexes/timedelta.py
b/python/pyspark/pandas/indexes/timedelta.py
index 112d2bda0688..762bc90f46c3 100644
--- a/python/pyspark/pandas/indexes/timedelta.py
+++ b/python/pyspark/pandas/indexes/timedelta.py
@@ -112,8 +112,9 @@ class TimedeltaIndex(Index):
raise TypeError("Index.name must be a hashable type")
if isinstance(data, (Series, Index)):
- if dtype is None:
- dtype = "timedelta64[ns]"
+ if LooseVersion(pd.__version__) < "3.0.0":
+ if dtype is None:
+ dtype = "timedelta64[ns]"
return cast(TimedeltaIndex, Index(data, dtype=dtype, copy=copy,
name=name))
kwargs = dict(
diff --git a/python/pyspark/pandas/namespace.py
b/python/pyspark/pandas/namespace.py
index 0a1f2413e01f..9490e3702523 100644
--- a/python/pyspark/pandas/namespace.py
+++ b/python/pyspark/pandas/namespace.py
@@ -1766,9 +1766,16 @@ def to_datetime(
"The 'infer_datetime_format' keyword is not supported in
pandas 3.0.0 and later."
)
+ ret_type: type
+ if LooseVersion(pd.__version__) < "3.0.0":
+ ret_type = Series[np.datetime64]
+ else:
+ # The unit is unpredictable.
+ ret_type = None
+
def pandas_to_datetime(
pser_or_pdf: Union[pd.DataFrame, pd.Series], cols: Optional[List[str]]
= None
- ) -> Series[np.datetime64]:
+ ) -> ret_type:
if isinstance(pser_or_pdf, pd.DataFrame):
pser_or_pdf = pser_or_pdf[cols]
return pd.to_datetime(pser_or_pdf, **kwargs)
@@ -2029,17 +2036,22 @@ def to_timedelta(
TimedeltaIndex(['0 days', '1 days', '2 days', '3 days', '4 days'],
dtype='timedelta64[ns]', freq=None)
"""
-
- def pandas_to_timedelta(pser: pd.Series) -> np.timedelta64:
- return pd.to_timedelta(
- arg=pser,
- unit=unit,
- errors=errors,
- )
-
if isinstance(arg, Series):
- return arg.transform(pandas_to_timedelta)
+ ret_dtype: Union[type, Dtype]
+ if LooseVersion(pd.__version__) < "3.0.0":
+ ret_dtype = np.timedelta64
+ else:
+ # The unit is unpredictable.
+ ret_dtype = None
+
+ def pandas_to_timedelta(pser: pd.Series) -> ret_dtype:
+ return pd.to_timedelta(
+ arg=pser,
+ unit=unit,
+ errors=errors,
+ )
+ return arg.transform(pandas_to_timedelta)
else:
return pd.to_timedelta(
arg=arg,
diff --git a/python/pyspark/pandas/tests/indexes/test_datetime.py
b/python/pyspark/pandas/tests/indexes/test_datetime.py
index 41a4a862fd17..b3a594973b97 100644
--- a/python/pyspark/pandas/tests/indexes/test_datetime.py
+++ b/python/pyspark/pandas/tests/indexes/test_datetime.py
@@ -71,55 +71,60 @@ class
DatetimeIndexTestsMixin(DatetimeIndexTestingFuncMixin):
ps.DatetimeIndex(["2004-01-01", "2002-12-31", "2000-04-01"]).all()
def test_day_name(self):
- for psidx, pidx in self.idx_pairs:
- self.assert_eq(psidx.day_name(), pidx.day_name())
+ for i, (psidx, pidx) in enumerate(self.idx_pairs):
+ with self.subTest(i=i):
+ self.assert_eq(psidx.day_name(), pidx.day_name())
def test_month_name(self):
- for psidx, pidx in self.idx_pairs:
- self.assert_eq(psidx.month_name(), pidx.month_name())
+ for i, (psidx, pidx) in enumerate(self.idx_pairs):
+ with self.subTest(i=i):
+ self.assert_eq(psidx.month_name(), pidx.month_name())
def test_normalize(self):
- for psidx, pidx in self.idx_pairs:
- self.assert_eq(psidx.normalize(), pidx.normalize())
+ for i, (psidx, pidx) in enumerate(self.idx_pairs):
+ with self.subTest(i=i):
+ self.assert_eq(psidx.normalize(), pidx.normalize())
def test_strftime(self):
- for psidx, pidx in self.idx_pairs:
- self.assert_eq(
- psidx.strftime(date_format="%B %d, %Y"),
pidx.strftime(date_format="%B %d, %Y")
- )
+ for i, (psidx, pidx) in enumerate(self.idx_pairs):
+ with self.subTest(i=i):
+ self.assert_eq(
+ psidx.strftime(date_format="%B %d, %Y"),
pidx.strftime(date_format="%B %d, %Y")
+ )
def test_arithmetic_op_exceptions(self):
- for psidx, pidx in self.idx_pairs:
- py_datetime = pidx.to_pydatetime()
- for other in [1, 0.1, psidx,
psidx.to_series().reset_index(drop=True), py_datetime]:
- expected_err_msg = "Addition can not be applied to datetimes."
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
psidx + other)
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
other + psidx)
+ for i, (psidx, pidx) in enumerate(self.idx_pairs):
+ with self.subTest(i=i):
+ py_datetime = pidx.to_pydatetime()
+ for other in [1, 0.1, psidx,
psidx.to_series().reset_index(drop=True), py_datetime]:
+ expected_err_msg = "Addition can not be applied to
datetimes."
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: psidx + other)
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: other + psidx)
- expected_err_msg = "Multiplication can not be applied to
datetimes."
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
psidx * other)
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
other * psidx)
+ expected_err_msg = "Multiplication can not be applied to
datetimes."
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: psidx * other)
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: other * psidx)
- expected_err_msg = "True division can not be applied to
datetimes."
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
psidx / other)
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
other / psidx)
+ expected_err_msg = "True division can not be applied to
datetimes."
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: psidx / other)
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: other / psidx)
- expected_err_msg = "Floor division can not be applied to
datetimes."
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
psidx // other)
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
other // psidx)
+ expected_err_msg = "Floor division can not be applied to
datetimes."
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: psidx // other)
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: other // psidx)
- expected_err_msg = "Modulo can not be applied to datetimes."
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
psidx % other)
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
other % psidx)
+ expected_err_msg = "Modulo can not be applied to
datetimes."
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: psidx % other)
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: other % psidx)
- expected_err_msg = "Datetime subtraction can only be applied to
datetime series."
+ expected_err_msg = "Datetime subtraction can only be applied
to datetime series."
- for other in [1, 0.1]:
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
psidx - other)
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
other - psidx)
+ for other in [1, 0.1]:
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: psidx - other)
+ self.assertRaisesRegex(TypeError, expected_err_msg,
lambda: other - psidx)
- self.assertRaisesRegex(TypeError, expected_err_msg, lambda: psidx
- other)
- self.assertRaises(NotImplementedError, lambda: py_datetime - psidx)
+ self.assertRaisesRegex(TypeError, expected_err_msg, lambda:
psidx - other)
+ self.assertRaises(NotImplementedError, lambda: py_datetime -
psidx)
class DatetimeIndexTests(
diff --git a/python/pyspark/pandas/tests/series/test_stat.py
b/python/pyspark/pandas/tests/series/test_stat.py
index c94c01583b20..7c7c1754c042 100644
--- a/python/pyspark/pandas/tests/series/test_stat.py
+++ b/python/pyspark/pandas/tests/series/test_stat.py
@@ -553,7 +553,7 @@ class SeriesStatMixin:
):
ps.Series(["a", "b", "c"]).prod()
with self.assertRaisesRegex(
- TypeError, "Could not convert datetime64\\[ns\\] \\(timestamp.*\\)
to numeric"
+ TypeError, r"Could not convert datetime64\[[nu]s\] \(timestamp.*\)
to numeric"
):
ps.Series([pd.Timestamp("2016-01-01") for _ in range(3)]).prod()
with self.assertRaisesRegex(NotImplementedError, "Series does not
support columns axis."):
diff --git a/python/pyspark/pandas/tests/test_namespace.py
b/python/pyspark/pandas/tests/test_namespace.py
index f68a637723f7..1945ca9d70f6 100644
--- a/python/pyspark/pandas/tests/test_namespace.py
+++ b/python/pyspark/pandas/tests/test_namespace.py
@@ -273,12 +273,12 @@ class NamespaceTestsMixin:
pd.to_timedelta(np.arange(5), unit="s"),
)
self.assert_eq(
- ps.to_timedelta(ps.Series([1, 2]), unit="d"),
- pd.to_timedelta(pd.Series([1, 2]), unit="d"),
+ ps.to_timedelta(ps.Series([1, 2]), unit="D"),
+ pd.to_timedelta(pd.Series([1, 2]), unit="D"),
)
self.assert_eq(
- ps.to_timedelta(pd.Series([1, 2]), unit="d"),
- pd.to_timedelta(pd.Series([1, 2]), unit="d"),
+ ps.to_timedelta(pd.Series([1, 2]), unit="D"),
+ pd.to_timedelta(pd.Series([1, 2]), unit="D"),
)
def test_timedelta_range(self):
diff --git a/python/pyspark/pandas/tests/test_typedef.py
b/python/pyspark/pandas/tests/test_typedef.py
index 7c9d2a600037..e07551ab3a4e 100644
--- a/python/pyspark/pandas/tests/test_typedef.py
+++ b/python/pyspark/pandas/tests/test_typedef.py
@@ -361,7 +361,12 @@ class TypeHintTestsMixin:
bool: (np.bool_, BooleanType()),
# datetime
np.datetime64: (np.datetime64, TimestampType()),
- datetime.datetime: (np.dtype("datetime64[ns]"), TimestampType()),
+ datetime.datetime: (
+ np.dtype("datetime64[ns]")
+ if LooseVersion(pd.__version__) < LooseVersion("3.0.0")
+ else np.dtype("datetime64[us]"),
+ TimestampType(),
+ ),
# DateType
datetime.date: (np.dtype("object"), DateType()),
# DecimalType
diff --git a/python/pyspark/pandas/typedef/typehints.py
b/python/pyspark/pandas/typedef/typehints.py
index 99249906089e..ce5f0971693f 100644
--- a/python/pyspark/pandas/typedef/typehints.py
+++ b/python/pyspark/pandas/typedef/typehints.py
@@ -300,9 +300,15 @@ def spark_type_to_pandas_dtype(
):
return np.dtype("object")
elif isinstance(spark_type, types.DayTimeIntervalType):
- return np.dtype("timedelta64[ns]")
+ if LooseVersion(pd.__version__) < "3.0.0":
+ return np.dtype("timedelta64[ns]")
+ else:
+ return np.dtype("timedelta64[us]")
elif isinstance(spark_type, (types.TimestampType, types.TimestampNTZType)):
- return np.dtype("datetime64[ns]")
+ if LooseVersion(pd.__version__) < "3.0.0":
+ return np.dtype("datetime64[ns]")
+ else:
+ return np.dtype("datetime64[us]")
else:
from pyspark.pandas.utils import default_session
@@ -620,7 +626,7 @@ def infer_return_type(f: Callable) -> Union[SeriesType,
DataFrameType, ScalarTyp
if hasattr(tpe, "__origin__") and issubclass(tpe.__origin__, SeriesType):
tpe = tpe.__args__[0]
- if issubclass(tpe, NameTypeHolder):
+ if isinstance(tpe, type) and issubclass(tpe, NameTypeHolder):
tpe = tpe.tpe
dtype, spark_type = pandas_on_spark_type(tpe)
return SeriesType(dtype, spark_type)
@@ -712,7 +718,10 @@ def create_type_for_series_type(param: Any) ->
Type[SeriesType]:
new_class = type(NameTypeHolder.short_name, (NameTypeHolder,), {})
new_class.tpe = param # type: ignore[assignment]
else:
- new_class = param.type if isinstance(param, np.dtype) else param
+ if LooseVersion(pd.__version__) < "3.0.0":
+ new_class = param.type if isinstance(param, np.dtype) else param
+ else:
+ new_class = param
return SeriesType[new_class] # type: ignore[valid-type]
@@ -872,11 +881,18 @@ def _new_type_holders(
holder_clazz.short_name, (holder_clazz,), {}
)
new_param.name = param.start
- if isinstance(param.stop, ExtensionDtype):
- new_param.tpe = param.stop # type: ignore[assignment]
+ if LooseVersion(pd.__version__) < "3.0.0":
+ if isinstance(param.stop, ExtensionDtype):
+ new_param.tpe = param.stop # type: ignore[assignment]
+ else:
+ # When the given argument is a numpy's dtype instance.
+ new_param.tpe = (
+ param.stop.type # type: ignore[assignment]
+ if isinstance(param.stop, np.dtype)
+ else param.stop
+ )
else:
- # When the given argument is a numpy's dtype instance.
- new_param.tpe = param.stop.type if isinstance(param.stop,
np.dtype) else param.stop # type: ignore[assignment]
+ new_param.tpe = param.stop
new_params.append(new_param)
return tuple(new_params)
elif is_unnamed_params:
@@ -886,10 +902,17 @@ def _new_type_holders(
new_type: Type[Union[NameTypeHolder, IndexNameTypeHolder]] = type(
holder_clazz.short_name, (holder_clazz,), {}
)
- if isinstance(param, ExtensionDtype):
- new_type.tpe = param # type: ignore[assignment]
+ if LooseVersion(pd.__version__) < "3.0.0":
+ if isinstance(param, ExtensionDtype):
+ new_type.tpe = param # type: ignore[assignment]
+ else:
+ new_type.tpe = (
+ param.type # type: ignore[assignment]
+ if isinstance(param, np.dtype)
+ else param
+ )
else:
- new_type.tpe = param.type if isinstance(param, np.dtype) else
param # type: ignore[assignment]
+ new_type.tpe = param
new_types.append(new_type)
return tuple(new_types)
else:
diff --git a/python/pyspark/sql/pandas/types.py
b/python/pyspark/sql/pandas/types.py
index 3b3bacc34db3..1c4ea9193b8e 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -27,6 +27,7 @@ from decimal import Decimal
from typing import Any, Callable, Dict, Iterable, List, Optional, Union,
TYPE_CHECKING
from pyspark.errors import PySparkTypeError, UnsupportedOperationException,
PySparkValueError
+from pyspark.loose_version import LooseVersion
from pyspark.sql.types import (
cast,
BooleanType,
@@ -863,6 +864,7 @@ def _to_corrected_pandas_type(dt: DataType) ->
Optional[Any]:
inferred incorrectly.
"""
import numpy as np
+ import pandas as pd
if type(dt) == ByteType:
return np.int8
@@ -879,11 +881,20 @@ def _to_corrected_pandas_type(dt: DataType) ->
Optional[Any]:
elif type(dt) == BooleanType:
return bool
elif type(dt) == TimestampType:
- return np.dtype("datetime64[ns]")
+ if LooseVersion(pd.__version__) < "3.0.0":
+ return np.dtype("datetime64[ns]")
+ else:
+ return np.dtype("datetime64[us]")
elif type(dt) == TimestampNTZType:
- return np.dtype("datetime64[ns]")
+ if LooseVersion(pd.__version__) < "3.0.0":
+ return np.dtype("datetime64[ns]")
+ else:
+ return np.dtype("datetime64[us]")
elif type(dt) == DayTimeIntervalType:
- return np.dtype("timedelta64[ns]")
+ if LooseVersion(pd.__version__) < "3.0.0":
+ return np.dtype("timedelta64[ns]")
+ else:
+ return np.dtype("timedelta64[us]")
else:
return None
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]