This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new db162e5a139 [SPARK-46471][PS][TESTS][FOLLOWUPS] Reorganize
`OpsOnDiffFramesEnabledTests`: Factor out `test_arithmetic_chain_*`
db162e5a139 is described below
commit db162e5a139d355264ce5c538687efa66e62c8c4
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Dec 22 08:57:52 2023 +0900
[SPARK-46471][PS][TESTS][FOLLOWUPS] Reorganize
`OpsOnDiffFramesEnabledTests`: Factor out `test_arithmetic_chain_*`
### What changes were proposed in this pull request?
Factor out `test_arithmetic_chain_*`
### Why are the changes needed?
for testing parallelism
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44443 from zhengruifeng/ps_test_diff_ops_arithmetic_chain.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
dev/sparktestsupport/modules.py | 6 +
.../test_parity_arithmetic_chain.py | 41 +++++
.../test_parity_arithmetic_chain_ext.py | 41 +++++
.../test_parity_arithmetic_chain_ext_float.py | 43 +++++
.../tests/diff_frames_ops/test_arithmetic_chain.py | 189 +++++++++++++++++++++
.../diff_frames_ops/test_arithmetic_chain_ext.py | 120 +++++++++++++
.../test_arithmetic_chain_ext_float.py | 122 +++++++++++++
.../pandas/tests/test_ops_on_diff_frames.py | 117 -------------
8 files changed, 562 insertions(+), 117 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 3e0b364ca84..47db204e2fa 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -867,6 +867,9 @@ pyspark_pandas_slow = Module(
"pyspark.pandas.tests.diff_frames_ops.test_arithmetic",
"pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext",
"pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext_float",
+ "pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain",
+ "pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext",
+ "pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext_float",
"pyspark.pandas.tests.diff_frames_ops.test_basic_slow",
"pyspark.pandas.tests.diff_frames_ops.test_cov",
"pyspark.pandas.tests.diff_frames_ops.test_corrwith",
@@ -1229,6 +1232,9 @@ pyspark_pandas_connect_part3 = Module(
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_ext",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_ext_float",
+
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain",
+
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain_ext",
+
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain_ext_float",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_groupby",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_groupby_aggregate",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_groupby_apply",
diff --git
a/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain.py
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain.py
new file mode 100644
index 00000000000..d24a4a41d0b
--- /dev/null
+++
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain import
ArithmeticChainMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class ArithmeticChainParityTests(
+ ArithmeticChainMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from
pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain
import * # noqa
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git
a/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext.py
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext.py
new file mode 100644
index 00000000000..590abf5b0d2
--- /dev/null
+++
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext import
ArithmeticChainExtMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class ArithmeticChainExtParityTests(
+ ArithmeticChainExtMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from
pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain_ext
import * # noqa
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git
a/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext_float.py
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext_float.py
new file mode 100644
index 00000000000..2bfd23d3f34
--- /dev/null
+++
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext_float.py
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext_float
import (
+ ArithmeticChainExtFloatMixin,
+)
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class ArithmeticChainExtFloatParityTests(
+ ArithmeticChainExtFloatMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from
pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain_ext_float
import * # noqa
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git
a/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain.py
b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain.py
new file mode 100644
index 00000000000..fef695dbb98
--- /dev/null
+++ b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain.py
@@ -0,0 +1,189 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.pandas.config import set_option, reset_option
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.pandas.typedef.typehints import extension_float_dtypes_available
+
+
+class ArithmeticChainTestingFuncMixin:
+ def _test_arithmetic_chain_frame(self, pdf1, pdf2, pdf3, *,
check_extension):
+ psdf1 = ps.from_pandas(pdf1)
+ psdf2 = ps.from_pandas(pdf2)
+ psdf3 = ps.from_pandas(pdf3)
+
+ # Series
+ self.assert_eq(
+ (psdf1.a - psdf2.b - psdf3.c).sort_index(), (pdf1.a - pdf2.b -
pdf3.c).sort_index()
+ )
+
+ self.assert_eq(
+ (psdf1.a * (psdf2.a * psdf3.c)).sort_index(), (pdf1.a * (pdf2.a *
pdf3.c)).sort_index()
+ )
+
+ if check_extension and not extension_float_dtypes_available:
+ self.assert_eq(
+ (psdf1["a"] / psdf2["a"] / psdf3["c"]).sort_index(),
+ (pdf1["a"] / pdf2["a"] / pdf3["c"]).sort_index(),
+ )
+ else:
+ self.assert_eq(
+ (psdf1["a"] / psdf2["a"] / psdf3["c"]).sort_index(),
+ (pdf1["a"] / pdf2["a"] / pdf3["c"]).sort_index(),
+ )
+
+ # DataFrame
+ self.assert_eq((psdf1 + psdf2 - psdf3).sort_index(), (pdf1 + pdf2 -
pdf3).sort_index())
+
+ # Multi-index columns
+ columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b")])
+ psdf1.columns = columns
+ psdf2.columns = columns
+ pdf1.columns = columns
+ pdf2.columns = columns
+ columns = pd.MultiIndex.from_tuples([("x", "b"), ("y", "c")])
+ psdf3.columns = columns
+ pdf3.columns = columns
+
+ # Series
+ self.assert_eq(
+ (psdf1[("x", "a")] - psdf2[("x", "b")] - psdf3[("y",
"c")]).sort_index(),
+ (pdf1[("x", "a")] - pdf2[("x", "b")] - pdf3[("y",
"c")]).sort_index(),
+ )
+
+ self.assert_eq(
+ (psdf1[("x", "a")] * (psdf2[("x", "b")] * psdf3[("y",
"c")])).sort_index(),
+ (pdf1[("x", "a")] * (pdf2[("x", "b")] * pdf3[("y",
"c")])).sort_index(),
+ )
+
+ # DataFrame
+ self.assert_eq((psdf1 + psdf2 - psdf3).sort_index(), (pdf1 + pdf2 -
pdf3).sort_index())
+
+ def _test_arithmetic_chain_series(self, pser1, pser2, pser3, *,
check_extension):
+ psser1 = ps.from_pandas(pser1)
+ psser2 = ps.from_pandas(pser2)
+ psser3 = ps.from_pandas(pser3)
+
+ # MultiIndex Series
+ self.assert_eq(
+ (psser1 + psser2 - psser3).sort_index(), (pser1 + pser2 -
pser3).sort_index()
+ )
+
+ self.assert_eq(
+ (psser1 * psser2 * psser3).sort_index(), (pser1 * pser2 *
pser3).sort_index()
+ )
+
+ if check_extension and not extension_float_dtypes_available:
+ self.assert_eq(
+ (psser1 - psser2 / psser3).sort_index(), (pser1 - pser2 /
pser3).sort_index()
+ )
+ else:
+ self.assert_eq(
+ (psser1 - psser2 / psser3).sort_index(), (pser1 - pser2 /
pser3).sort_index()
+ )
+
+ self.assert_eq(
+ (psser1 + psser2 * psser3).sort_index(), (pser1 + pser2 *
pser3).sort_index()
+ )
+
+
+class ArithmeticChainMixin(ArithmeticChainTestingFuncMixin):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ set_option("compute.ops_on_diff_frames", True)
+
+ @classmethod
+ def tearDownClass(cls):
+ reset_option("compute.ops_on_diff_frames")
+ super().tearDownClass()
+
+ @property
+ def pdf1(self):
+ return pd.DataFrame(
+ {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0,
0]},
+ index=[0, 1, 3, 5, 6, 8, 9, 10, 11],
+ )
+
+ @property
+ def pdf2(self):
+ return pd.DataFrame(
+ {"a": [9, 8, 7, 6, 5, 4, 3, 2, 1], "b": [0, 0, 0, 4, 5, 6, 1, 2,
3]},
+ index=list(range(9)),
+ )
+
+ @property
+ def pdf3(self):
+ return pd.DataFrame(
+ {"b": [1, 1, 1, 1, 1, 1, 1, 1, 1], "c": [1, 1, 1, 1, 1, 1, 1, 1,
1]},
+ index=list(range(9)),
+ )
+
+ @property
+ def pser1(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon", "koala"], ["speed", "weight", "length",
"power"]],
+ [[0, 3, 1, 1, 1, 2, 2, 2], [0, 2, 0, 3, 2, 0, 1, 3]],
+ )
+ return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1], index=midx)
+
+ @property
+ def pser2(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon"], ["speed", "weight", "length"]],
+ [[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]],
+ )
+ return pd.Series([-45, 200, -1.2, 30, -250, 1.5, 320, 1, -0.3],
index=midx)
+
+ @property
+ def pser3(self):
+ midx = pd.MultiIndex(
+ [["koalas", "cow", "falcon"], ["speed", "weight", "length"]],
+ [[0, 0, 0, 1, 1, 1, 2, 2, 2], [1, 1, 2, 0, 0, 2, 2, 2, 1]],
+ )
+ return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3], index=midx)
+
+ def test_arithmetic_chain(self):
+ self._test_arithmetic_chain_frame(self.pdf1, self.pdf2, self.pdf3,
check_extension=False)
+ self._test_arithmetic_chain_series(
+ self.pser1, self.pser2, self.pser3, check_extension=False
+ )
+
+
+class ArithmeticChainTests(
+ ArithmeticChainMixin,
+ PandasOnSparkTestCase,
+ SQLTestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain import *
# noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git
a/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext.py
b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext.py
new file mode 100644
index 00000000000..781800e6e59
--- /dev/null
+++ b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext.py
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import pandas as pd
+
+from pyspark.pandas.config import set_option, reset_option
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.pandas.typedef.typehints import extension_dtypes_available
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain import (
+ ArithmeticChainTestingFuncMixin,
+)
+
+
+class ArithmeticChainExtMixin(ArithmeticChainTestingFuncMixin):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ set_option("compute.ops_on_diff_frames", True)
+
+ @classmethod
+ def tearDownClass(cls):
+ reset_option("compute.ops_on_diff_frames")
+ super().tearDownClass()
+
+ @property
+ def pdf1(self):
+ return pd.DataFrame(
+ {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0,
0]},
+ index=[0, 1, 3, 5, 6, 8, 9, 10, 11],
+ )
+
+ @property
+ def pdf2(self):
+ return pd.DataFrame(
+ {"a": [9, 8, 7, 6, 5, 4, 3, 2, 1], "b": [0, 0, 0, 4, 5, 6, 1, 2,
3]},
+ index=list(range(9)),
+ )
+
+ @property
+ def pdf3(self):
+ return pd.DataFrame(
+ {"b": [1, 1, 1, 1, 1, 1, 1, 1, 1], "c": [1, 1, 1, 1, 1, 1, 1, 1,
1]},
+ index=list(range(9)),
+ )
+
+ @property
+ def pser1(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon", "koala"], ["speed", "weight", "length",
"power"]],
+ [[0, 3, 1, 1, 1, 2, 2, 2], [0, 2, 0, 3, 2, 0, 1, 3]],
+ )
+ return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1], index=midx)
+
+ @property
+ def pser2(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon"], ["speed", "weight", "length"]],
+ [[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]],
+ )
+ return pd.Series([-45, 200, -1.2, 30, -250, 1.5, 320, 1, -0.3],
index=midx)
+
+ @property
+ def pser3(self):
+ midx = pd.MultiIndex(
+ [["koalas", "cow", "falcon"], ["speed", "weight", "length"]],
+ [[0, 0, 0, 1, 1, 1, 2, 2, 2], [1, 1, 2, 0, 0, 2, 2, 2, 1]],
+ )
+ return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3], index=midx)
+
+ @unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes
are not available")
+ def test_arithmetic_chain_extension_dtypes(self):
+ self._test_arithmetic_chain_frame(
+ self.pdf1.astype("Int64"),
+ self.pdf2.astype("Int64"),
+ self.pdf3.astype("Int64"),
+ check_extension=True,
+ )
+ self._test_arithmetic_chain_series(
+ self.pser1.astype(int).astype("Int64"),
+ self.pser2.astype(int).astype("Int64"),
+ self.pser3.astype(int).astype("Int64"),
+ check_extension=True,
+ )
+
+
+class ArithmeticChainExtTests(
+ ArithmeticChainExtMixin,
+ PandasOnSparkTestCase,
+ SQLTestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext import
* # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git
a/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext_float.py
b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext_float.py
new file mode 100644
index 00000000000..e4b974709b0
--- /dev/null
+++
b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext_float.py
@@ -0,0 +1,122 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import pandas as pd
+
+from pyspark.pandas.config import set_option, reset_option
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.pandas.typedef.typehints import extension_float_dtypes_available
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain import (
+ ArithmeticChainTestingFuncMixin,
+)
+
+
+class ArithmeticChainExtFloatMixin(ArithmeticChainTestingFuncMixin):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ set_option("compute.ops_on_diff_frames", True)
+
+ @classmethod
+ def tearDownClass(cls):
+ reset_option("compute.ops_on_diff_frames")
+ super().tearDownClass()
+
+ @property
+ def pdf1(self):
+ return pd.DataFrame(
+ {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0,
0]},
+ index=[0, 1, 3, 5, 6, 8, 9, 10, 11],
+ )
+
+ @property
+ def pdf2(self):
+ return pd.DataFrame(
+ {"a": [9, 8, 7, 6, 5, 4, 3, 2, 1], "b": [0, 0, 0, 4, 5, 6, 1, 2,
3]},
+ index=list(range(9)),
+ )
+
+ @property
+ def pdf3(self):
+ return pd.DataFrame(
+ {"b": [1, 1, 1, 1, 1, 1, 1, 1, 1], "c": [1, 1, 1, 1, 1, 1, 1, 1,
1]},
+ index=list(range(9)),
+ )
+
+ @property
+ def pser1(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon", "koala"], ["speed", "weight", "length",
"power"]],
+ [[0, 3, 1, 1, 1, 2, 2, 2], [0, 2, 0, 3, 2, 0, 1, 3]],
+ )
+ return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1], index=midx)
+
+ @property
+ def pser2(self):
+ midx = pd.MultiIndex(
+ [["lama", "cow", "falcon"], ["speed", "weight", "length"]],
+ [[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]],
+ )
+ return pd.Series([-45, 200, -1.2, 30, -250, 1.5, 320, 1, -0.3],
index=midx)
+
+ @property
+ def pser3(self):
+ midx = pd.MultiIndex(
+ [["koalas", "cow", "falcon"], ["speed", "weight", "length"]],
+ [[0, 0, 0, 1, 1, 1, 2, 2, 2], [1, 1, 2, 0, 0, 2, 2, 2, 1]],
+ )
+ return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3], index=midx)
+
+ @unittest.skipIf(
+ not extension_float_dtypes_available, "pandas extension float dtypes
are not available"
+ )
+ def test_arithmetic_chain_extension_float_dtypes(self):
+ self._test_arithmetic_chain_frame(
+ self.pdf1.astype("Float64"),
+ self.pdf2.astype("Float64"),
+ self.pdf3.astype("Float64"),
+ check_extension=True,
+ )
+ self._test_arithmetic_chain_series(
+ self.pser1.astype("Float64"),
+ self.pser2.astype("Float64"),
+ self.pser3.astype("Float64"),
+ check_extension=True,
+ )
+
+
+class ArithmeticChainExtFloatTests(
+ ArithmeticChainExtFloatMixin,
+ PandasOnSparkTestCase,
+ SQLTestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext_float
import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
index 1b9b7cd940a..016908f0a9d 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
@@ -168,123 +168,6 @@ class OpsOnDiffFramesEnabledTestsMixin:
{"b": [1, 2, 3]}
).set_index("b")
- def test_arithmetic_chain(self):
- self._test_arithmetic_chain_frame(self.pdf1, self.pdf2, self.pdf3,
check_extension=False)
- self._test_arithmetic_chain_series(
- self.pser1, self.pser2, self.pser3, check_extension=False
- )
-
- @unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes
are not available")
- def test_arithmetic_chain_extension_dtypes(self):
- self._test_arithmetic_chain_frame(
- self.pdf1.astype("Int64"),
- self.pdf2.astype("Int64"),
- self.pdf3.astype("Int64"),
- check_extension=True,
- )
- self._test_arithmetic_chain_series(
- self.pser1.astype(int).astype("Int64"),
- self.pser2.astype(int).astype("Int64"),
- self.pser3.astype(int).astype("Int64"),
- check_extension=True,
- )
-
- @unittest.skipIf(
- not extension_float_dtypes_available, "pandas extension float dtypes
are not available"
- )
- def test_arithmetic_chain_extension_float_dtypes(self):
- self._test_arithmetic_chain_frame(
- self.pdf1.astype("Float64"),
- self.pdf2.astype("Float64"),
- self.pdf3.astype("Float64"),
- check_extension=True,
- )
- self._test_arithmetic_chain_series(
- self.pser1.astype("Float64"),
- self.pser2.astype("Float64"),
- self.pser3.astype("Float64"),
- check_extension=True,
- )
-
- def _test_arithmetic_chain_frame(self, pdf1, pdf2, pdf3, *,
check_extension):
- psdf1 = ps.from_pandas(pdf1)
- psdf2 = ps.from_pandas(pdf2)
- psdf3 = ps.from_pandas(pdf3)
-
- # Series
- self.assert_eq(
- (psdf1.a - psdf2.b - psdf3.c).sort_index(), (pdf1.a - pdf2.b -
pdf3.c).sort_index()
- )
-
- self.assert_eq(
- (psdf1.a * (psdf2.a * psdf3.c)).sort_index(), (pdf1.a * (pdf2.a *
pdf3.c)).sort_index()
- )
-
- if check_extension and not extension_float_dtypes_available:
- self.assert_eq(
- (psdf1["a"] / psdf2["a"] / psdf3["c"]).sort_index(),
- (pdf1["a"] / pdf2["a"] / pdf3["c"]).sort_index(),
- )
- else:
- self.assert_eq(
- (psdf1["a"] / psdf2["a"] / psdf3["c"]).sort_index(),
- (pdf1["a"] / pdf2["a"] / pdf3["c"]).sort_index(),
- )
-
- # DataFrame
- self.assert_eq((psdf1 + psdf2 - psdf3).sort_index(), (pdf1 + pdf2 -
pdf3).sort_index())
-
- # Multi-index columns
- columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b")])
- psdf1.columns = columns
- psdf2.columns = columns
- pdf1.columns = columns
- pdf2.columns = columns
- columns = pd.MultiIndex.from_tuples([("x", "b"), ("y", "c")])
- psdf3.columns = columns
- pdf3.columns = columns
-
- # Series
- self.assert_eq(
- (psdf1[("x", "a")] - psdf2[("x", "b")] - psdf3[("y",
"c")]).sort_index(),
- (pdf1[("x", "a")] - pdf2[("x", "b")] - pdf3[("y",
"c")]).sort_index(),
- )
-
- self.assert_eq(
- (psdf1[("x", "a")] * (psdf2[("x", "b")] * psdf3[("y",
"c")])).sort_index(),
- (pdf1[("x", "a")] * (pdf2[("x", "b")] * pdf3[("y",
"c")])).sort_index(),
- )
-
- # DataFrame
- self.assert_eq((psdf1 + psdf2 - psdf3).sort_index(), (pdf1 + pdf2 -
pdf3).sort_index())
-
- def _test_arithmetic_chain_series(self, pser1, pser2, pser3, *,
check_extension):
- psser1 = ps.from_pandas(pser1)
- psser2 = ps.from_pandas(pser2)
- psser3 = ps.from_pandas(pser3)
-
- # MultiIndex Series
- self.assert_eq(
- (psser1 + psser2 - psser3).sort_index(), (pser1 + pser2 -
pser3).sort_index()
- )
-
- self.assert_eq(
- (psser1 * psser2 * psser3).sort_index(), (pser1 * pser2 *
pser3).sort_index()
- )
-
- if check_extension and not extension_float_dtypes_available:
- self.assert_eq(
- (psser1 - psser2 / psser3).sort_index(), (pser1 - pser2 /
pser3).sort_index()
- )
- else:
- self.assert_eq(
- (psser1 - psser2 / psser3).sort_index(), (pser1 - pser2 /
pser3).sort_index()
- )
-
- self.assert_eq(
- (psser1 + psser2 * psser3).sort_index(), (pser1 + pser2 *
pser3).sort_index()
- )
-
def test_mod(self):
pser = pd.Series([100, None, -300, None, 500, -700])
pser_other = pd.Series([-150] * 6)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]