This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new db162e5a139 [SPARK-46471][PS][TESTS][FOLLOWUPS] Reorganize 
`OpsOnDiffFramesEnabledTests`: Factor out `test_arithmetic_chain_*`
db162e5a139 is described below

commit db162e5a139d355264ce5c538687efa66e62c8c4
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Dec 22 08:57:52 2023 +0900

    [SPARK-46471][PS][TESTS][FOLLOWUPS] Reorganize 
`OpsOnDiffFramesEnabledTests`: Factor out `test_arithmetic_chain_*`
    
    ### What changes were proposed in this pull request?
    Factor out `test_arithmetic_chain_*`
    
    ### Why are the changes needed?
    for testing parallelism
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44443 from zhengruifeng/ps_test_diff_ops_arithmetic_chain.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   6 +
 .../test_parity_arithmetic_chain.py                |  41 +++++
 .../test_parity_arithmetic_chain_ext.py            |  41 +++++
 .../test_parity_arithmetic_chain_ext_float.py      |  43 +++++
 .../tests/diff_frames_ops/test_arithmetic_chain.py | 189 +++++++++++++++++++++
 .../diff_frames_ops/test_arithmetic_chain_ext.py   | 120 +++++++++++++
 .../test_arithmetic_chain_ext_float.py             | 122 +++++++++++++
 .../pandas/tests/test_ops_on_diff_frames.py        | 117 -------------
 8 files changed, 562 insertions(+), 117 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 3e0b364ca84..47db204e2fa 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -867,6 +867,9 @@ pyspark_pandas_slow = Module(
         "pyspark.pandas.tests.diff_frames_ops.test_arithmetic",
         "pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext",
         "pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext_float",
+        "pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain",
+        "pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext",
+        "pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext_float",
         "pyspark.pandas.tests.diff_frames_ops.test_basic_slow",
         "pyspark.pandas.tests.diff_frames_ops.test_cov",
         "pyspark.pandas.tests.diff_frames_ops.test_corrwith",
@@ -1229,6 +1232,9 @@ pyspark_pandas_connect_part3 = Module(
         "pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic",
         
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_ext",
         
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_ext_float",
+        
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain",
+        
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain_ext",
+        
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain_ext_float",
         "pyspark.pandas.tests.connect.diff_frames_ops.test_parity_groupby",
         
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_groupby_aggregate",
         
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_groupby_apply",
diff --git 
a/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain.py
 
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain.py
new file mode 100644
index 00000000000..d24a4a41d0b
--- /dev/null
+++ 
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain import 
ArithmeticChainMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class ArithmeticChainParityTests(
+    ArithmeticChainMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from 
pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain 
import *  # noqa
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git 
a/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext.py
 
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext.py
new file mode 100644
index 00000000000..590abf5b0d2
--- /dev/null
+++ 
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext import 
ArithmeticChainExtMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class ArithmeticChainExtParityTests(
+    ArithmeticChainExtMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from 
pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain_ext 
import *  # noqa
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git 
a/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext_float.py
 
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext_float.py
new file mode 100644
index 00000000000..2bfd23d3f34
--- /dev/null
+++ 
b/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_arithmetic_chain_ext_float.py
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext_float 
import (
+    ArithmeticChainExtFloatMixin,
+)
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class ArithmeticChainExtFloatParityTests(
+    ArithmeticChainExtFloatMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from 
pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain_ext_float
 import *  # noqa
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git 
a/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain.py 
b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain.py
new file mode 100644
index 00000000000..fef695dbb98
--- /dev/null
+++ b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain.py
@@ -0,0 +1,189 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.pandas.config import set_option, reset_option
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.pandas.typedef.typehints import extension_float_dtypes_available
+
+
+class ArithmeticChainTestingFuncMixin:
+    def _test_arithmetic_chain_frame(self, pdf1, pdf2, pdf3, *, 
check_extension):
+        psdf1 = ps.from_pandas(pdf1)
+        psdf2 = ps.from_pandas(pdf2)
+        psdf3 = ps.from_pandas(pdf3)
+
+        # Series
+        self.assert_eq(
+            (psdf1.a - psdf2.b - psdf3.c).sort_index(), (pdf1.a - pdf2.b - 
pdf3.c).sort_index()
+        )
+
+        self.assert_eq(
+            (psdf1.a * (psdf2.a * psdf3.c)).sort_index(), (pdf1.a * (pdf2.a * 
pdf3.c)).sort_index()
+        )
+
+        if check_extension and not extension_float_dtypes_available:
+            self.assert_eq(
+                (psdf1["a"] / psdf2["a"] / psdf3["c"]).sort_index(),
+                (pdf1["a"] / pdf2["a"] / pdf3["c"]).sort_index(),
+            )
+        else:
+            self.assert_eq(
+                (psdf1["a"] / psdf2["a"] / psdf3["c"]).sort_index(),
+                (pdf1["a"] / pdf2["a"] / pdf3["c"]).sort_index(),
+            )
+
+        # DataFrame
+        self.assert_eq((psdf1 + psdf2 - psdf3).sort_index(), (pdf1 + pdf2 - 
pdf3).sort_index())
+
+        # Multi-index columns
+        columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b")])
+        psdf1.columns = columns
+        psdf2.columns = columns
+        pdf1.columns = columns
+        pdf2.columns = columns
+        columns = pd.MultiIndex.from_tuples([("x", "b"), ("y", "c")])
+        psdf3.columns = columns
+        pdf3.columns = columns
+
+        # Series
+        self.assert_eq(
+            (psdf1[("x", "a")] - psdf2[("x", "b")] - psdf3[("y", 
"c")]).sort_index(),
+            (pdf1[("x", "a")] - pdf2[("x", "b")] - pdf3[("y", 
"c")]).sort_index(),
+        )
+
+        self.assert_eq(
+            (psdf1[("x", "a")] * (psdf2[("x", "b")] * psdf3[("y", 
"c")])).sort_index(),
+            (pdf1[("x", "a")] * (pdf2[("x", "b")] * pdf3[("y", 
"c")])).sort_index(),
+        )
+
+        # DataFrame
+        self.assert_eq((psdf1 + psdf2 - psdf3).sort_index(), (pdf1 + pdf2 - 
pdf3).sort_index())
+
+    def _test_arithmetic_chain_series(self, pser1, pser2, pser3, *, 
check_extension):
+        psser1 = ps.from_pandas(pser1)
+        psser2 = ps.from_pandas(pser2)
+        psser3 = ps.from_pandas(pser3)
+
+        # MultiIndex Series
+        self.assert_eq(
+            (psser1 + psser2 - psser3).sort_index(), (pser1 + pser2 - 
pser3).sort_index()
+        )
+
+        self.assert_eq(
+            (psser1 * psser2 * psser3).sort_index(), (pser1 * pser2 * 
pser3).sort_index()
+        )
+
+        if check_extension and not extension_float_dtypes_available:
+            self.assert_eq(
+                (psser1 - psser2 / psser3).sort_index(), (pser1 - pser2 / 
pser3).sort_index()
+            )
+        else:
+            self.assert_eq(
+                (psser1 - psser2 / psser3).sort_index(), (pser1 - pser2 / 
pser3).sort_index()
+            )
+
+        self.assert_eq(
+            (psser1 + psser2 * psser3).sort_index(), (pser1 + pser2 * 
pser3).sort_index()
+        )
+
+
+class ArithmeticChainMixin(ArithmeticChainTestingFuncMixin):
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        set_option("compute.ops_on_diff_frames", True)
+
+    @classmethod
+    def tearDownClass(cls):
+        reset_option("compute.ops_on_diff_frames")
+        super().tearDownClass()
+
+    @property
+    def pdf1(self):
+        return pd.DataFrame(
+            {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0, 
0]},
+            index=[0, 1, 3, 5, 6, 8, 9, 10, 11],
+        )
+
+    @property
+    def pdf2(self):
+        return pd.DataFrame(
+            {"a": [9, 8, 7, 6, 5, 4, 3, 2, 1], "b": [0, 0, 0, 4, 5, 6, 1, 2, 
3]},
+            index=list(range(9)),
+        )
+
+    @property
+    def pdf3(self):
+        return pd.DataFrame(
+            {"b": [1, 1, 1, 1, 1, 1, 1, 1, 1], "c": [1, 1, 1, 1, 1, 1, 1, 1, 
1]},
+            index=list(range(9)),
+        )
+
+    @property
+    def pser1(self):
+        midx = pd.MultiIndex(
+            [["lama", "cow", "falcon", "koala"], ["speed", "weight", "length", 
"power"]],
+            [[0, 3, 1, 1, 1, 2, 2, 2], [0, 2, 0, 3, 2, 0, 1, 3]],
+        )
+        return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1], index=midx)
+
+    @property
+    def pser2(self):
+        midx = pd.MultiIndex(
+            [["lama", "cow", "falcon"], ["speed", "weight", "length"]],
+            [[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]],
+        )
+        return pd.Series([-45, 200, -1.2, 30, -250, 1.5, 320, 1, -0.3], 
index=midx)
+
+    @property
+    def pser3(self):
+        midx = pd.MultiIndex(
+            [["koalas", "cow", "falcon"], ["speed", "weight", "length"]],
+            [[0, 0, 0, 1, 1, 1, 2, 2, 2], [1, 1, 2, 0, 0, 2, 2, 2, 1]],
+        )
+        return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3], index=midx)
+
+    def test_arithmetic_chain(self):
+        self._test_arithmetic_chain_frame(self.pdf1, self.pdf2, self.pdf3, 
check_extension=False)
+        self._test_arithmetic_chain_series(
+            self.pser1, self.pser2, self.pser3, check_extension=False
+        )
+
+
+class ArithmeticChainTests(
+    ArithmeticChainMixin,
+    PandasOnSparkTestCase,
+    SQLTestUtils,
+):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain import *  
# noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git 
a/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext.py 
b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext.py
new file mode 100644
index 00000000000..781800e6e59
--- /dev/null
+++ b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext.py
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import pandas as pd
+
+from pyspark.pandas.config import set_option, reset_option
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.pandas.typedef.typehints import extension_dtypes_available
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain import (
+    ArithmeticChainTestingFuncMixin,
+)
+
+
+class ArithmeticChainExtMixin(ArithmeticChainTestingFuncMixin):
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        set_option("compute.ops_on_diff_frames", True)
+
+    @classmethod
+    def tearDownClass(cls):
+        reset_option("compute.ops_on_diff_frames")
+        super().tearDownClass()
+
+    @property
+    def pdf1(self):
+        return pd.DataFrame(
+            {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0, 
0]},
+            index=[0, 1, 3, 5, 6, 8, 9, 10, 11],
+        )
+
+    @property
+    def pdf2(self):
+        return pd.DataFrame(
+            {"a": [9, 8, 7, 6, 5, 4, 3, 2, 1], "b": [0, 0, 0, 4, 5, 6, 1, 2, 
3]},
+            index=list(range(9)),
+        )
+
+    @property
+    def pdf3(self):
+        return pd.DataFrame(
+            {"b": [1, 1, 1, 1, 1, 1, 1, 1, 1], "c": [1, 1, 1, 1, 1, 1, 1, 1, 
1]},
+            index=list(range(9)),
+        )
+
+    @property
+    def pser1(self):
+        midx = pd.MultiIndex(
+            [["lama", "cow", "falcon", "koala"], ["speed", "weight", "length", 
"power"]],
+            [[0, 3, 1, 1, 1, 2, 2, 2], [0, 2, 0, 3, 2, 0, 1, 3]],
+        )
+        return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1], index=midx)
+
+    @property
+    def pser2(self):
+        midx = pd.MultiIndex(
+            [["lama", "cow", "falcon"], ["speed", "weight", "length"]],
+            [[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]],
+        )
+        return pd.Series([-45, 200, -1.2, 30, -250, 1.5, 320, 1, -0.3], 
index=midx)
+
+    @property
+    def pser3(self):
+        midx = pd.MultiIndex(
+            [["koalas", "cow", "falcon"], ["speed", "weight", "length"]],
+            [[0, 0, 0, 1, 1, 1, 2, 2, 2], [1, 1, 2, 0, 0, 2, 2, 2, 1]],
+        )
+        return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3], index=midx)
+
+    @unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes 
are not available")
+    def test_arithmetic_chain_extension_dtypes(self):
+        self._test_arithmetic_chain_frame(
+            self.pdf1.astype("Int64"),
+            self.pdf2.astype("Int64"),
+            self.pdf3.astype("Int64"),
+            check_extension=True,
+        )
+        self._test_arithmetic_chain_series(
+            self.pser1.astype(int).astype("Int64"),
+            self.pser2.astype(int).astype("Int64"),
+            self.pser3.astype(int).astype("Int64"),
+            check_extension=True,
+        )
+
+
+class ArithmeticChainExtTests(
+    ArithmeticChainExtMixin,
+    PandasOnSparkTestCase,
+    SQLTestUtils,
+):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext import 
*  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git 
a/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext_float.py
 
