This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f9a2077fd32f [SPARK-49810][PYTHON] Extract the preparation of
`DataFrame.sort` to parent class
f9a2077fd32f is described below
commit f9a2077fd32faf63796a68cbb3483b486f220b1c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Sep 28 16:21:30 2024 +0900
[SPARK-49810][PYTHON] Extract the preparation of `DataFrame.sort` to parent
class
### What changes were proposed in this pull request?
Extract the preparation of df.sort to parent class
### Why are the changes needed?
deduplicate code, the logics in two classes are similar
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48282 from zhengruifeng/py_sql_sort.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/classic/dataframe.py | 52 +++---------------------------
python/pyspark/sql/connect/dataframe.py | 53 +++----------------------------
python/pyspark/sql/dataframe.py | 56 +++++++++++++++++++++++++++++++++
3 files changed, 65 insertions(+), 96 deletions(-)
diff --git a/python/pyspark/sql/classic/dataframe.py
b/python/pyspark/sql/classic/dataframe.py
index 0dd66a9d8654..9f9dedbd3820 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -55,6 +55,7 @@ from pyspark.serializers import BatchedSerializer,
CPickleSerializer, UTF8Deseri
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.column import Column
+from pyspark.sql.functions import builtin as F
from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column
from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.merge import MergeIntoWriter
@@ -873,7 +874,8 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
- jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
+ _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
+ jdf = self._jdf.sortWithinPartitions(self._jseq(_cols,
_to_java_column))
return DataFrame(jdf, self.sparkSession)
def sort(
@@ -881,7 +883,8 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
- jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
+ _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
+ jdf = self._jdf.sort(self._jseq(_cols, _to_java_column))
return DataFrame(jdf, self.sparkSession)
orderBy = sort
@@ -928,51 +931,6 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
_cols.append(c) # type: ignore[arg-type]
return self._jseq(_cols, _to_java_column)
- def _sort_cols(
- self,
- cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
- kwargs: Dict[str, Any],
- ) -> "JavaObject":
- """Return a JVM Seq of Columns that describes the sort order"""
- if not cols:
- raise PySparkValueError(
- errorClass="CANNOT_BE_EMPTY",
- messageParameters={"item": "column"},
- )
- if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
-
- jcols = []
- for c in cols:
- if isinstance(c, int) and not isinstance(c, bool):
- # ordinal is 1-based
- if c > 0:
- _c = self[c - 1]
- # negative ordinal means sort by desc
- elif c < 0:
- _c = self[-c - 1].desc()
- else:
- raise PySparkIndexError(
- errorClass="ZERO_INDEX",
- messageParameters={},
- )
- else:
- _c = c # type: ignore[assignment]
- jcols.append(_to_java_column(cast("ColumnOrName", _c)))
-
- ascending = kwargs.get("ascending", True)
- if isinstance(ascending, (bool, int)):
- if not ascending:
- jcols = [jc.desc() for jc in jcols]
- elif isinstance(ascending, list):
- jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending,
jcols)]
- else:
- raise PySparkTypeError(
- errorClass="NOT_BOOL_OR_LIST",
- messageParameters={"arg_name": "ascending", "arg_type":
type(ascending).__name__},
- )
- return self._jseq(jcols)
-
def describe(self, *cols: Union[str, List[str]]) -> ParentDataFrame:
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0] # type: ignore[assignment]
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 146cfe11bc50..136fe60532df 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -739,62 +739,16 @@ class DataFrame(ParentDataFrame):
def tail(self, num: int) -> List[Row]:
return DataFrame(plan.Tail(child=self._plan, limit=num),
session=self._session).collect()
- def _sort_cols(
- self,
- cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
- kwargs: Dict[str, Any],
- ) -> List[Column]:
- """Return a JVM Seq of Columns that describes the sort order"""
- if cols is None:
- raise PySparkValueError(
- errorClass="CANNOT_BE_EMPTY",
- messageParameters={"item": "cols"},
- )
-
- if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
-
- _cols: List[Column] = []
- for c in cols:
- if isinstance(c, int) and not isinstance(c, bool):
- # ordinal is 1-based
- if c > 0:
- _c = self[c - 1]
- # negative ordinal means sort by desc
- elif c < 0:
- _c = self[-c - 1].desc()
- else:
- raise PySparkIndexError(
- errorClass="ZERO_INDEX",
- messageParameters={},
- )
- else:
- _c = c # type: ignore[assignment]
- _cols.append(F._to_col(cast("ColumnOrName", _c)))
-
- ascending = kwargs.get("ascending", True)
- if isinstance(ascending, (bool, int)):
- if not ascending:
- _cols = [c.desc() for c in _cols]
- elif isinstance(ascending, list):
- _cols = [c if asc else c.desc() for asc, c in zip(ascending,
_cols)]
- else:
- raise PySparkTypeError(
- errorClass="NOT_BOOL_OR_LIST",
- messageParameters={"arg_name": "ascending", "arg_type":
type(ascending).__name__},
- )
-
- return [F._sort_col(c) for c in _cols]
-
def sort(
self,
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
+ _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
res = DataFrame(
plan.Sort(
self._plan,
- columns=self._sort_cols(cols, kwargs),
+ columns=[F._sort_col(c) for c in _cols],
is_global=True,
),
session=self._session,
@@ -809,10 +763,11 @@ class DataFrame(ParentDataFrame):
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
+ _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
res = DataFrame(
plan.Sort(
self._plan,
- columns=self._sort_cols(cols, kwargs),
+ columns=[F._sort_col(c) for c in _cols],
is_global=False,
),
session=self._session,
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 142034583dbd..5906108163b4 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2891,6 +2891,62 @@ class DataFrame:
"""
...
+ def _preapare_cols_for_sort(
+ self,
+ _to_col: Callable[[str], Column],
+ cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
+ kwargs: Dict[str, Any],
+ ) -> Sequence[Column]:
+ from pyspark.errors import PySparkTypeError, PySparkValueError,
PySparkIndexError
+
+ if not cols:
+ raise PySparkValueError(
+ errorClass="CANNOT_BE_EMPTY", messageParameters={"item":
"cols"}
+ )
+
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0]
+
+ _cols: List[Column] = []
+ for c in cols:
+ if isinstance(c, int) and not isinstance(c, bool):
+ # ordinal is 1-based
+ if c > 0:
+ _cols.append(self[c - 1])
+ # negative ordinal means sort by desc
+ elif c < 0:
+ _cols.append(self[-c - 1].desc())
+ else:
+ raise PySparkIndexError(
+ errorClass="ZERO_INDEX",
+ messageParameters={},
+ )
+ elif isinstance(c, Column):
+ _cols.append(c)
+ elif isinstance(c, str):
+ _cols.append(_to_col(c))
+ else:
+ raise PySparkTypeError(
+ errorClass="NOT_COLUMN_OR_INT_OR_STR",
+ messageParameters={
+ "arg_name": "col",
+ "arg_type": type(c).__name__,
+ },
+ )
+
+ ascending = kwargs.get("ascending", True)
+ if isinstance(ascending, (bool, int)):
+ if not ascending:
+ _cols = [c.desc() for c in _cols]
+ elif isinstance(ascending, list):
+ _cols = [c if asc else c.desc() for asc, c in zip(ascending,
_cols)]
+ else:
+ raise PySparkTypeError(
+ errorClass="NOT_COLUMN_OR_INT_OR_STR",
+ messageParameters={"arg_name": "ascending", "arg_type":
type(ascending).__name__},
+ )
+ return _cols
+
orderBy = sort
@dispatch_df_method
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]