This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8efc4c6c74be [SPARK-55967][PYTHON] Unify column conversion for connect 
dataframe
8efc4c6c74be is described below

commit 8efc4c6c74be08f05b5903ba390aa9c3aa21240a
Author: Tian Gao <[email protected]>
AuthorDate: Fri Mar 13 08:57:32 2026 +0800

    [SPARK-55967][PYTHON] Unify column conversion for connect dataframe
    
    ### What changes were proposed in this pull request?
    
    * A new column conversion helper function `_to_cols` is introduced.
    * All the methods that try to convert column from `list` or `Column` or 
`str` or `int` now uses this unified entry.
    * `ConnectColumn` import is moved to module-level
    * type hint for input argument is change from `List` to `Sequence`.
    * Incorrect overloads have been removed.
    
    ### Why are the changes needed?
    
    Reduce duplicated code. Actually in the future we can easily add support 
for ordinal for many methods.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    CI.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54764 from gaogaotiantian/unify-column-conversion.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/pandas/groupby.py                   |   2 +-
 python/pyspark/sql/classic/dataframe.py            |  95 ++++----
 python/pyspark/sql/connect/dataframe.py            | 251 +++++++++------------
 python/pyspark/sql/dataframe.py                    |  92 ++++----
 .../sql/tests/connect/test_connect_error.py        |   4 +-
 python/pyspark/sql/tests/typing/test_dataframe.yml |   8 +-
 6 files changed, 194 insertions(+), 258 deletions(-)

diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index f23422b43a22..ddb9cb032371 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -4248,7 +4248,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
                 
F.col(f"{auxiliary_col_name}.{CORRELATION_CORR_OUTPUT_COLUMN}"),
             )
 
-        sdf = sdf.orderBy(groupkey_names + [index_1_col_name])  # type: 
ignore[arg-type]
+        sdf = sdf.orderBy(groupkey_names + [index_1_col_name])
 
         sdf = sdf.select(
             *[F.col(col) for col in groupkey_names + numeric_col_names],
diff --git a/python/pyspark/sql/classic/dataframe.py 
b/python/pyspark/sql/classic/dataframe.py
index 243fa8a37b5a..e300d0d39293 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -544,15 +544,7 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
     def coalesce(self, numPartitions: int) -> ParentDataFrame:
         return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession)
 
-    @overload
-    def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> 
ParentDataFrame:
-        ...
-
-    @overload
-    def repartition(self, *cols: "ColumnOrName") -> ParentDataFrame:
-        ...
-
-    def repartition(  # type: ignore[misc]
+    def repartition(
         self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
     ) -> ParentDataFrame:
         if isinstance(numPartitions, int):
@@ -575,15 +567,7 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
                 },
             )
 
