This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d869e3680fe [SPARK-41786][CONNECT][PYTHON] Deduplicate helper functions
d869e3680fe is described below

commit d869e3680fe91f2ae90614f9d08d44da42610f0f
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Dec 31 10:24:42 2022 +0800

    [SPARK-41786][CONNECT][PYTHON] Deduplicate helper functions
    
    ### What changes were proposed in this pull request?
    Deduplicate helper functions
    
    ### Why are the changes needed?
    for simplicity
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    existing ut
    
    Closes #39307 from zhengruifeng/connect_function_cleanup.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/column.py           | 48 ++++++++++++--------------
 python/pyspark/sql/connect/dataframe.py        | 12 +++----
 python/pyspark/sql/connect/function_builder.py | 22 ++----------
 python/pyspark/sql/connect/functions.py        |  5 ++-
 python/pyspark/sql/connect/group.py            | 10 +++---
 5 files changed, 37 insertions(+), 60 deletions(-)

diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index d1e4b00f779..d9f96325c17 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -37,7 +37,6 @@ from pyspark.sql.connect.expressions import (
     Expression,
     UnresolvedFunction,
     UnresolvedExtractValue,
-    SQLExpression,
     LiteralExpression,
     CaseWhen,
     SortOrder,
@@ -60,7 +59,7 @@ if TYPE_CHECKING:
 
 def _func_op(name: str, doc: Optional[str] = "") -> Callable[["Column"], 
"Column"]:
     def wrapped(self: "Column") -> "Column":
-        return scalar_function(name, self)
+        return Column(UnresolvedFunction(name, [self._expr]))
 
     wrapped.__doc__ = doc
     return wrapped
@@ -70,16 +69,17 @@ def _bin_op(
     name: str, doc: Optional[str] = "binary function", reverse: bool = False
 ) -> Callable[["Column", Any], "Column"]:
     def wrapped(self: "Column", other: Any) -> "Column":
-        from pyspark.sql.connect.functions import lit
-
         if other is None or isinstance(
             other, (bool, float, int, str, datetime.datetime, datetime.date, 
decimal.Decimal)
         ):
-            other = lit(other)
+            other_expr = LiteralExpression._from_value(other)
+        else:
+            other_expr = other._expr
+
         if not reverse:
-            return scalar_function(name, self, other)
+            return Column(UnresolvedFunction(name, [self._expr, other_expr]))
         else:
-            return scalar_function(name, other, self)
+            return Column(UnresolvedFunction(name, [other_expr, self._expr]))
 
     wrapped.__doc__ = doc
     return wrapped
@@ -87,20 +87,12 @@ def _bin_op(
 
 def _unary_op(name: str, doc: Optional[str] = "unary function") -> 
Callable[["Column"], "Column"]:
     def wrapped(self: "Column") -> "Column":
-        return scalar_function(name, self)
+        return Column(UnresolvedFunction(name, [self._expr]))
 
     wrapped.__doc__ = doc
     return wrapped
 
 
-def scalar_function(op: str, *args: "Column") -> "Column":
-    return Column(UnresolvedFunction(op, [arg._expr for arg in args]))
-
-
-def sql_expression(expr: str) -> "Column":
-    return Column(SQLExpression(expr))
-
-
 class Column:
     def __init__(self, expr: "Expression") -> None:
         if not isinstance(expr, Expression):
@@ -182,7 +174,7 @@ class Column:
         if isinstance(value, Column):
             _value = value._expr
         else:
-            _value = LiteralExpression(value, 
LiteralExpression._infer_type(value))
+            _value = LiteralExpression._from_value(value)
 
         _branches = self._expr._branches + [(condition._expr, _value)]
 
@@ -204,7 +196,7 @@ class Column:
         if isinstance(value, Column):
             _value = value._expr
         else:
-            _value = LiteralExpression(value, 
LiteralExpression._infer_type(value))
+            _value = LiteralExpression._from_value(value)
 
         return Column(CaseWhen(branches=self._expr._branches, 
else_value=_value))
 
@@ -254,13 +246,14 @@ class Column:
         """Returns a binary expression with the current column as the left
         side and the other expression as the right side.
         """
-        from pyspark.sql.connect.functions import lit
-
         if other is None or isinstance(
             other, (bool, float, int, str, datetime.datetime, datetime.date, 
decimal.Decimal)
         ):
-            other = lit(other)
-        return scalar_function("==", self, other)
+            other_expr = LiteralExpression._from_value(other)
+        else:
+            other_expr = other._expr
+
+        return Column(UnresolvedFunction("==", [self._expr, other_expr]))
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
         return self._expr.to_plan(session)
@@ -318,14 +311,19 @@ class Column:
     over.__doc__ = PySparkColumn.over.__doc__
 
     def isin(self, *cols: Any) -> "Column":
-        from pyspark.sql.connect.functions import lit
-
         if len(cols) == 1 and isinstance(cols[0], (list, set)):
             _cols = list(cols[0])
         else:
             _cols = list(cols)
 
-        return Column(UnresolvedFunction("in", [self._expr] + [lit(c)._expr 
for c in _cols]))
+        _exprs = [self._expr]
+        for c in _cols:
+            if isinstance(c, Column):
+                _exprs.append(c._expr)
+            else:
+                _exprs.append(LiteralExpression._from_value(c))
+
+        return Column(UnresolvedFunction("in", _exprs))
 
     isin.__doc__ = PySparkColumn.isin.__doc__
 
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 018785b77b0..c7583fb62c6 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -42,13 +42,9 @@ from pyspark.sql.types import DataType, StructType, Row
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.connect.group import GroupedData
 from pyspark.sql.connect.readwriter import DataFrameWriter
-from pyspark.sql.connect.column import (
-    Column,
-    scalar_function,
-    sql_expression,
-)
+from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import UnresolvedRegex
-from pyspark.sql.connect.functions import col, lit
+from pyspark.sql.connect.functions import _invoke_function, col, lit, expr as 
sql_expression
 from pyspark.sql.dataframe import (
     DataFrame as PySparkDataFrame,
     DataFrameNaFunctions as PySparkDataFrameNaFunctions,
@@ -110,7 +106,7 @@ class DataFrame:
             raise ValueError("Argument 'exprs' must not be empty")
 
         if len(exprs) == 1 and isinstance(exprs[0], dict):
-            measures = [scalar_function(f, col(e)) for e, f in 
exprs[0].items()]
+            measures = [_invoke_function(f, col(e)) for e, f in 
exprs[0].items()]
             return self.groupBy().agg(*measures)
         else:
             # other expressions
@@ -152,7 +148,7 @@ class DataFrame:
     sparkSession.__doc__ = PySparkDataFrame.sparkSession.__doc__
 
     def count(self) -> int:
-        pdd = self.agg(scalar_function("count", lit(1))).toPandas()
+        pdd = self.agg(_invoke_function("count", lit(1))).toPandas()
         return pdd.iloc[0, 0]
 
     count.__doc__ = PySparkDataFrame.count.__doc__
diff --git a/python/pyspark/sql/connect/function_builder.py 
b/python/pyspark/sql/connect/function_builder.py
index 7eb0ffc26ae..081752da2fa 100644
--- a/python/pyspark/sql/connect/function_builder.py
+++ b/python/pyspark/sql/connect/function_builder.py
@@ -22,8 +22,8 @@ import pyspark.sql.types
 
 import pyspark.sql.connect.proto as proto
 from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.expressions import Expression, UnresolvedFunction
-from pyspark.sql.connect.functions import col
+from pyspark.sql.connect.expressions import Expression
+from pyspark.sql.connect.functions import _invoke_function_over_columns
 
 
 if TYPE_CHECKING:
@@ -34,22 +34,6 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.client import SparkConnectClient
 
 
-def _build(name: str, *args: "ColumnOrName") -> Column:
-    """
-    Simple wrapper function that converts the arguments into the appropriate 
types.
-    Parameters
-    ----------
-    name Name of the function to be called.
-    args The list of arguments.
-
-    Returns
-    -------
-    :class:`UnresolvedFunction`
-    """
-    cols = [arg if isinstance(arg, Column) else col(arg) for arg in args]
-    return Column(UnresolvedFunction(name, [col._expr for col in cols]))
-
-
 class UserDefinedFunction(Expression):
     """A user defied function is an expression that has a reference to the 
actual
     Python callable attached. During plan generation, the client sends a 
command to
@@ -81,7 +65,7 @@ class UserDefinedFunction(Expression):
         # Only do this once per session
         func_name = session.register_udf(self._func_ref, self._return_type)
         # Func name is used for the actual reference
-        return _build(func_name, *self._args).to_plan(session)
+        return _invoke_function_over_columns(func_name, 
*self._args).to_plan(session)
 
     def __str__(self) -> str:
         return f"UserDefinedFunction({self._func_name})"
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 8dd95b0c626..bab3fe4f6f2 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -51,7 +51,7 @@ def _invoke_function(name: str, *args: Union[Column, 
Expression]) -> Column:
 
     Returns
     -------
-    :class:`UnresolvedFunction`
+    :class:`Column`
     """
     expressions: List[Expression] = []
     for arg in args:
@@ -173,8 +173,7 @@ def lit(col: Any) -> Column:
     elif isinstance(col, list):
         return array(*[lit(c) for c in col])
     else:
-        dataType = LiteralExpression._infer_type(col)
-        return Column(LiteralExpression(col, dataType))
+        return Column(LiteralExpression._from_value(col))
 
 
 lit.__doc__ = pysparkfuncs.lit.__doc__
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index a6006c64158..e3852ce397c 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -31,8 +31,8 @@ from pyspark.sql.group import GroupedData as 
PySparkGroupedData
 from pyspark.sql.types import NumericType
 
 import pyspark.sql.connect.plan as plan
-from pyspark.sql.connect.column import Column, scalar_function
-from pyspark.sql.connect.functions import col, lit
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.functions import _invoke_function, col, lit
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import LiteralType
@@ -83,7 +83,7 @@ class GroupedData:
             # There is a special case for count(*) which is rewritten into 
count(1).
             # Convert the dict into key value pairs
             aggregate_cols = [
-                scalar_function(
+                _invoke_function(
                     exprs[0][k], lit(1) if exprs[0][k] == "count" and k == "*" 
else col(k)
                 )
                 for k in exprs[0]
@@ -139,7 +139,7 @@ class GroupedData:
                 child=self._df._plan,
                 group_type=self._group_type,
                 grouping_cols=self._grouping_cols,
-                aggregate_cols=[scalar_function(function, col(c)) for c in 
agg_cols],
+                aggregate_cols=[_invoke_function(function, col(c)) for c in 
agg_cols],
                 pivot_col=self._pivot_col,
                 pivot_values=self._pivot_values,
             ),
@@ -169,7 +169,7 @@ class GroupedData:
     mean = avg
 
     def count(self) -> "DataFrame":
-        return self.agg(scalar_function("count", lit(1)))
+        return self.agg(_invoke_function("count", lit(1)))
 
     count.__doc__ = PySparkGroupedData.count.__doc__
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to