Github user attilapiros commented on a diff in the pull request:
https://github.com/apache/spark/pull/20046#discussion_r158456767
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
---
@@ -154,6 +154,217 @@ class DataFrameWindowFunctionsSuite extends QueryTest
with SharedSQLContext {
Row(2.0d), Row(2.0d)))
}
+ test("corr, covar_pop, stddev_pop functions in specific window") {
+ val df = Seq(
+ ("a", "p1", 10.0, 20.0),
+ ("b", "p1", 20.0, 10.0),
+ ("c", "p2", 20.0, 20.0),
+ ("d", "p2", 20.0, 20.0),
+ ("e", "p3", 0.0, 0.0),
+ ("f", "p3", 6.0, 12.0),
+ ("g", "p3", 6.0, 12.0),
+ ("h", "p3", 8.0, 16.0),
+ ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2")
+ checkAnswer(
+ df.select(
+ $"key",
+ corr("value1", "value2").over(Window.partitionBy("partitionId")
+ .orderBy("key").rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing)),
+ covar_pop("value1", "value2")
+ .over(Window.partitionBy("partitionId")
+ .orderBy("key").rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing)),
+ var_pop("value1")
+ .over(Window.partitionBy("partitionId")
+ .orderBy("key").rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing)),
+ stddev_pop("value1")
+ .over(Window.partitionBy("partitionId")
+ .orderBy("key").rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing)),
+ var_pop("value2")
+ .over(Window.partitionBy("partitionId")
+ .orderBy("key").rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing)),
+ stddev_pop("value2")
+ .over(Window.partitionBy("partitionId")
+ .orderBy("key").rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing))),
+
+ // As stddev_pop(expr) = sqrt(var_pop(expr))
+ // the "stddev_pop" column can be calculated from the "var_pop"
column.
+ //
+ // As corr(expr1, expr2) = covar_pop(expr1, expr2) /
(stddev_pop(expr1) * stddev_pop(expr2))
+ // the "corr" column can be calculated from the "covar_pop" and the
two "stddev_pop" columns.
+ Seq(
+ Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0),
+ Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0),
+ Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0),
+ Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0),
+ Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
+ Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
+ Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
+ Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
+ Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0)))
+ }
+
+ test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in
specific window") {
+ val df = Seq(
+ ("a", "p1", 10.0, 20.0),
+ ("b", "p1", 20.0, 10.0),
+ ("c", "p2", 20.0, 20.0),
+ ("d", "p2", 20.0, 20.0),
+ ("e", "p3", 0.0, 0.0),
+ ("f", "p3", 6.0, 12.0),
+ ("g", "p3", 6.0, 12.0),
+ ("h", "p3", 8.0, 16.0),
+ ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2")
+ checkAnswer(
+ df.select(
+ $"key",
+ covar_samp("value1",
"value2").over(Window.partitionBy("partitionId")
+ .orderBy("key").rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing)),
+ var_samp("value1").over(Window.partitionBy("partitionId")
+ .orderBy("key").rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing)),
+ variance("value1").over(Window.partitionBy("partitionId")
+ .orderBy("key").rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing)),
+ stddev_samp("value1").over(Window.partitionBy("partitionId")
+ .orderBy("key").rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing)),
+ stddev("value1").over(Window.partitionBy("partitionId")
+ .orderBy("key").rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing))
+ ),
+ Seq(
+ Row("a", -50.0, 50.0, 50.0, 7.0710678118654755,
7.0710678118654755),
+ Row("b", -50.0, 50.0, 50.0, 7.0710678118654755,
7.0710678118654755),
+ Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ),
+ Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ),
+ Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544
),
+ Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544
),
+ Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544
),
+ Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544
),
+ Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN,
Double.NaN)))
+ }
+
+ test("collect_list in ascending ordered window") {
+ val df = Seq(
+ ("a", "p1", "1"),
+ ("b", "p1", "2"),
+ ("c", "p1", "2"),
+ ("d", "p1", null),
+ ("e", "p1", "3"),
+ ("f", "p2", "10"),
+ ("g", "p2", "11"),
+ ("h", "p3", "20"),
+ ("i", "p4", null)).toDF("key", "partition", "value")
+ checkAnswer(
+ df.select(
+ $"key",
+
collect_list("value").over(Window.partitionBy($"partition").orderBy($"value")
+ .rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing))),
+ Seq(
+ Row("a", Array("1", "2", "2", "3")),
+ Row("b", Array("1", "2", "2", "3")),
+ Row("c", Array("1", "2", "2", "3")),
+ Row("d", Array("1", "2", "2", "3")),
+ Row("e", Array("1", "2", "2", "3")),
+ Row("f", Array("10", "11")),
+ Row("g", Array("10", "11")),
+ Row("h", Array("20")),
+ Row("i", Array())))
+ }
+
+ test("collect_list in descending ordered window") {
+ val df = Seq(
+ ("a", "p1", "1"),
+ ("b", "p1", "2"),
+ ("c", "p1", "2"),
+ ("d", "p1", null),
+ ("e", "p1", "3"),
+ ("f", "p2", "10"),
+ ("g", "p2", "11"),
+ ("h", "p3", "20"),
+ ("i", "p4", null)).toDF("key", "partition", "value")
+ checkAnswer(
+ df.select(
+ $"key",
+
collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc)
+ .rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing))),
+ Seq(
+ Row("a", Array("3", "2", "2", "1")),
+ Row("b", Array("3", "2", "2", "1")),
+ Row("c", Array("3", "2", "2", "1")),
+ Row("d", Array("3", "2", "2", "1")),
+ Row("e", Array("3", "2", "2", "1")),
+ Row("f", Array("11", "10")),
+ Row("g", Array("11", "10")),
+ Row("h", Array("20")),
+ Row("i", Array())))
+ }
+
+ test("collect_set in window") {
+ val df = Seq(
+ ("a", "p1", 1),
+ ("b", "p1", 2),
+ ("c", "p1", 2),
+ ("d", "p1", 3),
+ ("e", "p1", 3),
+ ("f", "p2", 10),
+ ("g", "p2", 11),
+ ("h", "p3", 20)).toDF("key", "partition", "value")
+ checkAnswer(
+ df.select(
+ $"key",
+
collect_set("value").over(Window.partitionBy($"partition").orderBy($"value")
+ .rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing))),
+ Seq(
+ Row("a", Array(1, 2, 3)),
+ Row("b", Array(1, 2, 3)),
+ Row("c", Array(1, 2, 3)),
+ Row("d", Array(1, 2, 3)),
+ Row("e", Array(1, 2, 3)),
+ Row("f", Array(10, 11)),
+ Row("g", Array(10, 11)),
+ Row("h", Array(20))))
+ }
+
+ test("skewness and kurtosis functions in window") {
+ val df = Seq(
+ ("a", "p1", 1.0),
+ ("b", "p1", 1.0),
+ ("c", "p1", 2.0),
+ ("d", "p1", 2.0),
+ ("e", "p1", 3.0),
+ ("f", "p1", 3.0),
+ ("g", "p1", 3.0),
+ ("h", "p2", 1.0),
+ ("i", "p2", 2.0),
+ ("j", "p2", 5.0)).toDF("key", "partition", "value")
+ checkAnswer(
+ df.select(
+ $"key",
+
skewness("value").over(Window.partitionBy("partition").orderBy($"key")
+ .rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing)),
+
kurtosis("value").over(Window.partitionBy("partition").orderBy($"key")
+ .rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing))),
+ // results are checked by scipy.stats.skew() and
scipy.stats.kurtosis()
+ Seq(
+ Row("a", -0.27238010581457267, -1.506920415224914),
+ Row("b", -0.27238010581457267, -1.506920415224914),
+ Row("c", -0.27238010581457267, -1.506920415224914),
+ Row("d", -0.27238010581457267, -1.506920415224914),
+ Row("e", -0.27238010581457267, -1.506920415224914),
+ Row("f", -0.27238010581457267, -1.506920415224914),
+ Row("g", -0.27238010581457267, -1.506920415224914),
+ Row("h", 0.5280049792181881, -1.5000000000000013),
+ Row("i", 0.5280049792181881, -1.5000000000000013),
+ Row("j", 0.5280049792181881, -1.5000000000000013)))
+ }
+
+ test("aggregation function on invalid column") {
+ val df = Seq((1, "1")).toDF("key", "value")
+ val e = intercept[AnalysisException](
+ df.select(
+ $"key",
+ count("invalid").over(
--- End diff --
Thanks, I will remove the unnecessary parts.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]