Repository: spark
Updated Branches:
  refs/heads/master a471880af -> 9ea8d3d31


[SPARK-22362][SQL] Add unit test for Window Aggregate Functions

## What changes were proposed in this pull request?

Improving the test coverage of window functions focusing on missing test for 
window aggregate functions. No new UDAF test is added as it has been tested 
already.

## How was this patch tested?

Only new tests were added, automated tests were executed.

Author: “attilapiros” <[email protected]>
Author: Attila Zsolt Piros <[email protected]>

Closes #20046 from attilapiros/SPARK-22362.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9ea8d3d3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9ea8d3d3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9ea8d3d3

Branch: refs/heads/master
Commit: 9ea8d3d31b75246bf61118ac7934bc92c18b5f19
Parents: a471880
Author: “attilapiros” <[email protected]>
Authored: Thu Apr 19 18:55:59 2018 +0200
Committer: Herman van Hovell <[email protected]>
Committed: Thu Apr 19 18:55:59 2018 +0200

----------------------------------------------------------------------
 .../test/resources/sql-tests/inputs/window.sql  |  10 +-
 .../resources/sql-tests/results/window.sql.out  |  30 ++-
 .../sql/DataFrameWindowFunctionsSuite.scala     | 266 +++++++++++++++++++
 3 files changed, 294 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9ea8d3d3/sql/core/src/test/resources/sql-tests/inputs/window.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql 
b/sql/core/src/test/resources/sql-tests/inputs/window.sql
index c4bea34..cda4db4 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/window.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql
@@ -76,7 +76,15 @@ ntile(2) OVER w AS ntile,
 row_number() OVER w AS row_number,
 var_pop(val) OVER w AS var_pop,
 var_samp(val) OVER w AS var_samp,
-approx_count_distinct(val) OVER w AS approx_count_distinct
+approx_count_distinct(val) OVER w AS approx_count_distinct,
+covar_pop(val, val_long) OVER w AS covar_pop,
+corr(val, val_long) OVER w AS corr,
+stddev_samp(val) OVER w AS stddev_samp,
+stddev_pop(val) OVER w AS stddev_pop,
+collect_list(val) OVER w AS collect_list,
+collect_set(val) OVER w AS collect_set,
+skewness(val_double) OVER w AS skewness,
+kurtosis(val_double) OVER w AS kurtosis
 FROM testData
 WINDOW w AS (PARTITION BY cate ORDER BY val)
 ORDER BY cate, val;

http://git-wip-us.apache.org/repos/asf/spark/blob/9ea8d3d3/sql/core/src/test/resources/sql-tests/results/window.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out 
b/sql/core/src/test/resources/sql-tests/results/window.sql.out
index 133458a..4afbcd6 100644
--- a/sql/core/src/test/resources/sql-tests/results/window.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out
@@ -273,22 +273,30 @@ ntile(2) OVER w AS ntile,
 row_number() OVER w AS row_number,
 var_pop(val) OVER w AS var_pop,
 var_samp(val) OVER w AS var_samp,
-approx_count_distinct(val) OVER w AS approx_count_distinct
+approx_count_distinct(val) OVER w AS approx_count_distinct,
+covar_pop(val, val_long) OVER w AS covar_pop,
+corr(val, val_long) OVER w AS corr,
+stddev_samp(val) OVER w AS stddev_samp,
+stddev_pop(val) OVER w AS stddev_pop,
+collect_list(val) OVER w AS collect_list,
+collect_set(val) OVER w AS collect_set,
+skewness(val_double) OVER w AS skewness,
+kurtosis(val_double) OVER w AS kurtosis
 FROM testData
 WINDOW w AS (PARTITION BY cate ORDER BY val)
 ORDER BY cate, val
 -- !query 17 schema
