This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d59f71c627b [SPARK-40698][PS][SQL] Improve the precision of `product` 
for integral inputs
d59f71c627b is described below

commit d59f71c627b4a1499a25c1c53e1e73447d31bbbc
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Oct 11 11:32:36 2022 +0900

    [SPARK-40698][PS][SQL] Improve the precision of `product` for integral 
inputs
    
    ### What changes were proposed in this pull request?
    add a dedicated expression for `product`:
    
    1. for integral inputs, directly use `LongType` to avoid the rounding error:
    2. when `ignoreNA` is true, skip following values when meet a `zero`;
    3. when `ignoreNA` is false, skip following values when meet a `zero` or 
`null`;
    
    ### Why are the changes needed?
    
    1. existing computation logic is too complex in the PySpark side, with a 
dedicated expression, we can simplify the PySpark side and apply it in more 
cases.
    2. existing computation of `product` is likely to introduce rounding error 
for integral inputs, for example `55108 x 55108 x 55108 x 55108` in the 
following case:
    
    before:
    ```
    In [14]: df = pd.DataFrame({"a": [55108, 55108, 55108, 55108], "b": 
[55108.0, 55108.0, 55108.0, 55108.0], "c": [1, 2, 3, 4]})
    
    In [15]: df.a.prod()
    Out[15]: 9222710978872688896
    
    In [16]: type(df.a.prod())
    Out[16]: numpy.int64
    
    In [17]: df.b.prod()
    Out[17]: 9.222710978872689e+18
    
    In [18]: type(df.b.prod())
    Out[18]: numpy.float64
    
    In [19]:
    
    In [19]: psdf = ps.from_pandas(df)
    
    In [20]: psdf.a.prod()
    Out[20]: 9222710978872658944
    
    In [21]: type(psdf.a.prod())
    Out[21]: int
    
    In [22]: psdf.b.prod()
    Out[22]: 9.222710978872659e+18
    
    In [23]: type(psdf.b.prod())
    Out[23]: float
    
    In [24]: df.a.prod() - psdf.a.prod()
    Out[24]: 29952
    ```
    
    after:
    ```
    In [1]: import pyspark.pandas as ps
    
    In [2]: import pandas as pd
    
    In [3]: df = pd.DataFrame({"a": [55108, 55108, 55108, 55108], "b": 
[55108.0, 55108.0, 55108.0, 55108.0], "c": [1, 2, 3, 4]})
    
    In [4]: df.a.prod()
    Out[4]: 9222710978872688896
    
    In [5]: psdf = ps.from_pandas(df)
    
    In [6]: psdf.a.prod()
    Out[6]: 9222710978872688896
    
    In [7]: df.a.prod() - psdf.a.prod()
    Out[7]: 0
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing UT & added UT
    
    Closes #38148 from zhengruifeng/ps_new_prod.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/pandas/generic.py                   |  23 +----
 python/pyspark/pandas/groupby.py                   |  53 +++-------
 python/pyspark/pandas/spark/functions.py           |   5 +
 .../pyspark/pandas/tests/test_generic_functions.py |  16 +++
 python/pyspark/pandas/tests/test_groupby.py        |  28 ++++-
 .../catalyst/expressions/aggregate/Product.scala   | 115 ++++++++++++++++++++-
 .../spark/sql/api/python/PythonSQLUtils.scala      |   4 +
 7 files changed, 180 insertions(+), 64 deletions(-)

diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index 5b94da44125..9db688d9134 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -45,7 +45,6 @@ from pyspark.sql import Column, functions as F
 from pyspark.sql.types import (
     BooleanType,
     DoubleType,
-    IntegralType,
     LongType,
     NumericType,
 )
@@ -1421,32 +1420,16 @@ class Frame(object, metaclass=ABCMeta):
         def prod(psser: "Series") -> Column:
             spark_type = psser.spark.data_type
             spark_column = psser.spark.column
-
-            if not skipna:
-                spark_column = F.when(spark_column.isNull(), 
np.nan).otherwise(spark_column)
-
             if isinstance(spark_type, BooleanType):
-                scol = F.min(F.coalesce(spark_column, 
F.lit(True))).cast(LongType())
-            elif isinstance(spark_type, NumericType):
-                num_zeros = F.sum(F.when(spark_column == 0, 1).otherwise(0))
-                sign = F.when(
-                    F.sum(F.when(spark_column < 0, 1).otherwise(0)) % 2 == 0, 1
-                ).otherwise(-1)
-
-                scol = F.when(num_zeros > 0, 0).otherwise(
-                    sign * F.exp(F.sum(F.log(F.abs(spark_column))))
-                )
-
-                if isinstance(spark_type, IntegralType):
-                    scol = F.round(scol).cast(LongType())
-            else:
+                spark_column = spark_column.cast(LongType())
+            elif not isinstance(spark_type, NumericType):
                 raise TypeError(
                     "Could not convert {} ({}) to numeric".format(
                         spark_type_to_pandas_dtype(spark_type), 
spark_type.simpleString()
                     )
                 )
 
-            return F.coalesce(scol, F.lit(1))
+            return SF.product(spark_column, skipna)
 
         return self._reduce_for_stat_function(
             prod,
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 08a136aa268..b2525ce9a60 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -62,7 +62,6 @@ from pyspark.sql.types import (
     StructField,
     StructType,
     StringType,
-    IntegralType,
 )
 
 from pyspark import pandas as ps  # For running doctests and reference 
resolution in PyCharm.
@@ -1320,52 +1319,28 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
         1  NaN  2.0  0.0
         2  NaN NaN  NaN
         """
+        if not isinstance(min_count, int):
+            raise TypeError("min_count must be integer")
 
         self._validate_agg_columns(numeric_only=numeric_only, 
function_name="prod")
 
-        groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in 
range(len(self._groupkeys))]
-        internal, agg_columns, sdf = self._prepare_reduce(
-            groupkey_names=groupkey_names,
-            accepted_spark_types=(NumericType, BooleanType),
-            bool_to_numeric=True,
-        )
-
-        psdf: DataFrame = DataFrame(internal)
-        if len(psdf._internal.column_labels) > 0:
-
-            stat_exprs = []
-            for label in psdf._internal.column_labels:
-                psser = psdf._psser_for(label)
-                column = psser._dtype_op.nan_to_null(psser).spark.column
-                data_type = psser.spark.data_type
-                aggregating = (
-                    F.product(column).cast("long")
-                    if isinstance(data_type, IntegralType)
-                    else F.product(column)
-                )
-
-                if min_count > 0:
-                    prod_scol = F.when(
-                        F.count(F.when(~F.isnull(column), F.lit(0))) < 
min_count, F.lit(None)
-                    ).otherwise(aggregating)
-                else:
-                    prod_scol = aggregating
-
-                
stat_exprs.append(prod_scol.alias(psser._internal.data_spark_column_names[0]))
+        if min_count > 0:
 
-            sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
+            def prod(col: Column) -> Column:
+                return F.when(
+                    F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, 
F.lit(None)
+                ).otherwise(SF.product(col, True))
 
         else:
-            sdf = sdf.select(*groupkey_names).distinct()
 
-        internal = internal.copy(
-            spark_frame=sdf,
-            index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
-            data_spark_columns=[scol_for(sdf, col) for col in 
internal.data_spark_column_names],
-            data_fields=None,
-        )
+            def prod(col: Column) -> Column:
+                return SF.product(col, True)
 
-        return self._prepare_return(DataFrame(internal))
+        return self._reduce_for_stat_function(
+            prod,
+            accepted_spark_types=(NumericType, BooleanType),
+            bool_to_numeric=True,
+        )
 
     def all(self, skipna: bool = True) -> FrameLike:
         """
diff --git a/python/pyspark/pandas/spark/functions.py 
b/python/pyspark/pandas/spark/functions.py
index f9311296a57..658d3459b24 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -27,6 +27,11 @@ from pyspark.sql.column import (
 )
 
 
+def product(col: Column, dropna: bool) -> Column:
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))
+
+
 def stddev(col: Column, ddof: int) -> Column:
     sc = SparkContext._active_spark_context
     return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
diff --git a/python/pyspark/pandas/tests/test_generic_functions.py 
b/python/pyspark/pandas/tests/test_generic_functions.py
index 7c252c8356d..d4763022059 100644
--- a/python/pyspark/pandas/tests/test_generic_functions.py
+++ b/python/pyspark/pandas/tests/test_generic_functions.py
@@ -200,6 +200,22 @@ class GenericFunctionsTest(PandasOnSparkTestCase, 
TestUtils):
         self.assert_eq(pdf.b.kurtosis(), psdf.b.kurtosis())
         self.assert_eq(pdf.c.kurtosis(), psdf.c.kurtosis())
 
+    def test_prod_precision(self):
+        pdf = pd.DataFrame(
+            {
+                "a": [np.nan, np.nan, np.nan, np.nan],
+                "b": [1, np.nan, np.nan, -4],
+                "c": [1, -2, 3, -4],
+                "d": [55108, 55108, 55108, 55108],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(pdf.prod(), psdf.prod())
+        self.assert_eq(pdf.prod(skipna=False), psdf.prod(skipna=False))
+        self.assert_eq(pdf.prod(min_count=3), psdf.prod(min_count=3))
+        self.assert_eq(pdf.prod(skipna=False, min_count=3), 
psdf.prod(skipna=False, min_count=3))
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/pandas/tests/test_groupby.py 
b/python/pyspark/pandas/tests/test_groupby.py
index 33ba8155641..a203f77717e 100644
--- a/python/pyspark/pandas/tests/test_groupby.py
+++ b/python/pyspark/pandas/tests/test_groupby.py
@@ -1493,13 +1493,35 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
             self.psdf.groupby("B").nth("x")
 
     def test_prod(self):
+        pdf = pd.DataFrame(
+            {
+                "A": [1, 2, 1, 2, 1],
+                "B": [3.1, 4.1, 4.1, 3.1, 0.1],
+                "C": ["a", "b", "b", "a", "c"],
+                "D": [True, False, False, True, False],
+                "E": [-1, -2, 3, -4, -2],
+                "F": [-1.5, np.nan, -3.2, 0.1, 0],
+                "G": [np.nan, np.nan, np.nan, np.nan, np.nan],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
         for n in [0, 1, 2, 128, -1, -2, -128]:
-            self._test_stat_func(lambda groupby_obj: 
groupby_obj.prod(min_count=n))
             self._test_stat_func(
-                lambda groupby_obj: groupby_obj.prod(numeric_only=None, 
min_count=n)
+                lambda groupby_obj: groupby_obj.prod(min_count=n), 
check_exact=False
             )
             self._test_stat_func(
-                lambda groupby_obj: groupby_obj.prod(numeric_only=True, 
min_count=n)
+                lambda groupby_obj: groupby_obj.prod(numeric_only=None, 
min_count=n),
+                check_exact=False,
+            )
+            self._test_stat_func(
+                lambda groupby_obj: groupby_obj.prod(numeric_only=True, 
min_count=n),
+                check_exact=False,
+            )
+            self.assert_eq(
+                pdf.groupby("A").prod(min_count=n).sort_index(),
+                psdf.groupby("A").prod(min_count=n).sort_index(),
+                almost=True,
             )
 
     def test_cumcount(self):
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala
index 3af3944fd47..3325c8f16a4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala
@@ -18,9 +18,9 @@
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Expression, ImplicitCastInputTypes, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, 
Exp, Expression, If, ImplicitCastInputTypes, IsNull, Literal, Log}
 import org.apache.spark.sql.catalyst.trees.UnaryLike
-import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType}
+import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, 
DoubleType, IntegralType, LongType, NumericType}
 
 
 /** Multiply numerical values within an aggregation group */
@@ -63,3 +63,114 @@ case class Product(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): Product =
     copy(child = newChild)
 }
+
+/**
+ * Product in Pandas' fashion. This expression is dedicated only for Pandas 
API on Spark.
+ * It has three main differences from `Product`:
+ * 1, it compute the product of `Fractional` inputs in a more numerical-stable 
way;
+ * 2, it compute the product of `Integral` inputs with LongType variables 
internally;
+ * 3, it accepts NULLs when `ignoreNA` is False;
+ */
+case class PandasProduct(
+    child: Expression,
+    ignoreNA: Boolean)
+    extends DeclarativeAggregate with ImplicitCastInputTypes with 
UnaryLike[Expression] {
+
+  override def nullable: Boolean = !ignoreNA
+
+  override def dataType: DataType = child.dataType match {
+    case _: IntegralType => LongType
+    case _ => DoubleType
+  }
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+
+  private lazy val product =
+    AttributeReference("product", LongType, nullable = false)()
+  private lazy val logSum =
+    AttributeReference("logSum", DoubleType, nullable = false)()
+  private lazy val positive =
+    AttributeReference("positive", BooleanType, nullable = false)()
+  private lazy val containsZero =
+    AttributeReference("containsZero", BooleanType, nullable = false)()
+  private lazy val containsNull =
+    AttributeReference("containsNull", BooleanType, nullable = false)()
+
+  override lazy val aggBufferAttributes = child.dataType match {
+    case _: IntegralType =>
+      Seq(product, containsNull)
+    case _ =>
+      Seq(logSum, positive, containsZero, containsNull)
+  }
+
+  override lazy val initialValues: Seq[Expression] = child.dataType match {
+    case _: IntegralType =>
+      Seq(Literal(1L), Literal(false))
+    case _ =>
+      Seq(Literal(0.0), Literal(true), Literal(false), Literal(false))
+  }
+
+  override lazy val updateExpressions: Seq[Expression] = child.dataType match {
+    case _: IntegralType =>
+      Seq(
+        If(IsNull(child), product, product * child),
+        containsNull || IsNull(child)
+      )
+    case _ =>
+      val newLogSum = logSum + Log(Abs(child))
+      val newPositive = If(child < Literal(0.0), !positive, positive)
+      val newContainsZero = containsZero || child <=> Literal(0.0)
+      val newContainsNull = containsNull || IsNull(child)
+      if (ignoreNA) {
+        Seq(
+          If(IsNull(child) || newContainsZero, logSum, newLogSum),
+          newPositive,
+          newContainsZero,
+          newContainsNull
+        )
+      } else {
+        Seq(
+          If(newContainsNull || newContainsZero, logSum, newLogSum),
+          newPositive,
+          newContainsZero,
+          newContainsNull
+        )
+      }
+  }
+
+  override lazy val mergeExpressions: Seq[Expression] = child.dataType match {
+    case _: IntegralType =>
+      Seq(
+        product.left * product.right,
+        containsNull.left || containsNull.right
+      )
+    case _ =>
+      Seq(
+        logSum.left + logSum.right,
+        positive.left === positive.right,
+        containsZero.left || containsZero.right,
+        containsNull.left || containsNull.right
+      )
+  }
+
+  override lazy val evaluateExpression: Expression = child.dataType match {
+    case _: IntegralType =>
+      if (ignoreNA) {
+        product
+      } else {
+        If(containsNull, Literal(null, LongType), product)
+      }
+    case _ =>
+      val product = If(positive, Exp(logSum), -Exp(logSum))
+      if (ignoreNA) {
+        If(containsZero, Literal(0.0), product)
+      } else {
+        If(containsNull, Literal(null, DoubleType),
+          If(containsZero, Literal(0.0), product))
+      }
+  }
+
+  override def prettyName: String = "pandas_product"
+  override protected def withNewChildInternal(newChild: Expression): 
PandasProduct =
+    copy(child = newChild)
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index d43a9060677..70474f4d5c4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -155,6 +155,10 @@ private[sql] object PythonSQLUtils extends Logging {
     Column(TimestampDiff(unit, start.expr, end.expr))
   }
 
+  def pandasProduct(e: Column, ignoreNA: Boolean): Column = {
+    Column(PandasProduct(e.expr, ignoreNA).toAggregateExpression(false))
+  }
+
   def pandasStddev(e: Column, ddof: Int): Column = {
     Column(PandasStddev(e.expr, ddof).toAggregateExpression(false))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to