This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d59f71c627b [SPARK-40698][PS][SQL] Improve the precision of `product`
for integral inputs
d59f71c627b is described below
commit d59f71c627b4a1499a25c1c53e1e73447d31bbbc
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Oct 11 11:32:36 2022 +0900
[SPARK-40698][PS][SQL] Improve the precision of `product` for integral
inputs
### What changes were proposed in this pull request?
add a dedicated expression for `product`:
1. for integral inputs, directly use `LongType` to avoid the rounding error:
2. when `ignoreNA` is true, skip following values when meet a `zero`;
3. when `ignoreNA` is false, skip following values when meet a `zero` or
`null`;
### Why are the changes needed?
1. existing computation logic is too complex in the PySpark side, with a
dedicated expression, we can simplify the PySpark side and apply it in more
cases.
2. existing computation of `product` is likely to introduce rounding error
for integral inputs, for example `55108 x 55108 x 55108 x 55108` in the
following case:
before:
```
In [14]: df = pd.DataFrame({"a": [55108, 55108, 55108, 55108], "b":
[55108.0, 55108.0, 55108.0, 55108.0], "c": [1, 2, 3, 4]})
In [15]: df.a.prod()
Out[15]: 9222710978872688896
In [16]: type(df.a.prod())
Out[16]: numpy.int64
In [17]: df.b.prod()
Out[17]: 9.222710978872689e+18
In [18]: type(df.b.prod())
Out[18]: numpy.float64
In [19]:
In [19]: psdf = ps.from_pandas(df)
In [20]: psdf.a.prod()
Out[20]: 9222710978872658944
In [21]: type(psdf.a.prod())
Out[21]: int
In [22]: psdf.b.prod()
Out[22]: 9.222710978872659e+18
In [23]: type(psdf.b.prod())
Out[23]: float
In [24]: df.a.prod() - psdf.a.prod()
Out[24]: 29952
```
after:
```
In [1]: import pyspark.pandas as ps
In [2]: import pandas as pd
In [3]: df = pd.DataFrame({"a": [55108, 55108, 55108, 55108], "b":
[55108.0, 55108.0, 55108.0, 55108.0], "c": [1, 2, 3, 4]})
In [4]: df.a.prod()
Out[4]: 9222710978872688896
In [5]: psdf = ps.from_pandas(df)
In [6]: psdf.a.prod()
Out[6]: 9222710978872688896
In [7]: df.a.prod() - psdf.a.prod()
Out[7]: 0
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
existing UT & added UT
Closes #38148 from zhengruifeng/ps_new_prod.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/generic.py | 23 +----
python/pyspark/pandas/groupby.py | 53 +++-------
python/pyspark/pandas/spark/functions.py | 5 +
.../pyspark/pandas/tests/test_generic_functions.py | 16 +++
python/pyspark/pandas/tests/test_groupby.py | 28 ++++-
.../catalyst/expressions/aggregate/Product.scala | 115 ++++++++++++++++++++-
.../spark/sql/api/python/PythonSQLUtils.scala | 4 +
7 files changed, 180 insertions(+), 64 deletions(-)
diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index 5b94da44125..9db688d9134 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -45,7 +45,6 @@ from pyspark.sql import Column, functions as F
from pyspark.sql.types import (
BooleanType,
DoubleType,
- IntegralType,
LongType,
NumericType,
)
@@ -1421,32 +1420,16 @@ class Frame(object, metaclass=ABCMeta):
def prod(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
-
- if not skipna:
- spark_column = F.when(spark_column.isNull(),
np.nan).otherwise(spark_column)
-
if isinstance(spark_type, BooleanType):
- scol = F.min(F.coalesce(spark_column,
F.lit(True))).cast(LongType())
- elif isinstance(spark_type, NumericType):
- num_zeros = F.sum(F.when(spark_column == 0, 1).otherwise(0))
- sign = F.when(
- F.sum(F.when(spark_column < 0, 1).otherwise(0)) % 2 == 0, 1
- ).otherwise(-1)
-
- scol = F.when(num_zeros > 0, 0).otherwise(
- sign * F.exp(F.sum(F.log(F.abs(spark_column))))
- )
-
- if isinstance(spark_type, IntegralType):
- scol = F.round(scol).cast(LongType())
- else:
+ spark_column = spark_column.cast(LongType())
+ elif not isinstance(spark_type, NumericType):
raise TypeError(
"Could not convert {} ({}) to numeric".format(
spark_type_to_pandas_dtype(spark_type),
spark_type.simpleString()
)
)
- return F.coalesce(scol, F.lit(1))
+ return SF.product(spark_column, skipna)
return self._reduce_for_stat_function(
prod,
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 08a136aa268..b2525ce9a60 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -62,7 +62,6 @@ from pyspark.sql.types import (
StructField,
StructType,
StringType,
- IntegralType,
)
from pyspark import pandas as ps # For running doctests and reference
resolution in PyCharm.
@@ -1320,52 +1319,28 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
1 NaN 2.0 0.0
2 NaN NaN NaN
"""
+ if not isinstance(min_count, int):
+ raise TypeError("min_count must be integer")
self._validate_agg_columns(numeric_only=numeric_only,
function_name="prod")
- groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in
range(len(self._groupkeys))]
- internal, agg_columns, sdf = self._prepare_reduce(
- groupkey_names=groupkey_names,
- accepted_spark_types=(NumericType, BooleanType),
- bool_to_numeric=True,
- )
-
- psdf: DataFrame = DataFrame(internal)
- if len(psdf._internal.column_labels) > 0:
-
- stat_exprs = []
- for label in psdf._internal.column_labels:
- psser = psdf._psser_for(label)
- column = psser._dtype_op.nan_to_null(psser).spark.column
- data_type = psser.spark.data_type
- aggregating = (
- F.product(column).cast("long")
- if isinstance(data_type, IntegralType)
- else F.product(column)
- )
-
- if min_count > 0:
- prod_scol = F.when(
- F.count(F.when(~F.isnull(column), F.lit(0))) <
min_count, F.lit(None)
- ).otherwise(aggregating)
- else:
- prod_scol = aggregating
-
-
stat_exprs.append(prod_scol.alias(psser._internal.data_spark_column_names[0]))
+ if min_count > 0:
- sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
+ def prod(col: Column) -> Column:
+ return F.when(
+ F.count(F.when(~F.isnull(col), F.lit(0))) < min_count,
F.lit(None)
+ ).otherwise(SF.product(col, True))
else:
- sdf = sdf.select(*groupkey_names).distinct()
- internal = internal.copy(
- spark_frame=sdf,
- index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
- data_spark_columns=[scol_for(sdf, col) for col in
internal.data_spark_column_names],
- data_fields=None,
- )
+ def prod(col: Column) -> Column:
+ return SF.product(col, True)
- return self._prepare_return(DataFrame(internal))
+ return self._reduce_for_stat_function(
+ prod,
+ accepted_spark_types=(NumericType, BooleanType),
+ bool_to_numeric=True,
+ )
def all(self, skipna: bool = True) -> FrameLike:
"""
diff --git a/python/pyspark/pandas/spark/functions.py
b/python/pyspark/pandas/spark/functions.py
index f9311296a57..658d3459b24 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -27,6 +27,11 @@ from pyspark.sql.column import (
)
+def product(col: Column, dropna: bool) -> Column:
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))
+
+
def stddev(col: Column, ddof: int) -> Column:
sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
diff --git a/python/pyspark/pandas/tests/test_generic_functions.py
b/python/pyspark/pandas/tests/test_generic_functions.py
index 7c252c8356d..d4763022059 100644
--- a/python/pyspark/pandas/tests/test_generic_functions.py
+++ b/python/pyspark/pandas/tests/test_generic_functions.py
@@ -200,6 +200,22 @@ class GenericFunctionsTest(PandasOnSparkTestCase,
TestUtils):
self.assert_eq(pdf.b.kurtosis(), psdf.b.kurtosis())
self.assert_eq(pdf.c.kurtosis(), psdf.c.kurtosis())
+ def test_prod_precision(self):
+ pdf = pd.DataFrame(
+ {
+ "a": [np.nan, np.nan, np.nan, np.nan],
+ "b": [1, np.nan, np.nan, -4],
+ "c": [1, -2, 3, -4],
+ "d": [55108, 55108, 55108, 55108],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(pdf.prod(), psdf.prod())
+ self.assert_eq(pdf.prod(skipna=False), psdf.prod(skipna=False))
+ self.assert_eq(pdf.prod(min_count=3), psdf.prod(min_count=3))
+ self.assert_eq(pdf.prod(skipna=False, min_count=3),
psdf.prod(skipna=False, min_count=3))
+
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/pandas/tests/test_groupby.py
b/python/pyspark/pandas/tests/test_groupby.py
index 33ba8155641..a203f77717e 100644
--- a/python/pyspark/pandas/tests/test_groupby.py
+++ b/python/pyspark/pandas/tests/test_groupby.py
@@ -1493,13 +1493,35 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
self.psdf.groupby("B").nth("x")
def test_prod(self):
+ pdf = pd.DataFrame(
+ {
+ "A": [1, 2, 1, 2, 1],
+ "B": [3.1, 4.1, 4.1, 3.1, 0.1],
+ "C": ["a", "b", "b", "a", "c"],
+ "D": [True, False, False, True, False],
+ "E": [-1, -2, 3, -4, -2],
+ "F": [-1.5, np.nan, -3.2, 0.1, 0],
+ "G": [np.nan, np.nan, np.nan, np.nan, np.nan],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
for n in [0, 1, 2, 128, -1, -2, -128]:
- self._test_stat_func(lambda groupby_obj:
groupby_obj.prod(min_count=n))
self._test_stat_func(
- lambda groupby_obj: groupby_obj.prod(numeric_only=None,
min_count=n)
+ lambda groupby_obj: groupby_obj.prod(min_count=n),
check_exact=False
)
self._test_stat_func(
- lambda groupby_obj: groupby_obj.prod(numeric_only=True,
min_count=n)
+ lambda groupby_obj: groupby_obj.prod(numeric_only=None,
min_count=n),
+ check_exact=False,
+ )
+ self._test_stat_func(
+ lambda groupby_obj: groupby_obj.prod(numeric_only=True,
min_count=n),
+ check_exact=False,
+ )
+ self.assert_eq(
+ pdf.groupby("A").prod(min_count=n).sort_index(),
+ psdf.groupby("A").prod(min_count=n).sort_index(),
+ almost=True,
)
def test_cumcount(self):
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala
index 3af3944fd47..3325c8f16a4 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala
@@ -18,9 +18,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression, ImplicitCastInputTypes, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference,
Exp, Expression, If, ImplicitCastInputTypes, IsNull, Literal, Log}
import org.apache.spark.sql.catalyst.trees.UnaryLike
-import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType}
+import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType,
DoubleType, IntegralType, LongType, NumericType}
/** Multiply numerical values within an aggregation group */
@@ -63,3 +63,114 @@ case class Product(child: Expression)
override protected def withNewChildInternal(newChild: Expression): Product =
copy(child = newChild)
}
+
+/**
+ * Product in Pandas' fashion. This expression is dedicated only for Pandas
API on Spark.
+ * It has three main differences from `Product`:
+ * 1, it compute the product of `Fractional` inputs in a more numerical-stable
way;
+ * 2, it compute the product of `Integral` inputs with LongType variables
internally;
+ * 3, it accepts NULLs when `ignoreNA` is False;
+ */
+case class PandasProduct(
+ child: Expression,
+ ignoreNA: Boolean)
+ extends DeclarativeAggregate with ImplicitCastInputTypes with
UnaryLike[Expression] {
+
+ override def nullable: Boolean = !ignoreNA
+
+ override def dataType: DataType = child.dataType match {
+ case _: IntegralType => LongType
+ case _ => DoubleType
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+
+ private lazy val product =
+ AttributeReference("product", LongType, nullable = false)()
+ private lazy val logSum =
+ AttributeReference("logSum", DoubleType, nullable = false)()
+ private lazy val positive =
+ AttributeReference("positive", BooleanType, nullable = false)()
+ private lazy val containsZero =
+ AttributeReference("containsZero", BooleanType, nullable = false)()
+ private lazy val containsNull =
+ AttributeReference("containsNull", BooleanType, nullable = false)()
+
+ override lazy val aggBufferAttributes = child.dataType match {
+ case _: IntegralType =>
+ Seq(product, containsNull)
+ case _ =>
+ Seq(logSum, positive, containsZero, containsNull)
+ }
+
+ override lazy val initialValues: Seq[Expression] = child.dataType match {
+ case _: IntegralType =>
+ Seq(Literal(1L), Literal(false))
+ case _ =>
+ Seq(Literal(0.0), Literal(true), Literal(false), Literal(false))
+ }
+
+ override lazy val updateExpressions: Seq[Expression] = child.dataType match {
+ case _: IntegralType =>
+ Seq(
+ If(IsNull(child), product, product * child),
+ containsNull || IsNull(child)
+ )
+ case _ =>
+ val newLogSum = logSum + Log(Abs(child))
+ val newPositive = If(child < Literal(0.0), !positive, positive)
+ val newContainsZero = containsZero || child <=> Literal(0.0)
+ val newContainsNull = containsNull || IsNull(child)
+ if (ignoreNA) {
+ Seq(
+ If(IsNull(child) || newContainsZero, logSum, newLogSum),
+ newPositive,
+ newContainsZero,
+ newContainsNull
+ )
+ } else {
+ Seq(
+ If(newContainsNull || newContainsZero, logSum, newLogSum),
+ newPositive,
+ newContainsZero,
+ newContainsNull
+ )
+ }
+ }
+
+ override lazy val mergeExpressions: Seq[Expression] = child.dataType match {
+ case _: IntegralType =>
+ Seq(
+ product.left * product.right,
+ containsNull.left || containsNull.right
+ )
+ case _ =>
+ Seq(
+ logSum.left + logSum.right,
+ positive.left === positive.right,
+ containsZero.left || containsZero.right,
+ containsNull.left || containsNull.right
+ )
+ }
+
+ override lazy val evaluateExpression: Expression = child.dataType match {
+ case _: IntegralType =>
+ if (ignoreNA) {
+ product
+ } else {
+ If(containsNull, Literal(null, LongType), product)
+ }
+ case _ =>
+ val product = If(positive, Exp(logSum), -Exp(logSum))
+ if (ignoreNA) {
+ If(containsZero, Literal(0.0), product)
+ } else {
+ If(containsNull, Literal(null, DoubleType),
+ If(containsZero, Literal(0.0), product))
+ }
+ }
+
+ override def prettyName: String = "pandas_product"
+ override protected def withNewChildInternal(newChild: Expression):
PandasProduct =
+ copy(child = newChild)
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index d43a9060677..70474f4d5c4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -155,6 +155,10 @@ private[sql] object PythonSQLUtils extends Logging {
Column(TimestampDiff(unit, start.expr, end.expr))
}
+ def pandasProduct(e: Column, ignoreNA: Boolean): Column = {
+ Column(PandasProduct(e.expr, ignoreNA).toAggregateExpression(false))
+ }
+
def pandasStddev(e: Column, ddof: Int): Column = {
Column(PandasStddev(e.expr, ddof).toAggregateExpression(false))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]