Repository: spark
Updated Branches:
  refs/heads/master 7c5b64180 -> 09ad9533d


[SPARK-11720][SQL][ML] Handle edge cases when count = 0 or 1 for Stats function

return Double.NaN for mean/average when count == 0 for all numeric types that 
is converted to Double, Decimal type continue to return null.

Author: JihongMa <linlin200...@gmail.com>

Closes #9705 from JihongMA/SPARK-11720.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/09ad9533
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/09ad9533
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/09ad9533

Branch: refs/heads/master
Commit: 09ad9533d5760652de59fa4830c24cb8667958ac
Parents: 7c5b641
Author: JihongMa <linlin200...@gmail.com>
Authored: Wed Nov 18 13:03:37 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Wed Nov 18 13:03:37 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                   |  2 +-
 .../expressions/aggregate/CentralMomentAgg.scala  |  2 +-
 .../catalyst/expressions/aggregate/Kurtosis.scala |  9 +++++----
 .../catalyst/expressions/aggregate/Skewness.scala |  9 +++++----
 .../catalyst/expressions/aggregate/Stddev.scala   | 18 ++++++++++++++----
 .../catalyst/expressions/aggregate/Variance.scala | 18 ++++++++++++++----
 .../spark/sql/DataFrameAggregateSuite.scala       | 18 ++++++++++++------
 .../org/apache/spark/sql/DataFrameSuite.scala     |  2 +-
 8 files changed, 53 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/09ad9533/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ad6ad02..0dd75ba 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -761,7 +761,7 @@ class DataFrame(object):
         +-------+------------------+-----+
         |  count|                 2|    2|
         |   mean|               3.5| null|
-        | stddev|2.1213203435596424|  NaN|
+        | stddev|2.1213203435596424| null|
         |    min|                 2|Alice|
         |    max|                 5|  Bob|
         +-------+------------------+-----+

http://git-wip-us.apache.org/repos/asf/spark/blob/09ad9533/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index de5872a..d07d4c3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -206,7 +206,7 @@ abstract class CentralMomentAgg(child: Expression) extends 
ImperativeAggregate w
    * @param centralMoments Length `momentOrder + 1` array of central moments 
(un-normalized)
    *                       needed to compute the aggregate stat.
    */
-  def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): 
Double
+  def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any
 
   override final def eval(buffer: InternalRow): Any = {
     val n = buffer.getDouble(nOffset)

http://git-wip-us.apache.org/repos/asf/spark/blob/09ad9533/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
index 8fa3aac..c2bf2cb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
@@ -37,16 +37,17 @@ case class Kurtosis(child: Expression,
   override protected val momentOrder = 4
 
   // NOTE: this is the formula for excess kurtosis, which is default for R and 
SciPy
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Double = {
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
     require(moments.length == momentOrder + 1,
       s"$prettyName requires ${momentOrder + 1} central moments, received: 
${moments.length}")
     val m2 = moments(2)
     val m4 = moments(4)
 
-    if (n == 0.0 || m2 == 0.0) {
+    if (n == 0.0) {
+      null
+    } else if (m2 == 0.0) {
       Double.NaN
-    }
-    else {
+    } else {
       n * m4 / (m2 * m2) - 3.0
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/09ad9533/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
index e1c01a5..9411bce 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
@@ -36,16 +36,17 @@ case class Skewness(child: Expression,
 
   override protected val momentOrder = 3
 
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Double = {
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
     require(moments.length == momentOrder + 1,
       s"$prettyName requires ${momentOrder + 1} central moments, received: 
${moments.length}")
     val m2 = moments(2)
     val m3 = moments(3)
 
-    if (n == 0.0 || m2 == 0.0) {
+    if (n == 0.0) {
+      null
+    } else if (m2 == 0.0) {
       Double.NaN
-    }
-    else {
+    } else {
       math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/09ad9533/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
index 05dd5e3..eec79a9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
@@ -36,11 +36,17 @@ case class StddevSamp(child: Expression,
 
   override protected val momentOrder = 2
 
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Double = {
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
     require(moments.length == momentOrder + 1,
       s"$prettyName requires ${momentOrder + 1} central moment, received: 
${moments.length}")
 
-    if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0))
+    if (n == 0.0) {
+      null
+    } else if (n == 1.0) {
+      Double.NaN
+    } else {
+      math.sqrt(moments(2) / (n - 1.0))
+    }
   }
 }
 
@@ -62,10 +68,14 @@ case class StddevPop(
 
   override protected val momentOrder = 2
 
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Double = {
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
     require(moments.length == momentOrder + 1,
       s"$prettyName requires ${momentOrder + 1} central moment, received: 
${moments.length}")
 
-    if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n)
+    if (n == 0.0) {
+      null
+    } else {
+      math.sqrt(moments(2) / n)
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/09ad9533/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
index ede2da2..cf3a740 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
@@ -36,11 +36,17 @@ case class VarianceSamp(child: Expression,
 
   override protected val momentOrder = 2
 
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Double = {
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
     require(moments.length == momentOrder + 1,
       s"$prettyName requires ${momentOrder + 1} central moment, received: 
${moments.length}")
 
-    if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0)
+    if (n == 0.0) {
+      null
+    } else if (n == 1.0) {
+      Double.NaN
+    } else {
+      moments(2) / (n - 1.0)
+    }
   }
 }
 
@@ -62,10 +68,14 @@ case class VariancePop(
 
   override protected val momentOrder = 2
 
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Double = {
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
     require(moments.length == momentOrder + 1,
       s"$prettyName requires ${momentOrder + 1} central moment, received: 
${moments.length}")
 
-    if (n == 0.0) Double.NaN else moments(2) / n
+    if (n == 0.0) {
+      null
+    } else {
+      moments(2) / n
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/09ad9533/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 432e8d1..71adf21 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
     val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
     checkAnswer(
     emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
-    Row(Double.NaN, Double.NaN, Double.NaN))
+    Row(null, null, null))
   }
 
   test("zero sum") {
@@ -244,17 +244,23 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
   test("zero moments") {
     val input = Seq((1, 2)).toDF("a", "b")
     checkAnswer(
-      input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), 
kurtosis('a)),
-      Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN))
+      input.agg(stddev('a), stddev_samp('a), stddev_pop('a), variance('a),
+        var_samp('a), var_pop('a), skewness('a), kurtosis('a)),
+      Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0,
+        Double.NaN, Double.NaN))
 
     checkAnswer(
       input.agg(
+        expr("stddev(a)"),
+        expr("stddev_samp(a)"),
+        expr("stddev_pop(a)"),
         expr("variance(a)"),
         expr("var_samp(a)"),
         expr("var_pop(a)"),
         expr("skewness(a)"),
         expr("kurtosis(a)")),
-      Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN))
+      Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0,
+        Double.NaN, Double.NaN))
   }
 
   test("null moments") {
@@ -262,7 +268,7 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
 
     checkAnswer(
       emptyTableData.agg(variance('a), var_samp('a), var_pop('a), 
skewness('a), kurtosis('a)),
-      Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))
+      Row(null, null, null, null, null))
 
     checkAnswer(
       emptyTableData.agg(
@@ -271,6 +277,6 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
         expr("var_pop(a)"),
         expr("skewness(a)"),
         expr("kurtosis(a)")),
-      Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))
+      Row(null, null, null, null, null))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/09ad9533/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 5a7f246..6399b01 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     val emptyDescribeResult = Seq(
       Row("count", "0", "0"),
       Row("mean", null, null),
-      Row("stddev", "NaN", "NaN"),
+      Row("stddev", null, null),
       Row("min", null, null),
       Row("max", null, null))
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to