b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext_float.py
new file mode 100644
index 00000000000..e4b974709b0
--- /dev/null
+++ 
b/python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic_chain_ext_float.py
@@ -0,0 +1,122 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import pandas as pd
+
+from pyspark.pandas.config import set_option, reset_option
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.pandas.typedef.typehints import extension_float_dtypes_available
+from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain import (
+    ArithmeticChainTestingFuncMixin,
+)
+
+
+class ArithmeticChainExtFloatMixin(ArithmeticChainTestingFuncMixin):
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        set_option("compute.ops_on_diff_frames", True)
+
+    @classmethod
+    def tearDownClass(cls):
+        reset_option("compute.ops_on_diff_frames")
+        super().tearDownClass()
+
+    @property
+    def pdf1(self):
+        return pd.DataFrame(
+            {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0, 
0]},
+            index=[0, 1, 3, 5, 6, 8, 9, 10, 11],
+        )
+
+    @property
+    def pdf2(self):
+        return pd.DataFrame(
+            {"a": [9, 8, 7, 6, 5, 4, 3, 2, 1], "b": [0, 0, 0, 4, 5, 6, 1, 2, 
3]},
+            index=list(range(9)),
+        )
+
+    @property
+    def pdf3(self):
+        return pd.DataFrame(
+            {"b": [1, 1, 1, 1, 1, 1, 1, 1, 1], "c": [1, 1, 1, 1, 1, 1, 1, 1, 
1]},
+            index=list(range(9)),
+        )
+
+    @property
+    def pser1(self):
+        midx = pd.MultiIndex(
+            [["lama", "cow", "falcon", "koala"], ["speed", "weight", "length", 
"power"]],
+            [[0, 3, 1, 1, 1, 2, 2, 2], [0, 2, 0, 3, 2, 0, 1, 3]],
+        )
+        return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1], index=midx)
+
+    @property
+    def pser2(self):
+        midx = pd.MultiIndex(
+            [["lama", "cow", "falcon"], ["speed", "weight", "length"]],
+            [[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]],
+        )
+        return pd.Series([-45, 200, -1.2, 30, -250, 1.5, 320, 1, -0.3], 
index=midx)
+
+    @property
+    def pser3(self):
+        midx = pd.MultiIndex(
+            [["koalas", "cow", "falcon"], ["speed", "weight", "length"]],
+            [[0, 0, 0, 1, 1, 1, 2, 2, 2], [1, 1, 2, 0, 0, 2, 2, 2, 1]],
+        )
+        return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3], index=midx)
+
+    @unittest.skipIf(
+        not extension_float_dtypes_available, "pandas extension float dtypes 
are not available"
+    )
+    def test_arithmetic_chain_extension_float_dtypes(self):
+        self._test_arithmetic_chain_frame(
+            self.pdf1.astype("Float64"),
+            self.pdf2.astype("Float64"),
+            self.pdf3.astype("Float64"),
+            check_extension=True,
+        )
+        self._test_arithmetic_chain_series(
+            self.pser1.astype("Float64"),
+            self.pser2.astype("Float64"),
+            self.pser3.astype("Float64"),
+            check_extension=True,
+        )
+
+
+class ArithmeticChainExtFloatTests(
+    ArithmeticChainExtFloatMixin,
+    PandasOnSparkTestCase,
+    SQLTestUtils,
+):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext_float 
import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py 
b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
index 1b9b7cd940a..016908f0a9d 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
@@ -168,123 +168,6 @@ class OpsOnDiffFramesEnabledTestsMixin:
                 {"b": [1, 2, 3]}
             ).set_index("b")
 