-struct<val:int,cate:string,max:int,min:int,min:int,count:bigint,sum:bigint,avg:double,stddev:double,first_value:int,first_value_ignore_null:int,first_value_contain_null:int,last_value:int,last_value_ignore_null:int,last_value_contain_null:int,rank:int,dense_rank:int,cume_dist:double,percent_rank:double,ntile:int,row_number:int,var_pop:double,var_samp:double,approx_count_distinct:bigint>
+struct<val:int,cate:string,max:int,min:int,min:int,count:bigint,sum:bigint,avg:double,stddev:double,first_value:int,first_value_ignore_null:int,first_value_contain_null:int,last_value:int,last_value_ignore_null:int,last_value_contain_null:int,rank:int,dense_rank:int,cume_dist:double,percent_rank:double,ntile:int,row_number:int,var_pop:double,var_samp:double,approx_count_distinct:bigint,covar_pop:double,corr:double,stddev_samp:double,stddev_pop:double,collect_list:array<int>,collect_set:array<int>,skewness:double,kurtosis:double>
 -- !query 17 output
-NULL   NULL    NULL    NULL    NULL    0       NULL    NULL    NULL    NULL    
NULL    NULL    NULL    NULL    NULL    1       1       0.5     0.0     1       
1       NULL    NULL    0
-3      NULL    3       3       3       1       3       3.0     NaN     NULL    
3       NULL    3       3       3       2       2       1.0     1.0     2       
2       0.0     NaN     1
-NULL   a       NULL    NULL    NULL    0       NULL    NULL    NULL    NULL    
NULL    NULL    NULL    NULL    NULL    1       1       0.25    0.0     1       
1       NULL    NULL    0
-1      a       1       1       1       2       2       1.0     0.0     NULL    
1       NULL    1       1       1       2       2       0.75    
0.3333333333333333      1       2       0.0     0.0     1
-1      a       1       1       1       2       2       1.0     0.0     NULL    
1       NULL    1       1       1       2       2       0.75    
0.3333333333333333      2       3       0.0     0.0     1
-2      a       2       1       1       3       4       1.3333333333333333      
0.5773502691896258      NULL    1       NULL    2       2       2       4       
3       1.0     1.0     2       4       0.22222222222222224     
0.33333333333333337     2
-1      b       1       1       1       1       1       1.0     NaN     1       
1       1       1       1       1       1       1       0.3333333333333333      
0.0     1       1       0.0     NaN     1
-2      b       2       1       1       2       3       1.5     
0.7071067811865476      1       1       1       2       2       2       2       
2       0.6666666666666666      0.5     1       2       0.25    0.5     2
-3      b       3       1       1       3       6       2.0     1.0     1       
1       1       3       3       3       3       3       1.0     1.0     2       
3       0.6666666666666666      1.0     3
+NULL   NULL    NULL    NULL    NULL    0       NULL    NULL    NULL    NULL    
NULL    NULL    NULL    NULL    NULL    1       1       0.5     0.0     1       
1       NULL    NULL    0       NULL    NULL    NULL    NULL    []      []      
NULL    NULL
+3      NULL    3       3       3       1       3       3.0     NaN     NULL    
3       NULL    3       3       3       2       2       1.0     1.0     2       
2       0.0     NaN     1       0.0     NaN     NaN     0.0     [3]     [3]     
NaN     NaN
+NULL   a       NULL    NULL    NULL    0       NULL    NULL    NULL    NULL    
NULL    NULL    NULL    NULL    NULL    1       1       0.25    0.0     1       
1       NULL    NULL    0       NULL    NULL    NULL    NULL    []      []      
NaN     NaN
+1      a       1       1       1       2       2       1.0     0.0     NULL    
1       NULL    1       1       1       2       2       0.75    
0.3333333333333333      1       2       0.0     0.0     1       0.0     NULL    
0.0     0.0     [1,1]   [1]     0.7071067811865476      -1.5
+1      a       1       1       1       2       2       1.0     0.0     NULL    
1       NULL    1       1       1       2       2       0.75    
0.3333333333333333      2       3       0.0     0.0     1       0.0     NULL    
0.0     0.0     [1,1]   [1]     0.7071067811865476      -1.5
+2      a       2       1       1       3       4       1.3333333333333333      
0.5773502691896258      NULL    1       NULL    2       2       2       4       
3       1.0     1.0     2       4       0.22222222222222224     
0.33333333333333337     2       4.772185885555555E8     1.0     
0.5773502691896258      0.4714045207910317      [1,1,2] [1,2]   
1.1539890888012805      -0.6672217220327235
+1      b       1       1       1       1       1       1.0     NaN     1       
1       1       1       1       1       1       1       0.3333333333333333      
0.0     1       1       0.0     NaN     1       NULL    NULL    NaN     0.0     
[1]     [1]     NaN     NaN
+2      b       2       1       1       2       3       1.5     
0.7071067811865476      1       1       1       2       2       2       2       
2       0.6666666666666666      0.5     1       2       0.25    0.5     2       
0.0     NaN     0.7071067811865476      0.5     [1,2]   [1,2]   0.0     
-2.0000000000000013
+3      b       3       1       1       3       6       2.0     1.0     1       
1       1       3       3       3       3       3       1.0     1.0     2       
3       0.6666666666666666      1.0     3       5.3687091175E8  1.0     1.0     
0.816496580927726       [1,2,3] [1,2,3] 0.7057890433107311      
-1.4999999999999984
 
 
 -- !query 18

