This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 484d573933c [SPARK-38785][PYTHON][SQL] Implement
ExponentialMovingWindow
484d573933c is described below
commit 484d573933cd84d65ff0d3901f26c245766f46a5
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Apr 13 12:44:31 2022 +0900
[SPARK-38785][PYTHON][SQL] Implement ExponentialMovingWindow
### What changes were proposed in this pull request?
initial impl Series.ewm and DataFrame.ewm
| | supported | unsupported |
|---|---|---|
|function|mean|sum/std/var/cov/corr|
|params|com/span/halflife/alpha/min_periods|adjust/ignore_na/axis/method|
other unsupoorted functionality:
- chaining with groupby: `df.groupby('s1').ewm(com=0.5).mean()`
- `DatetimeIndex`
- do not support dataset containing NULL and NaN for now
### Why are the changes needed?
to support more pandas API
### Does this PR introduce _any_ user-facing change?
yes, new method is added:
```
In [3]: psdf.ewm(com=0.1).mean()
s1 s2
0 0.200000 2.000000
1 0.016667 1.083333
2 0.547368 2.827068
3 0.231557 1.165984
4 0.384688 0.105992
5 0.489517 0.009636
6 0.589956 0.000876
```
### How was this patch tested?
added testsuits
Closes #36063 from zhengruifeng/impl_ewma_in_sql.
Lead-authored-by: Ruifeng Zheng <[email protected]>
Co-authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../source/reference/pyspark.pandas/window.rst | 9 +
python/pyspark/pandas/generic.py | 52 +++++-
python/pyspark/pandas/missing/frame.py | 1 -
python/pyspark/pandas/missing/series.py | 1 -
python/pyspark/pandas/missing/window.py | 31 ++++
python/pyspark/pandas/tests/test_ewm.py | 126 ++++++++++++++
python/pyspark/pandas/tests/test_window.py | 72 ++++++++
python/pyspark/pandas/usage_logging/__init__.py | 2 +
python/pyspark/pandas/window.py | 186 +++++++++++++++++++++
.../catalyst/expressions/windowExpressions.scala | 45 +++++
.../spark/sql/api/python/PythonSQLUtils.scala | 4 +-
11 files changed, 525 insertions(+), 4 deletions(-)
diff --git a/python/docs/source/reference/pyspark.pandas/window.rst
b/python/docs/source/reference/pyspark.pandas/window.rst
index d8d9b6858fe..550036537f2 100644
--- a/python/docs/source/reference/pyspark.pandas/window.rst
+++ b/python/docs/source/reference/pyspark.pandas/window.rst
@@ -23,6 +23,7 @@ Window
Rolling objects are returned by ``.rolling`` calls:
:func:`pandas_on_spark.DataFrame.rolling`,
:func:`pandas_on_spark.Series.rolling`, etc.
Expanding objects are returned by ``.expanding`` calls:
:func:`pandas_on_spark.DataFrame.expanding`,
:func:`pandas_on_spark.Series.expanding`, etc.
+ExponentialMoving objects are returned by ``.ewm`` calls:
:func:`pandas_on_spark.DataFrame.ewm`, :func:`pandas_on_spark.Series.ewm`, etc.
Standard moving window functions
--------------------------------
@@ -47,3 +48,11 @@ Standard expanding window functions
Expanding.min
Expanding.max
Expanding.mean
+
+Exponential moving window functions
+-----------------------------------
+
+.. autosummary::
+ :toctree: api/
+
+ ExponentialMoving.mean
diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index 77501268684..71f8146fe42 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -82,7 +82,7 @@ if TYPE_CHECKING:
from pyspark.pandas.indexes.base import Index
from pyspark.pandas.groupby import GroupBy
from pyspark.pandas.series import Series
- from pyspark.pandas.window import Rolling, Expanding
+ from pyspark.pandas.window import Rolling, Expanding, ExponentialMoving
bool_type = bool
@@ -2619,6 +2619,56 @@ class Frame(object, metaclass=ABCMeta):
return Expanding(self, min_periods=min_periods)
+ # TODO: 'adjust', 'ignore_na', 'axis', 'method' parameter should be
implemented.
+ def ewm(
+ self: FrameLike,
+ com: Optional[float] = None,
+ span: Optional[float] = None,
+ halflife: Optional[float] = None,
+ alpha: Optional[float] = None,
+ min_periods: Optional[int] = None,
+ ) -> "ExponentialMoving[FrameLike]":
+ """
+ Provide exponentially weighted window transformations.
+
+ .. note:: 'min_periods' in pandas-on-Spark works as a fixed window
size unlike pandas.
+ Unlike pandas, NA is also counted as the period. This might be
changed
+ in the near future.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ com : float, optional
+ Specify decay in terms of center of mass.
+ alpha = 1 / (1 + com), for com >= 0.
+
+ span : float, optional
+ Specify decay in terms of span.
+ alpha = 2 / (span + 1), for span >= 1.
+
+ halflife : float, optional
+ Specify decay in terms of half-life.
+ alpha = 1 - exp(-ln(2) / halflife), for halflife > 0.
+
+ alpha : float, optional
+ Specify smoothing factor alpha directly.
+ 0 < alpha <= 1.
+
+ min_periods : int, default None
+ Minimum number of observations in window required to have a value
+ (otherwise result is NA).
+
+ Returns
+ -------
+ a Window sub-classed for the particular operation
+ """
+ from pyspark.pandas.window import ExponentialMoving
+
+ return ExponentialMoving(
+ self, com=com, span=span, halflife=halflife, alpha=alpha,
min_periods=min_periods
+ )
+
def get(self, key: Any, default: Optional[Any] = None) -> Any:
"""
Get item from object for given key (DataFrame column, Panel slice,
diff --git a/python/pyspark/pandas/missing/frame.py
b/python/pyspark/pandas/missing/frame.py
index 775115da8e4..0b6b1683497 100644
--- a/python/pyspark/pandas/missing/frame.py
+++ b/python/pyspark/pandas/missing/frame.py
@@ -39,7 +39,6 @@ class _MissingPandasLikeDataFrame:
compare = _unsupported_function("compare")
convert_dtypes = _unsupported_function("convert_dtypes")
corrwith = _unsupported_function("corrwith")
- ewm = _unsupported_function("ewm")
infer_objects = _unsupported_function("infer_objects")
interpolate = _unsupported_function("interpolate")
mode = _unsupported_function("mode")
diff --git a/python/pyspark/pandas/missing/series.py
b/python/pyspark/pandas/missing/series.py
index cd1730f5eb9..9bb191f1c81 100644
--- a/python/pyspark/pandas/missing/series.py
+++ b/python/pyspark/pandas/missing/series.py
@@ -36,7 +36,6 @@ class MissingPandasLikeSeries:
autocorr = _unsupported_function("autocorr")
combine = _unsupported_function("combine")
convert_dtypes = _unsupported_function("convert_dtypes")
- ewm = _unsupported_function("ewm")
infer_objects = _unsupported_function("infer_objects")
interpolate = _unsupported_function("interpolate")
reorder_levels = _unsupported_function("reorder_levels")
diff --git a/python/pyspark/pandas/missing/window.py
b/python/pyspark/pandas/missing/window.py
index fb79992e042..e6ac39901ff 100644
--- a/python/pyspark/pandas/missing/window.py
+++ b/python/pyspark/pandas/missing/window.py
@@ -54,6 +54,24 @@ def _unsupported_property_rolling(property_name,
deprecated=False, reason=""):
)
+def _unsupported_function_exponential_moving(method_name, deprecated=False,
reason=""):
+ return unsupported_function(
+ class_name="pandas.core.window.ExponentialMovingWindow",
+ method_name=method_name,
+ deprecated=deprecated,
+ reason=reason,
+ )
+
+
+def _unsupported_property_exponential_moving(property_name, deprecated=False,
reason=""):
+ return unsupported_property(
+ class_name="pandas.core.window.ExponentialMovingWindow",
+ property_name=property_name,
+ deprecated=deprecated,
+ reason=reason,
+ )
+
+
class MissingPandasLikeExpanding:
agg = _unsupported_function_expanding("agg")
aggregate = _unsupported_function_expanding("aggregate")
@@ -124,3 +142,16 @@ class MissingPandasLikeRollingGroupby:
is_datetimelike = _unsupported_property_rolling("is_datetimelike")
is_freq_type = _unsupported_property_rolling("is_freq_type")
ndim = _unsupported_property_rolling("ndim")
+
+
+class MissingPandasLikeExponentialMoving:
+ sum = _unsupported_function_exponential_moving("sum")
+ var = _unsupported_function_exponential_moving("var")
+ std = _unsupported_function_exponential_moving("std")
+ cov = _unsupported_function_exponential_moving("cov")
+ corr = _unsupported_function_exponential_moving("corr")
+
+ adjust = _unsupported_property_exponential_moving("adjust")
+ ignore_na = _unsupported_property_exponential_moving("ignore_na")
+ axis = _unsupported_property_exponential_moving("axis")
+ method = _unsupported_property_exponential_moving("method")
diff --git a/python/pyspark/pandas/tests/test_ewm.py
b/python/pyspark/pandas/tests/test_ewm.py
new file mode 100644
index 00000000000..7306aad44ff
--- /dev/null
+++ b/python/pyspark/pandas/tests/test_ewm.py
@@ -0,0 +1,126 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import numpy as np
+import pandas as pd
+
+import pyspark.pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.pandas.window import ExponentialMoving
+
+
+class EWMTest(PandasOnSparkTestCase, TestUtils):
+ def test_ewm_error(self):
+ with self.assertRaisesRegex(
+ TypeError, "psdf_or_psser must be a series or dataframe; however,
got:.*int"
+ ):
+ ExponentialMoving(1, 2)
+
+ psdf = ps.range(10)
+
+ with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
+ psdf.ewm(min_periods=-1, alpha=0.5).mean()
+
+ with self.assertRaisesRegex(ValueError, "com must be >= 0"):
+ psdf.ewm(com=-0.1).mean()
+
+ with self.assertRaisesRegex(ValueError, "span must be >= 1"):
+ psdf.ewm(span=0.7).mean()
+
+ with self.assertRaisesRegex(ValueError, "halflife must be > 0"):
+ psdf.ewm(halflife=0).mean()
+
+ with self.assertRaisesRegex(ValueError, "alpha must be in"):
+ psdf.ewm(alpha=1.7).mean()
+
+ with self.assertRaisesRegex(ValueError, "Must pass one of com, span,
halflife, or alpha"):
+ psdf.ewm().mean()
+
+ with self.assertRaisesRegex(
+ ValueError, "com, span, halflife, and alpha are mutually exclusive"
+ ):
+ psdf.ewm(com=0.5, alpha=0.7).mean()
+
+ def _test_ewm_func(self, f):
+ pser = pd.Series([1, 2, 3], index=np.random.rand(3), name="a")
+ psser = ps.from_pandas(pser)
+ self.assert_eq(getattr(psser.ewm(com=0.2), f)(),
getattr(pser.ewm(com=0.2), f)())
+ self.assert_eq(
+ getattr(psser.ewm(com=0.2), f)().sum(), getattr(pser.ewm(com=0.2),
f)().sum()
+ )
+ self.assert_eq(getattr(psser.ewm(span=1.7), f)(),
getattr(pser.ewm(span=1.7), f)())
+ self.assert_eq(
+ getattr(psser.ewm(span=1.7), f)().sum(),
getattr(pser.ewm(span=1.7), f)().sum()
+ )
+ self.assert_eq(getattr(psser.ewm(halflife=0.5), f)(),
getattr(pser.ewm(halflife=0.5), f)())
+ self.assert_eq(
+ getattr(psser.ewm(halflife=0.5), f)().sum(),
getattr(pser.ewm(halflife=0.5), f)().sum()
+ )
+ self.assert_eq(getattr(psser.ewm(alpha=0.7), f)(),
getattr(pser.ewm(alpha=0.7), f)())
+ self.assert_eq(
+ getattr(psser.ewm(alpha=0.7), f)().sum(),
getattr(pser.ewm(alpha=0.7), f)().sum()
+ )
+ self.assert_eq(
+ getattr(psser.ewm(alpha=0.7, min_periods=2), f)(),
+ getattr(pser.ewm(alpha=0.7, min_periods=2), f)(),
+ )
+ self.assert_eq(
+ getattr(psser.ewm(alpha=0.7, min_periods=2), f)().sum(),
+ getattr(pser.ewm(alpha=0.7, min_periods=2), f)().sum(),
+ )
+
+ pdf = pd.DataFrame(
+ {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]},
index=np.random.rand(4)
+ )
+ psdf = ps.from_pandas(pdf)
+ self.assert_eq(getattr(psdf.ewm(com=0.2), f)(),
getattr(pdf.ewm(com=0.2), f)())
+ self.assert_eq(getattr(psdf.ewm(com=0.2), f)().sum(),
getattr(pdf.ewm(com=0.2), f)().sum())
+ self.assert_eq(getattr(psdf.ewm(span=1.7), f)(),
getattr(pdf.ewm(span=1.7), f)())
+ self.assert_eq(
+ getattr(psdf.ewm(span=1.7), f)().sum(), getattr(pdf.ewm(span=1.7),
f)().sum()
+ )
+ self.assert_eq(getattr(psdf.ewm(halflife=0.5), f)(),
getattr(pdf.ewm(halflife=0.5), f)())
+ self.assert_eq(
+ getattr(psdf.ewm(halflife=0.5), f)().sum(),
getattr(pdf.ewm(halflife=0.5), f)().sum()
+ )
+ self.assert_eq(getattr(psdf.ewm(alpha=0.7), f)(),
getattr(pdf.ewm(alpha=0.7), f)())
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7), f)().sum(),
getattr(pdf.ewm(alpha=0.7), f)().sum()
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, min_periods=2), f)(),
+ getattr(pdf.ewm(alpha=0.7, min_periods=2), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, min_periods=2), f)().sum(),
+ getattr(pdf.ewm(alpha=0.7, min_periods=2), f)().sum(),
+ )
+
+ def test_ewm_mean(self):
+ self._test_ewm_func("mean")
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.test_ewm import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/test_window.py
b/python/pyspark/pandas/tests/test_window.py
index 5f84b3e5245..974b99d8e4d 100644
--- a/python/pyspark/pandas/tests/test_window.py
+++ b/python/pyspark/pandas/tests/test_window.py
@@ -24,6 +24,7 @@ from pyspark.pandas.missing.window import (
MissingPandasLikeRolling,
MissingPandasLikeExpandingGroupby,
MissingPandasLikeRollingGroupby,
+ MissingPandasLikeExponentialMoving,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
@@ -94,6 +95,40 @@ class ExpandingRollingTest(PandasOnSparkTestCase, TestUtils):
):
getattr(psdf.a.rolling(1), name)() # Series
+ # ExponentialMoving functions
+ missing_functions = inspect.getmembers(
+ MissingPandasLikeExponentialMoving, inspect.isfunction
+ )
+ unsupported_functions = [
+ name for (name, type_) in missing_functions if type_.__name__ ==
"unsupported_function"
+ ]
+ for name in unsupported_functions:
+ with self.assertRaisesRegex(
+ PandasNotImplementedError,
+ "method.*ExponentialMoving.*{}.*not implemented( yet\\.|\\.
.+)".format(name),
+ ):
+ getattr(psdf.ewm(com=0.5), name)() # Frame
+ with self.assertRaisesRegex(
+ PandasNotImplementedError,
+ "method.*ExponentialMoving.*{}.*not implemented( yet\\.|\\.
.+)".format(name),
+ ):
+ getattr(psdf.a.ewm(com=0.5), name)() # Series
+
+ deprecated_functions = [
+ name for (name, type_) in missing_functions if type_.__name__ ==
"deprecated_function"
+ ]
+ for name in deprecated_functions:
+ with self.assertRaisesRegex(
+ PandasNotImplementedError,
+ "method.*ExponentialMoving.*{}.*is deprecated".format(name),
+ ):
+ getattr(psdf.ewm(com=0.5), name)() # Frame
+ with self.assertRaisesRegex(
+ PandasNotImplementedError,
+ "method.*ExponentialMoving.*{}.*is deprecated".format(name),
+ ):
+ getattr(psdf.a.ewm(com=0.5), name)() # Series
+
# Expanding properties
missing_properties = inspect.getmembers(
MissingPandasLikeExpanding, lambda o: isinstance(o, property)
@@ -165,6 +200,43 @@ class ExpandingRollingTest(PandasOnSparkTestCase,
TestUtils):
):
getattr(psdf.a.rolling(1), name)() # Series
+ # ExponentialMoving properties
+ missing_properties = inspect.getmembers(
+ MissingPandasLikeExponentialMoving, lambda o: isinstance(o,
property)
+ )
+ unsupported_properties = [
+ name
+ for (name, type_) in missing_properties
+ if type_.fget.__name__ == "unsupported_property"
+ ]
+ for name in unsupported_properties:
+ with self.assertRaisesRegex(
+ PandasNotImplementedError,
+ "property.*ExponentialMoving.*{}.*not implemented( yet\\.|\\.
.+)".format(name),
+ ):
+ getattr(psdf.ewm(com=0.5), name)() # Frame
+ with self.assertRaisesRegex(
+ PandasNotImplementedError,
+ "property.*ExponentialMoving.*{}.*not implemented( yet\\.|\\.
.+)".format(name),
+ ):
+ getattr(psdf.a.ewm(com=0.5), name)() # Series
+ deprecated_properties = [
+ name
+ for (name, type_) in missing_properties
+ if type_.fget.__name__ == "deprecated_property"
+ ]
+ for name in deprecated_properties:
+ with self.assertRaisesRegex(
+ PandasNotImplementedError,
+ "property.*ExponentialMoving.*{}.*is deprecated".format(name),
+ ):
+ getattr(psdf.ewm(com=0.5), name)() # Frame
+ with self.assertRaisesRegex(
+ PandasNotImplementedError,
+ "property.*ExponentialMoving.*{}.*is deprecated".format(name),
+ ):
+ getattr(psdf.a.ewm(com=0.5), name)() # Series
+
def test_missing_groupby(self):
psdf = ps.DataFrame({"a": [1, 2, 3, 4, 5, 6, 7, 8, 9]})
diff --git a/python/pyspark/pandas/usage_logging/__init__.py
b/python/pyspark/pandas/usage_logging/__init__.py
index 7f082623c03..54a9460f292 100644
--- a/python/pyspark/pandas/usage_logging/__init__.py
+++ b/python/pyspark/pandas/usage_logging/__init__.py
@@ -47,6 +47,7 @@ from pyspark.pandas.missing.window import (
MissingPandasLikeRolling,
MissingPandasLikeExpandingGroupby,
MissingPandasLikeRollingGroupby,
+ MissingPandasLikeExponentialMoving,
)
from pyspark.pandas.series import Series
from pyspark.pandas.spark.accessors import (
@@ -122,6 +123,7 @@ def attach(logger_module: Union[str, ModuleType]) -> None:
(pd.core.window.Rolling, MissingPandasLikeRolling),
(pd.core.window.ExpandingGroupby, MissingPandasLikeExpandingGroupby),
(pd.core.window.RollingGroupby, MissingPandasLikeRollingGroupby),
+ (pd.core.window.ExponentialMovingWindow,
MissingPandasLikeExponentialMoving),
]
_attach(logger_module, modules, classes, missings) # type:
ignore[arg-type]
diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py
index 122cde624ea..4c5ababf0c8 100644
--- a/python/pyspark/pandas/window.py
+++ b/python/pyspark/pandas/window.py
@@ -18,6 +18,9 @@ from abc import ABCMeta, abstractmethod
from functools import partial
from typing import Any, Callable, Generic, List, Optional
+import numpy as np
+
+from pyspark import SparkContext
from pyspark.sql import Window
from pyspark.sql import functions as F
from pyspark.pandas.missing.window import (
@@ -25,6 +28,7 @@ from pyspark.pandas.missing.window import (
MissingPandasLikeRollingGroupby,
MissingPandasLikeExpanding,
MissingPandasLikeExpandingGroupby,
+ MissingPandasLikeExponentialMoving,
)
# For running doctests and reference resolution in PyCharm.
@@ -1749,6 +1753,188 @@ class ExpandingGroupby(ExpandingLike[FrameLike]):
return super().var()
+class ExponentialMovingLike(Generic[FrameLike], metaclass=ABCMeta):
+ def __init__(
+ self,
+ window: WindowSpec,
+ com: Optional[float] = None,
+ span: Optional[float] = None,
+ halflife: Optional[float] = None,
+ alpha: Optional[float] = None,
+ min_periods: Optional[int] = None,
+ ):
+ if (min_periods is not None) and (min_periods < 0):
+ raise ValueError("min_periods must be >= 0")
+ if min_periods is None:
+ min_periods = 0
+ self._min_periods = min_periods
+
+ self._window = window
+ # This unbounded Window is later used to handle 'min_periods' for now.
+ self._unbounded_window =
Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
+ Window.unboundedPreceding, Window.currentRow
+ )
+
+ if (com is not None) and (not com >= 0):
+ raise ValueError("com must be >= 0")
+ self._com = com
+
+ if (span is not None) and (not span >= 1):
+ raise ValueError("span must be >= 1")
+ self._span = span
+
+ if (halflife is not None) and (not halflife > 0):
+ raise ValueError("halflife must be > 0")
+ self._halflife = halflife
+
+ if (alpha is not None) and (not 0 < alpha <= 1):
+ raise ValueError("alpha must be in (0, 1]")
+ self._alpha = alpha
+
+ def _compute_unified_alpha(self) -> float:
+ unified_alpha = np.nan
+ opt_count = 0
+
+ if self._com is not None:
+ unified_alpha = 1.0 / (1 + self._com)
+ opt_count += 1
+ if self._span is not None:
+ unified_alpha = 2.0 / (1 + self._span)
+ opt_count += 1
+ if self._halflife is not None:
+ unified_alpha = 1.0 - np.exp(-np.log(2) / self._halflife)
+ opt_count += 1
+ if self._alpha is not None:
+ unified_alpha = self._alpha
+ opt_count += 1
+
+ if opt_count == 0:
+ raise ValueError("Must pass one of com, span, halflife, or alpha")
+ if opt_count != 1:
+ raise ValueError("com, span, halflife, and alpha are mutually
exclusive")
+
+ return unified_alpha
+
+ @abstractmethod
+ def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) ->
FrameLike:
+ """
+ Wraps a function that handles Spark column in order
+ to support it in both pandas-on-Spark Series and DataFrame.
+ Note that the given `func` name should be same as the API's method
name.
+ """
+ pass
+
+ def mean(self) -> FrameLike:
+ unified_alpha = self._compute_unified_alpha()
+
+ def mean(scol: Column) -> Column:
+ jf = SparkContext._active_spark_context._jvm.PythonSQLUtils.ewm
+ return F.when(
+ F.row_number().over(self._unbounded_window) >=
self._min_periods,
+ Column(jf(scol._jc, unified_alpha)).over(self._window),
+ ).otherwise(SF.lit(None))
+
+ return self._apply_as_series_or_frame(mean)
+
+
+class ExponentialMoving(ExponentialMovingLike[FrameLike]):
+ def __init__(
+ self,
+ psdf_or_psser: FrameLike,
+ com: Optional[float] = None,
+ span: Optional[float] = None,
+ halflife: Optional[float] = None,
+ alpha: Optional[float] = None,
+ min_periods: Optional[int] = None,
+ ):
+ from pyspark.pandas.frame import DataFrame
+ from pyspark.pandas.series import Series
+
+ if not isinstance(psdf_or_psser, (DataFrame, Series)):
+ raise TypeError(
+ "psdf_or_psser must be a series or dataframe; however, got: %s"
+ % type(psdf_or_psser)
+ )
+ self._psdf_or_psser = psdf_or_psser
+
+ window_spec = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
+ Window.unboundedPreceding, Window.currentRow
+ )
+
+ super().__init__(window_spec, com, span, halflife, alpha, min_periods)
+
+ def __getattr__(self, item: str) -> Any:
+ if hasattr(MissingPandasLikeExponentialMoving, item):
+ property_or_func = getattr(MissingPandasLikeExponentialMoving,
item)
+ if isinstance(property_or_func, property):
+ return property_or_func.fget(self)
+ else:
+ return partial(property_or_func, self)
+ raise AttributeError(item)
+
+ _apply_as_series_or_frame = Rolling._apply_as_series_or_frame
+
+ def mean(self) -> FrameLike:
+ """
+ Calculate an online exponentially weighted mean.
+
+ Notes
+ -----
+ There are behavior differences between pandas-on-Spark and pandas.
+
+ * the data should not contain NaNs. pandas-on-Spark will return an
error.
+ * the current implementation of this API uses Spark's Window without
+ specifying partition specification. This leads to move all data into
+ single partition in single machine and could cause serious
+ performance degradation. Avoid this method against very large
dataset.
+
+ Returns
+ -------
+ Series or DataFrame
+ Returned object type is determined by the caller of the
exponentially
+ calculation.
+
+ See Also
+ --------
+ Series.expanding : Calling object with Series data.
+ DataFrame.expanding : Calling object with DataFrames.
+ Series.mean : Equivalent method for Series.
+ DataFrame.mean : Equivalent method for DataFrame.
+
+ Examples
+ --------
+ The below examples will show computing exponentially weighted moving
average.
+
+ >>> df = ps.DataFrame({'s1': [.2, .0, .6, .2, .4, .5, .6], 's2': [2,
1, 3, 1, 0, 0, 0]})
+ >>> df.ewm(com=0.1).mean()
+ s1 s2
+ 0 0.200000 2.000000
+ 1 0.016667 1.083333
+ 2 0.547368 2.827068
+ 3 0.231557 1.165984
+ 4 0.384688 0.105992
+ 5 0.489517 0.009636
+ 6 0.589956 0.000876
+
+ >>> df.s2.ewm(halflife=1.5, min_periods=3).mean()
+ 0 NaN
+ 1 NaN
+ 2 2.182572
+ 3 1.663174
+ 4 0.979949
+ 5 0.593155
+ 6 0.364668
+ Name: s2, dtype: float64
+ """
+ return super().mean()
+
+ # TODO: when add 'adjust' and 'ignore_na' parameter, should add to here
too.
+ def __repr__(self) -> str:
+ return "ExponentialMoving [com={}, span={}, halflife={}, alpha={},
min_periods={}]".format(
+ self._com, self._span, self._halflife, self._alpha,
self._min_periods
+ )
+
+
def _test() -> None:
import os
import doctest
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index c701d10b00b..47a620aa16f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -1014,3 +1014,48 @@ case class PercentRank(children: Seq[Expression])
extends RankLike with SizeBase
override protected def withNewChildrenInternal(newChildren:
IndexedSeq[Expression]): PercentRank =
copy(children = newChildren)
}
+
+/**
+ * Exponential Weighted Moment. This expression is dedicated only for Pandas
API on Spark.
+ * An exponentially weighted window is similar to an expanding window but with
each prior point
+ * being exponentially weighted down relative to the current point.
+ * See
https://pandas.pydata.org/docs/user_guide/window.html#exponentially-weighted-window
+ * for details.
+ * Currently, only weighted moving average is supported. In general, it is
calculated as
+ * y_t = \frac{\sum_{i=0}^t w_i x_{t-i}}{\sum_{i=0}^t w_i},
+ * where x_t is the input, y_t is the result and the w_i are the weights.
+ */
+case class EWM(input: Expression, alpha: Double)
+ extends AggregateWindowFunction with UnaryLike[Expression] {
+ assert(0 < alpha && alpha <= 1)
+
+ override def dataType: DataType = DoubleType
+
+ private val numerator = AttributeReference("numerator", DoubleType, nullable
= false)()
+ private val denominator = AttributeReference("denominator", DoubleType,
nullable = false)()
+ override def aggBufferAttributes: Seq[AttributeReference] = numerator ::
denominator :: Nil
+
+ override val initialValues: Seq[Expression] = Seq(Literal(0.0), Literal(0.0))
+
+ override val updateExpressions: Seq[Expression] = {
+ val beta = Literal(1.0 - alpha)
+ val casted = input.cast(DoubleType)
+ // TODO: after adding param ignore_na, we can remove this check
+ val error = RaiseError(Literal("Input values must not be null or
NaN")).cast(DoubleType)
+ val validated = If(IsNull(casted) || IsNaN(casted), error, casted)
+ Seq(
+ /* numerator = */ numerator * beta + validated,
+ /* denominator = */ denominator * beta + Literal(1.0)
+ )
+ }
+
+ override val evaluateExpression: Expression = numerator / denominator
+
+ override def prettyName: String = "ewm"
+
+ override def sql: String = s"$prettyName(${input.sql})"
+
+ override def child: Expression = input
+
+ override protected def withNewChildInternal(newChild: Expression): EWM =
copy(input = newChild)
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 0a4eb051095..f71ed2818fc 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -29,7 +29,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
-import org.apache.spark.sql.catalyst.expressions.{CastTimestampNTZToLong,
ExpressionInfo, GenericRowWithSchema}
+import org.apache.spark.sql.catalyst.expressions.{CastTimestampNTZToLong, EWM,
ExpressionInfo, GenericRowWithSchema}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.execution.{ExplainMode, QueryExecution}
import org.apache.spark.sql.execution.arrow.ArrowConverters
@@ -94,6 +94,8 @@ private[sql] object PythonSQLUtils extends Logging {
}
def castTimestampNTZToLong(c: Column): Column =
Column(CastTimestampNTZToLong(c.expr))
+
+ def ewm(e: Column, alpha: Double): Column = Column(EWM(e.expr, alpha))
}
/**
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]