spark git commit: [SPARK-17641][SQL] Collect_list/Collect_set should not collect null values.
Repository: spark Updated Branches: refs/heads/branch-2.0 d358298f1 -> 0a69477a1 [SPARK-17641][SQL] Collect_list/Collect_set should not collect null values. ## What changes were proposed in this pull request? We added native versions of `collect_set` and `collect_list` in Spark 2.0. These currently also (try to) collect null values, this is different from the original Hive implementation. This PR fixes this by adding a null check to the `Collect.update` method. ## How was this patch tested? Added a regression test to `DataFrameAggregateSuite`. Author: Herman van Hovell Closes #15208 from hvanhovell/SPARK-17641. (cherry picked from commit 7d09232028967978d9db314ec041a762599f636b) Signed-off-by: Reynold Xin Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0a69477a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0a69477a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0a69477a Branch: refs/heads/branch-2.0 Commit: 0a69477a10adb3969a20ae870436299ef5152788 Parents: d358298 Author: Herman van Hovell Authored: Wed Sep 28 16:25:10 2016 -0700 Committer: Reynold Xin Committed: Wed Sep 28 16:25:31 2016 -0700 -- .../sql/catalyst/expressions/aggregate/collect.scala| 7 ++- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 12 2 files changed, 18 insertions(+), 1 deletion(-) -- http://git-wip-us.apache.org/repos/asf/spark/blob/0a69477a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala -- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 896ff61..78a388d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -65,7 +65,12 @@ abstract class Collect extends ImperativeAggregate { } override def update(b: MutableRow, input: InternalRow): Unit = { -buffer += child.eval(input) +// Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. +// See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator +val value = child.eval(input) +if (value != null) { + buffer += value +} } override def merge(buffer: MutableRow, input: InternalRow): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/0a69477a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala -- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index cb505ac..3454caf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -477,6 +477,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(error.message.contains("collect_set() cannot have map type data")) } + test("SPARK-17641: collect functions should not collect null values") { +val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b") +checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq("1", "1"), Seq(2, 2, 4))) +) +checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq("1"), Seq(2, 4))) +) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), - To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org
spark git commit: [SPARK-17641][SQL] Collect_list/Collect_set should not collect null values.
Repository: spark Updated Branches: refs/heads/master 557d6e322 -> 7d0923202 [SPARK-17641][SQL] Collect_list/Collect_set should not collect null values. ## What changes were proposed in this pull request? We added native versions of `collect_set` and `collect_list` in Spark 2.0. These currently also (try to) collect null values, this is different from the original Hive implementation. This PR fixes this by adding a null check to the `Collect.update` method. ## How was this patch tested? Added a regression test to `DataFrameAggregateSuite`. Author: Herman van Hovell Closes #15208 from hvanhovell/SPARK-17641. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7d092320 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7d092320 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7d092320 Branch: refs/heads/master Commit: 7d09232028967978d9db314ec041a762599f636b Parents: 557d6e3 Author: Herman van Hovell Authored: Wed Sep 28 16:25:10 2016 -0700 Committer: Reynold Xin Committed: Wed Sep 28 16:25:10 2016 -0700 -- .../sql/catalyst/expressions/aggregate/collect.scala| 7 ++- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 12 2 files changed, 18 insertions(+), 1 deletion(-) -- http://git-wip-us.apache.org/repos/asf/spark/blob/7d092320/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala -- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 896ff61..78a388d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -65,7 +65,12 @@ abstract class Collect extends ImperativeAggregate { } override def update(b: MutableRow, input: InternalRow): Unit = { -buffer += child.eval(input) +// Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. +// See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator +val value = child.eval(input) +if (value != null) { + buffer += value +} } override def merge(buffer: MutableRow, input: InternalRow): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/7d092320/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala -- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0e172be..7aa4f00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -477,6 +477,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(error.message.contains("collect_set() cannot have map type data")) } + test("SPARK-17641: collect functions should not collect null values") { +val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b") +checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq("1", "1"), Seq(2, 2, 4))) +) +checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq("1"), Seq(2, 4))) +) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), - To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org