-    @overload
-    def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> 
ParentDataFrame:
-        ...
-
-    @overload
-    def repartitionByRange(self, *cols: "ColumnOrName") -> ParentDataFrame:
-        ...
-
-    def repartitionByRange(  # type: ignore[misc]
+    def repartitionByRange(
         self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
     ) -> ParentDataFrame:
         if isinstance(numPartitions, int):
@@ -914,7 +898,7 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
 
     def sortWithinPartitions(
         self,
-        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+        *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"],
         **kwargs: Any,
     ) -> ParentDataFrame:
         _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
@@ -923,7 +907,7 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
 
     def sort(
         self,
-        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+        *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"],
         **kwargs: Any,
     ) -> ParentDataFrame:
         _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
@@ -944,22 +928,32 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
         """Return a JVM Scala Map from a dict"""
         return to_scala_map(self.sparkSession._sc._jvm, jm)
 
-    def _jcols(self, *cols: "ColumnOrName") -> "JavaObject":
+    def _jcols(self, *cols: Union[Sequence["ColumnOrName"], "ColumnOrName"]) 
-> "JavaObject":
         """Return a JVM Seq of Columns from a list of Column or column names
 
         If `cols` has only one list in it, cols[0] will be used as the list.
         """
-        if len(cols) == 1 and isinstance(cols[0], list):
-            cols = cols[0]
+        if (
+            len(cols) == 1
+            and not isinstance(cols[0], (str, Column))
+            and isinstance(cols[0], Sequence)
+        ):
+            cols = tuple(cols[0])
         return self._jseq(cols, _to_java_column)
 
-    def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> "JavaObject":
+    def _jcols_ordinal(
+        self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+    ) -> "JavaObject":
         """Return a JVM Seq of Columns from a list of Column or column names 
or column ordinals.
 
         If `cols` has only one list in it, cols[0] will be used as the list.
         """
-        if len(cols) == 1 and isinstance(cols[0], list):
-            cols = cols[0]
+        if (
+            len(cols) == 1
+            and not isinstance(cols[0], (int, str, Column))
+            and isinstance(cols[0], Sequence)
+        ):
+            cols = tuple(cols[0])
 
         _cols = []
         for c in cols:
@@ -1048,10 +1042,10 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
         ...
 
     @overload
-    def select(self, __cols: Union[List[Column], List[str]]) -> 
ParentDataFrame:
+    def select(self, __cols: Sequence["ColumnOrName"]) -> ParentDataFrame:
         ...
 
-    def select(self, *cols: "ColumnOrName") -> ParentDataFrame:  # type: 
ignore[misc]
+    def select(self, *cols: Union[Sequence["ColumnOrName"], "ColumnOrName"]) 
-> ParentDataFrame:
         jdf = self._jdf.select(self._jcols(*cols))
         return DataFrame(jdf, self.sparkSession)
 
@@ -1086,51 +1080,59 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
         ...
 
     @overload
-    def groupBy(self, __cols: Union[List[Column], List[str], List[int]]) -> 
"GroupedData":
+    def groupBy(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> 
"GroupedData":
         ...
 
-    def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":  # 
type: ignore[misc]
+    def groupBy(
+        self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+    ) -> "GroupedData":
         jgd = self._jdf.groupBy(self._jcols_ordinal(*cols))
         from pyspark.sql.group import GroupedData
 
         return GroupedData(jgd, self)
 
     @overload
-    def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
+    def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
         ...
 
     @overload
-    def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+    def rollup(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> 
"GroupedData":
         ...
 
-    def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":  # 
type: ignore[misc]
+    def rollup(
+        self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+    ) -> "GroupedData":
         jgd = self._jdf.rollup(self._jcols_ordinal(*cols))
         from pyspark.sql.group import GroupedData
 
         return GroupedData(jgd, self)
 
     @overload
-    def cube(self, *cols: "ColumnOrName") -> "GroupedData":
+    def cube(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
         ...
 
     @overload
-    def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+    def cube(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> "GroupedData":
         ...
 
-    def cube(self, *cols: "ColumnOrName") -> "GroupedData":  # type: 
ignore[misc]
+    def cube(
+        self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+    ) -> "GroupedData":
         jgd = self._jdf.cube(self._jcols_ordinal(*cols))
         from pyspark.sql.group import GroupedData
 
         return GroupedData(jgd, self)
 
     def groupingSets(
-        self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: 
"ColumnOrName"
+        self,
+        groupingSets: Sequence[Sequence["ColumnOrNameOrOrdinal"]],
+        *cols: "ColumnOrNameOrOrdinal",
     ) -> "GroupedData":
         from pyspark.sql.group import GroupedData
 
-        jgrouping_sets = _to_seq(self._sc, [self._jcols(*inner) for inner in 
groupingSets])
+        jgrouping_sets = _to_seq(self._sc, [self._jcols_ordinal(*inner) for 
inner in groupingSets])
 
-        jgd = self._jdf.groupingSets(jgrouping_sets, self._jcols(*cols))
+        jgd = self._jdf.groupingSets(jgrouping_sets, 
self._jcols_ordinal(*cols))
         return GroupedData(jgd, self)
 
     def unpivot(
@@ -1776,19 +1778,8 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
     # make it "compatible" by adding aliases. Therefore, we stop adding such
     # aliases as of Spark 3.0. Two methods below remain just
     # for legacy users currently.
-    @overload
-    def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
-        ...
-
-    @overload
-    def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> 
"GroupedData":
-        ...
-
-    def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":  # 
type: ignore[misc]
-        return self.groupBy(*cols)
-
-    def drop_duplicates(self, subset: Optional[List[str]] = None) -> 
ParentDataFrame:
-        return self.dropDuplicates(subset)
+    groupby = groupBy
+    drop_duplicates = dropDuplicates
 
     def writeTo(self, table: str) -> "DataFrameWriterV2":
         return DataFrameWriterV2(self, table)
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 89846e36e718..06f94a17940e 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -28,6 +28,7 @@ from typing import (
     Dict,
     Iterator,
     List,
+    NoReturn,
     Optional,
     Tuple,
     Union,
@@ -75,10 +76,12 @@ from pyspark.sql.connect.group import GroupedData
 from pyspark.sql.connect.merge import MergeIntoWriter
 from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
 from pyspark.sql.connect.streaming.readwriter import DataStreamWriter
+from pyspark.sql.connect.column import Column as ConnectColumn
 from pyspark.sql.column import Column
 from pyspark.sql.connect.expressions import (
     ColumnReference,
     DirectShufflePartitionID,
+    SortOrder,
     SubqueryExpression,
     UnresolvedRegex,
     UnresolvedStar,
@@ -210,6 +213,60 @@ class DataFrame(ParentDataFrame):
         else:
             return None
 
+    def _to_cols(
+        self,
+        *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"],
+        arg_name: str = "cols",
+        allow_ordinal: bool = False,
+        allow_sequence: bool = True,
+        sort: bool = False,
+    ) -> List[Column]:
+        def _raise() -> NoReturn:
+            if allow_ordinal:
+                raise PySparkTypeError(
+                    errorClass="NOT_COLUMN_OR_INT_OR_STR",
+                    messageParameters={"arg_name": arg_name, "arg_type": 
type(c).__name__},
+                )
+            else:
+                raise PySparkTypeError(
+                    errorClass="NOT_COLUMN_OR_STR",
+                    messageParameters={"arg_name": arg_name, "arg_type": 
type(c).__name__},
+                )
+
+        if (
+            len(cols) == 1
+            and not isinstance(cols[0], (int, str, Column))
+            and isinstance(cols[0], Sequence)
+        ):
+            if allow_sequence:
+                cols = tuple(cols[0])
+            else:
+                _raise()
+
+        _cols: List[Column] = []
+        for c in cols:
+            if isinstance(c, Column):
+                col = c
+            elif isinstance(c, str):
+                col = F.col(c)
+            elif isinstance(c, int) and not isinstance(c, bool):
+                if allow_ordinal:
+                    if c < 1:
+                        raise PySparkIndexError(
+                            errorClass="INDEX_NOT_POSITIVE", 
messageParameters={"index": str(c)}
+                        )
+                    col = self[c - 1]
+                else:
+                    _raise()
+            else:
+                _raise()
+
+            if sort:
+                if not isinstance(col._expr, SortOrder):
+                    col = col.asc()
+            _cols.append(col)
+        return _cols
+
     @property
     def write(self) -> "DataFrameWriter":
         def cb(qe: "ExecutionInfo") -> None:
@@ -226,19 +283,12 @@ class DataFrame(ParentDataFrame):
         ...
 
     @overload
-    def select(self, __cols: Union[List[Column], List[str]]) -> 
ParentDataFrame:
+    def select(self, __cols: Union[Sequence["ColumnOrName"], "ColumnOrName"]) 
-> ParentDataFrame:
         ...
 
-    def select(self, *cols: "ColumnOrName") -> ParentDataFrame:  # type: 
ignore[misc]
-        if len(cols) == 1 and isinstance(cols[0], list):
-            cols = cols[0]
-        if any(not isinstance(c, (str, Column)) for c in cols):
-            raise PySparkTypeError(
-                errorClass="NOT_LIST_OF_COLUMN_OR_STR",
-                messageParameters={"arg_name": "columns"},
-            )
+    def select(self, *cols: Union[Sequence["ColumnOrName"], "ColumnOrName"]) 
-> ParentDataFrame:
         return DataFrame(
-            plan.Project(self._plan, [F._to_col(c) for c in cols]),
+            plan.Project(self._plan, self._to_cols(*cols)),
             session=self._session,
         )
 
@@ -284,8 +334,6 @@ class DataFrame(ParentDataFrame):
         return self._col(colName, is_metadata_column=True)
 
     def colRegex(self, colName: str) -> Column:
-        from pyspark.sql.connect.column import Column as ConnectColumn
-
         if not isinstance(colName, str):
             raise PySparkTypeError(
                 errorClass="NOT_STR",
@@ -342,15 +390,7 @@ class DataFrame(ParentDataFrame):
         res._cached_schema = self._cached_schema
         return res
 
-    @overload
-    def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> 
ParentDataFrame:
-        ...
-
-    @overload
-    def repartition(self, *cols: "ColumnOrName") -> ParentDataFrame:
-        ...
-
-    def repartition(  # type: ignore[misc]
+    def repartition(
         self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
     ) -> ParentDataFrame:
         if isinstance(numPartitions, int):
@@ -369,15 +409,13 @@ class DataFrame(ParentDataFrame):
                 )
             else:
                 res = DataFrame(
-                    plan.RepartitionByExpression(
-                        self._plan, numPartitions, [F._to_col(c) for c in cols]
-                    ),
+                    plan.RepartitionByExpression(self._plan, numPartitions, 
self._to_cols(cols)),
                     self.sparkSession,
                 )
         elif isinstance(numPartitions, (str, Column)):
             cols = (numPartitions,) + cols
             res = DataFrame(
-                plan.RepartitionByExpression(self._plan, None, [F._to_col(c) 
for c in cols]),
+                plan.RepartitionByExpression(self._plan, None, 
self._to_cols(cols)),
                 self.sparkSession,
             )
         else:
@@ -392,15 +430,7 @@ class DataFrame(ParentDataFrame):
         res._cached_schema = self._cached_schema
         return res
 
-    @overload
-    def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> 
ParentDataFrame:
-        ...
-
-    @overload
-    def repartitionByRange(self, *cols: "ColumnOrName") -> ParentDataFrame:
-        ...
-
-    def repartitionByRange(  # type: ignore[misc]
+    def repartitionByRange(
         self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
     ) -> ParentDataFrame:
         if isinstance(numPartitions, int):
@@ -420,14 +450,14 @@ class DataFrame(ParentDataFrame):
             else:
                 res = DataFrame(
                     plan.RepartitionByExpression(
-                        self._plan, numPartitions, [F._sort_col(c) for c in 
cols]
+                        self._plan, numPartitions, self._to_cols(cols, 
sort=True)
                     ),
                     self.sparkSession,
                 )
         elif isinstance(numPartitions, (str, Column)):
             res = DataFrame(
                 plan.RepartitionByExpression(
-                    self._plan, None, [F._sort_col(c) for c in [numPartitions] 
+ list(cols)]
+                    self._plan, None, self._to_cols((numPartitions,) + cols, 
sort=True)
                 ),
                 self.sparkSession,
             )
@@ -445,8 +475,6 @@ class DataFrame(ParentDataFrame):
     def repartitionById(
         self, numPartitions: int, partitionIdCol: "ColumnOrName"
     ) -> ParentDataFrame:
-        from pyspark.sql.connect.column import Column as ConnectColumn
-
         if not isinstance(numPartitions, int) or isinstance(numPartitions, 
bool):
             raise PySparkTypeError(
                 errorClass="NOT_INT",
@@ -533,6 +561,8 @@ class DataFrame(ParentDataFrame):
         return res
 
     def drop(self, *cols: "ColumnOrName") -> ParentDataFrame:
+        # We can't convert names to columns here because drop has different 
behavior
+        # for names and columns.
         _cols = list(cols)
         if any(not isinstance(c, (str, Column)) for c in _cols):
             raise PySparkTypeError(
@@ -560,136 +590,67 @@ class DataFrame(ParentDataFrame):
     def first(self) -> Optional[Row]:
         return self.head()
 
-    @overload  # type: ignore[no-overload-impl]
-    def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
+    @overload
+    def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
         ...
 
     @overload
-    def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> 
"GroupedData":
+    def groupBy(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> 
"GroupedData":
         ...
 
-    def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
-        if len(cols) == 1 and isinstance(cols[0], list):
-            cols = cols[0]
-
-        _cols: List[Column] = []
-        for c in cols:
-            if isinstance(c, Column):
-                _cols.append(c)
-            elif isinstance(c, str):
-                _cols.append(F.col(c))
-            elif isinstance(c, int) and not isinstance(c, bool):
-                if c < 1:
-                    raise PySparkIndexError(
-                        errorClass="INDEX_NOT_POSITIVE", 
messageParameters={"index": str(c)}
-                    )
-                # ordinal is 1-based
-                _cols.append(self[c - 1])
-            else:
-                raise PySparkTypeError(
-                    errorClass="NOT_COLUMN_OR_STR",
-                    messageParameters={"arg_name": "cols", "arg_type": 
type(c).__name__},
-                )
-
-        return GroupedData(df=self, group_type="groupby", grouping_cols=_cols)
+    def groupBy(
+        self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+    ) -> "GroupedData":
+        return GroupedData(
+            df=self, group_type="groupby", grouping_cols=self._to_cols(*cols, 
allow_ordinal=True)
+        )
 
     groupby = groupBy  # type: ignore[assignment]
 
     @overload
-    def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
+    def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
         ...
 
     @overload
-    def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+    def rollup(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> 
"GroupedData":
         ...
 
-    def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":  # 
type: ignore[misc]
-        _cols: List[Column] = []
-        for c in cols:
-            if isinstance(c, Column):
-                _cols.append(c)
-            elif isinstance(c, str):
-                _cols.append(F.col(c))
-            elif isinstance(c, int) and not isinstance(c, bool):
-                if c < 1:
-                    raise PySparkIndexError(
-                        errorClass="INDEX_NOT_POSITIVE", 
messageParameters={"index": str(c)}
-                    )
-                # ordinal is 1-based
-                _cols.append(self[c - 1])
-            else:
-                raise PySparkTypeError(
-                    errorClass="NOT_COLUMN_OR_STR",
-                    messageParameters={"arg_name": "cols", "arg_type": 
type(c).__name__},
-                )
-
-        return GroupedData(df=self, group_type="rollup", grouping_cols=_cols)
+    def rollup(
+        self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+    ) -> "GroupedData":
+        return GroupedData(
+            df=self, group_type="rollup", grouping_cols=self._to_cols(*cols, 
allow_ordinal=True)
+        )
 
     @overload
     def cube(self, *cols: "ColumnOrName") -> "GroupedData":
         ...
 
     @overload
-    def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+    def cube(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> "GroupedData":
         ...
 
-    def cube(self, *cols: "ColumnOrName") -> "GroupedData":  # type: 
ignore[misc]
-        _cols: List[Column] = []
-        for c in cols:
-            if isinstance(c, Column):
-                _cols.append(c)
-            elif isinstance(c, str):
-                _cols.append(F.col(c))
-            elif isinstance(c, int) and not isinstance(c, bool):
-                if c < 1:
-                    raise PySparkIndexError(
-                        errorClass="INDEX_NOT_POSITIVE", 
messageParameters={"index": str(c)}
-                    )
-                # ordinal is 1-based
-                _cols.append(self[c - 1])
-            else:
-                raise PySparkTypeError(
-                    errorClass="NOT_COLUMN_OR_STR",
-                    messageParameters={"arg_name": "cols", "arg_type": 
type(c).__name__},
-                )
-
-        return GroupedData(df=self, group_type="cube", grouping_cols=_cols)
+    def cube(
+        self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+    ) -> "GroupedData":
+        return GroupedData(
+            df=self, group_type="cube", grouping_cols=self._to_cols(*cols, 
allow_ordinal=True)
+        )
 
     def groupingSets(
-        self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: 
"ColumnOrName"
+        self,
+        groupingSets: Sequence[Sequence["ColumnOrNameOrOrdinal"]],
+        *cols: "ColumnOrNameOrOrdinal",
     ) -> "GroupedData":
         gsets: List[List[Column]] = []
         for grouping_set in groupingSets:
-            gset: List[Column] = []
-            for c in grouping_set:
-                if isinstance(c, Column):
-                    gset.append(c)
-                elif isinstance(c, str):
-                    gset.append(F.col(c))
-                else:
-                    raise PySparkTypeError(
-                        errorClass="NOT_COLUMN_OR_STR",
-                        messageParameters={
-                            "arg_name": "groupingSets",
-                            "arg_type": type(c).__name__,
-                        },
-                    )
-            gsets.append(gset)
-
-        gcols: List[Column] = []
-        for c in cols:
-            if isinstance(c, Column):
-                gcols.append(c)
-            elif isinstance(c, str):
-                gcols.append(F.col(c))
-            else:
-                raise PySparkTypeError(
-                    errorClass="NOT_COLUMN_OR_STR",
-                    messageParameters={"arg_name": "cols", "arg_type": 
type(c).__name__},
-                )
+            gsets.append(self._to_cols(grouping_set, arg_name="groupingSets", 
allow_ordinal=True))
 
         return GroupedData(
-            df=self, group_type="grouping_sets", grouping_cols=gcols, 
grouping_sets=gsets
+            df=self,
+            group_type="grouping_sets",
+            grouping_cols=self._to_cols(cols, allow_ordinal=True),
+            grouping_sets=gsets,
         )
 
     @overload
@@ -788,7 +749,7 @@ class DataFrame(ParentDataFrame):
 
     def sort(
         self,
-        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+        *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"],
         **kwargs: Any,
     ) -> ParentDataFrame:
         _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
@@ -807,7 +768,7 @@ class DataFrame(ParentDataFrame):
 
     def sortWithinPartitions(
         self,
-        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+        *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"],
         **kwargs: Any,
     ) -> ParentDataFrame:
         _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
@@ -1748,8 +1709,6 @@ class DataFrame(ParentDataFrame):
     def __getitem__(
         self, item: Union[int, str, Column, List, Tuple]
     ) -> Union[Column, ParentDataFrame]:
-        from pyspark.sql.connect.column import Column as ConnectColumn
-
         if isinstance(item, str):
             if item == "*":
                 return ConnectColumn(
@@ -1791,8 +1750,6 @@ class DataFrame(ParentDataFrame):
             )
 
     def _col(self, name: str, is_metadata_column: bool = False) -> Column:
-        from pyspark.sql.connect.column import Column as ConnectColumn
-
         return ConnectColumn(
             ColumnReference(
                 unparsed_identifier=name,
@@ -1863,13 +1820,9 @@ class DataFrame(ParentDataFrame):
         return ConnectTableArg(SubqueryExpression(self._plan, 
subquery_type="table_arg"))
 
     def scalar(self) -> Column:
-        from pyspark.sql.connect.column import Column as ConnectColumn
-
         return ConnectColumn(SubqueryExpression(self._plan, 
subquery_type="scalar"))
 
     def exists(self) -> Column:
-        from pyspark.sql.connect.column import Column as ConnectColumn
-
         return ConnectColumn(SubqueryExpression(self._plan, 
subquery_type="exists"))
 
     @property
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 2ee3e4e9d703..a9f51f8fca20 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1766,15 +1766,7 @@ class DataFrame:
         """
         return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession)
 
-    @overload
-    def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> 
"DataFrame":
-        ...
-
-    @overload
-    def repartition(self, *cols: "ColumnOrName") -> "DataFrame":
-        ...
-
-    @dispatch_df_method  # type: ignore[misc]
+    @dispatch_df_method
     def repartition(
         self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
     ) -> "DataFrame":
@@ -1882,15 +1874,7 @@ class DataFrame:
         """
         ...
 
-    @overload
-    def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> 
"DataFrame":
-        ...
-
-    @overload
-    def repartitionByRange(self, *cols: "ColumnOrName") -> "DataFrame":
-        ...
-
-    @dispatch_df_method  # type: ignore[misc]
+    @dispatch_df_method
     def repartitionByRange(
         self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
     ) -> "DataFrame":
@@ -2974,7 +2958,7 @@ class DataFrame:
     @dispatch_df_method
     def sortWithinPartitions(
         self,
-        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+        *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"],
         **kwargs: Any,
     ) -> "DataFrame":
         """Returns a new :class:`DataFrame` with each partition sorted by the 
specified column(s).
@@ -3038,7 +3022,7 @@ class DataFrame:
     @dispatch_df_method
     def sort(
         self,
-        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+        *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"],
         **kwargs: Any,
     ) -> "DataFrame":
         """Returns a new :class:`DataFrame` sorted by the specified column(s).
@@ -3198,7 +3182,7 @@ class DataFrame:
     def _preapare_cols_for_sort(
         self,
         _to_col: Callable[[str], Column],
-        cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
+        cols: Sequence[Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]],
         kwargs: Dict[str, Any],
     ) -> Sequence[Column]:
         from pyspark.errors import PySparkTypeError, PySparkValueError, 
PySparkIndexError
@@ -3208,10 +3192,16 @@ class DataFrame:
                 errorClass="CANNOT_BE_EMPTY", messageParameters={"item": 
"cols"}
             )
 
-        if len(cols) == 1 and isinstance(cols[0], list):
-            cols = cols[0]
+        if (
+            len(cols) == 1
+            and not isinstance(cols[0], (int, str, Column))
+            and isinstance(cols[0], Sequence)
+        ):
+            cols = tuple(cols[0])
 
-        def _get_col(c: Union[int, str, Column, List[int | str | Column]]) -> 
Column:
+        def _get_col(
+            c: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+        ) -> Column:
             if isinstance(c, int) and not isinstance(c, bool):
                 # ordinal is 1-based
                 if c > 0:
@@ -3621,11 +3611,11 @@ class DataFrame:
         ...
 
     @overload
-    def select(self, __cols: Union[List[Column], List[str]]) -> "DataFrame":
+    def select(self, __cols: Sequence["ColumnOrName"]) -> "DataFrame":
         ...
 
-    @dispatch_df_method  # type: ignore[misc]
-    def select(self, *cols: "ColumnOrName") -> "DataFrame":
+    @dispatch_df_method
+    def select(self, *cols: Union[Sequence["ColumnOrName"], "ColumnOrName"]) 
-> "DataFrame":
         """Projects a set of expressions and returns a new :class:`DataFrame`.
 
         .. versionadded:: 1.3.0
@@ -3871,11 +3861,13 @@ class DataFrame:
         ...
 
     @overload
-    def groupBy(self, __cols: Union[List[Column], List[str], List[int]]) -> 
"GroupedData":
+    def groupBy(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> 
"GroupedData":
         ...
 
-    @dispatch_df_method  # type: ignore[misc]
-    def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
+    @dispatch_df_method
+    def groupBy(
+        self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+    ) -> "GroupedData":
         """
         Groups the :class:`DataFrame` by the specified columns so that 
aggregation
         can be performed on them.
@@ -3977,15 +3969,17 @@ class DataFrame:
         ...
 
     @overload
-    def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
+    def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
         ...
 
     @overload
-    def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+    def rollup(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> 
"GroupedData":
         ...
 
-    @dispatch_df_method  # type: ignore[misc]
-    def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
+    @dispatch_df_method
+    def rollup(
+        self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+    ) -> "GroupedData":
         """
         Create a multi-dimensional rollup for the current :class:`DataFrame` 
using
         the specified columns, allowing for aggregation on them.
@@ -4060,15 +4054,17 @@ class DataFrame:
         ...
 
     @overload
-    def cube(self, *cols: "ColumnOrName") -> "GroupedData":
+    def cube(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
         ...
 
     @overload
-    def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+    def cube(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> "GroupedData":
         ...
 
-    @dispatch_df_method  # type: ignore[misc]
-    def cube(self, *cols: "ColumnOrName") -> "GroupedData":
+    @dispatch_df_method
+    def cube(
+        self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+    ) -> "GroupedData":
         """
         Create a multi-dimensional cube for the current :class:`DataFrame` 
using
         the specified columns, allowing aggregations to be performed on them.
@@ -4149,7 +4145,9 @@ class DataFrame:
 
     @dispatch_df_method
     def groupingSets(
-        self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: 
"ColumnOrName"
+        self,
+        groupingSets: Sequence[Sequence["ColumnOrNameOrOrdinal"]],
+        *cols: "ColumnOrNameOrOrdinal",
     ) -> "GroupedData":
         """
         Create multi-dimensional aggregation for the current 
:class:`DataFrame` using the specified
@@ -4815,14 +4813,6 @@ class DataFrame:
 
         .. versionadded:: 2.4.0
 
-        .. versionchanged:: 3.4.0
-            Supports Spark Connect.
-
-        Parameters
-        ----------
-        other : :class:`DataFrame`
-            Another :class:`DataFrame` that needs to be combined.
-
         Returns
         -------
         :class:`DataFrame`
@@ -6270,11 +6260,13 @@ class DataFrame:
         ...
 
     @overload
-    def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> 
"GroupedData":
+    def groupby(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> 
"GroupedData":
         ...
 
-    @dispatch_df_method  # type: ignore[misc]
-    def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
+    @dispatch_df_method
+    def groupby(
+        self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"], 
"ColumnOrNameOrOrdinal"]
+    ) -> "GroupedData":
         """
         :func:`groupby` is an alias for :func:`groupBy`.
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_error.py 
b/python/pyspark/sql/tests/connect/test_connect_error.py
index f5da7d945922..c8f203c71e6d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_error.py
+++ b/python/pyspark/sql/tests/connect/test_connect_error.py
@@ -232,8 +232,8 @@ class SparkConnectErrorTests(ReusedConnectTestCase):
 
         self.check_error(
             exception=e1.exception,
-            errorClass="NOT_LIST_OF_COLUMN_OR_STR",
-            messageParameters={"arg_name": "columns"},
+            errorClass="NOT_COLUMN_OR_STR",
+            messageParameters={"arg_name": "cols", "arg_type": "NoneType"},
         )
 
     def test_ym_interval_in_collect(self):
diff --git a/python/pyspark/sql/tests/typing/test_dataframe.yml 
b/python/pyspark/sql/tests/typing/test_dataframe.yml
index 5e4b20d3588c..4bd408310ba1 100644
--- a/python/pyspark/sql/tests/typing/test_dataframe.yml
+++ b/python/pyspark/sql/tests/typing/test_dataframe.yml
@@ -54,7 +54,7 @@
     df.select(["name", "age"])
     df.select([col("name"), col("age")])
 
-    df.select(["name", col("age")])  # E: Argument 1 to "select" of 
"DataFrame" has incompatible type "list[object]"; expected "list[Column] | 
list[str]"  [arg-type]
+    df.select(["name", col("age")])
 
 
 - case: groupBy
@@ -71,7 +71,7 @@
     df.groupby(["name", "age"])
     df.groupBy([col("name"), col("age")])
     df.groupby([col("name"), col("age")])
-    df.groupBy(["name", col("age")])  # E: Argument 1 to "groupBy" of 
"DataFrame" has incompatible type "list[object]"; expected "list[Column] | 
list[str] | list[int]"  [arg-type]
+    df.groupBy(["name", col("age")])
 
 
 - case: rollup
@@ -88,7 +88,7 @@
     df.rollup([col("name"), col("age")])
 
 
-    df.rollup(["name", col("age")])  # E: Argument 1 to "rollup" of 
"DataFrame" has incompatible type "list[object]"; expected "list[Column] | 
list[str]"  [arg-type]
+    df.rollup(["name", col("age")])
 
 
 - case: cube
@@ -105,7 +105,7 @@
     df.cube([col("name"), col("age")])
 
 
-    df.cube(["name", col("age")])  # E: Argument 1 to "cube" of "DataFrame" 
has incompatible type "list[object]"; expected "list[Column] | list[str]"  
[arg-type]
+    df.cube(["name", col("age")])
 
 
 - case: dropColumns


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to