This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d869e3680fe [SPARK-41786][CONNECT][PYTHON] Deduplicate helper functions
d869e3680fe is described below
commit d869e3680fe91f2ae90614f9d08d44da42610f0f
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Dec 31 10:24:42 2022 +0800
[SPARK-41786][CONNECT][PYTHON] Deduplicate helper functions
### What changes were proposed in this pull request?
Deduplicate helper functions
### Why are the changes needed?
for simplicity
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing ut
Closes #39307 from zhengruifeng/connect_function_cleanup.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/column.py | 48 ++++++++++++--------------
python/pyspark/sql/connect/dataframe.py | 12 +++----
python/pyspark/sql/connect/function_builder.py | 22 ++----------
python/pyspark/sql/connect/functions.py | 5 ++-
python/pyspark/sql/connect/group.py | 10 +++---
5 files changed, 37 insertions(+), 60 deletions(-)
diff --git a/python/pyspark/sql/connect/column.py
b/python/pyspark/sql/connect/column.py
index d1e4b00f779..d9f96325c17 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -37,7 +37,6 @@ from pyspark.sql.connect.expressions import (
Expression,
UnresolvedFunction,
UnresolvedExtractValue,
- SQLExpression,
LiteralExpression,
CaseWhen,
SortOrder,
@@ -60,7 +59,7 @@ if TYPE_CHECKING:
def _func_op(name: str, doc: Optional[str] = "") -> Callable[["Column"],
"Column"]:
def wrapped(self: "Column") -> "Column":
- return scalar_function(name, self)
+ return Column(UnresolvedFunction(name, [self._expr]))
wrapped.__doc__ = doc
return wrapped
@@ -70,16 +69,17 @@ def _bin_op(
name: str, doc: Optional[str] = "binary function", reverse: bool = False
) -> Callable[["Column", Any], "Column"]:
def wrapped(self: "Column", other: Any) -> "Column":
- from pyspark.sql.connect.functions import lit
-
if other is None or isinstance(
other, (bool, float, int, str, datetime.datetime, datetime.date,
decimal.Decimal)
):
- other = lit(other)
+ other_expr = LiteralExpression._from_value(other)
+ else:
+ other_expr = other._expr
+
if not reverse:
- return scalar_function(name, self, other)
+ return Column(UnresolvedFunction(name, [self._expr, other_expr]))
else:
- return scalar_function(name, other, self)
+ return Column(UnresolvedFunction(name, [other_expr, self._expr]))
wrapped.__doc__ = doc
return wrapped
@@ -87,20 +87,12 @@ def _bin_op(
def _unary_op(name: str, doc: Optional[str] = "unary function") ->
Callable[["Column"], "Column"]:
def wrapped(self: "Column") -> "Column":
- return scalar_function(name, self)
+ return Column(UnresolvedFunction(name, [self._expr]))
wrapped.__doc__ = doc
return wrapped
-def scalar_function(op: str, *args: "Column") -> "Column":
- return Column(UnresolvedFunction(op, [arg._expr for arg in args]))
-
-
-def sql_expression(expr: str) -> "Column":
- return Column(SQLExpression(expr))
-
-
class Column:
def __init__(self, expr: "Expression") -> None:
if not isinstance(expr, Expression):
@@ -182,7 +174,7 @@ class Column:
if isinstance(value, Column):
_value = value._expr
else:
- _value = LiteralExpression(value,
LiteralExpression._infer_type(value))
+ _value = LiteralExpression._from_value(value)
_branches = self._expr._branches + [(condition._expr, _value)]
@@ -204,7 +196,7 @@ class Column:
if isinstance(value, Column):
_value = value._expr
else:
- _value = LiteralExpression(value,
LiteralExpression._infer_type(value))
+ _value = LiteralExpression._from_value(value)
return Column(CaseWhen(branches=self._expr._branches,
else_value=_value))
@@ -254,13 +246,14 @@ class Column:
"""Returns a binary expression with the current column as the left
side and the other expression as the right side.
"""
- from pyspark.sql.connect.functions import lit
-
if other is None or isinstance(
other, (bool, float, int, str, datetime.datetime, datetime.date,
decimal.Decimal)
):
- other = lit(other)
- return scalar_function("==", self, other)
+ other_expr = LiteralExpression._from_value(other)
+ else:
+ other_expr = other._expr
+
+ return Column(UnresolvedFunction("==", [self._expr, other_expr]))
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
return self._expr.to_plan(session)
@@ -318,14 +311,19 @@ class Column:
over.__doc__ = PySparkColumn.over.__doc__
def isin(self, *cols: Any) -> "Column":
- from pyspark.sql.connect.functions import lit
-
if len(cols) == 1 and isinstance(cols[0], (list, set)):
_cols = list(cols[0])
else:
_cols = list(cols)
- return Column(UnresolvedFunction("in", [self._expr] + [lit(c)._expr
for c in _cols]))
+ _exprs = [self._expr]
+ for c in _cols:
+ if isinstance(c, Column):
+ _exprs.append(c._expr)
+ else:
+ _exprs.append(LiteralExpression._from_value(c))
+
+ return Column(UnresolvedFunction("in", _exprs))
isin.__doc__ = PySparkColumn.isin.__doc__
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 018785b77b0..c7583fb62c6 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -42,13 +42,9 @@ from pyspark.sql.types import DataType, StructType, Row
import pyspark.sql.connect.plan as plan
from pyspark.sql.connect.group import GroupedData
from pyspark.sql.connect.readwriter import DataFrameWriter
-from pyspark.sql.connect.column import (
- Column,
- scalar_function,
- sql_expression,
-)
+from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import UnresolvedRegex
-from pyspark.sql.connect.functions import col, lit
+from pyspark.sql.connect.functions import _invoke_function, col, lit, expr as
sql_expression
from pyspark.sql.dataframe import (
DataFrame as PySparkDataFrame,
DataFrameNaFunctions as PySparkDataFrameNaFunctions,
@@ -110,7 +106,7 @@ class DataFrame:
raise ValueError("Argument 'exprs' must not be empty")
if len(exprs) == 1 and isinstance(exprs[0], dict):
- measures = [scalar_function(f, col(e)) for e, f in
exprs[0].items()]
+ measures = [_invoke_function(f, col(e)) for e, f in
exprs[0].items()]
return self.groupBy().agg(*measures)
else:
# other expressions
@@ -152,7 +148,7 @@ class DataFrame:
sparkSession.__doc__ = PySparkDataFrame.sparkSession.__doc__
def count(self) -> int:
- pdd = self.agg(scalar_function("count", lit(1))).toPandas()
+ pdd = self.agg(_invoke_function("count", lit(1))).toPandas()
return pdd.iloc[0, 0]
count.__doc__ = PySparkDataFrame.count.__doc__
diff --git a/python/pyspark/sql/connect/function_builder.py
b/python/pyspark/sql/connect/function_builder.py
index 7eb0ffc26ae..081752da2fa 100644
--- a/python/pyspark/sql/connect/function_builder.py
+++ b/python/pyspark/sql/connect/function_builder.py
@@ -22,8 +22,8 @@ import pyspark.sql.types
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.expressions import Expression, UnresolvedFunction
-from pyspark.sql.connect.functions import col
+from pyspark.sql.connect.expressions import Expression
+from pyspark.sql.connect.functions import _invoke_function_over_columns
if TYPE_CHECKING:
@@ -34,22 +34,6 @@ if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient
-def _build(name: str, *args: "ColumnOrName") -> Column:
- """
- Simple wrapper function that converts the arguments into the appropriate
types.
- Parameters
- ----------
- name Name of the function to be called.
- args The list of arguments.
-
- Returns
- -------
- :class:`UnresolvedFunction`
- """
- cols = [arg if isinstance(arg, Column) else col(arg) for arg in args]
- return Column(UnresolvedFunction(name, [col._expr for col in cols]))
-
-
class UserDefinedFunction(Expression):
"""A user defied function is an expression that has a reference to the
actual
Python callable attached. During plan generation, the client sends a
command to
@@ -81,7 +65,7 @@ class UserDefinedFunction(Expression):
# Only do this once per session
func_name = session.register_udf(self._func_ref, self._return_type)
# Func name is used for the actual reference
- return _build(func_name, *self._args).to_plan(session)
+ return _invoke_function_over_columns(func_name,
*self._args).to_plan(session)
def __str__(self) -> str:
return f"UserDefinedFunction({self._func_name})"
diff --git a/python/pyspark/sql/connect/functions.py
b/python/pyspark/sql/connect/functions.py
index 8dd95b0c626..bab3fe4f6f2 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -51,7 +51,7 @@ def _invoke_function(name: str, *args: Union[Column,
Expression]) -> Column:
Returns
-------
- :class:`UnresolvedFunction`
+ :class:`Column`
"""
expressions: List[Expression] = []
for arg in args:
@@ -173,8 +173,7 @@ def lit(col: Any) -> Column:
elif isinstance(col, list):
return array(*[lit(c) for c in col])
else:
- dataType = LiteralExpression._infer_type(col)
- return Column(LiteralExpression(col, dataType))
+ return Column(LiteralExpression._from_value(col))
lit.__doc__ = pysparkfuncs.lit.__doc__
diff --git a/python/pyspark/sql/connect/group.py
b/python/pyspark/sql/connect/group.py
index a6006c64158..e3852ce397c 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -31,8 +31,8 @@ from pyspark.sql.group import GroupedData as
PySparkGroupedData
from pyspark.sql.types import NumericType
import pyspark.sql.connect.plan as plan
-from pyspark.sql.connect.column import Column, scalar_function
-from pyspark.sql.connect.functions import col, lit
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.functions import _invoke_function, col, lit
if TYPE_CHECKING:
from pyspark.sql.connect._typing import LiteralType
@@ -83,7 +83,7 @@ class GroupedData:
# There is a special case for count(*) which is rewritten into
count(1).
# Convert the dict into key value pairs
aggregate_cols = [
- scalar_function(
+ _invoke_function(
exprs[0][k], lit(1) if exprs[0][k] == "count" and k == "*"
else col(k)
)
for k in exprs[0]
@@ -139,7 +139,7 @@ class GroupedData:
child=self._df._plan,
group_type=self._group_type,
grouping_cols=self._grouping_cols,
- aggregate_cols=[scalar_function(function, col(c)) for c in
agg_cols],
+ aggregate_cols=[_invoke_function(function, col(c)) for c in
agg_cols],
pivot_col=self._pivot_col,
pivot_values=self._pivot_values,
),
@@ -169,7 +169,7 @@ class GroupedData:
mean = avg
def count(self) -> "DataFrame":
- return self.agg(scalar_function("count", lit(1)))
+ return self.agg(_invoke_function("count", lit(1)))
count.__doc__ = PySparkGroupedData.count.__doc__
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]