This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c10b2c0fbaeb [SPARK-46471][PS][TESTS] Reorganize
`OpsOnDiffFramesEnabledTests`: Factor out `test_arithmetic_*`
c10b2c0fbaeb is described below
commit c10b2c0fbaeb5a497599bb77c11577e78266904a
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Dec 21 16:53:58 2023 +0800
[SPARK-46471][PS][TESTS] Reorganize `OpsOnDiffFramesEnabledTests`: Factor
out `test_arithmetic_*`
### What changes were proposed in this pull request?
Factor out `test_arithmetic_*` from `OpsOnDiffFramesEnabledTests`
### Why are the changes needed?
`OpsOnDiffFramesEnabledTests` and its parity test are slow:
```
Starting test(python3.9):
pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames (temp output:
/__w/spark/spark/python/target/6b1d192e-052f-42d4-9023-04df84120fce/python3.9__pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames__gycsek91.log)
Finished test(python3.9):
pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames (740s)
```
break it into small tests to be more suitable for parallelism
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44435 from zhengruifeng/ps_test_diff_ops_arithmetic.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 6 +
.../diff_frames_ops/test_parity_arithmetic.py | 41 ++++++
.../diff_frames_ops/test_parity_arithmetic_ext.py | 41 ++++++
.../test_parity_arithmetic_ext_float.py | 41 ++++++
.../tests/diff_frames_ops/test_arithmetic.py | 156 +++++++++++++++++++++
.../tests/diff_frames_ops/test_arithmetic_ext.py | 99 +++++++++++++
.../diff_frames_ops/test_arithmetic_ext_float.py | 99 +++++++++++++
.../pandas/tests/test_ops_on_diff_frames.py | 89 ------------
8 files changed, 483 insertions(+), 89 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index cbd3b35c0015..7a5ac426dc7c 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -864,6 +864,9 @@ pyspark_pandas_slow = Module(
"pyspark.pandas.tests.test_indexing",
"pyspark.pandas.tests.test_ops_on_diff_frames",
"pyspark.pandas.tests.diff_frames_ops.test_align",
+ "pyspark.pandas.tests.diff_frames_ops.test_arithmetic",
+ "pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext",
+ "pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext_float",
"pyspark.pandas.tests.diff_frames_ops.test_basic_slow",
"pyspark.pandas.tests.diff_frames_ops.test_cov",
"pyspark.pandas.tests.diff_frames_ops.test_corrwith",
@@ -1223,6 +1226,9 @@ pyspark_pandas_connect_part3 = Module(
"pyspark.pandas.tests.connect.indexes.test_parity_datetime_property",
"pyspark.pandas.tests.connect.indexes.test_parity_datetime_round",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames",
+ "pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic",
+
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_ext",
+
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_ext_float",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_groupby",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_groupby_aggregate",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_groupby_apply",
diff --git
a/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic.py
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic.py
new file mode 100644
index 000000000000..669d6ace2404
--- /dev/null
+++
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic import
ArithmeticMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class ArithmeticParityTests(
+ ArithmeticMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic
import * # noqa
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git
a/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_ext.py
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_ext.py
new file mode 100644
index 000000000000..16a93d1f15b7
--- /dev/null
+++
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_ext.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext import
ArithmeticExtMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class ArithmeticExtParityTests(
+ ArithmeticExtMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from
pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_ext import
* # noqa
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git
a/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_ext_float.py
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_ext_float.py
new file mode 100644
index 000000000000..75e8bc3ae0df
--- /dev/null
+++
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_ext_float.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext_float import
ArithmeticExtFloatMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class ArithmeticExtFloatParityTests(
+ ArithmeticExtFloatMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from
pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_ext_float
import * # noqa
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic.py
b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic.py
new file mode 100644
index 000000000000..8af0e80c6e60
--- /dev/null
+++ b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic.py
@@ -0,0 +1,156 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.pandas.config import set_option, reset_option
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.pandas.typedef.typehints import extension_float_dtypes_available
+
+
+class ArithmeticTestingFuncMixin:
+ def _test_arithmetic_frame(self, pdf1, pdf2, *, check_extension):
+ psdf1 = ps.from_pandas(pdf1)
+ psdf2 = ps.from_pandas(pdf2)
+
+ # Series
+ self.assert_eq((psdf1.a - psdf2.b).sort_index(), (pdf1.a -
pdf2.b).sort_index())
+
+ self.assert_eq((psdf1.a * psdf2.a).sort_index(), (pdf1.a *
pdf2.a).sort_index())
+
+ if check_extension and not extension_float_dtypes_available:
+ self.assert_eq(
+ (psdf1["a"] / psdf2["a"]).sort_index(), (pdf1["a"] /
pdf2["a"]).sort_index()
+ )
+ else:
+ self.assert_eq(
+ (psdf1["a"] / psdf2["a"]).sort_index(), (pdf1["a"] /
pdf2["a"]).sort_index()
+ )
+
+ # DataFrame
+ self.assert_eq((psdf1 + psdf2).sort_index(), (pdf1 +
pdf2).sort_index())
+
+ # Multi-index columns
+ columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b")])
+ psdf1.columns = columns
+ psdf2.columns = columns
+ pdf1.columns = columns
+ pdf2.columns = columns
+
+ # Series
+ self.assert_eq(
+ (psdf1[("x", "a")] - psdf2[("x", "b")]).sort_index(),
+ (pdf1[("x", "a")] - pdf2[("x", "b")]).sort_index(),
+ )
+
+ self.assert_eq(
+ (psdf1[("x", "a")] - psdf2["x"]["b"]).sort_index(),
+ (pdf1[("x", "a")] - pdf2["x"]["b"]).sort_index(),
+ )
+
+ self.assert_eq(
+ (psdf1["x"]["a"] - psdf2[("x", "b")]).sort_index(),
+ (pdf1["x"]["a"] - pdf2[("x", "b")]).sort_index(),
+ )
+
+ # DataFrame
+ self.assert_eq((psdf1 + psdf2).sort_index(), (pdf1 +
pdf2).sort_index())
+
+ def _test_arithmetic_series(self, pser1, pser2, *, check_extension):
+ psser1 = ps.from_pandas(pser1)
+ psser2 = ps.from_pandas(pser2)
+
+ # MultiIndex Series
+ self.assert_eq((psser1 + psser2).sort_index(), (pser1 +
pser2).sort_index())
+
+ self.assert_eq((psser1 - psser2).sort_index(), (pser1 -
pser2).sort_index())
+
+ self.assert_eq((psser1 * psser2).sort_index(), (pser1 *
pser2).sort_index())
+
+ if check_extension and not extension_float_dtypes_available:
+ self.assert_eq((psser1 / psser2).sort_index(), (pser1 /
pser2).sort_index())
+ else:
+ self.assert_eq((psser1 / psser2).sort_index(), (pser1 /
pser2).sort_index())
+
+
+class ArithmeticMixin(ArithmeticTestingFuncMixin):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ set_option("compute.ops_on_diff_frames", True)
+
+ @classmethod
+ def tearDownClass(cls):
+ reset_option("compute.ops_on_diff_frames")
+ super().tearDownClass()
+
+ @property
+ def pdf1(self):
+ return pd.DataFrame(
+ {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0,
0]},
+ index=[0, 1, 3, 5, 6, 8, 9, 10, 11],
+ )
+
+ @property
+ def pdf2(self):
+ return pd.DataFrame(
+ {"a": [9, 8, 7, 6, 5, 4, 3, 2, 1], "b": [0, 0, 0, 4, 5, 6, 1, 2,
3]},
+ index=list(range(9)),
+ )
+
+ @property
+ def pser1(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon", "koala"], ["speed", "weight", "length",
"power"]],
+ [[0, 3, 1, 1, 1, 2, 2, 2], [0, 2, 0, 3, 2, 0, 1, 3]],
+ )
+ return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1], index=midx)
+
+ @property
+ def pser2(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon"], ["speed", "weight", "length"]],
+ [[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]],
+ )
+ return pd.Series([-45, 200, -1.2, 30, -250, 1.5, 320, 1, -0.3],
index=midx)
+
+ def test_arithmetic(self):
+ self._test_arithmetic_frame(self.pdf1, self.pdf2,
check_extension=False)
+ self._test_arithmetic_series(self.pser1, self.pser2,
check_extension=False)
+
+
+class ArithmeticTests(
+ ArithmeticMixin,
+ PandasOnSparkTestCase,
+ SQLTestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.diff_frames_ops.test_arithmetic import * #
noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_ext.py
b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_ext.py
new file mode 100644
index 000000000000..9d06e74ba24e
--- /dev/null
+++ b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_ext.py
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import pandas as pd
+
+from pyspark.pandas.config import set_option, reset_option
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.pandas.typedef.typehints import extension_dtypes_available
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic import
ArithmeticMixin
+
+
+class ArithmeticExtMixin(ArithmeticMixin):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ set_option("compute.ops_on_diff_frames", True)
+
+ @classmethod
+ def tearDownClass(cls):
+ reset_option("compute.ops_on_diff_frames")
+ super().tearDownClass()
+
+ @property
+ def pdf1(self):
+ return pd.DataFrame(
+ {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0,
0]},
+ index=[0, 1, 3, 5, 6, 8, 9, 10, 11],
+ )
+
+ @property
+ def pdf2(self):
+ return pd.DataFrame(
+ {"a": [9, 8, 7, 6, 5, 4, 3, 2, 1], "b": [0, 0, 0, 4, 5, 6, 1, 2,
3]},
+ index=list(range(9)),
+ )
+
+ @property
+ def pser1(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon", "koala"], ["speed", "weight", "length",
"power"]],
+ [[0, 3, 1, 1, 1, 2, 2, 2], [0, 2, 0, 3, 2, 0, 1, 3]],
+ )
+ return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1], index=midx)
+
+ @property
+ def pser2(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon"], ["speed", "weight", "length"]],
+ [[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]],
+ )
+ return pd.Series([-45, 200, -1.2, 30, -250, 1.5, 320, 1, -0.3],
index=midx)
+
+ @unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes
are not available")
+ def test_arithmetic_extension_dtypes(self):
+ self._test_arithmetic_frame(
+ self.pdf1.astype("Int64"), self.pdf2.astype("Int64"),
check_extension=True
+ )
+ self._test_arithmetic_series(
+ self.pser1.astype(int).astype("Int64"),
+ self.pser2.astype(int).astype("Int64"),
+ check_extension=True,
+ )
+
+
+class ArithmeticExtTests(
+ ArithmeticExtMixin,
+ PandasOnSparkTestCase,
+ SQLTestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext import * #
noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git
a/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_ext_float.py
b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_ext_float.py
new file mode 100644
index 000000000000..2d21bd37bc07
--- /dev/null
+++ b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_ext_float.py
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import pandas as pd
+
+from pyspark.pandas.config import set_option, reset_option
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.pandas.typedef.typehints import extension_float_dtypes_available
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic import
ArithmeticMixin
+
+
+class ArithmeticExtFloatMixin(ArithmeticMixin):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ set_option("compute.ops_on_diff_frames", True)
+
+ @classmethod
+ def tearDownClass(cls):
+ reset_option("compute.ops_on_diff_frames")
+ super().tearDownClass()
+
+ @property
+ def pdf1(self):
+ return pd.DataFrame(
+ {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0,
0]},
+ index=[0, 1, 3, 5, 6, 8, 9, 10, 11],
+ )
+
+ @property
+ def pdf2(self):
+ return pd.DataFrame(
+ {"a": [9, 8, 7, 6, 5, 4, 3, 2, 1], "b": [0, 0, 0, 4, 5, 6, 1, 2,
3]},
+ index=list(range(9)),
+ )
+
+ @property
+ def pser1(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon", "koala"], ["speed", "weight", "length",
"power"]],
+ [[0, 3, 1, 1, 1, 2, 2, 2], [0, 2, 0, 3, 2, 0, 1, 3]],
+ )
+ return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1], index=midx)
+
+ @property
+ def pser2(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon"], ["speed", "weight", "length"]],
+ [[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]],
+ )
+ return pd.Series([-45, 200, -1.2, 30, -250, 1.5, 320, 1, -0.3],
index=midx)
+
+ @unittest.skipIf(
+ not extension_float_dtypes_available, "pandas extension float dtypes
are not available"
+ )
+ def test_arithmetic_extension_float_dtypes(self):
+ self._test_arithmetic_frame(
+ self.pdf1.astype("Float64"), self.pdf2.astype("Float64"),
check_extension=True
+ )
+ self._test_arithmetic_series(
+ self.pser1.astype("Float64"), self.pser2.astype("Float64"),
check_extension=True
+ )
+
+
+class ArithmeticExtFloatTests(
+ ArithmeticExtFloatMixin,
+ PandasOnSparkTestCase,
+ SQLTestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext_float import
* # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
index 3fe1e5370597..1b9b7cd940ae 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
@@ -168,95 +168,6 @@ class OpsOnDiffFramesEnabledTestsMixin:
{"b": [1, 2, 3]}
).set_index("b")
- def test_arithmetic(self):
- self._test_arithmetic_frame(self.pdf1, self.pdf2,
check_extension=False)
- self._test_arithmetic_series(self.pser1, self.pser2,
check_extension=False)
-
- @unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes
are not available")
- def test_arithmetic_extension_dtypes(self):
- self._test_arithmetic_frame(
- self.pdf1.astype("Int64"), self.pdf2.astype("Int64"),
check_extension=True
- )
- self._test_arithmetic_series(
- self.pser1.astype(int).astype("Int64"),
- self.pser2.astype(int).astype("Int64"),
- check_extension=True,
- )
-
- @unittest.skipIf(
- not extension_float_dtypes_available, "pandas extension float dtypes
are not available"
- )
- def test_arithmetic_extension_float_dtypes(self):
- self._test_arithmetic_frame(
- self.pdf1.astype("Float64"), self.pdf2.astype("Float64"),
check_extension=True
- )
- self._test_arithmetic_series(
- self.pser1.astype("Float64"), self.pser2.astype("Float64"),
check_extension=True
- )
-
- def _test_arithmetic_frame(self, pdf1, pdf2, *, check_extension):
- psdf1 = ps.from_pandas(pdf1)
- psdf2 = ps.from_pandas(pdf2)
-
- # Series
- self.assert_eq((psdf1.a - psdf2.b).sort_index(), (pdf1.a -
pdf2.b).sort_index())
-
- self.assert_eq((psdf1.a * psdf2.a).sort_index(), (pdf1.a *
pdf2.a).sort_index())
-
- if check_extension and not extension_float_dtypes_available:
- self.assert_eq(
- (psdf1["a"] / psdf2["a"]).sort_index(), (pdf1["a"] /
pdf2["a"]).sort_index()
- )
- else:
- self.assert_eq(
- (psdf1["a"] / psdf2["a"]).sort_index(), (pdf1["a"] /
pdf2["a"]).sort_index()
- )
-
- # DataFrame
- self.assert_eq((psdf1 + psdf2).sort_index(), (pdf1 +
pdf2).sort_index())
-
- # Multi-index columns
- columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b")])
- psdf1.columns = columns
- psdf2.columns = columns
- pdf1.columns = columns
- pdf2.columns = columns
-
- # Series
- self.assert_eq(
- (psdf1[("x", "a")] - psdf2[("x", "b")]).sort_index(),
- (pdf1[("x", "a")] - pdf2[("x", "b")]).sort_index(),
- )
-
- self.assert_eq(
- (psdf1[("x", "a")] - psdf2["x"]["b"]).sort_index(),
- (pdf1[("x", "a")] - pdf2["x"]["b"]).sort_index(),
- )
-
- self.assert_eq(
- (psdf1["x"]["a"] - psdf2[("x", "b")]).sort_index(),
- (pdf1["x"]["a"] - pdf2[("x", "b")]).sort_index(),
- )
-
- # DataFrame
- self.assert_eq((psdf1 + psdf2).sort_index(), (pdf1 +
pdf2).sort_index())
-
- def _test_arithmetic_series(self, pser1, pser2, *, check_extension):
- psser1 = ps.from_pandas(pser1)
- psser2 = ps.from_pandas(pser2)
-
- # MultiIndex Series
- self.assert_eq((psser1 + psser2).sort_index(), (pser1 +
pser2).sort_index())
-
- self.assert_eq((psser1 - psser2).sort_index(), (pser1 -
pser2).sort_index())
-
- self.assert_eq((psser1 * psser2).sort_index(), (pser1 *
pser2).sort_index())
-
- if check_extension and not extension_float_dtypes_available:
- self.assert_eq((psser1 / psser2).sort_index(), (pser1 /
pser2).sort_index())
- else:
- self.assert_eq((psser1 / psser2).sort_index(), (pser1 /
pser2).sort_index())
-
def test_arithmetic_chain(self):
self._test_arithmetic_chain_frame(self.pdf1, self.pdf2, self.pdf3,
check_extension=False)
self._test_arithmetic_chain_series(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]