This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0eb96ae6eb86 [SPARK-47620][PYTHON][CONNECT] Add a helper function to
sort columns
0eb96ae6eb86 is described below
commit 0eb96ae6eb8680155d4c6974dadaeebd7475a1fc
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Mar 28 12:22:27 2024 +0800
[SPARK-47620][PYTHON][CONNECT] Add a helper function to sort columns
### What changes were proposed in this pull request?
Add a helper function `_sort_col` to sort columns
### Why are the changes needed?
simple code refactoring
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #45743 from zhengruifeng/connect_sort_col.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/dataframe.py | 16 +++-------------
python/pyspark/sql/connect/functions/builtin.py | 12 ++++++++++++
python/pyspark/sql/connect/plan.py | 15 ++-------------
3 files changed, 17 insertions(+), 26 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 82929ccfbc4f..672ac8b9c25c 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -72,7 +72,6 @@ from pyspark.sql.connect.readwriter import DataFrameWriter,
DataFrameWriterV2
from pyspark.sql.connect.streaming.readwriter import DataStreamWriter
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import (
- SortOrder,
ColumnReference,
UnresolvedRegex,
UnresolvedStar,
@@ -349,15 +348,6 @@ class DataFrame:
def repartitionByRange( # type: ignore[misc]
self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
) -> "DataFrame":
- def _convert_col(col: "ColumnOrName") -> Column:
- if isinstance(col, Column):
- if isinstance(col._expr, SortOrder):
- return col
- else:
- return col.asc()
- else:
- return F.col(col).asc()
-
if isinstance(numPartitions, int):
if not numPartitions > 0:
raise PySparkValueError(
@@ -375,14 +365,14 @@ class DataFrame:
else:
return DataFrame(
plan.RepartitionByExpression(
- self._plan, numPartitions, [_convert_col(c) for c in
cols]
+ self._plan, numPartitions, [F._sort_col(c) for c in
cols]
),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
return DataFrame(
plan.RepartitionByExpression(
- self._plan, None, [_convert_col(c) for c in
[numPartitions] + list(cols)]
+ self._plan, None, [F._sort_col(c) for c in [numPartitions]
+ list(cols)]
),
self.sparkSession,
)
@@ -729,7 +719,7 @@ class DataFrame:
message_parameters={"arg_name": "ascending", "arg_type":
type(ascending).__name__},
)
- return _cols
+ return [F._sort_col(c) for c in _cols]
def sort(
self,
diff --git a/python/pyspark/sql/connect/functions/builtin.py
b/python/pyspark/sql/connect/functions/builtin.py
index c423c5f188ef..57c1f881eebf 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -45,6 +45,7 @@ from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import (
CaseWhen,
+ SortOrder,
Expression,
LiteralExpression,
ColumnReference,
@@ -88,6 +89,17 @@ def _to_col(col: "ColumnOrName") -> Column:
return col if isinstance(col, Column) else column(col)
+def _sort_col(col: "ColumnOrName") -> Column:
+ assert isinstance(col, (Column, str))
+ if isinstance(col, Column):
+ if isinstance(col._expr, SortOrder):
+ return col
+ else:
+ return col.asc()
+ else:
+ return column(col).asc()
+
+
def _invoke_function(name: str, *args: Union[Column, Expression]) -> Column:
"""
Simple wrapper function that converts the arguments into the appropriate
types.
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index d9dd8874398c..863c27fabf6b 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -46,10 +46,7 @@ from pyspark.sql.types import DataType
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.conversion import storage_level_to_proto
from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.expressions import (
- Expression,
- SortOrder,
-)
+from pyspark.sql.connect.expressions import Expression
from pyspark.sql.connect.types import pyspark_types_to_proto_types,
UnparsedDataType
from pyspark.errors import (
PySparkValueError,
@@ -674,19 +671,11 @@ class Sort(LogicalPlan):
self.columns = columns
self.is_global = is_global
- def _convert_col(
- self, col: Column, session: "SparkConnectClient"
- ) -> proto.Expression.SortOrder:
- if isinstance(col._expr, SortOrder):
- return col._expr.to_plan(session).sort_order
- else:
- return SortOrder(col._expr).to_plan(session).sort_order
-
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.sort.input.CopyFrom(self._child.plan(session))
- plan.sort.order.extend([self._convert_col(c, session) for c in
self.columns])
+ plan.sort.order.extend([c.to_plan(session).sort_order for c in
self.columns])
plan.sort.is_global = self.is_global
return plan
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]