This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 65d46f7e400 [SPARK-41382][CONNECT][PYTHON] Implement `product` function
65d46f7e400 is described below

commit 65d46f7e4000fba514878287f5218dc93961c999
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Dec 7 11:41:46 2022 +0800

    [SPARK-41382][CONNECT][PYTHON] Implement `product` function
    
    ### What changes were proposed in this pull request?
    Implement `product` function
    
    ### Why are the changes needed?
    for API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    new API
    
    ### How was this patch tested?
    added test
    
    Closes #38915 from zhengruifeng/connect_function_product.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  | 27 +++++++++-
 python/pyspark/sql/connect/functions.py            | 61 +++++++++++-----------
 .../sql/tests/connect/test_connect_function.py     |  1 +
 3 files changed, 56 insertions(+), 33 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 2a9f4260ff2..20cf68c3c08 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -406,7 +406,8 @@ class SparkConnectPlanner(session: SparkSession) {
       case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
         transformUnresolvedExpression(exp)
       case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION =>
-        transformScalarFunction(exp.getUnresolvedFunction)
+        transformUnregisteredFunction(exp.getUnresolvedFunction)
+          .getOrElse(transformUnresolvedFunction(exp.getUnresolvedFunction))
       case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias)
       case proto.Expression.ExprTypeCase.EXPRESSION_STRING =>
         transformExpressionString(exp.getExpressionString)
@@ -538,7 +539,8 @@ class SparkConnectPlanner(session: SparkSession) {
    *   Proto representation of the function call.
    * @return
    */
-  private def transformScalarFunction(fun: 
proto.Expression.UnresolvedFunction): Expression = {
+  private def transformUnresolvedFunction(
+      fun: proto.Expression.UnresolvedFunction): Expression = {
     if (fun.getIsUserDefinedFunction) {
       UnresolvedFunction(
         
session.sessionState.sqlParser.parseFunctionIdentifier(fun.getFunctionName),
@@ -552,6 +554,27 @@ class SparkConnectPlanner(session: SparkSession) {
     }
   }
 
+  /**
+   * For some reason, not all functions are registered in 'FunctionRegistry'. 
For a unregistered
+   * function, we can still wrap it under the proto 'UnresolvedFunction', and 
then resolve it in
+   * this method.
+   */
+  private def transformUnregisteredFunction(
+      fun: proto.Expression.UnresolvedFunction): Option[Expression] = {
+    fun.getFunctionName match {
+      case "product" =>
+        if (fun.getArgumentsCount != 1) {
+          throw InvalidPlanInput("Product requires single child expression")
+        }
+        Some(
+          aggregate
+            .Product(transformExpression(fun.getArgumentsList.asScala.head))
+            .toAggregateExpression())
+
+      case _ => None
+    }
+  }
+
   private def transformAlias(alias: proto.Expression.Alias): NamedExpression = 
{
     if (alias.getNameCount == 1) {
       val md = if (alias.hasMetadata()) {
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index a52fc58fd0a..b576a092f99 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -2920,37 +2920,36 @@ def percentile_approx(
     return _invoke_function("percentile_approx", _to_col(col), percentage_col, 
lit(accuracy))
 
 
-# TODO(SPARK-41382): add product in FunctionRegistry?
-# def product(col: "ColumnOrName") -> Column:
-#     """
-#     Aggregate function: returns the product of the values in a group.
-#
-#     .. versionadded:: 3.4.0
-#
-#     Parameters
-#     ----------
-#     col : str, :class:`Column`
-#         column containing values to be multiplied together
-#
-#     Returns
-#     -------
-#     :class:`~pyspark.sql.Column`
-#         the column for computed results.
-#
-#     Examples
-#     --------
-#     >>> df = spark.range(1, 10).toDF('x').withColumn('mod3', col('x') % 3)
-#     >>> prods = df.groupBy('mod3').agg(product('x').alias('product'))
-#     >>> prods.orderBy('mod3').show()
-#     +----+-------+
-#     |mod3|product|
-#     +----+-------+
-#     |   0|  162.0|
-#     |   1|   28.0|
-#     |   2|   80.0|
-#     +----+-------+
-#     """
-#     return _invoke_function_over_columns("product", col)
+def product(col: "ColumnOrName") -> Column:
+    """
+    Aggregate function: returns the product of the values in a group.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col : str, :class:`Column`
+        column containing values to be multiplied together
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        the column for computed results.
+
+    Examples
+    --------
+    >>> df = spark.range(1, 10).toDF('x').withColumn('mod3', col('x') % 3)
+    >>> prods = df.groupBy('mod3').agg(product('x').alias('product'))
+    >>> prods.orderBy('mod3').show()
+    +----+-------+
+    |mod3|product|
+    +----+-------+
+    |   0|  162.0|
+    |   1|   28.0|
+    |   2|   80.0|
+    +----+-------+
+    """
+    return _invoke_function_over_columns("product", col)
 
 
 def skewness(col: "ColumnOrName") -> Column:
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index b2625aed3dd..2f1e4968942 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -359,6 +359,7 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase):
             (CF.median, SF.median),
             (CF.min, SF.min),
             (CF.mode, SF.mode),
+            (CF.product, SF.product),
             (CF.skewness, SF.skewness),
             (CF.stddev, SF.stddev),
             (CF.stddev_pop, SF.stddev_pop),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to