This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2893cd304a9b [SPARK-46391][PS][TESTS] Reorganize `ExpandingParityTests`
2893cd304a9b is described below
commit 2893cd304a9b7a6782727da58204331a9083cdaf
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Dec 14 18:41:32 2023 +0800
[SPARK-46391][PS][TESTS] Reorganize `ExpandingParityTests`
### What changes were proposed in this pull request?
Reorganize `ExpandingParityTests`
### Why are the changes needed?
to make the test more consistent with pandas
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44332 from zhengruifeng/ps_test_expanding.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 12 ++-
.../connect/{ => window}/test_parity_expanding.py | 10 +-
.../test_parity_expanding_adv.py} | 12 ++-
.../test_parity_expanding_error.py} | 12 ++-
.../test_parity_groupby_expanding.py} | 12 ++-
.../test_parity_groupby_expanding_adv.py} | 12 ++-
.../pyspark/pandas/tests/window/test_expanding.py | 96 +++++++++++++++++++
.../test_expanding_adv.py} | 33 +++++--
.../test_expanding_error.py} | 28 ++++--
.../test_groupby_expanding.py} | 103 ++-------------------
.../test_groupby_expanding_adv.py} | 35 +++++--
11 files changed, 220 insertions(+), 145 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index fee9198dff42..22fdde139d28 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -727,7 +727,11 @@ pyspark_pandas = Module(
"pyspark.pandas.tests.test_dataframe_conversion",
"pyspark.pandas.tests.test_dataframe_spark_io",
"pyspark.pandas.tests.test_default_index",
- "pyspark.pandas.tests.test_expanding",
+ "pyspark.pandas.tests.window.test_expanding",
+ "pyspark.pandas.tests.window.test_expanding_adv",
+ "pyspark.pandas.tests.window.test_expanding_error",
+ "pyspark.pandas.tests.window.test_groupby_expanding",
+ "pyspark.pandas.tests.window.test_groupby_expanding_adv",
"pyspark.pandas.tests.test_extension",
"pyspark.pandas.tests.window.test_ewm_error",
"pyspark.pandas.tests.window.test_ewm_mean",
@@ -1135,7 +1139,11 @@ pyspark_pandas_connect_part2 = Module(
"pyspark.pandas.tests.connect.window.test_parity_groupby_rolling",
"pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_adv",
"pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_count",
- "pyspark.pandas.tests.connect.test_parity_expanding",
+ "pyspark.pandas.tests.connect.window.test_parity_expanding",
+ "pyspark.pandas.tests.connect.window.test_parity_expanding_adv",
+ "pyspark.pandas.tests.connect.window.test_parity_expanding_error",
+ "pyspark.pandas.tests.connect.window.test_parity_groupby_expanding",
+
"pyspark.pandas.tests.connect.window.test_parity_groupby_expanding_adv",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_rolling",
"pyspark.pandas.tests.connect.computation.test_parity_missing_data",
"pyspark.pandas.tests.connect.groupby.test_parity_index",
diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
b/python/pyspark/pandas/tests/connect/window/test_parity_expanding.py
similarity index 79%
copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_expanding.py
index 7f8b1a3cac2f..ac83a1c3b34c 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_expanding.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin
+from pyspark.pandas.tests.window.test_expanding import ExpandingMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
class ExpandingParityTests(
- ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils,
ReusedConnectTestCase
+ ExpandingMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa:
F401
+ from pyspark.pandas.tests.connect.window.test_parity_expanding import * #
noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_adv.py
similarity index 77%
copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_expanding_adv.py
index 7f8b1a3cac2f..0baec678bede 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_adv.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin
+from pyspark.pandas.tests.window.test_expanding_adv import ExpandingAdvMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class ExpandingParityTests(
- ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils,
ReusedConnectTestCase
+class ExpandingAdvParityTests(
+ ExpandingAdvMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa:
F401
+ from pyspark.pandas.tests.connect.window.test_parity_expanding_adv import
* # noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_error.py
similarity index 76%
copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py
copy to
python/pyspark/pandas/tests/connect/window/test_parity_expanding_error.py
index 7f8b1a3cac2f..a8531a02799c 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_error.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin
+from pyspark.pandas.tests.window.test_expanding_error import
ExpandingErrorMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class ExpandingParityTests(
- ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils,
ReusedConnectTestCase
+class ExpandingErrorParityTests(
+ ExpandingErrorMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa:
F401
+ from pyspark.pandas.tests.connect.window.test_parity_expanding_error
import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding.py
similarity index 76%
copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py
copy to
python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding.py
index 7f8b1a3cac2f..356bc5298264 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
+++
b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin
+from pyspark.pandas.tests.window.test_groupby_expanding import
GroupByExpandingMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class ExpandingParityTests(
- ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils,
ReusedConnectTestCase
+class GroupByExpandingParityTests(
+ GroupByExpandingMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa:
F401
+ from pyspark.pandas.tests.connect.window.test_parity_groupby_expanding
import * # noqa
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding_adv.py
similarity index 75%
copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py
copy to
python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding_adv.py
index 7f8b1a3cac2f..b743e335b154 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
+++
b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding_adv.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin
+from pyspark.pandas.tests.window.test_groupby_expanding_adv import
GroupByExpandingAdvMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class ExpandingParityTests(
- ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils,
ReusedConnectTestCase
+class GroupByExpandingAdvParityTests(
+ GroupByExpandingAdvMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa:
F401
+ from pyspark.pandas.tests.connect.window.test_parity_groupby_expanding_adv
import * # noqa
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/window/test_expanding.py
b/python/pyspark/pandas/tests/window/test_expanding.py
new file mode 100644
index 000000000000..ebe54ff21719
--- /dev/null
+++ b/python/pyspark/pandas/tests/window/test_expanding.py
@@ -0,0 +1,96 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+import pandas as pd
+
+import pyspark.pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+
+
+class ExpandingTestingFuncMixin:
+ def _test_expanding_func(self, ps_func, pd_func=None):
+ if not pd_func:
+ pd_func = ps_func
+ if isinstance(pd_func, str):
+ pd_func = self.convert_str_to_lambda(pd_func)
+ if isinstance(ps_func, str):
+ ps_func = self.convert_str_to_lambda(ps_func)
+ pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a")
+ psser = ps.from_pandas(pser)
+ self.assert_eq(ps_func(psser.expanding(2)),
pd_func(pser.expanding(2)), almost=True)
+ self.assert_eq(ps_func(psser.expanding(2)),
pd_func(pser.expanding(2)), almost=True)
+
+ # Multiindex
+ pser = pd.Series(
+ [1, 2, 3], index=pd.MultiIndex.from_tuples([("a", "x"), ("a",
"y"), ("b", "z")])
+ )
+ psser = ps.from_pandas(pser)
+ self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)))
+
+ pdf = pd.DataFrame(
+ {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]},
index=np.random.rand(4)
+ )
+ psdf = ps.from_pandas(pdf)
+ self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2)))
+ self.assert_eq(ps_func(psdf.expanding(2)).sum(),
pd_func(pdf.expanding(2)).sum())
+
+ # Multiindex column
+ columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
+ pdf.columns = columns
+ psdf.columns = columns
+ self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2)))
+
+
+class ExpandingMixin(ExpandingTestingFuncMixin):
+ def test_expanding_repr(self):
+ self.assertEqual(repr(ps.range(10).expanding(5)), "Expanding
[min_periods=5]")
+
+ def test_expanding_count(self):
+ self._test_expanding_func("count")
+
+ def test_expanding_min(self):
+ self._test_expanding_func("min")
+
+ def test_expanding_max(self):
+ self._test_expanding_func("max")
+
+ def test_expanding_mean(self):
+ self._test_expanding_func("mean")
+
+ def test_expanding_sum(self):
+ self._test_expanding_func("sum")
+
+
+class ExpandingTests(
+ ExpandingMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.window.test_expanding import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
b/python/pyspark/pandas/tests/window/test_expanding_adv.py
similarity index 55%
copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py
copy to python/pyspark/pandas/tests/window/test_expanding_adv.py
index 7f8b1a3cac2f..e537f1ecfbc0 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
+++ b/python/pyspark/pandas/tests/window/test_expanding_adv.py
@@ -14,24 +14,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import unittest
-from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.pandas.tests.window.test_expanding import
ExpandingTestingFuncMixin
-class ExpandingParityTests(
- ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils,
ReusedConnectTestCase
+class ExpandingAdvMixin(ExpandingTestingFuncMixin):
+ def test_expanding_quantile(self):
+ self._test_expanding_func(lambda x: x.quantile(0.5), lambda x:
x.quantile(0.5, "lower"))
+
+ def test_expanding_std(self):
+ self._test_expanding_func("std")
+
+ def test_expanding_var(self):
+ self._test_expanding_func("var")
+
+ def test_expanding_skew(self):
+ self._test_expanding_func("skew")
+
+ def test_expanding_kurt(self):
+ self._test_expanding_func("kurt")
+
+
+class ExpandingAdvTests(
+ ExpandingAdvMixin,
+ PandasOnSparkTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa:
F401
+ import unittest
+ from pyspark.pandas.tests.window.test_expanding_adv import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
b/python/pyspark/pandas/tests/window/test_expanding_error.py
similarity index 60%
copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py
copy to python/pyspark/pandas/tests/window/test_expanding_error.py
index 7f8b1a3cac2f..fa888f5f1696 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
+++ b/python/pyspark/pandas/tests/window/test_expanding_error.py
@@ -14,24 +14,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import unittest
-from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+import pyspark.pandas as ps
+from pyspark.pandas.window import Expanding
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
-class ExpandingParityTests(
- ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils,
ReusedConnectTestCase
+class ExpandingErrorMixin:
+ def test_expanding_error(self):
+ with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
+ ps.range(10).expanding(-1)
+
+ with self.assertRaisesRegex(
+ TypeError, "psdf_or_psser must be a series or dataframe; however,
got:.*int"
+ ):
+ Expanding(1, 2)
+
+
+class ExpandingErrorTests(
+ ExpandingErrorMixin,
+ PandasOnSparkTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa:
F401
+ import unittest
+ from pyspark.pandas.tests.window.test_expanding_error import * # noqa:
F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_expanding.py
b/python/pyspark/pandas/tests/window/test_groupby_expanding.py
similarity index 56%
rename from python/pyspark/pandas/tests/test_expanding.py
rename to python/pyspark/pandas/tests/window/test_groupby_expanding.py
index 5166f8132665..44fecd7e58eb 100644
--- a/python/pyspark/pandas/tests/test_expanding.py
+++ b/python/pyspark/pandas/tests/window/test_groupby_expanding.py
@@ -19,85 +19,10 @@ import numpy as np
import pandas as pd
import pyspark.pandas as ps
-from pyspark.pandas.window import Expanding
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-class ExpandingTestsMixin:
- def _test_expanding_func(self, ps_func, pd_func=None):
- if not pd_func:
- pd_func = ps_func
- if isinstance(pd_func, str):
- pd_func = self.convert_str_to_lambda(pd_func)
- if isinstance(ps_func, str):
- ps_func = self.convert_str_to_lambda(ps_func)
- pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a")
- psser = ps.from_pandas(pser)
- self.assert_eq(ps_func(psser.expanding(2)),
pd_func(pser.expanding(2)), almost=True)
- self.assert_eq(ps_func(psser.expanding(2)),
pd_func(pser.expanding(2)), almost=True)
-
- # Multiindex
- pser = pd.Series(
- [1, 2, 3], index=pd.MultiIndex.from_tuples([("a", "x"), ("a",
"y"), ("b", "z")])
- )
- psser = ps.from_pandas(pser)
- self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)))
-
- pdf = pd.DataFrame(
- {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]},
index=np.random.rand(4)
- )
- psdf = ps.from_pandas(pdf)
- self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2)))
- self.assert_eq(ps_func(psdf.expanding(2)).sum(),
pd_func(pdf.expanding(2)).sum())
-
- # Multiindex column
- columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
- pdf.columns = columns
- psdf.columns = columns
- self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2)))
-
- def test_expanding_error(self):
- with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
- ps.range(10).expanding(-1)
-
- with self.assertRaisesRegex(
- TypeError, "psdf_or_psser must be a series or dataframe; however,
got:.*int"
- ):
- Expanding(1, 2)
-
- def test_expanding_repr(self):
- self.assertEqual(repr(ps.range(10).expanding(5)), "Expanding
[min_periods=5]")
-
- def test_expanding_count(self):
- self._test_expanding_func("count")
-
- def test_expanding_min(self):
- self._test_expanding_func("min")
-
- def test_expanding_max(self):
- self._test_expanding_func("max")
-
- def test_expanding_mean(self):
- self._test_expanding_func("mean")
-
- def test_expanding_quantile(self):
- self._test_expanding_func(lambda x: x.quantile(0.5), lambda x:
x.quantile(0.5, "lower"))
-
- def test_expanding_sum(self):
- self._test_expanding_func("sum")
-
- def test_expanding_std(self):
- self._test_expanding_func("std")
-
- def test_expanding_var(self):
- self._test_expanding_func("var")
-
- def test_expanding_skew(self):
- self._test_expanding_func("skew")
-
- def test_expanding_kurt(self):
- self._test_expanding_func("kurt")
-
+class GroupByExpandingTestingFuncMixin:
def _test_groupby_expanding_func(self, ps_func, pd_func=None):
if not pd_func:
pd_func = ps_func
@@ -172,6 +97,8 @@ class ExpandingTestsMixin:
pd_func(pdf.groupby([("a", "x"), ("a",
"y")]).expanding(2)).sort_index(),
)
+
+class GroupByExpandingMixin(GroupByExpandingTestingFuncMixin):
def test_groupby_expanding_count(self):
self._test_groupby_expanding_func("count")
@@ -184,34 +111,20 @@ class ExpandingTestsMixin:
def test_groupby_expanding_mean(self):
self._test_groupby_expanding_func("mean")
- def test_groupby_expanding_quantile(self):
- self._test_groupby_expanding_func(
- lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")
- )
-
def test_groupby_expanding_sum(self):
self._test_groupby_expanding_func("sum")
- def test_groupby_expanding_std(self):
- self._test_groupby_expanding_func("std")
-
- def test_groupby_expanding_var(self):
- self._test_groupby_expanding_func("var")
-
- def test_groupby_expanding_skew(self):
- self._test_groupby_expanding_func("skew")
-
- def test_groupby_expanding_kurt(self):
- self._test_groupby_expanding_func("kurt")
-
-class ExpandingTests(ExpandingTestsMixin, PandasOnSparkTestCase, TestUtils):
+class GroupByExpandingTests(
+ GroupByExpandingMixin,
+ PandasOnSparkTestCase,
+):
pass
if __name__ == "__main__":
import unittest
- from pyspark.pandas.tests.test_expanding import * # noqa: F401
+ from pyspark.pandas.tests.window.test_groupby_expanding import * # noqa:
F401
try:
import xmlrunner
diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
b/python/pyspark/pandas/tests/window/test_groupby_expanding_adv.py
similarity index 50%
rename from python/pyspark/pandas/tests/connect/test_parity_expanding.py
rename to python/pyspark/pandas/tests/window/test_groupby_expanding_adv.py
index 7f8b1a3cac2f..22cb03dc0ff3 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py
+++ b/python/pyspark/pandas/tests/window/test_groupby_expanding_adv.py
@@ -14,24 +14,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import unittest
-from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.pandas.tests.window.test_groupby_expanding import
GroupByExpandingTestingFuncMixin
-class ExpandingParityTests(
- ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils,
ReusedConnectTestCase
+class GroupByExpandingAdvMixin(GroupByExpandingTestingFuncMixin):
+ def test_groupby_expanding_quantile(self):
+ self._test_groupby_expanding_func(
+ lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")
+ )
+
+ def test_groupby_expanding_std(self):
+ self._test_groupby_expanding_func("std")
+
+ def test_groupby_expanding_var(self):
+ self._test_groupby_expanding_func("var")
+
+ def test_groupby_expanding_skew(self):
+ self._test_groupby_expanding_func("skew")
+
+ def test_groupby_expanding_kurt(self):
+ self._test_groupby_expanding_func("kurt")
+
+
+class GroupByExpandingAdvTests(
+ GroupByExpandingAdvMixin,
+ PandasOnSparkTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa:
F401
+ import unittest
+ from pyspark.pandas.tests.window.test_groupby_expanding_adv import * #
noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
except ImportError:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]