http://git-wip-us.apache.org/repos/asf/spark/blob/9ea8d3d3/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 2811478..3ea398a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql
 
 import java.sql.{Date, Timestamp}
 
+import scala.collection.mutable
+
 import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, 
UserDefinedAggregateFunction, Window}
 import org.apache.spark.sql.functions._
@@ -86,6 +88,236 @@ class DataFrameWindowFunctionsSuite extends QueryTest with 
SharedSQLContext {
     assert(e.message.contains("requires window to be ordered"))
   }
 
+  test("corr, covar_pop, stddev_pop functions in specific window") {
+    val df = Seq(
+      ("a", "p1", 10.0, 20.0),
+      ("b", "p1", 20.0, 10.0),
+      ("c", "p2", 20.0, 20.0),
+      ("d", "p2", 20.0, 20.0),
+      ("e", "p3", 0.0, 0.0),
+      ("f", "p3", 6.0, 12.0),
+      ("g", "p3", 6.0, 12.0),
+      ("h", "p3", 8.0, 16.0),
+      ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2")
+    checkAnswer(
+      df.select(
+        $"key",
+        corr("value1", "value2").over(Window.partitionBy("partitionId")
+          .orderBy("key").rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)),
+        covar_pop("value1", "value2")
+          .over(Window.partitionBy("partitionId")
+            .orderBy("key").rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)),
+        var_pop("value1")
+          .over(Window.partitionBy("partitionId")
+            .orderBy("key").rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)),
+        stddev_pop("value1")
+          .over(Window.partitionBy("partitionId")
+            .orderBy("key").rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)),
+        var_pop("value2")
+          .over(Window.partitionBy("partitionId")
+            .orderBy("key").rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)),
+        stddev_pop("value2")
+          .over(Window.partitionBy("partitionId")
+            .orderBy("key").rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing))),
+
+      // As stddev_pop(expr) = sqrt(var_pop(expr))
+      // the "stddev_pop" column can be calculated from the "var_pop" column.
+      //
+      // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) 
* stddev_pop(expr2))
+      // the "corr" column can be calculated from the "covar_pop" and the two 
"stddev_pop" columns.
+      Seq(
+        Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0),
+        Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0),
+        Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0),
+        Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0),
+        Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
+        Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
+        Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
+        Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
+        Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0)))
+  }
+
+  test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in 
specific window") {
+    val df = Seq(
+      ("a", "p1", 10.0, 20.0),
+      ("b", "p1", 20.0, 10.0),
+      ("c", "p2", 20.0, 20.0),
+      ("d", "p2", 20.0, 20.0),
+      ("e", "p3", 0.0, 0.0),
+      ("f", "p3", 6.0, 12.0),
+      ("g", "p3", 6.0, 12.0),
+      ("h", "p3", 8.0, 16.0),
+      ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2")
+    checkAnswer(
+      df.select(
+        $"key",
+        covar_samp("value1", "value2").over(Window.partitionBy("partitionId")
+          .orderBy("key").rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)),
+        var_samp("value1").over(Window.partitionBy("partitionId")
+          .orderBy("key").rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)),
+        variance("value1").over(Window.partitionBy("partitionId")
+          .orderBy("key").rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)),
+        stddev_samp("value1").over(Window.partitionBy("partitionId")
+          .orderBy("key").rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)),
+        stddev("value1").over(Window.partitionBy("partitionId")
+          .orderBy("key").rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing))
+      ),
+      Seq(
+        Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
+        Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
+        Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ),
+        Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ),
+        Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
+        Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
+        Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
+        Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
+        Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)))
+  }
+
+  test("collect_list in ascending ordered window") {
+    val df = Seq(
+      ("a", "p1", "1"),
+      ("b", "p1", "2"),
+      ("c", "p1", "2"),
+      ("d", "p1", null),
+      ("e", "p1", "3"),
+      ("f", "p2", "10"),
+      ("g", "p2", "11"),
+      ("h", "p3", "20"),
+      ("i", "p4", null)).toDF("key", "partition", "value")
+    checkAnswer(
+      df.select(
+        $"key",
+        sort_array(
+          
collect_list("value").over(Window.partitionBy($"partition").orderBy($"value")
+            .rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)))),
+      Seq(
+        Row("a", Array("1", "2", "2", "3")),
+        Row("b", Array("1", "2", "2", "3")),
+        Row("c", Array("1", "2", "2", "3")),
+        Row("d", Array("1", "2", "2", "3")),
+        Row("e", Array("1", "2", "2", "3")),
+        Row("f", Array("10", "11")),
+        Row("g", Array("10", "11")),
+        Row("h", Array("20")),
+        Row("i", Array())))
+  }
+
+  test("collect_list in descending ordered window") {
+    val df = Seq(
+      ("a", "p1", "1"),
+      ("b", "p1", "2"),
+      ("c", "p1", "2"),
+      ("d", "p1", null),
+      ("e", "p1", "3"),
+      ("f", "p2", "10"),
+      ("g", "p2", "11"),
+      ("h", "p3", "20"),
+      ("i", "p4", null)).toDF("key", "partition", "value")
+    checkAnswer(
+      df.select(
+        $"key",
+        sort_array(
+          
collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc)
+            .rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)))),
+      Seq(
+        Row("a", Array("1", "2", "2", "3")),
+        Row("b", Array("1", "2", "2", "3")),
+        Row("c", Array("1", "2", "2", "3")),
+        Row("d", Array("1", "2", "2", "3")),
+        Row("e", Array("1", "2", "2", "3")),
+        Row("f", Array("10", "11")),
+        Row("g", Array("10", "11")),
+        Row("h", Array("20")),
+        Row("i", Array())))
+  }
+
+  test("collect_set in window") {
+    val df = Seq(
+      ("a", "p1", "1"),
+      ("b", "p1", "2"),
+      ("c", "p1", "2"),
+      ("d", "p1", "3"),
+      ("e", "p1", "3"),
+      ("f", "p2", "10"),
+      ("g", "p2", "11"),
+      ("h", "p3", "20")).toDF("key", "partition", "value")
+    checkAnswer(
+      df.select(
+        $"key",
+        sort_array(
+          
collect_set("value").over(Window.partitionBy($"partition").orderBy($"value")
+            .rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)))),
+      Seq(
+        Row("a", Array("1", "2", "3")),
+        Row("b", Array("1", "2", "3")),
+        Row("c", Array("1", "2", "3")),
+        Row("d", Array("1", "2", "3")),
+        Row("e", Array("1", "2", "3")),
+        Row("f", Array("10", "11")),
+        Row("g", Array("10", "11")),
+        Row("h", Array("20"))))
+  }
+
+  test("skewness and kurtosis functions in window") {
+    val df = Seq(
+      ("a", "p1", 1.0),
+      ("b", "p1", 1.0),
+      ("c", "p1", 2.0),
+      ("d", "p1", 2.0),
+      ("e", "p1", 3.0),
+      ("f", "p1", 3.0),
+      ("g", "p1", 3.0),
+      ("h", "p2", 1.0),
+      ("i", "p2", 2.0),
+      ("j", "p2", 5.0)).toDF("key", "partition", "value")
+    checkAnswer(
+      df.select(
+        $"key",
+        skewness("value").over(Window.partitionBy("partition").orderBy($"key")
+          .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
+        kurtosis("value").over(Window.partitionBy("partition").orderBy($"key")
+          .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))),
+      // results are checked by scipy.stats.skew() and scipy.stats.kurtosis()
+      Seq(
+        Row("a", -0.27238010581457267, -1.506920415224914),
+        Row("b", -0.27238010581457267, -1.506920415224914),
+        Row("c", -0.27238010581457267, -1.506920415224914),
+        Row("d", -0.27238010581457267, -1.506920415224914),
+        Row("e", -0.27238010581457267, -1.506920415224914),
+        Row("f", -0.27238010581457267, -1.506920415224914),
+        Row("g", -0.27238010581457267, -1.506920415224914),
+        Row("h", 0.5280049792181881, -1.5000000000000013),
+        Row("i", 0.5280049792181881, -1.5000000000000013),
+        Row("j", 0.5280049792181881, -1.5000000000000013)))
+  }
+
+  test("aggregation function on invalid column") {
+    val df = Seq((1, "1")).toDF("key", "value")
+    val e = intercept[AnalysisException](
+      df.select($"key", count("invalid").over()))
+    assert(e.message.contains("cannot resolve '`invalid`' given input columns: 
[key, value]"))
+  }
+
+  test("numerical aggregate functions on string column") {
+    val df = Seq((1, "a", "b")).toDF("key", "value1", "value2")
+    checkAnswer(
+      df.select($"key",
+        var_pop("value1").over(),
+        variance("value1").over(),
+        stddev_pop("value1").over(),
+        stddev("value1").over(),
+        sum("value1").over(),
+        mean("value1").over(),
+        avg("value1").over(),
+        corr("value1", "value2").over(),
+        covar_pop("value1", "value2").over(),
+        covar_samp("value1", "value2").over(),
+        skewness("value1").over(),
+        kurtosis("value1").over()),
+      Seq(Row(1, null, null, null, null, null, null, null, null, null, null, 
null, null)))
+  }
+
   test("statistical functions") {
     val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), 
("b", 2)).
       toDF("key", "value")
@@ -232,6 +464,40 @@ class DataFrameWindowFunctionsSuite extends QueryTest with 
SharedSQLContext {
         Row("b", 2, null, null, null, null, null, null)))
   }
 
