This is an automated email from the ASF dual-hosted git repository.

holden pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3f124c30ddd3 [SPARK-46168][PS] Add axis argument for idxmax
3f124c30ddd3 is described below

commit 3f124c30ddd3a169e6c2534f685cf006fae6fda1
Author: Devin Petersohn <[email protected]>
AuthorDate: Mon Feb 23 12:41:34 2026 -0800

    [SPARK-46168][PS] Add axis argument for idxmax
    
    ### What changes were proposed in this pull request?
    
    Add support for axis argument for idxmax.
    
    ### Why are the changes needed?
    
    To support a missing API parameter
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, a new API
    
    ### How was this patch tested?
    
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Co-authored-by: Claude Sonnet 4.5
    
    Closes #54044 from devin-petersohn/devin/idxmax_axis.
    
    Authored-by: Devin Petersohn <[email protected]>
    Signed-off-by: Holden Karau <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   2 +
 python/pyspark/pandas/frame.py                     |  92 ++++++++---
 .../pandas/tests/computation/test_idxmax_idxmin.py | 170 +++++++++++++++++++++
 .../computation/test_parity_idxmax_idxmin.py       |  34 +++++
 4 files changed, 281 insertions(+), 17 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index d62bf4414ffd..e1851602d020 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -874,6 +874,7 @@ pyspark_pandas = Module(
         "pyspark.pandas.tests.computation.test_cumulative",
         "pyspark.pandas.tests.computation.test_describe",
         "pyspark.pandas.tests.computation.test_eval",
+        "pyspark.pandas.tests.computation.test_idxmax_idxmin",
         "pyspark.pandas.tests.computation.test_melt",
         "pyspark.pandas.tests.computation.test_missing_data",
         "pyspark.pandas.tests.computation.test_pivot",
@@ -1322,6 +1323,7 @@ pyspark_pandas_connect = Module(
         "pyspark.pandas.tests.connect.computation.test_parity_cumulative",
         "pyspark.pandas.tests.connect.computation.test_parity_describe",
         "pyspark.pandas.tests.connect.computation.test_parity_eval",
+        "pyspark.pandas.tests.connect.computation.test_parity_idxmax_idxmin",
         "pyspark.pandas.tests.connect.computation.test_parity_melt",
         "pyspark.pandas.tests.connect.computation.test_parity_missing_data",
         "pyspark.pandas.tests.connect.computation.test_parity_pivot",
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 2fa90e8e15cf..5609f76cd719 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -12311,7 +12311,6 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
 
         return self._apply_series_op(op, should_resolve=True)
 
-    # TODO(SPARK-46168): axis = 1
     def idxmax(self, axis: Axis = 0) -> "Series":
         """
         Return index of first occurrence of maximum over requested axis.
@@ -12322,8 +12321,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
 
         Parameters
         ----------
-        axis : 0 or 'index'
-            Can only be set to 0 now.
+        axis : {0 or 'index', 1 or 'columns'}, default 0
+            The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for 
column-wise.
 
         Returns
         -------
@@ -12351,6 +12350,15 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         c    2
         dtype: int64
 
+        For axis=1, return the column label of the maximum value in each row:
+
+        >>> psdf.idxmax(axis=1)
+        0    c
+        1    c
+        2    c
+        3    c
+        dtype: object
+
         For Multi-column Index
 
         >>> psdf = ps.DataFrame({'a': [1, 2, 3, 2],
@@ -12371,23 +12379,73 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         c  z    2
         dtype: int64
         """
-        max_cols = map(lambda scol: F.max(scol), 
self._internal.data_spark_columns)
-        sdf_max = self._internal.spark_frame.select(*max_cols).head()
-        # `sdf_max` looks like below
-        # +------+------+------+
-        # |(a, x)|(b, y)|(c, z)|
-        # +------+------+------+
-        # |     3|   4.0|   400|
-        # +------+------+------+
+        axis = validate_axis(axis)
+        if axis == 0:
+            max_cols = map(lambda scol: F.max(scol), 
self._internal.data_spark_columns)
+            sdf_max = self._internal.spark_frame.select(*max_cols).head()
+            # `sdf_max` looks like below
+            # +------+------+------+
+            # |(a, x)|(b, y)|(c, z)|
+            # +------+------+------+
+            # |     3|   4.0|   400|
+            # +------+------+------+
+
+            conds = (
+                scol == max_val for scol, max_val in 
zip(self._internal.data_spark_columns, sdf_max)
+            )
+            cond = reduce(lambda x, y: x | y, conds)
 
-        conds = (
-            scol == max_val for scol, max_val in 
zip(self._internal.data_spark_columns, sdf_max)
-        )
-        cond = reduce(lambda x, y: x | y, conds)
+            psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
 
-        psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
+            return cast(ps.Series, 
ps.from_pandas(psdf._to_internal_pandas().idxmax()))
+        else:
+            from pyspark.pandas.series import first_series
+
+            column_labels = self._internal.column_labels
+
+            if len(column_labels) == 0:
+                # Check if DataFrame has rows - if yes, raise error; if no, 
return empty Series
+                # to match pandas behavior
+                if len(self) > 0:
+                    raise ValueError("attempt to get argmax of an empty 
sequence")
+                else:
+                    return ps.Series([], dtype=np.int64)
+
+            if self._internal.column_labels_level > 1:
+                raise NotImplementedError(
+                    "idxmax with axis=1 does not support MultiIndex columns 
yet"
+                )
+
+            max_value = F.greatest(
+                *[
+                    F.coalesce(self._internal.spark_column_for(label), 
F.lit(float("-inf")))
+                    for label in column_labels
+                ],
+                F.lit(float("-inf")),
+            )
+
+            result = None
+            # Iterate over the column labels in reverse order to get the first 
occurrence of the
+            # maximum value.
+            for label in reversed(column_labels):
+                scol = self._internal.spark_column_for(label)
+                label_value = label[0] if len(label) == 1 else label
+                condition = (scol == max_value) & scol.isNotNull()
+
+                result = (
+                    F.when(condition, F.lit(label_value))
+                    if result is None
+                    else F.when(condition, 
F.lit(label_value)).otherwise(result)
+                )
+
+            result = F.when(max_value == float("-inf"), 
F.lit(None)).otherwise(result)
+
+            internal = self._internal.with_new_columns(
+                [result.alias(SPARK_DEFAULT_SERIES_NAME)],
+                column_labels=[None],
+            )
 
-        return cast(ps.Series, 
ps.from_pandas(psdf._to_internal_pandas().idxmax()))
+            return first_series(DataFrame(internal))
 
     # TODO(SPARK-46168): axis = 1
     def idxmin(self, axis: Axis = 0) -> "Series":
diff --git a/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py 
b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
new file mode 100644
index 000000000000..2de14fccbc43
--- /dev/null
+++ b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
@@ -0,0 +1,170 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+
+
+class FrameIdxMaxMinMixin:
+    def test_idxmax(self):
+        # Test basic axis=0 (default)
+        pdf = pd.DataFrame(
+            {
+                "a": [1, 2, 3, 2],
+                "b": [4.0, 2.0, 3.0, 1.0],
+                "c": [300, 200, 400, 200],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmax(), pdf.idxmax())
+        self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0))
+        self.assert_eq(psdf.idxmax(axis="index"), pdf.idxmax(axis="index"))
+
+        # Test axis=1
+        self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+        self.assert_eq(psdf.idxmax(axis="columns"), pdf.idxmax(axis="columns"))
+
+        # Test with NAs
+        pdf = pd.DataFrame(
+            {
+                "a": [1.0, None, 3.0],
+                "b": [None, 2.0, None],
+                "c": [3.0, 4.0, None],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmax(), pdf.idxmax())
+        self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0))
+        self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+        # Test with all-NA row
+        pdf = pd.DataFrame(
+            {
+                "a": [1.0, None],
+                "b": [2.0, None],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+        # Test with ties (first occurrence should win)
+        pdf = pd.DataFrame(
+            {
+                "a": [3, 2, 1],
+                "b": [3, 5, 1],
+                "c": [1, 5, 1],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+        # Test with single column
+        pdf = pd.DataFrame({"a": [1, 2, 3]})
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+        # Test with empty DataFrame
+        pdf = pd.DataFrame({})
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+        # Test with different data types
+        pdf = pd.DataFrame(
+            {
+                "int_col": [1, 2, 3],
+                "float_col": [1.5, 2.5, 0.5],
+                "negative": [-5, -10, -1],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+        # Test with custom index
+        pdf = pd.DataFrame(
+            {
+                "a": [1, 2, 3],
+                "b": [4, 5, 6],
+                "c": [7, 8, 9],
+            },
+            index=["row1", "row2", "row3"],
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+    def test_idxmax_multiindex_columns(self):
+        # Test that MultiIndex columns raise NotImplementedError for axis=1
+        pdf = pd.DataFrame(
+            {
+                "a": [1, 2, 3],
+                "b": [4, 5, 6],
+                "c": [7, 8, 9],
+            }
+        )
+        pdf.columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b"), ("z", 
"c")])
+        psdf = ps.from_pandas(pdf)
+
+        # axis=0 should work fine (it uses pandas internally)
+        self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0))
+
+        # axis=1 should raise NotImplementedError
+        with self.assertRaises(NotImplementedError):
+            psdf.idxmax(axis=1)
+
+    def test_idxmax_empty_dataframe(self):
+        # Test empty DataFrame with no rows and no columns - should return 
empty Series
+        pdf = pd.DataFrame({})
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+        # Test empty DataFrame with rows but no columns - should raise 
ValueError
+        pdf = pd.DataFrame(index=range(3))
+        psdf = ps.from_pandas(pdf)
+
+        with self.assertRaises(ValueError) as pdf_context:
+            pdf.idxmax(axis=1)
+
+        with self.assertRaises(ValueError) as psdf_context:
+            psdf.idxmax(axis=1)
+
+        # Verify both raise the same error message
+        self.assertEqual(str(pdf_context.exception), 
str(psdf_context.exception))
+
+
+class FrameIdxMaxMinTests(
+    FrameIdxMaxMinMixin,
+    PandasOnSparkTestCase,
+    SQLTestUtils,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.testing import main
+
+    main()
diff --git 
a/python/pyspark/pandas/tests/connect/computation/test_parity_idxmax_idxmin.py 
b/python/pyspark/pandas/tests/connect/computation/test_parity_idxmax_idxmin.py
new file mode 100644
index 000000000000..06e723d39708
--- /dev/null
+++ 
b/python/pyspark/pandas/tests/connect/computation/test_parity_idxmax_idxmin.py
@@ -0,0 +1,34 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.pandas.tests.computation.test_idxmax_idxmin import 
FrameIdxMaxMinMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class FrameParityIdxMaxMinTests(
+    FrameIdxMaxMinMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.testing import main
+
+    main()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to