This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8efc4c6c74be [SPARK-55967][PYTHON] Unify column conversion for connect
dataframe
8efc4c6c74be is described below
commit 8efc4c6c74be08f05b5903ba390aa9c3aa21240a
Author: Tian Gao <[email protected]>
AuthorDate: Fri Mar 13 08:57:32 2026 +0800
[SPARK-55967][PYTHON] Unify column conversion for connect dataframe
### What changes were proposed in this pull request?
* A new column conversion helper function `_to_cols` is introduced.
* All the methods that try to convert column from `list` or `Column` or
`str` or `int` now uses this unified entry.
* `ConnectColumn` import is moved to module-level
* type hint for input argument is change from `List` to `Sequence`.
* Incorrect overloads have been removed.
### Why are the changes needed?
Reduce duplicated code. Actually in the future we can easily add support
for ordinal for many methods.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54764 from gaogaotiantian/unify-column-conversion.
Authored-by: Tian Gao <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/pandas/groupby.py | 2 +-
python/pyspark/sql/classic/dataframe.py | 95 ++++----
python/pyspark/sql/connect/dataframe.py | 251 +++++++++------------
python/pyspark/sql/dataframe.py | 92 ++++----
.../sql/tests/connect/test_connect_error.py | 4 +-
python/pyspark/sql/tests/typing/test_dataframe.yml | 8 +-
6 files changed, 194 insertions(+), 258 deletions(-)
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index f23422b43a22..ddb9cb032371 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -4248,7 +4248,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
F.col(f"{auxiliary_col_name}.{CORRELATION_CORR_OUTPUT_COLUMN}"),
)
- sdf = sdf.orderBy(groupkey_names + [index_1_col_name]) # type:
ignore[arg-type]
+ sdf = sdf.orderBy(groupkey_names + [index_1_col_name])
sdf = sdf.select(
*[F.col(col) for col in groupkey_names + numeric_col_names],
diff --git a/python/pyspark/sql/classic/dataframe.py
b/python/pyspark/sql/classic/dataframe.py
index 243fa8a37b5a..e300d0d39293 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -544,15 +544,7 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
def coalesce(self, numPartitions: int) -> ParentDataFrame:
return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession)
- @overload
- def repartition(self, numPartitions: int, *cols: "ColumnOrName") ->
ParentDataFrame:
- ...
-
- @overload
- def repartition(self, *cols: "ColumnOrName") -> ParentDataFrame:
- ...
-
- def repartition( # type: ignore[misc]
+ def repartition(
self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
) -> ParentDataFrame:
if isinstance(numPartitions, int):
@@ -575,15 +567,7 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
},
)
- @overload
- def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") ->
ParentDataFrame:
- ...
-
- @overload
- def repartitionByRange(self, *cols: "ColumnOrName") -> ParentDataFrame:
- ...
-
- def repartitionByRange( # type: ignore[misc]
+ def repartitionByRange(
self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
) -> ParentDataFrame:
if isinstance(numPartitions, int):
@@ -914,7 +898,7 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
def sortWithinPartitions(
self,
- *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+ *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"],
**kwargs: Any,
) -> ParentDataFrame:
_cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
@@ -923,7 +907,7 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
def sort(
self,
- *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+ *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"],
**kwargs: Any,
) -> ParentDataFrame:
_cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
@@ -944,22 +928,32 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
"""Return a JVM Scala Map from a dict"""
return to_scala_map(self.sparkSession._sc._jvm, jm)
- def _jcols(self, *cols: "ColumnOrName") -> "JavaObject":
+ def _jcols(self, *cols: Union[Sequence["ColumnOrName"], "ColumnOrName"])
-> "JavaObject":
"""Return a JVM Seq of Columns from a list of Column or column names
If `cols` has only one list in it, cols[0] will be used as the list.
"""
- if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
+ if (
+ len(cols) == 1
+ and not isinstance(cols[0], (str, Column))
+ and isinstance(cols[0], Sequence)
+ ):
+ cols = tuple(cols[0])
return self._jseq(cols, _to_java_column)
- def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> "JavaObject":
+ def _jcols_ordinal(
+ self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> "JavaObject":
"""Return a JVM Seq of Columns from a list of Column or column names
or column ordinals.
If `cols` has only one list in it, cols[0] will be used as the list.
"""
- if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
+ if (
+ len(cols) == 1
+ and not isinstance(cols[0], (int, str, Column))
+ and isinstance(cols[0], Sequence)
+ ):
+ cols = tuple(cols[0])
_cols = []
for c in cols:
@@ -1048,10 +1042,10 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
...
@overload
- def select(self, __cols: Union[List[Column], List[str]]) ->
ParentDataFrame:
+ def select(self, __cols: Sequence["ColumnOrName"]) -> ParentDataFrame:
...
- def select(self, *cols: "ColumnOrName") -> ParentDataFrame: # type:
ignore[misc]
+ def select(self, *cols: Union[Sequence["ColumnOrName"], "ColumnOrName"])
-> ParentDataFrame:
jdf = self._jdf.select(self._jcols(*cols))
return DataFrame(jdf, self.sparkSession)
@@ -1086,51 +1080,59 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
...
@overload
- def groupBy(self, __cols: Union[List[Column], List[str], List[int]]) ->
"GroupedData":
+ def groupBy(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) ->
"GroupedData":
...
- def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": #
type: ignore[misc]
+ def groupBy(
+ self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> "GroupedData":
jgd = self._jdf.groupBy(self._jcols_ordinal(*cols))
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self)
@overload
- def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
+ def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
...
@overload
- def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+ def rollup(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) ->
"GroupedData":
...
- def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": #
type: ignore[misc]
+ def rollup(
+ self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> "GroupedData":
jgd = self._jdf.rollup(self._jcols_ordinal(*cols))
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self)
@overload
- def cube(self, *cols: "ColumnOrName") -> "GroupedData":
+ def cube(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
...
@overload
- def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+ def cube(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> "GroupedData":
...
- def cube(self, *cols: "ColumnOrName") -> "GroupedData": # type:
ignore[misc]
+ def cube(
+ self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> "GroupedData":
jgd = self._jdf.cube(self._jcols_ordinal(*cols))
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self)
def groupingSets(
- self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols:
"ColumnOrName"
+ self,
+ groupingSets: Sequence[Sequence["ColumnOrNameOrOrdinal"]],
+ *cols: "ColumnOrNameOrOrdinal",
) -> "GroupedData":
from pyspark.sql.group import GroupedData
- jgrouping_sets = _to_seq(self._sc, [self._jcols(*inner) for inner in
groupingSets])
+ jgrouping_sets = _to_seq(self._sc, [self._jcols_ordinal(*inner) for
inner in groupingSets])
- jgd = self._jdf.groupingSets(jgrouping_sets, self._jcols(*cols))
+ jgd = self._jdf.groupingSets(jgrouping_sets,
self._jcols_ordinal(*cols))
return GroupedData(jgd, self)
def unpivot(
@@ -1776,19 +1778,8 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
# make it "compatible" by adding aliases. Therefore, we stop adding such
# aliases as of Spark 3.0. Two methods below remain just
# for legacy users currently.
- @overload
- def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
- ...
-
- @overload
- def groupby(self, __cols: Union[List[Column], List[str], List[int]]) ->
"GroupedData":
- ...
-
- def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": #
type: ignore[misc]
- return self.groupBy(*cols)
-
- def drop_duplicates(self, subset: Optional[List[str]] = None) ->
ParentDataFrame:
- return self.dropDuplicates(subset)
+ groupby = groupBy
+ drop_duplicates = dropDuplicates
def writeTo(self, table: str) -> "DataFrameWriterV2":
return DataFrameWriterV2(self, table)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 89846e36e718..06f94a17940e 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -28,6 +28,7 @@ from typing import (
Dict,
Iterator,
List,
+ NoReturn,
Optional,
Tuple,
Union,
@@ -75,10 +76,12 @@ from pyspark.sql.connect.group import GroupedData
from pyspark.sql.connect.merge import MergeIntoWriter
from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.connect.streaming.readwriter import DataStreamWriter
+from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.column import Column
from pyspark.sql.connect.expressions import (
ColumnReference,
DirectShufflePartitionID,
+ SortOrder,
SubqueryExpression,
UnresolvedRegex,
UnresolvedStar,
@@ -210,6 +213,60 @@ class DataFrame(ParentDataFrame):
else:
return None
+ def _to_cols(
+ self,
+ *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"],
+ arg_name: str = "cols",
+ allow_ordinal: bool = False,
+ allow_sequence: bool = True,
+ sort: bool = False,
+ ) -> List[Column]:
+ def _raise() -> NoReturn:
+ if allow_ordinal:
+ raise PySparkTypeError(
+ errorClass="NOT_COLUMN_OR_INT_OR_STR",
+ messageParameters={"arg_name": arg_name, "arg_type":
type(c).__name__},
+ )
+ else:
+ raise PySparkTypeError(
+ errorClass="NOT_COLUMN_OR_STR",
+ messageParameters={"arg_name": arg_name, "arg_type":
type(c).__name__},
+ )
+
+ if (
+ len(cols) == 1
+ and not isinstance(cols[0], (int, str, Column))
+ and isinstance(cols[0], Sequence)
+ ):
+ if allow_sequence:
+ cols = tuple(cols[0])
+ else:
+ _raise()
+
+ _cols: List[Column] = []
+ for c in cols:
+ if isinstance(c, Column):
+ col = c
+ elif isinstance(c, str):
+ col = F.col(c)
+ elif isinstance(c, int) and not isinstance(c, bool):
+ if allow_ordinal:
+ if c < 1:
+ raise PySparkIndexError(
+ errorClass="INDEX_NOT_POSITIVE",
messageParameters={"index": str(c)}
+ )
+ col = self[c - 1]
+ else:
+ _raise()
+ else:
+ _raise()
+
+ if sort:
+ if not isinstance(col._expr, SortOrder):
+ col = col.asc()
+ _cols.append(col)
+ return _cols
+
@property
def write(self) -> "DataFrameWriter":
def cb(qe: "ExecutionInfo") -> None:
@@ -226,19 +283,12 @@ class DataFrame(ParentDataFrame):
...
@overload
- def select(self, __cols: Union[List[Column], List[str]]) ->
ParentDataFrame:
+ def select(self, __cols: Union[Sequence["ColumnOrName"], "ColumnOrName"])
-> ParentDataFrame:
...
- def select(self, *cols: "ColumnOrName") -> ParentDataFrame: # type:
ignore[misc]
- if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
- if any(not isinstance(c, (str, Column)) for c in cols):
- raise PySparkTypeError(
- errorClass="NOT_LIST_OF_COLUMN_OR_STR",
- messageParameters={"arg_name": "columns"},
- )
+ def select(self, *cols: Union[Sequence["ColumnOrName"], "ColumnOrName"])
-> ParentDataFrame:
return DataFrame(
- plan.Project(self._plan, [F._to_col(c) for c in cols]),
+ plan.Project(self._plan, self._to_cols(*cols)),
session=self._session,
)
@@ -284,8 +334,6 @@ class DataFrame(ParentDataFrame):
return self._col(colName, is_metadata_column=True)
def colRegex(self, colName: str) -> Column:
- from pyspark.sql.connect.column import Column as ConnectColumn
-
if not isinstance(colName, str):
raise PySparkTypeError(
errorClass="NOT_STR",
@@ -342,15 +390,7 @@ class DataFrame(ParentDataFrame):
res._cached_schema = self._cached_schema
return res
- @overload
- def repartition(self, numPartitions: int, *cols: "ColumnOrName") ->
ParentDataFrame:
- ...
-
- @overload
- def repartition(self, *cols: "ColumnOrName") -> ParentDataFrame:
- ...
-
- def repartition( # type: ignore[misc]
+ def repartition(
self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
) -> ParentDataFrame:
if isinstance(numPartitions, int):
@@ -369,15 +409,13 @@ class DataFrame(ParentDataFrame):
)
else:
res = DataFrame(
- plan.RepartitionByExpression(
- self._plan, numPartitions, [F._to_col(c) for c in cols]
- ),
+ plan.RepartitionByExpression(self._plan, numPartitions,
self._to_cols(cols)),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
cols = (numPartitions,) + cols
res = DataFrame(
- plan.RepartitionByExpression(self._plan, None, [F._to_col(c)
for c in cols]),
+ plan.RepartitionByExpression(self._plan, None,
self._to_cols(cols)),
self.sparkSession,
)
else:
@@ -392,15 +430,7 @@ class DataFrame(ParentDataFrame):
res._cached_schema = self._cached_schema
return res
- @overload
- def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") ->
ParentDataFrame:
- ...
-
- @overload
- def repartitionByRange(self, *cols: "ColumnOrName") -> ParentDataFrame:
- ...
-
- def repartitionByRange( # type: ignore[misc]
+ def repartitionByRange(
self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
) -> ParentDataFrame:
if isinstance(numPartitions, int):
@@ -420,14 +450,14 @@ class DataFrame(ParentDataFrame):
else:
res = DataFrame(
plan.RepartitionByExpression(
- self._plan, numPartitions, [F._sort_col(c) for c in
cols]
+ self._plan, numPartitions, self._to_cols(cols,
sort=True)
),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
res = DataFrame(
plan.RepartitionByExpression(
- self._plan, None, [F._sort_col(c) for c in [numPartitions]
+ list(cols)]
+ self._plan, None, self._to_cols((numPartitions,) + cols,
sort=True)
),
self.sparkSession,
)
@@ -445,8 +475,6 @@ class DataFrame(ParentDataFrame):
def repartitionById(
self, numPartitions: int, partitionIdCol: "ColumnOrName"
) -> ParentDataFrame:
- from pyspark.sql.connect.column import Column as ConnectColumn
-
if not isinstance(numPartitions, int) or isinstance(numPartitions,
bool):
raise PySparkTypeError(
errorClass="NOT_INT",
@@ -533,6 +561,8 @@ class DataFrame(ParentDataFrame):
return res
def drop(self, *cols: "ColumnOrName") -> ParentDataFrame:
+ # We can't convert names to columns here because drop has different
behavior
+ # for names and columns.
_cols = list(cols)
if any(not isinstance(c, (str, Column)) for c in _cols):
raise PySparkTypeError(
@@ -560,136 +590,67 @@ class DataFrame(ParentDataFrame):
def first(self) -> Optional[Row]:
return self.head()
- @overload # type: ignore[no-overload-impl]
- def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
+ @overload
+ def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
...
@overload
- def groupby(self, __cols: Union[List[Column], List[str], List[int]]) ->
"GroupedData":
+ def groupBy(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) ->
"GroupedData":
...
- def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
- if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
-
- _cols: List[Column] = []
- for c in cols:
- if isinstance(c, Column):
- _cols.append(c)
- elif isinstance(c, str):
- _cols.append(F.col(c))
- elif isinstance(c, int) and not isinstance(c, bool):
- if c < 1:
- raise PySparkIndexError(
- errorClass="INDEX_NOT_POSITIVE",
messageParameters={"index": str(c)}
- )
- # ordinal is 1-based
- _cols.append(self[c - 1])
- else:
- raise PySparkTypeError(
- errorClass="NOT_COLUMN_OR_STR",
- messageParameters={"arg_name": "cols", "arg_type":
type(c).__name__},
- )
-
- return GroupedData(df=self, group_type="groupby", grouping_cols=_cols)
+ def groupBy(
+ self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> "GroupedData":
+ return GroupedData(
+ df=self, group_type="groupby", grouping_cols=self._to_cols(*cols,
allow_ordinal=True)
+ )
groupby = groupBy # type: ignore[assignment]
@overload
- def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
+ def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
...
@overload
- def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+ def rollup(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) ->
"GroupedData":
...
- def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": #
type: ignore[misc]
- _cols: List[Column] = []
- for c in cols:
- if isinstance(c, Column):
- _cols.append(c)
- elif isinstance(c, str):
- _cols.append(F.col(c))
- elif isinstance(c, int) and not isinstance(c, bool):
- if c < 1:
- raise PySparkIndexError(
- errorClass="INDEX_NOT_POSITIVE",
messageParameters={"index": str(c)}
- )
- # ordinal is 1-based
- _cols.append(self[c - 1])
- else:
- raise PySparkTypeError(
- errorClass="NOT_COLUMN_OR_STR",
- messageParameters={"arg_name": "cols", "arg_type":
type(c).__name__},
- )
-
- return GroupedData(df=self, group_type="rollup", grouping_cols=_cols)
+ def rollup(
+ self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> "GroupedData":
+ return GroupedData(
+ df=self, group_type="rollup", grouping_cols=self._to_cols(*cols,
allow_ordinal=True)
+ )
@overload
def cube(self, *cols: "ColumnOrName") -> "GroupedData":
...
@overload
- def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+ def cube(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> "GroupedData":
...
- def cube(self, *cols: "ColumnOrName") -> "GroupedData": # type:
ignore[misc]
- _cols: List[Column] = []
- for c in cols:
- if isinstance(c, Column):
- _cols.append(c)
- elif isinstance(c, str):
- _cols.append(F.col(c))
- elif isinstance(c, int) and not isinstance(c, bool):
- if c < 1:
- raise PySparkIndexError(
- errorClass="INDEX_NOT_POSITIVE",
messageParameters={"index": str(c)}
- )
- # ordinal is 1-based
- _cols.append(self[c - 1])
- else:
- raise PySparkTypeError(
- errorClass="NOT_COLUMN_OR_STR",
- messageParameters={"arg_name": "cols", "arg_type":
type(c).__name__},
- )
-
- return GroupedData(df=self, group_type="cube", grouping_cols=_cols)
+ def cube(
+ self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> "GroupedData":
+ return GroupedData(
+ df=self, group_type="cube", grouping_cols=self._to_cols(*cols,
allow_ordinal=True)
+ )
def groupingSets(
- self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols:
"ColumnOrName"
+ self,
+ groupingSets: Sequence[Sequence["ColumnOrNameOrOrdinal"]],
+ *cols: "ColumnOrNameOrOrdinal",
) -> "GroupedData":
gsets: List[List[Column]] = []
for grouping_set in groupingSets:
- gset: List[Column] = []
- for c in grouping_set:
- if isinstance(c, Column):
- gset.append(c)
- elif isinstance(c, str):
- gset.append(F.col(c))
- else:
- raise PySparkTypeError(
- errorClass="NOT_COLUMN_OR_STR",
- messageParameters={
- "arg_name": "groupingSets",
- "arg_type": type(c).__name__,
- },
- )
- gsets.append(gset)
-
- gcols: List[Column] = []
- for c in cols:
- if isinstance(c, Column):
- gcols.append(c)
- elif isinstance(c, str):
- gcols.append(F.col(c))
- else:
- raise PySparkTypeError(
- errorClass="NOT_COLUMN_OR_STR",
- messageParameters={"arg_name": "cols", "arg_type":
type(c).__name__},
- )
+ gsets.append(self._to_cols(grouping_set, arg_name="groupingSets",
allow_ordinal=True))
return GroupedData(
- df=self, group_type="grouping_sets", grouping_cols=gcols,
grouping_sets=gsets
+ df=self,
+ group_type="grouping_sets",
+ grouping_cols=self._to_cols(cols, allow_ordinal=True),
+ grouping_sets=gsets,
)
@overload
@@ -788,7 +749,7 @@ class DataFrame(ParentDataFrame):
def sort(
self,
- *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+ *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"],
**kwargs: Any,
) -> ParentDataFrame:
_cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
@@ -807,7 +768,7 @@ class DataFrame(ParentDataFrame):
def sortWithinPartitions(
self,
- *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+ *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"],
**kwargs: Any,
) -> ParentDataFrame:
_cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
@@ -1748,8 +1709,6 @@ class DataFrame(ParentDataFrame):
def __getitem__(
self, item: Union[int, str, Column, List, Tuple]
) -> Union[Column, ParentDataFrame]:
- from pyspark.sql.connect.column import Column as ConnectColumn
-
if isinstance(item, str):
if item == "*":
return ConnectColumn(
@@ -1791,8 +1750,6 @@ class DataFrame(ParentDataFrame):
)
def _col(self, name: str, is_metadata_column: bool = False) -> Column:
- from pyspark.sql.connect.column import Column as ConnectColumn
-
return ConnectColumn(
ColumnReference(
unparsed_identifier=name,
@@ -1863,13 +1820,9 @@ class DataFrame(ParentDataFrame):
return ConnectTableArg(SubqueryExpression(self._plan,
subquery_type="table_arg"))
def scalar(self) -> Column:
- from pyspark.sql.connect.column import Column as ConnectColumn
-
return ConnectColumn(SubqueryExpression(self._plan,
subquery_type="scalar"))
def exists(self) -> Column:
- from pyspark.sql.connect.column import Column as ConnectColumn
-
return ConnectColumn(SubqueryExpression(self._plan,
subquery_type="exists"))
@property
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 2ee3e4e9d703..a9f51f8fca20 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1766,15 +1766,7 @@ class DataFrame:
"""
return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession)
- @overload
- def repartition(self, numPartitions: int, *cols: "ColumnOrName") ->
"DataFrame":
- ...
-
- @overload
- def repartition(self, *cols: "ColumnOrName") -> "DataFrame":
- ...
-
- @dispatch_df_method # type: ignore[misc]
+ @dispatch_df_method
def repartition(
self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
) -> "DataFrame":
@@ -1882,15 +1874,7 @@ class DataFrame:
"""
...
- @overload
- def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") ->
"DataFrame":
- ...
-
- @overload
- def repartitionByRange(self, *cols: "ColumnOrName") -> "DataFrame":
- ...
-
- @dispatch_df_method # type: ignore[misc]
+ @dispatch_df_method
def repartitionByRange(
self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
) -> "DataFrame":
@@ -2974,7 +2958,7 @@ class DataFrame:
@dispatch_df_method
def sortWithinPartitions(
self,
- *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+ *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"],
**kwargs: Any,
) -> "DataFrame":
"""Returns a new :class:`DataFrame` with each partition sorted by the
specified column(s).
@@ -3038,7 +3022,7 @@ class DataFrame:
@dispatch_df_method
def sort(
self,
- *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+ *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"],
**kwargs: Any,
) -> "DataFrame":
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
@@ -3198,7 +3182,7 @@ class DataFrame:
def _preapare_cols_for_sort(
self,
_to_col: Callable[[str], Column],
- cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
+ cols: Sequence[Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]],
kwargs: Dict[str, Any],
) -> Sequence[Column]:
from pyspark.errors import PySparkTypeError, PySparkValueError,
PySparkIndexError
@@ -3208,10 +3192,16 @@ class DataFrame:
errorClass="CANNOT_BE_EMPTY", messageParameters={"item":
"cols"}
)
- if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
+ if (
+ len(cols) == 1
+ and not isinstance(cols[0], (int, str, Column))
+ and isinstance(cols[0], Sequence)
+ ):
+ cols = tuple(cols[0])
- def _get_col(c: Union[int, str, Column, List[int | str | Column]]) ->
Column:
+ def _get_col(
+ c: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> Column:
if isinstance(c, int) and not isinstance(c, bool):
# ordinal is 1-based
if c > 0:
@@ -3621,11 +3611,11 @@ class DataFrame:
...
@overload
- def select(self, __cols: Union[List[Column], List[str]]) -> "DataFrame":
+ def select(self, __cols: Sequence["ColumnOrName"]) -> "DataFrame":
...
- @dispatch_df_method # type: ignore[misc]
- def select(self, *cols: "ColumnOrName") -> "DataFrame":
+ @dispatch_df_method
+ def select(self, *cols: Union[Sequence["ColumnOrName"], "ColumnOrName"])
-> "DataFrame":
"""Projects a set of expressions and returns a new :class:`DataFrame`.
.. versionadded:: 1.3.0
@@ -3871,11 +3861,13 @@ class DataFrame:
...
@overload
- def groupBy(self, __cols: Union[List[Column], List[str], List[int]]) ->
"GroupedData":
+ def groupBy(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) ->
"GroupedData":
...
- @dispatch_df_method # type: ignore[misc]
- def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
+ @dispatch_df_method
+ def groupBy(
+ self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> "GroupedData":
"""
Groups the :class:`DataFrame` by the specified columns so that
aggregation
can be performed on them.
@@ -3977,15 +3969,17 @@ class DataFrame:
...
@overload
- def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
+ def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
...
@overload
- def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+ def rollup(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) ->
"GroupedData":
...
- @dispatch_df_method # type: ignore[misc]
- def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
+ @dispatch_df_method
+ def rollup(
+ self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> "GroupedData":
"""
Create a multi-dimensional rollup for the current :class:`DataFrame`
using
the specified columns, allowing for aggregation on them.
@@ -4060,15 +4054,17 @@ class DataFrame:
...
@overload
- def cube(self, *cols: "ColumnOrName") -> "GroupedData":
+ def cube(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
...
@overload
- def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+ def cube(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) -> "GroupedData":
...
- @dispatch_df_method # type: ignore[misc]
- def cube(self, *cols: "ColumnOrName") -> "GroupedData":
+ @dispatch_df_method
+ def cube(
+ self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> "GroupedData":
"""
Create a multi-dimensional cube for the current :class:`DataFrame`
using
the specified columns, allowing aggregations to be performed on them.
@@ -4149,7 +4145,9 @@ class DataFrame:
@dispatch_df_method
def groupingSets(
- self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols:
"ColumnOrName"
+ self,
+ groupingSets: Sequence[Sequence["ColumnOrNameOrOrdinal"]],
+ *cols: "ColumnOrNameOrOrdinal",
) -> "GroupedData":
"""
Create multi-dimensional aggregation for the current
:class:`DataFrame` using the specified
@@ -4815,14 +4813,6 @@ class DataFrame:
.. versionadded:: 2.4.0
- .. versionchanged:: 3.4.0
- Supports Spark Connect.
-
- Parameters
- ----------
- other : :class:`DataFrame`
- Another :class:`DataFrame` that needs to be combined.
-
Returns
-------
:class:`DataFrame`
@@ -6270,11 +6260,13 @@ class DataFrame:
...
@overload
- def groupby(self, __cols: Union[List[Column], List[str], List[int]]) ->
"GroupedData":
+ def groupby(self, __cols: Sequence["ColumnOrNameOrOrdinal"]) ->
"GroupedData":
...
- @dispatch_df_method # type: ignore[misc]
- def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
+ @dispatch_df_method
+ def groupby(
+ self, *cols: Union[Sequence["ColumnOrNameOrOrdinal"],
"ColumnOrNameOrOrdinal"]
+ ) -> "GroupedData":
"""
:func:`groupby` is an alias for :func:`groupBy`.
diff --git a/python/pyspark/sql/tests/connect/test_connect_error.py
b/python/pyspark/sql/tests/connect/test_connect_error.py
index f5da7d945922..c8f203c71e6d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_error.py
+++ b/python/pyspark/sql/tests/connect/test_connect_error.py
@@ -232,8 +232,8 @@ class SparkConnectErrorTests(ReusedConnectTestCase):
self.check_error(
exception=e1.exception,
- errorClass="NOT_LIST_OF_COLUMN_OR_STR",
- messageParameters={"arg_name": "columns"},
+ errorClass="NOT_COLUMN_OR_STR",
+ messageParameters={"arg_name": "cols", "arg_type": "NoneType"},
)
def test_ym_interval_in_collect(self):
diff --git a/python/pyspark/sql/tests/typing/test_dataframe.yml
b/python/pyspark/sql/tests/typing/test_dataframe.yml
index 5e4b20d3588c..4bd408310ba1 100644
--- a/python/pyspark/sql/tests/typing/test_dataframe.yml
+++ b/python/pyspark/sql/tests/typing/test_dataframe.yml
@@ -54,7 +54,7 @@
df.select(["name", "age"])
df.select([col("name"), col("age")])
- df.select(["name", col("age")]) # E: Argument 1 to "select" of
"DataFrame" has incompatible type "list[object]"; expected "list[Column] |
list[str]" [arg-type]
+ df.select(["name", col("age")])
- case: groupBy
@@ -71,7 +71,7 @@
df.groupby(["name", "age"])
df.groupBy([col("name"), col("age")])
df.groupby([col("name"), col("age")])
- df.groupBy(["name", col("age")]) # E: Argument 1 to "groupBy" of
"DataFrame" has incompatible type "list[object]"; expected "list[Column] |
list[str] | list[int]" [arg-type]
+ df.groupBy(["name", col("age")])
- case: rollup
@@ -88,7 +88,7 @@
df.rollup([col("name"), col("age")])
- df.rollup(["name", col("age")]) # E: Argument 1 to "rollup" of
"DataFrame" has incompatible type "list[object]"; expected "list[Column] |
list[str]" [arg-type]
+ df.rollup(["name", col("age")])
- case: cube
@@ -105,7 +105,7 @@
df.cube([col("name"), col("age")])
- df.cube(["name", col("age")]) # E: Argument 1 to "cube" of "DataFrame"
has incompatible type "list[object]"; expected "list[Column] | list[str]"
[arg-type]
+ df.cube(["name", col("age")])
- case: dropColumns
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]