This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8c401be [SPARK-35901][PYTHON] Refine type hints in
pyspark.pandas.window
8c401be is described below
commit 8c401beb806267d4c23aeb27ab8898dcc3a0f98d
Author: Takuya UESHIN <[email protected]>
AuthorDate: Mon Jun 28 12:23:32 2021 +0900
[SPARK-35901][PYTHON] Refine type hints in pyspark.pandas.window
### What changes were proposed in this pull request?
Refines type hints in `pyspark.pandas.window`.
Also, some refactoring is included to clean up the type hierarchy of
`Rolling` and `Expanding`.
### Why are the changes needed?
We can use more strict type hints for functions in pyspark.pandas.window
using the generic way.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
Closes #33097 from ueshin/issues/SPARK-35901/window.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/frame.py | 14 +++
python/pyspark/pandas/generic.py | 18 +--
python/pyspark/pandas/groupby.py | 22 ++--
python/pyspark/pandas/series.py | 14 +++
python/pyspark/pandas/window.py | 249 ++++++++++++++++++---------------------
5 files changed, 166 insertions(+), 151 deletions(-)
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 7f26346..6b6301a 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -11676,6 +11676,20 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
"""
return DataFrame(pd.DataFrame.from_dict(data, orient=orient,
dtype=dtype, columns=columns))
+ # Override the `groupby` to specify the actual return type annotation.
+ def groupby(
+ self,
+ by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]],
+ axis: Union[int, str] = 0,
+ as_index: bool = True,
+ dropna: bool = True,
+ ) -> "DataFrameGroupBy":
+ return cast(
+ "DataFrameGroupBy", super().groupby(by=by, axis=axis,
as_index=as_index, dropna=dropna)
+ )
+
+ groupby.__doc__ = Frame.groupby.__doc__
+
def _build_groupby(
self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool
) -> "DataFrameGroupBy":
diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index 3a33295..a0c3f23 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -67,13 +67,13 @@ from pyspark.pandas.utils import (
validate_axis,
SPARK_CONF_ARROW_ENABLED,
)
-from pyspark.pandas.window import Rolling, Expanding
if TYPE_CHECKING:
from pyspark.pandas.frame import DataFrame # noqa: F401 (SPARK-34943)
from pyspark.pandas.indexes.base import Index # noqa: F401 (SPARK-34943)
from pyspark.pandas.groupby import GroupBy # noqa: F401 (SPARK-34943)
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
+ from pyspark.pandas.window import Rolling, Expanding # noqa: F401
(SPARK-34943)
T_Frame = TypeVar("T_Frame", bound="Frame")
@@ -2508,7 +2508,9 @@ class Frame(object, metaclass=ABCMeta):
return tuple(last_valid_row)
# TODO: 'center', 'win_type', 'on', 'axis' parameter should be implemented.
- def rolling(self, window: int, min_periods: Optional[int] = None) ->
Rolling:
+ def rolling(
+ self: T_Frame, window: int, min_periods: Optional[int] = None
+ ) -> "Rolling[T_Frame]":
"""
Provide rolling transformations.
@@ -2533,13 +2535,13 @@ class Frame(object, metaclass=ABCMeta):
-------
a Window sub-classed for the particular operation
"""
- return Rolling(
- cast(Union["Series", "DataFrame"], self), window=window,
min_periods=min_periods
- )
+ from pyspark.pandas.window import Rolling
+
+ return Rolling(self, window=window, min_periods=min_periods)
# TODO: 'center' and 'axis' parameter should be implemented.
# 'axis' implementation, refer https://github.com/pyspark.pandas/pull/607
- def expanding(self, min_periods: int = 1) -> Expanding:
+ def expanding(self: T_Frame, min_periods: int = 1) -> "Expanding[T_Frame]":
"""
Provide expanding transformations.
@@ -2557,7 +2559,9 @@ class Frame(object, metaclass=ABCMeta):
-------
a Window sub-classed for the particular operation
"""
- return Expanding(cast(Union["Series", "DataFrame"], self),
min_periods=min_periods)
+ from pyspark.pandas.window import Expanding
+
+ return Expanding(self, min_periods=min_periods)
def get(self, key: Any, default: Optional[Any] = None) -> Any:
"""
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 860540e..1620c8c 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -41,6 +41,7 @@ from typing import (
TypeVar,
Union,
cast,
+ TYPE_CHECKING,
)
import pandas as pd
@@ -85,9 +86,12 @@ from pyspark.pandas.utils import (
verify_temp_column_name,
)
from pyspark.pandas.spark.utils import as_nullable_spark_type,
force_decimal_precision_scale
-from pyspark.pandas.window import RollingGroupby, ExpandingGroupby
from pyspark.pandas.exceptions import DataError
+if TYPE_CHECKING:
+ from pyspark.pandas.window import RollingGroupby, ExpandingGroupby #
noqa: F401 (SPARK-34943)
+
+
# to keep it the same as pandas
NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"])
@@ -2320,7 +2324,7 @@ class GroupBy(Generic[T_Frame], metaclass=ABCMeta):
return self._reduce_for_stat_function(stat_function,
only_numeric=False)
- def rolling(self, window: int, min_periods: Optional[int] = None) ->
RollingGroupby:
+ def rolling(self, window: int, min_periods: Optional[int] = None) ->
"RollingGroupby[T_Frame]":
"""
Return an rolling grouper, providing rolling
functionality per group.
@@ -2345,11 +2349,11 @@ class GroupBy(Generic[T_Frame], metaclass=ABCMeta):
Series.groupby
DataFrame.groupby
"""
- return RollingGroupby(
- cast(Union[SeriesGroupBy, DataFrameGroupBy], self), window,
min_periods=min_periods
- )
+ from pyspark.pandas.window import RollingGroupby
- def expanding(self, min_periods: int = 1) -> ExpandingGroupby:
+ return RollingGroupby(self, window, min_periods=min_periods)
+
+ def expanding(self, min_periods: int = 1) -> "ExpandingGroupby[T_Frame]":
"""
Return an expanding grouper, providing expanding
functionality per group.
@@ -2369,9 +2373,9 @@ class GroupBy(Generic[T_Frame], metaclass=ABCMeta):
Series.groupby
DataFrame.groupby
"""
- return ExpandingGroupby(
- cast(Union[SeriesGroupBy, DataFrameGroupBy], self),
min_periods=min_periods
- )
+ from pyspark.pandas.window import ExpandingGroupby
+
+ return ExpandingGroupby(self, min_periods=min_periods)
def get_group(self, name: Union[Any, Tuple, List[Union[Any, Tuple]]]) ->
T_Frame:
"""
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 3de7243..b4e95ad 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -6216,6 +6216,20 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
result = unpack_scalar(self._internal.spark_frame.select(scol))
return result if result is not None else np.nan
+ # Override the `groupby` to specify the actual return type annotation.
+ def groupby(
+ self,
+ by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]],
+ axis: Union[int, str_type] = 0,
+ as_index: bool = True,
+ dropna: bool = True,
+ ) -> "SeriesGroupBy":
+ return cast(
+ "SeriesGroupBy", super().groupby(by=by, axis=axis,
as_index=as_index, dropna=dropna)
+ )
+
+ groupby.__doc__ = Frame.groupby.__doc__
+
def _build_groupby(
self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool
) -> "SeriesGroupBy":
diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py
index 8c9a59d..b1ee83f 100644
--- a/python/pyspark/pandas/window.py
+++ b/python/pyspark/pandas/window.py
@@ -14,15 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+from abc import ABCMeta, abstractmethod
from functools import partial
from typing import ( # noqa: F401 (SPARK-34943)
Any,
- Union,
- TYPE_CHECKING,
Callable,
+ Generic,
List,
- cast,
Optional,
+ TypeVar,
)
from pyspark.sql import Window
@@ -42,18 +42,15 @@ from pyspark.pandas.utils import scol_for
from pyspark.sql.column import Column
from pyspark.sql.window import WindowSpec
-if TYPE_CHECKING:
- from pyspark.pandas.frame import DataFrame # noqa: F401 (SPARK-34943)
- from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
- from pyspark.pandas.groupby import SeriesGroupBy # noqa: F401
(SPARK-34943)
- from pyspark.pandas.groupby import DataFrameGroupBy # noqa: F401
(SPARK-34943)
+from pyspark.pandas.generic import Frame
+from pyspark.pandas.groupby import GroupBy
-class RollingAndExpanding(object):
- def __init__(
- self, psdf_or_psser: Union["Series", "DataFrame"], window: WindowSpec,
min_periods: int
- ):
- self._psdf_or_psser = psdf_or_psser
+T_Frame = TypeVar("T_Frame", bound=Frame)
+
+
+class RollingAndExpanding(Generic[T_Frame], metaclass=ABCMeta):
+ def __init__(self, window: WindowSpec, min_periods: int):
self._window = window
# This unbounded Window is later used to handle 'min_periods' for now.
self._unbounded_window =
Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
@@ -61,28 +58,20 @@ class RollingAndExpanding(object):
)
self._min_periods = min_periods
- def _apply_as_series_or_frame(
- self, func: Callable[[Column], Column]
- ) -> Union["Series", "DataFrame"]:
+ @abstractmethod
+ def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) ->
T_Frame:
"""
Wraps a function that handles Spark column in order
to support it in both pandas-on-Spark Series and DataFrame.
Note that the given `func` name should be same as the API's method
name.
"""
- raise NotImplementedError(
- "A class that inherits this class should implement this method "
- "to handle the index and columns of output."
- )
+ pass
- def count(self) -> Union["Series", "DataFrame"]:
- def count(scol: Column) -> Column:
- return F.count(scol).over(self._window)
-
- return cast(
- Union["Series", "DataFrame"],
self._apply_as_series_or_frame(count).astype("float64")
- )
+ @abstractmethod
+ def count(self) -> T_Frame:
+ pass
- def sum(self) -> Union["Series", "DataFrame"]:
+ def sum(self) -> T_Frame:
def sum(scol: Column) -> Column:
return F.when(
F.row_number().over(self._unbounded_window) >=
self._min_periods,
@@ -91,7 +80,7 @@ class RollingAndExpanding(object):
return self._apply_as_series_or_frame(sum)
- def min(self) -> Union["Series", "DataFrame"]:
+ def min(self) -> T_Frame:
def min(scol: Column) -> Column:
return F.when(
F.row_number().over(self._unbounded_window) >=
self._min_periods,
@@ -100,7 +89,7 @@ class RollingAndExpanding(object):
return self._apply_as_series_or_frame(min)
- def max(self) -> Union["Series", "DataFrame"]:
+ def max(self) -> T_Frame:
def max(scol: Column) -> Column:
return F.when(
F.row_number().over(self._unbounded_window) >=
self._min_periods,
@@ -109,7 +98,7 @@ class RollingAndExpanding(object):
return self._apply_as_series_or_frame(max)
- def mean(self) -> Union["Series", "DataFrame"]:
+ def mean(self) -> T_Frame:
def mean(scol: Column) -> Column:
return F.when(
F.row_number().over(self._unbounded_window) >=
self._min_periods,
@@ -118,7 +107,7 @@ class RollingAndExpanding(object):
return self._apply_as_series_or_frame(mean)
- def std(self) -> Union["Series", "DataFrame"]:
+ def std(self) -> T_Frame:
def std(scol: Column) -> Column:
return F.when(
F.row_number().over(self._unbounded_window) >=
self._min_periods,
@@ -127,7 +116,7 @@ class RollingAndExpanding(object):
return self._apply_as_series_or_frame(std)
- def var(self) -> Union["Series", "DataFrame"]:
+ def var(self) -> T_Frame:
def var(scol: Column) -> Column:
return F.when(
F.row_number().over(self._unbounded_window) >=
self._min_periods,
@@ -137,15 +126,12 @@ class RollingAndExpanding(object):
return self._apply_as_series_or_frame(var)
-class Rolling(RollingAndExpanding):
+class RollingLike(RollingAndExpanding[T_Frame]):
def __init__(
self,
- psdf_or_psser: Union["Series", "DataFrame"],
window: int,
min_periods: Optional[int] = None,
):
- from pyspark.pandas import DataFrame, Series
-
if window < 0:
raise ValueError("window must be >= 0")
if (min_periods is not None) and (min_periods < 0):
@@ -155,17 +141,37 @@ class Rolling(RollingAndExpanding):
# a value.
min_periods = window
+ window_spec = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
+ Window.currentRow - (window - 1), Window.currentRow
+ )
+
+ super().__init__(window_spec, min_periods)
+
+ def count(self) -> T_Frame:
+ def count(scol: Column) -> Column:
+ return F.count(scol).over(self._window)
+
+ return self._apply_as_series_or_frame(count).astype("float64") #
type: ignore
+
+
+class Rolling(RollingLike[T_Frame]):
+ def __init__(
+ self,
+ psdf_or_psser: T_Frame,
+ window: int,
+ min_periods: Optional[int] = None,
+ ):
+ from pyspark.pandas.frame import DataFrame
+ from pyspark.pandas.series import Series
+
+ super().__init__(window, min_periods)
+
if not isinstance(psdf_or_psser, (DataFrame, Series)):
raise TypeError(
"psdf_or_psser must be a series or dataframe; however, got: %s"
% type(psdf_or_psser)
)
-
- window_spec = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
- Window.currentRow - (window - 1), Window.currentRow
- )
-
- super().__init__(psdf_or_psser, window_spec, min_periods)
+ self._psdf_or_psser = psdf_or_psser
def __getattr__(self, item: str) -> Any:
if hasattr(MissingPandasLikeRolling, item):
@@ -176,15 +182,13 @@ class Rolling(RollingAndExpanding):
return partial(property_or_func, self)
raise AttributeError(item)
- def _apply_as_series_or_frame(
- self, func: Callable[[Column], Column]
- ) -> Union["Series", "DataFrame"]:
+ def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) ->
T_Frame:
return self._psdf_or_psser._apply_series_op(
lambda psser: psser._with_new_scol(func(psser.spark.column)), #
TODO: dtype?
should_resolve=True,
)
- def count(self) -> Union["Series", "DataFrame"]:
+ def count(self) -> T_Frame:
"""
The rolling count of any non-NaN observations inside the window.
@@ -233,7 +237,7 @@ class Rolling(RollingAndExpanding):
"""
return super().count()
- def sum(self) -> Union["Series", "DataFrame"]:
+ def sum(self) -> T_Frame:
"""
Calculate rolling summation of given DataFrame or Series.
@@ -311,7 +315,7 @@ class Rolling(RollingAndExpanding):
"""
return super().sum()
- def min(self) -> Union["Series", "DataFrame"]:
+ def min(self) -> T_Frame:
"""
Calculate the rolling minimum.
@@ -389,7 +393,7 @@ class Rolling(RollingAndExpanding):
"""
return super().min()
- def max(self) -> Union["Series", "DataFrame"]:
+ def max(self) -> T_Frame:
"""
Calculate the rolling maximum.
@@ -466,7 +470,7 @@ class Rolling(RollingAndExpanding):
"""
return super().max()
- def mean(self) -> Union["Series", "DataFrame"]:
+ def mean(self) -> T_Frame:
"""
Calculate the rolling mean of the values.
@@ -544,7 +548,7 @@ class Rolling(RollingAndExpanding):
"""
return super().mean()
- def std(self) -> Union["Series", "DataFrame"]:
+ def std(self) -> T_Frame:
"""
Calculate rolling standard deviation.
@@ -594,7 +598,7 @@ class Rolling(RollingAndExpanding):
"""
return super().std()
- def var(self) -> Union["Series", "DataFrame"]:
+ def var(self) -> T_Frame:
"""
Calculate unbiased rolling variance.
@@ -645,27 +649,14 @@ class Rolling(RollingAndExpanding):
return super().var()
-class RollingGroupby(Rolling):
+class RollingGroupby(RollingLike[T_Frame]):
def __init__(
self,
- groupby: Union["SeriesGroupBy", "DataFrameGroupBy"],
+ groupby: GroupBy[T_Frame],
window: int,
min_periods: Optional[int] = None,
):
- from pyspark.pandas.groupby import SeriesGroupBy
- from pyspark.pandas.groupby import DataFrameGroupBy
-
- if isinstance(groupby, SeriesGroupBy):
- psdf_or_psser = groupby._psser # type: Union[DataFrame, Series]
- elif isinstance(groupby, DataFrameGroupBy):
- psdf_or_psser = groupby._psdf
- else:
- raise TypeError(
- "groupby must be a SeriesGroupBy or DataFrameGroupBy; "
- "however, got: %s" % type(groupby)
- )
-
- super().__init__(psdf_or_psser, window, min_periods)
+ super().__init__(window, min_periods)
self._groupby = groupby
self._window = self._window.partitionBy(*[ser.spark.column for ser in
groupby._groupkeys])
@@ -682,17 +673,13 @@ class RollingGroupby(Rolling):
return partial(property_or_func, self)
raise AttributeError(item)
- def _apply_as_series_or_frame(
- self, func: Callable[[Column], Column]
- ) -> Union["Series", "DataFrame"]:
+ def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) ->
T_Frame:
"""
Wraps a function that handles Spark column in order
to support it in both pandas-on-Spark Series and DataFrame.
Note that the given `func` name should be same as the API's method
name.
"""
from pyspark.pandas import DataFrame
- from pyspark.pandas.series import first_series
- from pyspark.pandas.groupby import SeriesGroupBy
groupby = self._groupby
psdf = groupby._psdf
@@ -755,13 +742,9 @@ class RollingGroupby(Rolling):
data_fields=[c._internal.data_fields[0] for c in applied],
)
- ret = DataFrame(internal) # type: DataFrame
- if isinstance(groupby, SeriesGroupBy):
- return first_series(ret)
- else:
- return ret
+ return groupby._cleanup_and_return(DataFrame(internal))
- def count(self) -> Union["Series", "DataFrame"]:
+ def count(self) -> T_Frame:
"""
The rolling count of any non-NaN observations inside the window.
@@ -815,7 +798,7 @@ class RollingGroupby(Rolling):
"""
return super().count()
- def sum(self) -> Union["Series", "DataFrame"]:
+ def sum(self) -> T_Frame:
"""
The rolling summation of any non-NaN observations inside the window.
@@ -869,7 +852,7 @@ class RollingGroupby(Rolling):
"""
return super().sum()
- def min(self) -> Union["Series", "DataFrame"]:
+ def min(self) -> T_Frame:
"""
The rolling minimum of any non-NaN observations inside the window.
@@ -923,7 +906,7 @@ class RollingGroupby(Rolling):
"""
return super().min()
- def max(self) -> Union["Series", "DataFrame"]:
+ def max(self) -> T_Frame:
"""
The rolling maximum of any non-NaN observations inside the window.
@@ -977,7 +960,7 @@ class RollingGroupby(Rolling):
"""
return super().max()
- def mean(self) -> Union["Series", "DataFrame"]:
+ def mean(self) -> T_Frame:
"""
The rolling mean of any non-NaN observations inside the window.
@@ -1031,7 +1014,7 @@ class RollingGroupby(Rolling):
"""
return super().mean()
- def std(self) -> Union["Series", "DataFrame"]:
+ def std(self) -> T_Frame:
"""
Calculate rolling standard deviation.
@@ -1050,7 +1033,7 @@ class RollingGroupby(Rolling):
"""
return super().std()
- def var(self) -> Union["Series", "DataFrame"]:
+ def var(self) -> T_Frame:
"""
Calculate unbiased rolling variance.
@@ -1070,24 +1053,40 @@ class RollingGroupby(Rolling):
return super().var()
-class Expanding(RollingAndExpanding):
- def __init__(self, psdf_or_psser: Union["Series", "DataFrame"],
min_periods: int = 1):
- from pyspark.pandas import DataFrame, Series
-
+class ExpandingLike(RollingAndExpanding[T_Frame]):
+ def __init__(self, min_periods: int = 1):
if min_periods < 0:
raise ValueError("min_periods must be >= 0")
+ window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
+ Window.unboundedPreceding, Window.currentRow
+ )
+
+ super().__init__(window, min_periods)
+
+ def count(self) -> T_Frame:
+ def count(scol: Column) -> Column:
+ return F.when(
+ F.row_number().over(self._unbounded_window) >=
self._min_periods,
+ F.count(scol).over(self._window),
+ ).otherwise(F.lit(None))
+
+ return self._apply_as_series_or_frame(count).astype("float64") #
type: ignore
+
+
+class Expanding(ExpandingLike[T_Frame]):
+ def __init__(self, psdf_or_psser: T_Frame, min_periods: int = 1):
+ from pyspark.pandas.frame import DataFrame
+ from pyspark.pandas.series import Series
+
+ super().__init__(min_periods)
+
if not isinstance(psdf_or_psser, (DataFrame, Series)):
raise TypeError(
"psdf_or_psser must be a series or dataframe; however, got: %s"
% type(psdf_or_psser)
)
-
- window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
- Window.unboundedPreceding, Window.currentRow
- )
-
- super().__init__(psdf_or_psser, window, min_periods)
+ self._psdf_or_psser = psdf_or_psser
def __getattr__(self, item: str) -> Any:
if hasattr(MissingPandasLikeExpanding, item):
@@ -1104,7 +1103,7 @@ class Expanding(RollingAndExpanding):
_apply_as_series_or_frame = Rolling._apply_as_series_or_frame
- def count(self) -> Union["Series", "DataFrame"]:
+ def count(self) -> T_Frame:
"""
The expanding count of any non-NaN observations inside the window.
@@ -1143,16 +1142,9 @@ class Expanding(RollingAndExpanding):
2 2.0
3 3.0
"""
+ return super().count()
- def count(scol: Column) -> Column:
- return F.when(
- F.row_number().over(self._unbounded_window) >=
self._min_periods,
- F.count(scol).over(self._window),
- ).otherwise(F.lit(None))
-
- return self._apply_as_series_or_frame(count).astype("float64") #
type: ignore
-
- def sum(self) -> Union["Series", "DataFrame"]:
+ def sum(self) -> T_Frame:
"""
Calculate expanding summation of given DataFrame or Series.
@@ -1214,7 +1206,7 @@ class Expanding(RollingAndExpanding):
"""
return super().sum()
- def min(self) -> Union["Series", "DataFrame"]:
+ def min(self) -> T_Frame:
"""
Calculate the expanding minimum.
@@ -1251,7 +1243,7 @@ class Expanding(RollingAndExpanding):
"""
return super().min()
- def max(self) -> Union["Series", "DataFrame"]:
+ def max(self) -> T_Frame:
"""
Calculate the expanding maximum.
@@ -1287,7 +1279,7 @@ class Expanding(RollingAndExpanding):
"""
return super().max()
- def mean(self) -> Union["Series", "DataFrame"]:
+ def mean(self) -> T_Frame:
"""
Calculate the expanding mean of the values.
@@ -1331,7 +1323,7 @@ class Expanding(RollingAndExpanding):
"""
return super().mean()
- def std(self) -> Union["Series", "DataFrame"]:
+ def std(self) -> T_Frame:
"""
Calculate expanding standard deviation.
@@ -1381,7 +1373,7 @@ class Expanding(RollingAndExpanding):
"""
return super().std()
- def var(self) -> Union["Series", "DataFrame"]:
+ def var(self) -> T_Frame:
"""
Calculate unbiased expanding variance.
@@ -1432,22 +1424,9 @@ class Expanding(RollingAndExpanding):
return super().var()
-class ExpandingGroupby(Expanding):
- def __init__(self, groupby: Union["SeriesGroupBy", "DataFrameGroupBy"],
min_periods: int = 1):
- from pyspark.pandas.groupby import SeriesGroupBy
- from pyspark.pandas.groupby import DataFrameGroupBy
-
- if isinstance(groupby, SeriesGroupBy):
- psdf_or_psser = groupby._psser # type: Union[DataFrame, Series]
- elif isinstance(groupby, DataFrameGroupBy):
- psdf_or_psser = groupby._psdf
- else:
- raise TypeError(
- "groupby must be a SeriesGroupBy or DataFrameGroupBy; "
- "however, got: %s" % type(groupby)
- )
-
- super().__init__(psdf_or_psser, min_periods)
+class ExpandingGroupby(ExpandingLike[T_Frame]):
+ def __init__(self, groupby: GroupBy[T_Frame], min_periods: int = 1):
+ super().__init__(min_periods)
self._groupby = groupby
self._window = self._window.partitionBy(*[ser.spark.column for ser in
groupby._groupkeys])
@@ -1464,9 +1443,9 @@ class ExpandingGroupby(Expanding):
return partial(property_or_func, self)
raise AttributeError(item)
- _apply_as_series_or_frame = RollingGroupby._apply_as_series_or_frame #
type: ignore
+ _apply_as_series_or_frame = RollingGroupby._apply_as_series_or_frame
- def count(self) -> Union["Series", "DataFrame"]:
+ def count(self) -> T_Frame:
"""
The expanding count of any non-NaN observations inside the window.
@@ -1520,7 +1499,7 @@ class ExpandingGroupby(Expanding):
"""
return super().count()
- def sum(self) -> Union["Series", "DataFrame"]:
+ def sum(self) -> T_Frame:
"""
Calculate expanding summation of given DataFrame or Series.
@@ -1574,7 +1553,7 @@ class ExpandingGroupby(Expanding):
"""
return super().sum()
- def min(self) -> Union["Series", "DataFrame"]:
+ def min(self) -> T_Frame:
"""
Calculate the expanding minimum.
@@ -1628,7 +1607,7 @@ class ExpandingGroupby(Expanding):
"""
return super().min()
- def max(self) -> Union["Series", "DataFrame"]:
+ def max(self) -> T_Frame:
"""
Calculate the expanding maximum.
@@ -1681,7 +1660,7 @@ class ExpandingGroupby(Expanding):
"""
return super().max()
- def mean(self) -> Union["Series", "DataFrame"]:
+ def mean(self) -> T_Frame:
"""
Calculate the expanding mean of the values.
@@ -1735,7 +1714,7 @@ class ExpandingGroupby(Expanding):
"""
return super().mean()
- def std(self) -> Union["Series", "DataFrame"]:
+ def std(self) -> T_Frame:
"""
Calculate expanding standard deviation.
@@ -1755,7 +1734,7 @@ class ExpandingGroupby(Expanding):
"""
return super().std()
- def var(self) -> Union["Series", "DataFrame"]:
+ def var(self) -> T_Frame:
"""
Calculate unbiased expanding variance.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]