This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2a10c8d93aa [SPARK-45069][SQL] SQL variable should always be resolved after outer reference 2a10c8d93aa is described below commit 2a10c8d93aa9033842471e4f676fddb3b3f90940 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Mon Sep 11 22:57:47 2023 +0800 [SPARK-45069][SQL] SQL variable should always be resolved after outer reference ### What changes were proposed in this pull request? This is a bug fix for the recently added SQL variable feature. It's designed to resolve columns to SQL variable as the last resort, but for columns in Aggregate, we may resolve columns to outer reference first. ### Why are the changes needed? bug fix ### Does this PR introduce _any_ user-facing change? yes, the query result can be wrong before this fix ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #42803 from cloud-fan/meta-col. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 50 +++++++++++----------- .../catalyst/analysis/ColumnResolutionHelper.scala | 26 ++++++++--- .../analysis/ResolveReferencesInAggregate.scala | 24 +++++------ .../analysis/ResolveReferencesInSort.scala | 13 +++--- .../analyzer-results/sql-session-variables.sql.out | 25 +++++++++-- .../sql-tests/inputs/sql-session-variables.sql | 3 ++ .../results/sql-session-variables.sql.out | 19 +++++++- 7 files changed, 105 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a8c99075cdb..da983ff0c7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -683,7 +683,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // of the analysis phase. val colResolved = h.mapExpressions { e => resolveExpressionByPlanOutput( - resolveColWithAgg(e, aggForResolving), aggForResolving, allowOuter = true) + resolveColWithAgg(e, aggForResolving), aggForResolving, includeLastResort = true) } val cond = if (SubqueryExpression.hasSubquery(colResolved.havingCondition)) { val fake = Project(Alias(colResolved.havingCondition, "fake")() :: Nil, aggregate.child) @@ -1450,6 +1450,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor * e.g. `SELECT col, current_date FROM t`. * 4. Resolves the columns to outer references with the outer plan if we are resolving subquery * expressions. + * 5. Resolves the columns to SQL variables. * * Some plan nodes have special column reference resolution logic, please read these sub-rules for * details: @@ -1568,7 +1569,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case g @ Generate(generator, _, _, _, _, _) if generator.resolved => g case g @ Generate(generator, join, outer, qualifier, output, child) => - val newG = resolveExpressionByPlanOutput(generator, child, throws = true, allowOuter = true) + val newG = resolveExpressionByPlanOutput( + generator, child, throws = true, includeLastResort = true) if (newG.fastEquals(generator)) { g } else { @@ -1584,7 +1586,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case plan => plan } val resolvedOrder = mg.dataOrder - .map(resolveExpressionByPlanOutput(_, planForResolve).asInstanceOf[SortOrder]) + .map(resolveExpressionByPlanOutput(_, planForResolve).asInstanceOf[SortOrder]) mg.copy(dataOrder = resolvedOrder) // Left and right sort expression have to be resolved against the respective child plan only @@ -1614,13 +1616,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // Special case for Project as it supports lateral column alias. case p: Project => - val resolvedNoOuter = p.projectList - .map(resolveExpressionByPlanChildren(_, p, allowOuter = false)) + val resolvedBasic = p.projectList.map(resolveExpressionByPlanChildren(_, p)) // Lateral column alias has higher priority than outer reference. - val resolvedWithLCA = resolveLateralColumnAlias(resolvedNoOuter) - val resolvedWithOuter = resolvedWithLCA.map(resolveOuterRef) - val resolvedWithVariables = resolvedWithOuter.map(p => resolveVariables(p)) - p.copy(projectList = resolvedWithVariables.map(_.asInstanceOf[NamedExpression])) + val resolvedWithLCA = resolveLateralColumnAlias(resolvedBasic) + val resolvedFinal = resolvedWithLCA.map(resolveColsLastResort) + p.copy(projectList = resolvedFinal.map(_.asInstanceOf[NamedExpression])) case o: OverwriteByExpression if o.table.resolved => // The delete condition of `OverwriteByExpression` will be passed to the table @@ -1714,7 +1714,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // Columns in HAVING should be resolved with `agg.child.output` first, to follow the SQL // standard. See more details in SPARK-31519. val resolvedWithAgg = resolveColWithAgg(e, agg) - resolveExpressionByPlanChildren(resolvedWithAgg, u, allowOuter = true) + resolveExpressionByPlanChildren(resolvedWithAgg, u, includeLastResort = true) } // RepartitionByExpression can host missing attributes that are from a descendant node. @@ -1724,32 +1724,32 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // node, and project them way at the end via an extra Project. case r @ RepartitionByExpression(partitionExprs, child, _, _) if !r.resolved || r.missingInput.nonEmpty => - val resolvedNoOuter = partitionExprs.map(resolveExpressionByPlanChildren(_, r)) - val (newPartitionExprs, newChild) = resolveExprsAndAddMissingAttrs(resolvedNoOuter, child) - // Outer reference has lower priority than this. See the doc of `ResolveReferences`. - val resolvedWithOuter = newPartitionExprs.map(resolveOuterRef) - val finalPartitionExprs = resolvedWithOuter.map(e => resolveVariables(e)) + val resolvedBasic = partitionExprs.map(resolveExpressionByPlanChildren(_, r)) + val (newPartitionExprs, newChild) = resolveExprsAndAddMissingAttrs(resolvedBasic, child) + // Missing columns should be resolved right after basic column resolution. + // See the doc of `ResolveReferences`. + val resolvedFinal = newPartitionExprs.map(resolveColsLastResort) if (child.output == newChild.output) { - r.copy(finalPartitionExprs, newChild) + r.copy(resolvedFinal, newChild) } else { - Project(child.output, r.copy(finalPartitionExprs, newChild)) + Project(child.output, r.copy(resolvedFinal, newChild)) } // Filter can host both grouping expressions/aggregate functions and missing attributes. // The grouping expressions/aggregate functions resolution takes precedence over missing // attributes. See the classdoc of `ResolveReferences` for details. case f @ Filter(cond, child) if !cond.resolved || f.missingInput.nonEmpty => - val resolvedNoOuter = resolveExpressionByPlanChildren(cond, f) - val resolvedWithAgg = resolveColWithAgg(resolvedNoOuter, child) + val resolvedBasic = resolveExpressionByPlanChildren(cond, f) + val resolvedWithAgg = resolveColWithAgg(resolvedBasic, child) val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(resolvedWithAgg), child) - // Outer reference has lowermost priority. See the doc of `ResolveReferences`. - val resolvedWithOuter = resolveOuterRef(newCond.head) - val finalCond = resolveVariables(resolvedWithOuter) + // Missing columns should be resolved right after basic column resolution. + // See the doc of `ResolveReferences`. + val resolvedFinal = resolveColsLastResort(newCond.head) if (child.output == newChild.output) { - f.copy(condition = finalCond) + f.copy(condition = resolvedFinal) } else { // Add missing attributes and then project them away. - val newFilter = Filter(finalCond, newChild) + val newFilter = Filter(resolvedFinal, newChild) Project(child.output, newFilter) } @@ -1758,7 +1758,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}") - q.mapExpressions(resolveExpressionByPlanChildren(_, q, allowOuter = true)) + q.mapExpressions(resolveExpressionByPlanChildren(_, q, includeLastResort = true)) } private object MergeResolvePolicy extends Enumeration { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index d7b1f99f1ed..54a9c6ca018 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -135,7 +135,7 @@ trait ColumnResolutionHelper extends Logging { resolveColumnByName: Seq[String] => Option[Expression], getAttrCandidates: () => Seq[Attribute], throws: Boolean, - allowOuter: Boolean): Expression = { + includeLastResort: Boolean): Expression = { def innerResolve(e: Expression, isTopLevel: Boolean): Expression = withOrigin(e.origin) { if (e.resolved) return e val resolved = e match { @@ -196,8 +196,11 @@ trait ColumnResolutionHelper extends Logging { try { val resolved = innerResolve(expr, isTopLevel = true) - val withOuterResolved = if (allowOuter) resolveOuterRef(resolved) else resolved - resolveVariables(withOuterResolved) + if (includeLastResort) { + resolveColsLastResort(resolved) + } else { + resolved + } } catch { case ae: AnalysisException if !throws => logDebug(ae.getMessage) @@ -421,7 +424,7 @@ trait ColumnResolutionHelper extends Logging { expr: Expression, plan: LogicalPlan, throws: Boolean = false, - allowOuter: Boolean = false): Expression = { + includeLastResort: Boolean = false): Expression = { resolveExpression( expr, resolveColumnByName = nameParts => { @@ -429,7 +432,7 @@ trait ColumnResolutionHelper extends Logging { }, getAttrCandidates = () => plan.output, throws = throws, - allowOuter = allowOuter) + includeLastResort = includeLastResort) } /** @@ -443,7 +446,7 @@ trait ColumnResolutionHelper extends Logging { def resolveExpressionByPlanChildren( e: Expression, q: LogicalPlan, - allowOuter: Boolean = false): Expression = { + includeLastResort: Boolean = false): Expression = { val newE = if (e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) { // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and // expression are from Spark Connect, and need to be resolved in this way: @@ -467,7 +470,16 @@ trait ColumnResolutionHelper extends Logging { q.children.head.output }, throws = true, - allowOuter = allowOuter) + includeLastResort = includeLastResort) + } + + /** + * The last resort to resolve columns. Currently it does two things: + * - Try to resolve column names as outer references + * - Try to resolve column names as SQL variable + */ + protected def resolveColsLastResort(e: Expression): Expression = { + resolveVariables(resolveOuterRef(e)) } def resolveExprInAssignment(expr: Expression, hostPlan: LogicalPlan): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala index 6bc1949a4e0..4f5a11835c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala @@ -59,23 +59,23 @@ class ResolveReferencesInAggregate(val catalogManager: CatalogManager) extends S case _ => a } - val resolvedGroupExprsNoOuter = a.groupingExpressions - .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) - val resolvedAggExprsNoOuter = a.aggregateExpressions.map( - resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) - val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsNoOuter) - val resolvedAggExprsWithOuter = resolvedAggExprsWithLCA.map(resolveOuterRef) + val resolvedGroupExprsBasic = a.groupingExpressions + .map(resolveExpressionByPlanChildren(_, planForResolve)) + val resolvedAggExprsBasic = a.aggregateExpressions.map( + resolveExpressionByPlanChildren(_, planForResolve)) + val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsBasic) + val resolvedAggExprsFinal = resolvedAggExprsWithLCA.map(resolveColsLastResort) .map(_.asInstanceOf[NamedExpression]) // `groupingExpressions` may rely on `aggregateExpressions`, due to features like GROUP BY alias // and GROUP BY ALL. We only do basic resolution for `groupingExpressions`, and will further // resolve it after `aggregateExpressions` are all resolved. Note: the basic resolution is // needed as `aggregateExpressions` may rely on `groupingExpressions` as well, for the session // window feature. See the rule `SessionWindowing` for more details. - val resolvedGroupExprs = if (resolvedAggExprsWithOuter.forall(_.resolved)) { + val resolvedGroupExprs = if (resolvedAggExprsFinal.forall(_.resolved)) { val resolved = resolveGroupByAll( - resolvedAggExprsWithOuter, - resolveGroupByAlias(resolvedAggExprsWithOuter, resolvedGroupExprsNoOuter) - ).map(resolveOuterRef) + resolvedAggExprsFinal, + resolveGroupByAlias(resolvedAggExprsFinal, resolvedGroupExprsBasic) + ).map(resolveColsLastResort) // TODO: currently we don't support LCA in `groupingExpressions` yet. if (resolved.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE))) { throw new AnalysisException( @@ -89,7 +89,7 @@ class ResolveReferencesInAggregate(val catalogManager: CatalogManager) extends S // alias/ALL in the next iteration. If aggregate expressions end up as unresolved, we don't // need to resolve grouping expressions at all, as `CheckAnalysis` will report error for // aggregate expressions first. - resolvedGroupExprsNoOuter + resolvedGroupExprsBasic } a.copy( // The aliases in grouping expressions are useless and will be removed at the end of analysis @@ -105,7 +105,7 @@ class ResolveReferencesInAggregate(val catalogManager: CatalogManager) extends S // GROUP BY will be removed eventually, by following iterations. if (e.resolved) trimAliases(e) else e }, - aggregateExpressions = resolvedAggExprsWithOuter) + aggregateExpressions = resolvedAggExprsFinal) } private def resolveGroupByAlias( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala index e4e9188662a..02583ebb8f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala @@ -50,19 +50,18 @@ class ResolveReferencesInSort(val catalogManager: CatalogManager) extends SQLConfHelper with ColumnResolutionHelper { def apply(s: Sort): LogicalPlan = { - val resolvedNoOuter = s.order.map(resolveExpressionByPlanOutput(_, s.child)) - val resolvedWithAgg = resolvedNoOuter.map(resolveColWithAgg(_, s.child)) + val resolvedBasic = s.order.map(resolveExpressionByPlanOutput(_, s.child)) + val resolvedWithAgg = resolvedBasic.map(resolveColWithAgg(_, s.child)) val (missingAttrResolved, newChild) = resolveExprsAndAddMissingAttrs(resolvedWithAgg, s.child) val orderByAllResolved = resolveOrderByAll( s.global, newChild, missingAttrResolved.map(_.asInstanceOf[SortOrder])) - val resolvedWithOuter = orderByAllResolved.map(e => resolveOuterRef(e)) - val finalOrdering = resolvedWithOuter.map(e => resolveVariables(e) - .asInstanceOf[SortOrder]) + val resolvedFinal = orderByAllResolved + .map(e => resolveColsLastResort(e).asInstanceOf[SortOrder]) if (s.child.output == newChild.output) { - s.copy(order = finalOrdering) + s.copy(order = resolvedFinal) } else { // Add missing attributes and then project them away. - val newSort = s.copy(order = finalOrdering, child = newChild) + val newSort = s.copy(order = resolvedFinal, child = newChild) Project(s.child.output, newSort) } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out index 45bfbf69db3..ff645867415 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out @@ -485,9 +485,28 @@ org.apache.spark.sql.AnalysisException -- !query -SET VARIABLE title = 'Test qualifiers - fail' +SET VARIABLE title = 'Test variable in aggregate' -- !query analysis SetVariable [variablereference(system.session.title='Test qualifiers - success')] ++- Project [Test variable in aggregate AS title#x] + +- OneRowRelation + + +-- !query +SELECT (SELECT MAX(id) FROM RANGE(10) WHERE id < title) FROM VALUES 1, 2 AS t(title) +-- !query analysis +Project [scalar-subquery#x [title#x] AS scalarsubquery(title)#xL] +: +- Aggregate [max(id#xL) AS max(id)#xL] +: +- Filter (id#xL < cast(outer(title#x) as bigint)) +: +- Range (0, 10, step=1, splits=None) ++- SubqueryAlias t + +- LocalRelation [title#x] + + +-- !query +SET VARIABLE title = 'Test qualifiers - fail' +-- !query analysis +SetVariable [variablereference(system.session.title='Test variable in aggregate')] +- Project [Test qualifiers - fail AS title#x] +- OneRowRelation @@ -1881,10 +1900,10 @@ Project [var1#x AS 2#x] SELECT c1 AS `2` FROM VALUES(2) AS T(var1), LATERAL(SELECT var1) AS TT(c1) -- !query analysis Project [c1#x AS 2#x] -+- LateralJoin lateral-subquery#x [], Inner ++- LateralJoin lateral-subquery#x [var1#x], Inner : +- SubqueryAlias TT : +- Project [var1#x AS c1#x] - : +- Project [variablereference(system.session.var1=1) AS var1#x] + : +- Project [outer(var1#x)] : +- OneRowRelation +- SubqueryAlias T +- LocalRelation [var1#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql index 4992453603c..53149a5e37b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql @@ -80,6 +80,9 @@ DECLARE OR REPLACE VARIABLE var1 INT; DROP TEMPORARY VARIABLE sysTem.sesSion.vAr1; DROP TEMPORARY VARIABLE var1; +SET VARIABLE title = 'Test variable in aggregate'; +SELECT (SELECT MAX(id) FROM RANGE(10) WHERE id < title) FROM VALUES 1, 2 AS t(title); + SET VARIABLE title = 'Test qualifiers - fail'; DECLARE OR REPLACE VARIABLE builtin.var1 INT; DECLARE OR REPLACE VARIABLE system.sesion.var1 INT; diff --git a/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out index b3146e645c5..0297a8a11a9 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out @@ -544,6 +544,23 @@ org.apache.spark.sql.AnalysisException } +-- !query +SET VARIABLE title = 'Test variable in aggregate' +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT (SELECT MAX(id) FROM RANGE(10) WHERE id < title) FROM VALUES 1, 2 AS t(title) +-- !query schema +struct<scalarsubquery(title):bigint> +-- !query output +0 +1 + + -- !query SET VARIABLE title = 'Test qualifiers - fail' -- !query schema @@ -2058,7 +2075,7 @@ SELECT c1 AS `2` FROM VALUES(2) AS T(var1), LATERAL(SELECT var1) AS TT(c1) -- !query schema struct<2:int> -- !query output -1 +2 -- !query --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org