Repository: spark Updated Branches: refs/heads/master 9053054c7 -> d2e44d7db
[SPARK-16192][SQL] Add type checks in CollectSet ## What changes were proposed in this pull request? `CollectSet` cannot have map-typed data because MapTypeData does not implement `equals`. So, this pr is to add type checks in `CheckAnalysis`. ## How was this patch tested? Added tests to check failures when we found map-typed data in `CollectSet`. Author: Takeshi YAMAMURO <linguin....@gmail.com> Closes #13892 from maropu/SPARK-16192. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d2e44d7d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d2e44d7d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d2e44d7d Branch: refs/heads/master Commit: d2e44d7db82ff3c3326af7bf7ea69c803803698e Parents: 9053054 Author: Takeshi YAMAMURO <linguin....@gmail.com> Authored: Fri Jun 24 21:07:03 2016 -0700 Committer: Herman van Hovell <hvanhov...@databricks.com> Committed: Fri Jun 24 21:07:03 2016 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 4 ++-- .../sql/catalyst/expressions/aggregate/collect.scala | 9 +++++++++ .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d2e44d7d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 8992276..ac9693e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -73,9 +73,9 @@ trait CheckAnalysis extends PredicateHelper { s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") case g: Grouping => - failAnalysis(s"grouping() can only be used with GroupingSets/Cube/Rollup") + failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup") case g: GroupingID => - failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") + failAnalysis("grouping_id() can only be used with GroupingSets/Cube/Rollup") case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") http://git-wip-us.apache.org/repos/asf/spark/blob/d2e44d7d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 1f4ff9c..ac2cefa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import scala.collection.generic.Growable import scala.collection.mutable +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -107,6 +108,14 @@ case class CollectSet( def this(child: Expression) = this(child, 0, 0) + override def checkInputDataTypes(): TypeCheckResult = { + if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("collect_set() cannot have map type data") + } + } + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) http://git-wip-us.apache.org/repos/asf/spark/blob/d2e44d7d/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 69a9907..92aa7b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -457,6 +457,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("collect_set functions cannot have maps") { + val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1)) + .toDF("a", "x", "y") + .select($"a", map($"x", $"y").as("b")) + val error = intercept[AnalysisException] { + df.select(collect_set($"a"), collect_set($"b")) + } + assert(error.message.contains("collect_set() cannot have map type data")) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org