+  test("last/first on descending ordered window") {
+    val nullStr: String = null
+    val df = Seq(
+      ("a", 0, nullStr),
+      ("a", 1, "x"),
+      ("a", 2, "y"),
+      ("a", 3, "z"),
+      ("a", 4, "v"),
+      ("b", 1, "k"),
+      ("b", 2, "l"),
+      ("b", 3, nullStr)).
+      toDF("key", "order", "value")
+    val window = Window.partitionBy($"key").orderBy($"order".desc)
+    checkAnswer(
+      df.select(
+        $"key",
+        $"order",
+        first($"value").over(window),
+        first($"value", ignoreNulls = false).over(window),
+        first($"value", ignoreNulls = true).over(window),
+        last($"value").over(window),
+        last($"value", ignoreNulls = false).over(window),
+        last($"value", ignoreNulls = true).over(window)),
+      Seq(
+        Row("a", 0, "v", "v", "v", null, null, "x"),
+        Row("a", 1, "v", "v", "v", "x", "x", "x"),
+        Row("a", 2, "v", "v", "v", "y", "y", "y"),
+        Row("a", 3, "v", "v", "v", "z", "z", "z"),
+        Row("a", 4, "v", "v", "v", "v", "v", "v"),
+        Row("b", 1, null, null, "l", "k", "k", "k"),
+        Row("b", 2, null, null, "l", "l", "l", "l"),
+        Row("b", 3, null, null, null, null, null, null)))
+  }
+
   test("SPARK-12989 ExtractWindowExpressions treats alias as regular 
attribute") {
     val src = Seq((0, 3, 5)).toDF("a", "b", "c")
       .withColumn("Data", struct("a", "b"))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to