-    def test_arithmetic_chain(self):
-        self._test_arithmetic_chain_frame(self.pdf1, self.pdf2, self.pdf3, 
check_extension=False)
-        self._test_arithmetic_chain_series(
-            self.pser1, self.pser2, self.pser3, check_extension=False
-        )
-
-    @unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes 
are not available")
-    def test_arithmetic_chain_extension_dtypes(self):
-        self._test_arithmetic_chain_frame(
-            self.pdf1.astype("Int64"),
-            self.pdf2.astype("Int64"),
-            self.pdf3.astype("Int64"),
-            check_extension=True,
-        )
-        self._test_arithmetic_chain_series(
-            self.pser1.astype(int).astype("Int64"),
-            self.pser2.astype(int).astype("Int64"),
-            self.pser3.astype(int).astype("Int64"),
-            check_extension=True,
-        )
-
-    @unittest.skipIf(
-        not extension_float_dtypes_available, "pandas extension float dtypes 
are not available"
-    )
-    def test_arithmetic_chain_extension_float_dtypes(self):
-        self._test_arithmetic_chain_frame(
-            self.pdf1.astype("Float64"),
-            self.pdf2.astype("Float64"),
-            self.pdf3.astype("Float64"),
-            check_extension=True,
-        )
-        self._test_arithmetic_chain_series(
-            self.pser1.astype("Float64"),
-            self.pser2.astype("Float64"),
-            self.pser3.astype("Float64"),
-            check_extension=True,
-        )
-
-    def _test_arithmetic_chain_frame(self, pdf1, pdf2, pdf3, *, 
check_extension):
-        psdf1 = ps.from_pandas(pdf1)
-        psdf2 = ps.from_pandas(pdf2)
-        psdf3 = ps.from_pandas(pdf3)
-
-        # Series
-        self.assert_eq(
-            (psdf1.a - psdf2.b - psdf3.c).sort_index(), (pdf1.a - pdf2.b - 
pdf3.c).sort_index()
-        )
-
-        self.assert_eq(
-            (psdf1.a * (psdf2.a * psdf3.c)).sort_index(), (pdf1.a * (pdf2.a * 
pdf3.c)).sort_index()
-        )
-
-        if check_extension and not extension_float_dtypes_available:
-            self.assert_eq(
-                (psdf1["a"] / psdf2["a"] / psdf3["c"]).sort_index(),
-                (pdf1["a"] / pdf2["a"] / pdf3["c"]).sort_index(),
-            )
-        else:
-            self.assert_eq(
-                (psdf1["a"] / psdf2["a"] / psdf3["c"]).sort_index(),
-                (pdf1["a"] / pdf2["a"] / pdf3["c"]).sort_index(),
-            )
-
-        # DataFrame
-        self.assert_eq((psdf1 + psdf2 - psdf3).sort_index(), (pdf1 + pdf2 - 
pdf3).sort_index())
-
-        # Multi-index columns
-        columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b")])
-        psdf1.columns = columns
-        psdf2.columns = columns
-        pdf1.columns = columns
-        pdf2.columns = columns
-        columns = pd.MultiIndex.from_tuples([("x", "b"), ("y", "c")])
-        psdf3.columns = columns
-        pdf3.columns = columns
-
-        # Series
-        self.assert_eq(
-            (psdf1[("x", "a")] - psdf2[("x", "b")] - psdf3[("y", 
"c")]).sort_index(),
-            (pdf1[("x", "a")] - pdf2[("x", "b")] - pdf3[("y", 
"c")]).sort_index(),
-        )
-
-        self.assert_eq(
-            (psdf1[("x", "a")] * (psdf2[("x", "b")] * psdf3[("y", 
"c")])).sort_index(),
-            (pdf1[("x", "a")] * (pdf2[("x", "b")] * pdf3[("y", 
"c")])).sort_index(),
-        )
-
-        # DataFrame
-        self.assert_eq((psdf1 + psdf2 - psdf3).sort_index(), (pdf1 + pdf2 - 
pdf3).sort_index())
-
-    def _test_arithmetic_chain_series(self, pser1, pser2, pser3, *, 
check_extension):
-        psser1 = ps.from_pandas(pser1)
-        psser2 = ps.from_pandas(pser2)
-        psser3 = ps.from_pandas(pser3)
-
-        # MultiIndex Series
-        self.assert_eq(
-            (psser1 + psser2 - psser3).sort_index(), (pser1 + pser2 - 
pser3).sort_index()
-        )
-
-        self.assert_eq(
-            (psser1 * psser2 * psser3).sort_index(), (pser1 * pser2 * 
pser3).sort_index()
-        )
-
-        if check_extension and not extension_float_dtypes_available:
-            self.assert_eq(
-                (psser1 - psser2 / psser3).sort_index(), (pser1 - pser2 / 
pser3).sort_index()
-            )
-        else:
-            self.assert_eq(
-                (psser1 - psser2 / psser3).sort_index(), (pser1 - pser2 / 
pser3).sort_index()
-            )
-
-        self.assert_eq(
-            (psser1 + psser2 * psser3).sort_index(), (pser1 + pser2 * 
pser3).sort_index()
-        )
-
     def test_mod(self):
         pser = pd.Series([100, None, -300, None, 500, -700])
         pser_other = pd.Series([-150] * 6)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to