This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 290502d1c9f8 [SPARK-51820][SQL] Move `UnresolvedOrdinal` construction before analysis to avoid issue with group by ordinal 290502d1c9f8 is described below commit 290502d1c9f8d5a9ce6594fc7aced31cb672ce26 Author: Mihailo Timotic <mihailo.timo...@databricks.com> AuthorDate: Fri Apr 18 08:34:16 2025 +0800 [SPARK-51820][SQL] Move `UnresolvedOrdinal` construction before analysis to avoid issue with group by ordinal ### What changes were proposed in this pull request? This is a followup to https://github.com/apache/spark/pull/43797 and https://github.com/apache/spark/pull/50461 where hacks were introduced in order to solve the issue of repeated analysis on plans that have a group by ordinal. The latter PR caused a regression so the hack needs to be removed. This PR proposed a move of `UnresolvedOrdinal` construction before Analyzer runs. ### Why are the changes needed? We are reverting a hack introduced in the previous PRs to improve the behavior of group by ordinal and additionally fix the issue that https://github.com/apache/spark/pull/50461 was trying to solve. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #50606 from mihailotim-db/mihailotim-db/new_group_by_ordinal. Authored-by: Mihailo Timotic <mihailo.timo...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 18 +--- .../analysis/ResolveReferencesInAggregate.scala | 26 +----- .../analysis/SubstituteUnresolvedOrdinals.scala | 64 -------------- .../apache/spark/sql/catalyst/dsl/package.scala | 10 ++- .../spark/sql/catalyst/parser/AstBuilder.scala | 97 ++++++++++++++++++---- ... => GroupByOrdinalsRepeatedAnalysisSuite.scala} | 65 +++++---------- .../sql/connect/planner/SparkConnectPlanner.scala | 66 ++++++++++++--- .../org/apache/spark/sql/classic/Dataset.scala | 13 ++- .../sql/classic/RelationalGroupedDataset.scala | 12 ++- .../analyzer-results/group-by-all.sql.out | 6 +- .../analyzer-results/group-by-ordinal.sql.out | 42 +++++----- .../postgreSQL/select_implicit.sql.out | 4 +- .../udf/postgreSQL/udf-select_implicit.sql.out | 4 +- .../sql-tests/results/group-by-ordinal.sql.out | 40 ++++----- .../results/postgreSQL/select_implicit.sql.out | 4 +- .../results/udaf/udaf-group-by-ordinal.sql.out | 28 +++---- .../udf/postgreSQL/udf-select_implicit.sql.out | 4 +- .../analysis/resolver/AggregateResolverSuite.scala | 65 ++++++++------- .../sql/analysis/resolver/ResolverGuardSuite.scala | 20 +++-- 19 files changed, 304 insertions(+), 284 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5e356fba96a6..0c9e537eda6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -348,7 +348,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor CTESubstitution, WindowsSubstitution, EliminateUnions, - SubstituteUnresolvedOrdinals, EliminateLazyExpression), Batch("Disable Hints", Once, new ResolveHints.DisableHints), @@ -1975,24 +1974,13 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor withPosition(ordinal) { if (index > 0 && index <= aggs.size) { val ordinalExpr = aggs(index - 1) + if (ordinalExpr.exists(_.isInstanceOf[AggregateExpression])) { throw QueryCompilationErrors.groupByPositionRefersToAggregateFunctionError( index, ordinalExpr) - } else { - trimAliases(ordinalExpr) match { - // HACK ALERT: If the ordinal expression is also an integer literal, don't use it - // but still keep the ordinal literal. The reason is we may repeatedly - // analyze the plan. Using a different integer literal may lead to - // a repeat GROUP BY ordinal resolution which is wrong. GROUP BY - // constant is meaningless so whatever value does not matter here. - // TODO: (SPARK-45932) GROUP BY ordinal should pull out grouping expressions to - // a Project, then the resolved ordinal expression is always - // `AttributeReference`. - case Literal(_: Int, IntegerType) => - Literal(index) - case _ => ordinalExpr - } } + + ordinalExpr } else { throw QueryCompilationErrors.groupByPositionRangeError(index, aggs.size) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala index 7ea90854932e..f01c00f9fa75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, IntegerLiteral, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendColumns, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_ATTRIBUTE} @@ -129,27 +129,9 @@ class ResolveReferencesInAggregate(val catalogManager: CatalogManager) extends S groupExprs: Seq[Expression]): Seq[Expression] = { assert(selectList.forall(_.resolved)) if (isGroupByAll(groupExprs)) { - val expandedGroupExprs = expandGroupByAll(selectList) - if (expandedGroupExprs.isEmpty) { - // Don't replace the ALL when we fail to infer the grouping columns. We will eventually - // tell the user in checkAnalysis that we cannot resolve the all in group by. - groupExprs - } else { - // This is a valid GROUP BY ALL aggregate. - expandedGroupExprs.get.zipWithIndex.map { case (expr, index) => - trimAliases(expr) match { - // HACK ALERT: If the expanded grouping expression is an integer literal, don't use it - // but use an integer literal of the index. The reason is we may repeatedly - // analyze the plan, and the original integer literal may cause failures - // with a later GROUP BY ordinal resolution. GROUP BY constant is - // meaningless so whatever value does not matter here. - case IntegerLiteral(_) => - // GROUP BY ordinal uses 1-based index. - Literal(index + 1) - case _ => expr - } - } - } + // Don't replace the ALL when we fail to infer the grouping columns. We will eventually tell + // the user in checkAnalysis that we cannot resolve the all in group by. + expandGroupByAll(selectList).getOrElse(groupExprs) } else { groupExprs } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala deleted file mode 100644 index fa08ae61daec..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.catalyst.expressions.{BaseGroupingSets, Expression, Literal, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin -import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.types.IntegerType - -/** - * Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression. - */ -object SubstituteUnresolvedOrdinals extends Rule[LogicalPlan] { - private def containIntLiteral(e: Expression): Boolean = e match { - case Literal(_, IntegerType) => true - case gs: BaseGroupingSets => gs.children.exists(containIntLiteral) - case _ => false - } - - private def substituteUnresolvedOrdinal(expression: Expression): Expression = expression match { - case ordinal @ Literal(index: Int, IntegerType) => - withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) - case e => e - } - - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( - t => t.containsPattern(LITERAL) && t.containsAnyPattern(AGGREGATE, SORT), ruleId) { - case s: Sort if conf.orderByOrdinal && s.order.exists(o => containIntLiteral(o.child)) => - val newOrders = s.order.map { - case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => - val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) - withOrigin(order.origin)(order.copy(child = newOrdinal)) - case other => other - } - withOrigin(s.origin)(s.copy(order = newOrders)) - - case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(containIntLiteral) => - val newGroups = a.groupingExpressions.map { - case ordinal @ Literal(index: Int, IntegerType) => - withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) - case gs: BaseGroupingSets => - withOrigin(gs.origin)(gs.withNewChildren(gs.children.map(substituteUnresolvedOrdinal))) - case other => other - } - withOrigin(a.origin)(a.copy(groupingExpressions = newGroups)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index ccb60278f492..00bde9f8c1f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -24,6 +24,7 @@ import scala.language.implicitConversions import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -64,7 +65,7 @@ import org.apache.spark.unsafe.types.UTF8String * LocalRelation [key#2,value#3], [] * }}} */ -package object dsl { +package object dsl extends SQLConfHelper { trait ImplicitOperators { def expr: Expression @@ -446,11 +447,16 @@ package object dsl { def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { + // Replace top-level integer literals with ordinals, if `groupByOrdinal` is enabled. + val groupingExpressionsWithOrdinals = groupingExprs.map { + case Literal(value: Int, IntegerType) if conf.groupByOrdinal => UnresolvedOrdinal(value) + case other => other + } val aliasedExprs = aggregateExprs.map { case ne: NamedExpression => ne case e => UnresolvedAlias(e) } - Aggregate(groupingExprs, aliasedExprs, logicalPlan) + Aggregate(groupingExpressionsWithOrdinals, aliasedExprs, logicalPlan) } def having( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cccab1b9b0e0..0456f0e57d2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.parser -import java.util.Locale +import java.util.{List, Locale} import java.util.concurrent.TimeUnit import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Set} @@ -1286,17 +1286,17 @@ class AstBuilder extends DataTypeAstBuilder val withOrder = if ( !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { clause = PipeOperators.orderByClause - Sort(order.asScala.map(visitSortItem).toSeq, global = true, query) + Sort(order.asScala.map(visitSortItemAndReplaceOrdinals).toSeq, global = true, query) } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { clause = PipeOperators.sortByClause - Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query) + Sort(sort.asScala.map(visitSortItemAndReplaceOrdinals).toSeq, global = false, query) } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { clause = PipeOperators.distributeByClause withRepartitionByExpression(ctx, expressionList(distributeBy), query) } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { clause = PipeOperators.sortByDistributeByClause Sort( - sort.asScala.map(visitSortItem).toSeq, + sort.asScala.map(visitSortItemAndReplaceOrdinals).toSeq, global = false, withRepartitionByExpression(ctx, expressionList(distributeBy), query)) } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { @@ -1825,24 +1825,27 @@ class AstBuilder extends DataTypeAstBuilder } visitNamedExpression(n) }.toSeq + val groupByExpressionsWithOrdinals = + replaceOrdinalsInGroupingExpressions(groupByExpressions) if (ctx.GROUPING != null) { // GROUP BY ... GROUPING SETS (...) // `groupByExpressions` can be non-empty for Hive compatibility. It may add extra grouping // expressions that do not exist in GROUPING SETS (...), and the value is always null. // For example, `SELECT a, b, c FROM ... GROUP BY a, b, c GROUPING SETS (a, b)`, the output // of column `c` is always null. - val groupingSets = - ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq) - Aggregate(Seq(GroupingSets(groupingSets.toSeq, groupByExpressions)), - selectExpressions, query) + val groupingSetsWithOrdinals = visitGroupingSetAndReplaceOrdinals(ctx.groupingSet) + Aggregate( + Seq(GroupingSets(groupingSetsWithOrdinals, groupByExpressionsWithOrdinals)), + selectExpressions, query + ) } else { // GROUP BY .... (WITH CUBE | WITH ROLLUP)? val mappedGroupByExpressions = if (ctx.CUBE != null) { - Seq(Cube(groupByExpressions.map(Seq(_)))) + Seq(Cube(groupByExpressionsWithOrdinals.map(Seq(_)))) } else if (ctx.ROLLUP != null) { - Seq(Rollup(groupByExpressions.map(Seq(_)))) + Seq(Rollup(groupByExpressionsWithOrdinals.map(Seq(_)))) } else { - groupByExpressions + groupByExpressionsWithOrdinals } Aggregate(mappedGroupByExpressions, selectExpressions, query) } @@ -1856,8 +1859,12 @@ class AstBuilder extends DataTypeAstBuilder } else { expression(groupByExpr.expression) } - }) - Aggregate(groupByExpressions.toSeq, selectExpressions, query) + }).toSeq + Aggregate( + groupingExpressions = replaceOrdinalsInGroupingExpressions(groupByExpressions), + aggregateExpressions = selectExpressions, + child = query + ) } } @@ -1865,7 +1872,7 @@ class AstBuilder extends DataTypeAstBuilder groupingAnalytics: GroupingAnalyticsContext): BaseGroupingSets = { val groupingSets = groupingAnalytics.groupingSet.asScala .map(_.expression.asScala.map(e => expression(e)).toSeq) - if (groupingAnalytics.CUBE != null) { + val baseGroupingSet = if (groupingAnalytics.CUBE != null) { // CUBE(A, B, (A, B), ()) is not supported. if (groupingSets.exists(_.isEmpty)) { throw QueryParsingErrors.invalidGroupingSetError("CUBE", groupingAnalytics) @@ -1889,6 +1896,9 @@ class AstBuilder extends DataTypeAstBuilder } GroupingSets(groupingSets.toSeq) } + baseGroupingSet.withNewChildren( + newChildren = replaceOrdinalsInGroupingExpressions(baseGroupingSet.children) + ).asInstanceOf[BaseGroupingSets] } /** @@ -6532,12 +6542,12 @@ class AstBuilder extends DataTypeAstBuilder case n: NamedExpression => newGroupingExpressions += n newAggregateExpressions += n - // If the grouping expression is an integer literal, create [[UnresolvedOrdinal]] and - // [[UnresolvedPipeAggregateOrdinal]] expressions to represent it in the final grouping - // and aggregate expressions, respectively. This will let the + // If the grouping expression is an [[UnresolvedOrdinal]], replace the ordinal value and + // create [[UnresolvedPipeAggregateOrdinal]] expressions to represent it in the final + // grouping and aggregate expressions, respectively. This will let the // [[ResolveOrdinalInOrderByAndGroupBy]] rule detect the ordinal in the aggregate list // and replace it with the corresponding attribute from the child operator. - case Literal(v: Int, IntegerType) if conf.groupByOrdinal => + case UnresolvedOrdinal(v: Int) => newGroupingExpressions += UnresolvedOrdinal(newAggregateExpressions.length + 1) newAggregateExpressions += UnresolvedAlias(UnresolvedPipeAggregateOrdinal(v), None) case e: Expression => @@ -6558,6 +6568,57 @@ class AstBuilder extends DataTypeAstBuilder } } + /** + * Visits [[SortItemContext]] and replaces top-level [[Literal]]s with [[UnresolvedOrdinal]] in + * resulting expression, if `orderByOrdinal` is enabled. + */ + private def visitSortItemAndReplaceOrdinals(sortItemContext: SortItemContext) = { + val visitedSortItem = visitSortItem(sortItemContext) + visitedSortItem.withNewChildren( + newChildren = Seq(replaceIntegerLiteralWithOrdinal( + expression = visitedSortItem.child, + canReplaceWithOrdinal = conf.orderByOrdinal + )) + ).asInstanceOf[SortOrder] + } + + /** + * Replaces top-level integer [[Literal]]s with [[UnresolvedOrdinal]] in grouping expressions, if + * `groupByOrdinal` is enabled. + */ + private def replaceOrdinalsInGroupingExpressions(groupingExpressions: Seq[Expression]) = + groupingExpressions.map(groupByExpression => + replaceIntegerLiteralWithOrdinal( + expression = groupByExpression, + canReplaceWithOrdinal = conf.groupByOrdinal + ) + ).toSeq + + /** + * Visits grouping expressions in a [[GroupingSetContext]] and replaces top-level integer + * [[Literal]]s with [[UnresolvedOrdinal]]s in resulting expressions, if `groupByOrdinal` is + * enabled. + */ + private def visitGroupingSetAndReplaceOrdinals(groupingSet: List[GroupingSetContext]) = { + groupingSet.asScala.map(_.expression.asScala.map(e => { + val visitedExpression = expression(e) + replaceIntegerLiteralWithOrdinal( + expression = visitedExpression, + canReplaceWithOrdinal = conf.groupByOrdinal + ) + }).toSeq).toSeq + } + + /** + * Replaces integer [[Literal]] with [[UnresolvedOrdinal]] if `canReplaceWithOrdinal` is true. + */ + private def replaceIntegerLiteralWithOrdinal( + expression: Expression, + canReplaceWithOrdinal: Boolean = true) = expression match { + case Literal(value: Int, IntegerType) if canReplaceWithOrdinal => UnresolvedOrdinal(value) + case other => other + } + /** * Check plan for any parameters. * If it finds any throws UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/GroupByOrdinalsRepeatedAnalysisSuite.scala similarity index 59% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/GroupByOrdinalsRepeatedAnalysisSuite.scala index 39cf298aec43..ac120e80d51c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/GroupByOrdinalsRepeatedAnalysisSuite.scala @@ -17,63 +17,42 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.TestRelations.{testRelation, testRelation2} +import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.internal.SQLConf -class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { - private lazy val a = testRelation2.output(0) - private lazy val b = testRelation2.output(1) +class GroupByOrdinalsRepeatedAnalysisSuite extends AnalysisTest { test("unresolved ordinal should not be unresolved") { // Expression OrderByOrdinal is unresolved. assert(!UnresolvedOrdinal(0).resolved) } - test("order by ordinal") { - // Tests order by ordinal, apply single rule. - val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc) + test("SPARK-45920: group by ordinal repeated analysis") { + val plan = testRelation.groupBy(Literal(1))(Literal(100).as("a")).analyze comparePlans( - SubstituteUnresolvedOrdinals.apply(plan), - testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc)) - - // Tests order by ordinal, do full analysis - checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc)) + plan, + testRelation.groupBy(Literal(1))(Literal(100).as("a")).analyze + ) - // order by ordinal can be turned off by config - withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "false") { - comparePlans( - SubstituteUnresolvedOrdinals.apply(plan), - testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) + val testRelationWithData = testRelation.copy(data = Seq(new GenericInternalRow(Array(1: Any)))) + // Copy the plan to reset its `analyzed` flag, so that analyzer rules will re-apply. + val copiedPlan = plan.transform { + case _: LocalRelation => testRelationWithData } - } - - test("group by ordinal") { - // Tests group by ordinal, apply single rule. - val plan2 = testRelation2.groupBy(Literal(1), Literal(2))($"a", $"b") comparePlans( - SubstituteUnresolvedOrdinals.apply(plan2), - testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))($"a", $"b")) - - // Tests group by ordinal, do full analysis - checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b)) - - // group by ordinal can be turned off by config - withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") { - comparePlans( - SubstituteUnresolvedOrdinals.apply(plan2), - testRelation2.groupBy(Literal(1), Literal(2))($"a", $"b")) - } + copiedPlan.analyze, // repeated analysis + testRelationWithData.groupBy(Literal(1))(Literal(100).as("a")).analyze + ) } - test("SPARK-45920: group by ordinal repeated analysis") { - val plan = testRelation.groupBy(Literal(1))(Literal(100).as("a")).analyze + test("SPARK-47895: group by all repeated analysis") { + val plan = testRelation.groupBy($"all")(Literal(100).as("a")).analyze comparePlans( plan, - testRelation.groupBy(Literal(1))(Literal(100).as("a")) + testRelation.groupBy(Literal(1))(Literal(100).as("a")).analyze ) val testRelationWithData = testRelation.copy(data = Seq(new GenericInternalRow(Array(1: Any)))) @@ -83,15 +62,15 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { } comparePlans( copiedPlan.analyze, // repeated analysis - testRelationWithData.groupBy(Literal(1))(Literal(100).as("a")) + testRelationWithData.groupBy(Literal(1))(Literal(100).as("a")).analyze ) } - test("SPARK-47895: group by all repeated analysis") { - val plan = testRelation.groupBy($"all")(Literal(100).as("a")).analyze + test("SPARK-47895: group by alias repeated analysis") { + val plan = testRelation.groupBy($"b")(Literal(100).as("b")).analyze comparePlans( plan, - testRelation.groupBy(Literal(1))(Literal(100).as("a")) + testRelation.groupBy(Literal(1))(Literal(100).as("b")).analyze ) val testRelationWithData = testRelation.copy(data = Seq(new GenericInternalRow(Array(1: Any)))) @@ -101,7 +80,7 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { } comparePlans( copiedPlan.analyze, // repeated analysis - testRelationWithData.groupBy(Literal(1))(Literal(100).as("a")) + testRelationWithData.groupBy(Literal(1))(Literal(100).as("b")).analyze ) } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index a7a76e334e47..911d79ecdb12 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -46,7 +46,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} import org.apache.spark.sql.{Column, Encoders, ForeachWriter, Observation, Row} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedStarWithColumns, UnresolvedStarWithColumnsRenames, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, Unresolved [...] +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedOrdinal, UnresolvedPlanId, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedStarWithColumns, UnresolvedStarWithColumnsRenames, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedF [...] import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, UnboundRowEncoder} import org.apache.spark.sql.catalyst.expressions._ @@ -2373,7 +2373,7 @@ class SparkConnectPlanner( private def transformSortOrder(order: proto.Expression.SortOrder) = { expressions.SortOrder( - child = transformExpression(order.getChild), + child = transformSortOrderAndReplaceOrdinals(order.getChild), direction = order.getDirection match { case proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING => expressions.Ascending @@ -2387,6 +2387,19 @@ class SparkConnectPlanner( sameOrderExpressions = Seq.empty) } + /** + * Transforms an input protobuf sort order expression into the Catalyst expression and converts + * top-level integer [[Literal]]s to [[UnresolvedOrdinal]]s, if `orderByOrdinal` is enabled. + */ + private def transformSortOrderAndReplaceOrdinals(sortItem: proto.Expression) = { + val transformedSortItem = transformExpression(sortItem) + if (session.sessionState.conf.orderByOrdinal) { + replaceIntegerLiteralWithOrdinal(transformedSortItem) + } else { + transformedSortItem + } + } + private def transformDrop(rel: proto.Drop): LogicalPlan = { var output = Dataset.ofRows(session, transformRelation(rel.getInput)) if (rel.getColumnsCount > 0) { @@ -2439,27 +2452,28 @@ class SparkConnectPlanner( input } - val groupingExprs = rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression) + val groupingExpressionsWithOrdinals = rel.getGroupingExpressionsList.asScala.toSeq + .map(transformGroupingExpressionAndReplaceOrdinals) val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq .map(expr => transformExpressionWithTypedReduceExpression(expr, logicalPlan)) - val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression) + val aliasedAgg = (groupingExpressionsWithOrdinals ++ aggExprs).map(toNamedExpression) rel.getGroupType match { case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY => logical.Aggregate( - groupingExpressions = groupingExprs, + groupingExpressions = groupingExpressionsWithOrdinals, aggregateExpressions = aliasedAgg, child = logicalPlan) case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP => logical.Aggregate( - groupingExpressions = Seq(Rollup(groupingExprs.map(Seq(_)))), + groupingExpressions = Seq(Rollup(groupingExpressionsWithOrdinals.map(Seq(_)))), aggregateExpressions = aliasedAgg, child = logicalPlan) case proto.Aggregate.GroupType.GROUP_TYPE_CUBE => logical.Aggregate( - groupingExpressions = Seq(Cube(groupingExprs.map(Seq(_)))), + groupingExpressions = Seq(Cube(groupingExpressionsWithOrdinals.map(Seq(_)))), aggregateExpressions = aliasedAgg, child = logicalPlan) @@ -2477,21 +2491,23 @@ class SparkConnectPlanner( .map(expressions.Literal.apply) } logical.Pivot( - groupByExprsOpt = Some(groupingExprs.map(toNamedExpression)), + groupByExprsOpt = Some(groupingExpressionsWithOrdinals.map(toNamedExpression)), pivotColumn = pivotExpr, pivotValues = valueExprs, aggregates = aggExprs, child = logicalPlan) case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS => - val groupingSetsExprs = rel.getGroupingSetsList.asScala.toSeq.map { getGroupingSets => - getGroupingSets.getGroupingSetList.asScala.toSeq.map(transformExpression) - } + val groupingSetsExpressionsWithOrdinals = + rel.getGroupingSetsList.asScala.toSeq.map { getGroupingSets => + getGroupingSets.getGroupingSetList.asScala.toSeq + .map(transformGroupingExpressionAndReplaceOrdinals) + } logical.Aggregate( groupingExpressions = Seq( GroupingSets( - groupingSets = groupingSetsExprs, - userGivenGroupByExprs = groupingExprs)), + groupingSets = groupingSetsExpressionsWithOrdinals, + userGivenGroupByExprs = groupingExpressionsWithOrdinals)), aggregateExpressions = aliasedAgg, child = logicalPlan) @@ -2499,6 +2515,20 @@ class SparkConnectPlanner( } } + /** + * Transforms an input protobuf grouping expression into the Catalyst expression and converts + * top-level integer [[Literal]]s to [[UnresolvedOrdinal]]s, if `groupByOrdinal` is enabled. + */ + private def transformGroupingExpressionAndReplaceOrdinals( + groupingExpression: proto.Expression) = { + val transformedGroupingExpression = transformExpression(groupingExpression) + if (session.sessionState.conf.groupByOrdinal) { + replaceIntegerLiteralWithOrdinal(transformedGroupingExpression) + } else { + transformedGroupingExpression + } + } + @deprecated("TypedReduce is now implemented using a normal UDAF aggregator.", "4.0.0") private def transformTypedReduceExpression( fun: proto.Expression.UnresolvedFunction, @@ -4072,6 +4102,16 @@ class SparkConnectPlanner( } } + /** + * Replaces a top-level integer [[Literal]] in a grouping expression with [[UnresolvedOrdinal]] + * that has the same index. + */ + private def replaceIntegerLiteralWithOrdinal(groupingExpression: Expression) = + groupingExpression match { + case Literal(value: Int, IntegerType) => UnresolvedOrdinal(value) + case other => other + } + private def assertPlan(assertion: Boolean, message: => String = ""): Unit = { if (!assertion) throw InvalidPlanInput(message) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index 8765e1bfc7c6..5c3ebb32b36a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -929,7 +929,18 @@ class Dataset[T] private[sql]( /** @inheritdoc */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { - RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) + // Replace top-level integer literals in grouping expressions with ordinals, if + // `groupByOrdinal` is enabled. + val groupingExpressionsWithOrdinals = cols.map { col => col.expr match { + case Literal(value: Int, IntegerType) if sparkSession.sessionState.conf.groupByOrdinal => + UnresolvedOrdinal(value) + case other => other + }} + RelationalGroupedDataset( + df = toDF(), + groupingExprs = groupingExpressionsWithOrdinals, + groupType = RelationalGroupedDataset.GroupByType + ) } /** @inheritdoc */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala index 082292145e85..8f05aff7e90f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala @@ -23,7 +23,7 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql import org.apache.spark.sql.{AnalysisException, Column, Encoder} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedOrdinal} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -36,7 +36,7 @@ import org.apache.spark.sql.classic.TypedAggUtils.withInputType import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types.{IntegerType, NumericType, StructType} import org.apache.spark.util.ArrayImplicits._ /** @@ -67,7 +67,13 @@ class RelationalGroupedDataset protected[sql]( @scala.annotation.nowarn("cat=deprecation") val aggregates = if (df.sparkSession.sessionState.conf.dataFrameRetainGroupColumns) { - groupingExprs match { + // We need to unwrap ordinals from grouping expressions in order to add grouping columns to + // aggregate expressions. + val groupingExpressionsWithUnwrappedOrdinals = groupingExprs.map { + case UnresolvedOrdinal(value) => Literal(value, IntegerType) + case other => other + } + groupingExpressionsWithUnwrappedOrdinals match { // call `toList` because `Stream` and `LazyList` can't serialize in scala 2.13 case s: LazyList[Expression] => s.toList ++ aggExprs case s: Stream[Expression] => s.toList ++ aggExprs diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-all.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-all.sql.out index b1a4b85d3ae5..7837ba426d95 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-all.sql.out @@ -365,7 +365,7 @@ Aggregate [country#x, city#x, name#x, id#x, power#x], [country#x, city#x, name#x -- !query SELECT 4 AS a, 5 AS b, 6 AS c, * FROM data GROUP BY ALL -- !query analysis -Aggregate [1, 2, 3, country#x, city#x, name#x, id#x, power#x], [4 AS a#x, 5 AS b#x, 6 AS c#x, country#x, city#x, name#x, id#x, power#x] +Aggregate [4, 5, 6, country#x, city#x, name#x, id#x, power#x], [4 AS a#x, 5 AS b#x, 6 AS c#x, country#x, city#x, name#x, id#x, power#x] +- SubqueryAlias data +- View (`data`, [country#x, city#x, name#x, id#x, power#x]) +- Project [cast(country#x as string) AS country#x, cast(city#x as string) AS city#x, cast(name#x as string) AS name#x, cast(id#x as int) AS id#x, cast(power#x as decimal(3,1)) AS power#x] @@ -377,7 +377,7 @@ Aggregate [1, 2, 3, country#x, city#x, name#x, id#x, power#x], [4 AS a#x, 5 AS b -- !query SELECT *, 4 AS a, 5 AS b, 6 AS c FROM data GROUP BY ALL -- !query analysis -Aggregate [country#x, city#x, name#x, id#x, power#x, 6, 7, 8], [country#x, city#x, name#x, id#x, power#x, 4 AS a#x, 5 AS b#x, 6 AS c#x] +Aggregate [country#x, city#x, name#x, id#x, power#x, 4, 5, 6], [country#x, city#x, name#x, id#x, power#x, 4 AS a#x, 5 AS b#x, 6 AS c#x] +- SubqueryAlias data +- View (`data`, [country#x, city#x, name#x, id#x, power#x]) +- Project [cast(country#x as string) AS country#x, cast(city#x as string) AS city#x, cast(name#x as string) AS name#x, cast(id#x as int) AS id#x, cast(power#x as decimal(3,1)) AS power#x] @@ -389,7 +389,7 @@ Aggregate [country#x, city#x, name#x, id#x, power#x, 6, 7, 8], [country#x, city# -- !query SELECT 4 AS a, 5 AS b, *, 6 AS c FROM data GROUP BY ALL -- !query analysis -Aggregate [1, 2, country#x, city#x, name#x, id#x, power#x, 8], [4 AS a#x, 5 AS b#x, country#x, city#x, name#x, id#x, power#x, 6 AS c#x] +Aggregate [4, 5, country#x, city#x, name#x, id#x, power#x, 6], [4 AS a#x, 5 AS b#x, country#x, city#x, name#x, id#x, power#x, 6 AS c#x] +- SubqueryAlias data +- View (`data`, [country#x, city#x, name#x, id#x, power#x]) +- Project [cast(country#x as string) AS country#x, cast(city#x as string) AS city#x, cast(name#x as string) AS name#x, cast(id#x as int) AS id#x, cast(power#x as decimal(3,1)) AS power#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-ordinal.sql.out index 7ffd5bf22baf..c659fa2bcbb4 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-ordinal.sql.out @@ -61,7 +61,7 @@ Aggregate [a#x, a#x], [a#x, 1 AS 1#x, sum(b#x) AS sum(b)#xL] -- !query select a, 1, sum(b) from data group by 1, 2 -- !query analysis -Aggregate [a#x, 2], [a#x, 1 AS 1#x, sum(b#x) AS sum(b)#xL] +Aggregate [a#x, 1], [a#x, 1 AS 1#x, sum(b#x) AS sum(b)#xL] +- SubqueryAlias data +- View (`data`, [a#x, b#x]) +- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x] @@ -120,9 +120,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 32, + "startIndex" : 23, "stopIndex" : 33, - "fragment" : "-1" + "fragment" : "group by -1" } ] } @@ -141,9 +141,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 32, + "startIndex" : 23, "stopIndex" : 32, - "fragment" : "0" + "fragment" : "group by 0" } ] } @@ -162,9 +162,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 32, + "startIndex" : 23, "stopIndex" : 32, - "fragment" : "3" + "fragment" : "group by 3" } ] } @@ -183,9 +183,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 40, + "startIndex" : 31, "stopIndex" : 40, - "fragment" : "3" + "fragment" : "group by 3" } ] } @@ -204,9 +204,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 44, + "startIndex" : 35, "stopIndex" : 44, - "fragment" : "3" + "fragment" : "group by 3" } ] } @@ -401,9 +401,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 45, + "startIndex" : 33, "stopIndex" : 46, - "fragment" : "-1" + "fragment" : "group by a, -1" } ] } @@ -422,9 +422,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 45, + "startIndex" : 33, "stopIndex" : 45, - "fragment" : "3" + "fragment" : "group by a, 3" } ] } @@ -443,9 +443,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 47, - "stopIndex" : 48, - "fragment" : "-1" + "startIndex" : 33, + "stopIndex" : 52, + "fragment" : "group by cube(-1, 2)" } ] } @@ -464,9 +464,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 50, - "stopIndex" : 50, - "fragment" : "3" + "startIndex" : 33, + "stopIndex" : 51, + "fragment" : "group by cube(1, 3)" } ] } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/select_implicit.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/select_implicit.sql.out index 83b10b3cb67c..73a818a676a4 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/select_implicit.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/select_implicit.sql.out @@ -197,9 +197,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 54, + "startIndex" : 45, "stopIndex" : 54, - "fragment" : "3" + "fragment" : "GROUP BY 3" } ] } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/postgreSQL/udf-select_implicit.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/postgreSQL/udf-select_implicit.sql.out index 05f47aeace9f..f9752344379c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/postgreSQL/udf-select_implicit.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/postgreSQL/udf-select_implicit.sql.out @@ -200,9 +200,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 64, + "startIndex" : 55, "stopIndex" : 64, - "fragment" : "3" + "fragment" : "GROUP BY 3" } ] } diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index e13c4b13899a..64bfac15a814 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -102,9 +102,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 32, + "startIndex" : 23, "stopIndex" : 33, - "fragment" : "-1" + "fragment" : "group by -1" } ] } @@ -125,9 +125,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 32, + "startIndex" : 23, "stopIndex" : 32, - "fragment" : "0" + "fragment" : "group by 0" } ] } @@ -148,9 +148,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 32, + "startIndex" : 23, "stopIndex" : 32, - "fragment" : "3" + "fragment" : "group by 3" } ] } @@ -171,9 +171,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 40, + "startIndex" : 31, "stopIndex" : 40, - "fragment" : "3" + "fragment" : "group by 3" } ] } @@ -194,9 +194,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 44, + "startIndex" : 35, "stopIndex" : 44, - "fragment" : "3" + "fragment" : "group by 3" } ] } @@ -432,9 +432,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 45, + "startIndex" : 33, "stopIndex" : 46, - "fragment" : "-1" + "fragment" : "group by a, -1" } ] } @@ -455,9 +455,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 45, + "startIndex" : 33, "stopIndex" : 45, - "fragment" : "3" + "fragment" : "group by a, 3" } ] } @@ -478,9 +478,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 47, - "stopIndex" : 48, - "fragment" : "-1" + "startIndex" : 33, + "stopIndex" : 52, + "fragment" : "group by cube(-1, 2)" } ] } @@ -501,9 +501,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 50, - "stopIndex" : 50, - "fragment" : "3" + "startIndex" : 33, + "stopIndex" : 51, + "fragment" : "group by cube(1, 3)" } ] } diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/select_implicit.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/select_implicit.sql.out index f0c283cb4036..46c278e66d87 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/select_implicit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/select_implicit.sql.out @@ -224,9 +224,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 54, + "startIndex" : 45, "stopIndex" : 54, - "fragment" : "3" + "fragment" : "GROUP BY 3" } ] } diff --git a/sql/core/src/test/resources/sql-tests/results/udaf/udaf-group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf/udaf-group-by-ordinal.sql.out index 45a19ba2c3f1..83374fc741c2 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf/udaf-group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf/udaf-group-by-ordinal.sql.out @@ -102,9 +102,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 41, + "startIndex" : 32, "stopIndex" : 41, - "fragment" : "3" + "fragment" : "group by 3" } ] } @@ -125,9 +125,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 45, + "startIndex" : 36, "stopIndex" : 45, - "fragment" : "3" + "fragment" : "group by 3" } ] } @@ -351,9 +351,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 44, + "startIndex" : 32, "stopIndex" : 45, - "fragment" : "-1" + "fragment" : "group by a, -1" } ] } @@ -374,9 +374,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 44, + "startIndex" : 32, "stopIndex" : 44, - "fragment" : "3" + "fragment" : "group by a, 3" } ] } @@ -397,9 +397,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 46, - "stopIndex" : 47, - "fragment" : "-1" + "startIndex" : 32, + "stopIndex" : 51, + "fragment" : "group by cube(-1, 2)" } ] } @@ -420,9 +420,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 49, - "stopIndex" : 49, - "fragment" : "3" + "startIndex" : 32, + "stopIndex" : 50, + "fragment" : "group by cube(1, 3)" } ] } diff --git a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-select_implicit.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-select_implicit.sql.out index a3a7cee4eaa7..9381678d3594 100755 --- a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-select_implicit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-select_implicit.sql.out @@ -227,9 +227,9 @@ org.apache.spark.sql.AnalysisException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 64, + "startIndex" : 55, "stopIndex" : 64, - "fragment" : "3" + "fragment" : "GROUP BY 3" } ] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/AggregateResolverSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/AggregateResolverSuite.scala index bcd474f330de..02cfb35b0d62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/AggregateResolverSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/AggregateResolverSuite.scala @@ -44,12 +44,6 @@ class AggregateResolverSuite extends QueryTest with SharedSparkSession { resolverRunner.resolve(query) } - test("Valid group by ordinal") { - val resolverRunner = createResolverRunner() - val query = table.groupBy(intToLiteral(1))(intToLiteral(1)) - resolverRunner.resolve(query) - } - test("Group by aggregate function") { val resolverRunner = createResolverRunner() val query = table.groupBy($"count".function(intToLiteral(1)))(intToLiteral(1)) @@ -62,31 +56,6 @@ class AggregateResolverSuite extends QueryTest with SharedSparkSession { ) } - test("Group by ordinal which refers to aggregate function") { - val resolverRunner = createResolverRunner() - val query = table.groupBy(intToLiteral(1))($"count".function(intToLiteral(1))) - checkError( - exception = intercept[AnalysisException] { - resolverRunner.resolve(query) - }, - condition = "GROUP_BY_POS_AGGREGATE", - parameters = - Map("index" -> "1", "aggExpr" -> $"count".function(intToLiteral(1)).as("count(1)").sql) - ) - } - - test("Group by ordinal out of range") { - val resolverRunner = createResolverRunner() - val query = table.groupBy(intToLiteral(100))(intToLiteral(1)) - checkError( - exception = intercept[AnalysisException] { - resolverRunner.resolve(query) - }, - condition = "GROUP_BY_POS_OUT_OF_RANGE", - parameters = Map("index" -> "100", "size" -> "1") - ) - } - test("Select a column which is not in the group by clause") { val resolverRunner = createResolverRunner() val query = table.groupBy("b".attr)("a".attr) @@ -127,6 +96,39 @@ class AggregateResolverSuite extends QueryTest with SharedSparkSession { ) } + // Disabling following tests until SPARK-51820 is handled in single-pass analyzer. + /* + test("Valid group by ordinal") { + val resolverRunner = createResolverRunner() + val query = table.groupBy(intToLiteral(1))(intToLiteral(1)) + resolverRunner.resolve(query) + } + + test("Group by ordinal which refers to aggregate function") { + val resolverRunner = createResolverRunner() + val query = table.groupBy(intToLiteral(1))($"count".function(intToLiteral(1))) + checkError( + exception = intercept[AnalysisException] { + resolverRunner.resolve(query) + }, + condition = "GROUP_BY_POS_AGGREGATE", + parameters = + Map("index" -> "1", "aggExpr" -> $"count".function(intToLiteral(1)).as("count(1)").sql) + ) + } + + test("Group by ordinal out of range") { + val resolverRunner = createResolverRunner() + val query = table.groupBy(intToLiteral(100))(intToLiteral(1)) + checkError( + exception = intercept[AnalysisException] { + resolverRunner.resolve(query) + }, + condition = "GROUP_BY_POS_OUT_OF_RANGE", + parameters = Map("index" -> "100", "size" -> "1") + ) + } + test("Group by ordinal with a star in the aggregate expression list") { val resolverRunner = createResolverRunner() val query = table.groupBy(intToLiteral(1))(star()) @@ -138,6 +140,7 @@ class AggregateResolverSuite extends QueryTest with SharedSparkSession { parameters = Map.empty ) } + */ private def createResolverRunner(): ResolverRunner = { val resolver = new Resolver( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala index 457c7e504ab0..ab44c354afe0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala @@ -230,16 +230,12 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { checkResolverGuard("SELECT current_timestamp", shouldPass = true) } - test("Group by") { - checkResolverGuard("SELECT col1 FROM VALUES(1) GROUP BY 1", shouldPass = true) + test("Group by all") { checkResolverGuard("SELECT col1, count(col1) FROM VALUES(1) GROUP BY ALL", shouldPass = true) - checkResolverGuard("SELECT col1, col1 + 1 FROM VALUES(1) GROUP BY 1, col1", shouldPass = true) } - test("Order by") { - checkResolverGuard("SELECT col1 FROM VALUES(1) ORDER BY 1", shouldPass = true) + test("Order by all") { checkResolverGuard("SELECT col1 FROM VALUES(1) ORDER BY ALL", shouldPass = true) - checkResolverGuard("SELECT col1, col1 + 1 FROM VALUES(1) ORDER BY 1, col1", shouldPass = true) } test("Scalar subquery") { @@ -272,6 +268,18 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { // Queries that shouldn't pass the OperatorResolverGuard + // We temporarily disable group by ordinal until we address SPARK-51820 in single-pass. + test("Group by ordinal") { + checkResolverGuard("SELECT col1 FROM VALUES(1) GROUP BY 1", shouldPass = false) + checkResolverGuard("SELECT col1, col1 + 1 FROM VALUES(1) GROUP BY 1, col1", shouldPass = false) + } + + // We temporarily disable order by ordinal until we address SPARK-51820 in single-pass. + test("Order by ordinal") { + checkResolverGuard("SELECT col1 FROM VALUES(1) ORDER BY 1", shouldPass = false) + checkResolverGuard("SELECT col1, col1 + 1 FROM VALUES(1) ORDER BY 1, col1", shouldPass = false) + } + test("Unsupported literal functions") { checkResolverGuard("SELECT current_user", shouldPass = false) checkResolverGuard("SELECT session_user", shouldPass = false) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org