Repository: spark Updated Branches: refs/heads/master 0874ff3aa -> 05f652d6c
[SPARK-13957][SQL] Support Group By Ordinal in SQL #### What changes were proposed in this pull request? This PR is to support group by position in SQL. For example, when users input the following query ```SQL select c1 as a, c2, c3, sum(*) from tbl group by 1, 3, c4 ``` The ordinals are recognized as the positions in the select list. Thus, `Analyzer` converts it to ```SQL select c1, c2, c3, sum(*) from tbl group by c1, c3, c4 ``` This is controlled by the config option `spark.sql.groupByOrdinal`. - When true, the ordinal numbers in group by clauses are treated as the position in the select list. - When false, the ordinal numbers are ignored. - Only convert integer literals (not foldable expressions). If found foldable expressions, ignore them. - When the positions specified in the group by clauses correspond to the aggregate functions in select list, output an exception message. - star is not allowed to use in the select list when users specify ordinals in group by Note: This PR is taken from https://github.com/apache/spark/pull/10731. When merging this PR, please give the credit to zhichao-li Also cc all the people who are involved in the previous discussion: rxin cloud-fan marmbrus yhuai hvanhovell adrian-wang chenghao-intel tejasapatil #### How was this patch tested? Added a few test cases for both positive and negative test cases. Author: gatorsmile <gatorsm...@gmail.com> Author: xiaoli <lixiao1...@gmail.com> Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local> Closes #11846 from gatorsmile/groupByOrdinal. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/05f652d6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/05f652d6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/05f652d6 Branch: refs/heads/master Commit: 05f652d6c2bbd764a1dd5a45301811e14519486f Parents: 0874ff3 Author: gatorsmile <gatorsm...@gmail.com> Authored: Fri Mar 25 12:55:58 2016 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Mar 25 12:55:58 2016 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/CatalystConf.scala | 8 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 72 +++++++++++---- .../spark/sql/catalyst/planning/patterns.scala | 3 +- .../org/apache/spark/sql/internal/SQLConf.scala | 6 ++ .../org/apache/spark/sql/SQLQuerySuite.scala | 92 ++++++++++++++++++-- 5 files changed, 156 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/05f652d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index e10ab97..d5ac015 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -23,6 +23,7 @@ private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean def orderByOrdinal: Boolean + def groupByOrdinal: Boolean /** * Returns the [[Resolver]] for the current configuration, which can be used to determin if two @@ -48,11 +49,16 @@ object EmptyConf extends CatalystConf { override def orderByOrdinal: Boolean = { throw new UnsupportedOperationException } + override def groupByOrdinal: Boolean = { + throw new UnsupportedOperationException + } } /** A CatalystConf that can be used for local testing. */ case class SimpleCatalystConf( caseSensitiveAnalysis: Boolean, - orderByOrdinal: Boolean = true) + orderByOrdinal: Boolean = true, + groupByOrdinal: Boolean = true) + extends CatalystConf { } http://git-wip-us.apache.org/repos/asf/spark/blob/05f652d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 07b0f5e..d0a31e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -85,6 +85,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveUpCast :: + ResolveOrdinalInOrderByAndGroupBy :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -385,7 +386,13 @@ class Analyzer( p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => - a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) + if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) { + failAnalysis( + "Group by position: star is not allowed to use in the select list " + + "when using ordinals in group by") + } else { + a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) + } // If the script transformation input contains Stars, expand it. case t: ScriptTransformation if containsStar(t.input) => t.copy( @@ -634,21 +641,23 @@ class Analyzer( } } - /** - * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT - * clause. This rule detects such queries and adds the required attributes to the original - * projection, so that they will be available during sorting. Another projection is added to - * remove these attributes after sorting. - * - * This rule also resolves the position number in sort references. This support is introduced - * in Spark 2.0. Before Spark 2.0, the integers in Order By has no effect on output sorting. - * - When the sort references are not integer but foldable expressions, ignore them. - * - When spark.sql.orderByOrdinal is set to false, ignore the position numbers too. - */ - object ResolveSortReferences extends Rule[LogicalPlan] { + /** + * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by + * clauses. This rule is to convert ordinal positions to the corresponding expressions in the + * select list. This support is introduced in Spark 2.0. + * + * - When the sort references or group by expressions are not integer but foldable expressions, + * just ignore them. + * - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position + * numbers too. + * + * Before the release of Spark 2.0, the literals in order/sort by and group by clauses + * have no effect on the results. + */ + object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s: Sort if !s.child.resolved => s - // Replace the index with the related attribute for ORDER BY + case p if !p.childrenResolved => p + // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. case s @ Sort(orders, global, child) if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) => @@ -665,10 +674,41 @@ class Analyzer( } Sort(newOrders, global, child) + // Replace the index with the corresponding expression in aggregateExpressions. The index is + // a 1-base position of aggregateExpressions, which is output columns (select expression) + case a @ Aggregate(groups, aggs, child) + if conf.groupByOrdinal && aggs.forall(_.resolved) && + groups.exists(IntegerIndex.unapply(_).nonEmpty) => + val newGroups = groups.map { + case IntegerIndex(index) if index > 0 && index <= aggs.size => + aggs(index - 1) match { + case e if ResolveAggregateFunctions.containsAggregate(e) => + throw new UnresolvedException(a, + s"Group by position: the '$index'th column in the select contains an " + + s"aggregate function: ${e.sql}. Aggregate functions are not allowed in GROUP BY") + case o => o + } + case IntegerIndex(index) => + throw new UnresolvedException(a, + s"Group by position: '$index' exceeds the size of the select list '${aggs.size}'.") + case o => o + } + Aggregate(newGroups, aggs, child) + } + } + + /** + * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT + * clause. This rule detects such queries and adds the required attributes to the original + * projection, so that they will be available during sorting. Another projection is added to + * remove these attributes after sorting. + */ + object ResolveSortReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved => + case s @ Sort(order, _, child) if !s.resolved && child.resolved => try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) http://git-wip-us.apache.org/repos/asf/spark/blob/05f652d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index ada8424..9c92707 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -210,7 +210,8 @@ object Unions { object IntegerIndex { def unapply(a: Any): Option[Int] = a match { case Literal(a: Int, IntegerType) => Some(a) - // When resolving ordinal in Sort, negative values are extracted for issuing error messages. + // When resolving ordinal in Sort and Group By, negative values are extracted + // for issuing error messages. case UnaryMinus(IntegerLiteral(v)) => Some(-v) case _ => None } http://git-wip-us.apache.org/repos/asf/spark/blob/05f652d6/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 863a876..77af0e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -445,6 +445,11 @@ object SQLConf { doc = "When true, the ordinal numbers are treated as the position in the select list. " + "When false, the ordinal numbers in order/sort By clause are ignored.") + val GROUP_BY_ORDINAL = booleanConf("spark.sql.groupByOrdinal", + defaultValue = Some(true), + doc = "When true, the ordinal numbers in group by clauses are treated as the position " + + "in the select list. When false, the ordinal numbers are ignored.") + // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. // @@ -668,6 +673,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) + override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ http://git-wip-us.apache.org/repos/asf/spark/blob/05f652d6/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index eb486a1..61358fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ @@ -459,25 +460,103 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } - test("literal in agg grouping expressions") { + test("Group By Ordinal - basic") { checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + sql("SELECT a, sum(b) FROM testData2 GROUP BY 1"), + sql("SELECT a, sum(b) FROM testData2 GROUP BY a")) + // duplicate group-by columns checkAnswer( sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + } + + test("Group By Ordinal - non aggregate expressions") { + checkAnswer( + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) + + checkAnswer( + sql("SELECT a, b + 2 as c, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) + } + + test("Group By Ordinal - non-foldable constant expression") { + checkAnswer( + sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b, 1 + 0"), + sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b")) + checkAnswer( sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + } + + test("Group By Ordinal - alias") { + checkAnswer( + sql("SELECT a, (b + 2) as c, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) + + checkAnswer( + sql("SELECT a as b, b as a, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b")) + } + + test("Group By Ordinal - constants") { checkAnswer( sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), sql("SELECT 1, 2, sum(b) FROM testData2")) } + test("Group By Ordinal - negative cases") { + intercept[UnresolvedException[Aggregate]] { + sql("SELECT a, b FROM testData2 GROUP BY -1") + } + + intercept[UnresolvedException[Aggregate]] { + sql("SELECT a, b FROM testData2 GROUP BY 3") + } + + var e = intercept[UnresolvedException[Aggregate]]( + sql("SELECT SUM(a) FROM testData2 GROUP BY 1")) + assert(e.getMessage contains + "Invalid call to Group by position: the '1'th column in the select contains " + + "an aggregate function") + + e = intercept[UnresolvedException[Aggregate]]( + sql("SELECT SUM(a) + 1 FROM testData2 GROUP BY 1")) + assert(e.getMessage contains + "Invalid call to Group by position: the '1'th column in the select contains " + + "an aggregate function") + + var ae = intercept[AnalysisException]( + sql("SELECT a, rand(0), sum(b) FROM testData2 GROUP BY a, 2")) + assert(ae.getMessage contains + "nondeterministic expression rand(0) should not appear in grouping expression") + + ae = intercept[AnalysisException]( + sql("SELECT * FROM testData2 GROUP BY a, b, 1")) + assert(ae.getMessage contains + "Group by position: star is not allowed to use in the select list " + + "when using ordinals in group by") + } + + test("Group By Ordinal: spark.sql.groupByOrdinal=false") { + withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") { + // If spark.sql.groupByOrdinal=false, ignore the position number. + intercept[AnalysisException] { + sql("SELECT a, sum(b) FROM testData2 GROUP BY 1") + } + // '*' is not allowed to use in the select list when users specify ordinals in group by + checkAnswer( + sql("SELECT * FROM testData2 GROUP BY a, b, 1"), + sql("SELECT * FROM testData2 GROUP BY a, b")) + } + } + test("aggregates with nulls") { checkAnswer( sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + @@ -2174,7 +2253,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"), sql("SELECT * FROM testData2 ORDER BY b ASC")) - checkAnswer( sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org