This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6d86d41b53c [SPARK-39488][SQL] Simplify the error handling of TempResolvedColumn 6d86d41b53c is described below commit 6d86d41b53c338a1897b27668eb22623383828bb Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Thu Jun 16 14:25:17 2022 +0800 [SPARK-39488][SQL] Simplify the error handling of TempResolvedColumn ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/35404 and https://github.com/apache/spark/pull/36746 , to simplify the error handling of `TempResolvedColumn`. The idea is: 1. The rule `ResolveAggregationFunctions` in the main resolution batch creates `TempResolvedColumn` and only removes it if the aggregate expression is fully resolved. It either strips `TempResolvedColumn` if it's inside aggregate function or group expression, or restores `TempResolvedColumn` to `UnresolvedAttribute` otherwise, hoping other rules can resolve it. 2. The rule `RemoveTempResolvedColumn` in a latter batch can still hit `TempResolvedColumn` if the aggregate expression is unresolved (due to input type mismatch for example, e.g. `avg(bool_col)`, `date_add(int_group_col, 1)`). At this stage, there is no way to restore `TempResolvedColumn` to `UnresolvedAttribute` and resolve it differently. The query will fail and we should blindly strip `TempResolvedColumn` to provide better error message. ### Why are the changes needed? code cleanup ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #36809 from cloud-fan/error. Lead-authored-by: Wenchen Fan <wenc...@databricks.com> Co-authored-by: Wenchen Fan <cloud0...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 65 ++++++++++++---------- .../sql/catalyst/analysis/CheckAnalysis.scala | 17 +----- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../src/test/resources/sql-tests/inputs/having.sql | 3 + .../resources/sql-tests/results/having.sql.out | 9 +++ 5 files changed, 49 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 446bc46d9b1..9fe9d490539 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -28,7 +28,6 @@ import scala.util.{Failure, Random, Success, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.DATA_TYPE_MISMATCH_ERROR_MESSAGE import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _} @@ -2647,10 +2646,6 @@ class Analyzer(override val catalogManager: CatalogManager) (extraAggExprs.toSeq, transformed) } - private def trimTempResolvedField(input: Expression): Expression = input.transform { - case t: TempResolvedColumn => t.child - } - private def buildAggExprList( expr: Expression, agg: Aggregate, @@ -2666,12 +2661,12 @@ class Analyzer(override val catalogManager: CatalogManager) } else { expr match { case ae: AggregateExpression => - val cleaned = trimTempResolvedField(ae) + val cleaned = RemoveTempResolvedColumn.trimTempResolvedColumn(ae) val alias = Alias(cleaned, cleaned.toString)() aggExprList += alias alias.toAttribute case grouping: Expression if agg.groupingExpressions.exists(grouping.semanticEquals) => - trimTempResolvedField(grouping) match { + RemoveTempResolvedColumn.trimTempResolvedColumn(grouping) match { case ne: NamedExpression => aggExprList += ne ne.toAttribute @@ -2683,7 +2678,7 @@ class Analyzer(override val catalogManager: CatalogManager) case t: TempResolvedColumn => // Undo the resolution as this column is neither inside aggregate functions nor a // grouping column. It shouldn't be resolved with `agg.child.output`. - CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts)) + RemoveTempResolvedColumn.restoreTempResolvedColumn(t) case other => other.withNewChildren(other.children.map(buildAggExprList(_, agg, aggExprList))) } @@ -4345,32 +4340,42 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } /** - * Removes all [[TempResolvedColumn]]s in the query plan. This is the last resort, in case some - * rules in the main resolution batch miss to remove [[TempResolvedColumn]]s. We should run this - * rule right after the main resolution batch. + * The rule `ResolveAggregationFunctions` in the main resolution batch creates + * [[TempResolvedColumn]] in filter conditions and sort expressions to hold the temporarily resolved + * column with `agg.child`. When filter conditions or sort expressions are resolved, + * `ResolveAggregationFunctions` will replace [[TempResolvedColumn]], to [[AttributeReference]] if + * it's inside aggregate functions or group expressions, or to [[UnresolvedAttribute]] otherwise, + * hoping other rules can resolve it. + * + * This rule runs after the main resolution batch, and can still hit [[TempResolvedColumn]] if + * filter conditions or sort expressions are not resolved. When this happens, there is no point to + * turn [[TempResolvedColumn]] to [[UnresolvedAttribute]], as we can't resolve the column + * differently, and query will fail. This rule strips all [[TempResolvedColumn]]s in Filter/Sort and + * turns them to [[AttributeReference]] so that the error message can tell users why the filter + * conditions or sort expressions were not resolved. */ object RemoveTempResolvedColumn extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - plan.foreachUp { - // HAVING clause will be resolved as a Filter. When having func(column with wrong data type), - // the column could be wrapped by a TempResolvedColumn, e.g. mean(tempresolvedcolumn(t.c)). - // Because TempResolvedColumn can still preserve column data type, here is a chance to check - // if the data type matches with the required data type of the function. We can throw an error - // when data types mismatches. - case operator: Filter => - operator.expressions.foreach(_.foreachUp { - case e: Expression if e.childrenResolved && e.checkInputDataTypes().isFailure => - e.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckFailure(message) => - e.setTagValue(DATA_TYPE_MISMATCH_ERROR_MESSAGE, message) - } - case _ => - }) - case _ => + plan.resolveOperatorsUp { + case f @ Filter(cond, agg: Aggregate) if agg.resolved => + withOrigin(f.origin)(f.copy(condition = trimTempResolvedColumn(cond))) + case s @ Sort(sortOrder, _, agg: Aggregate) if agg.resolved => + val newSortOrder = sortOrder.map { order => + trimTempResolvedColumn(order).asInstanceOf[SortOrder] + } + withOrigin(s.origin)(s.copy(order = newSortOrder)) + case other => other.transformExpressionsUp { + // This should not happen. We restore TempResolvedColumn to UnresolvedAttribute to be safe. + case t: TempResolvedColumn => restoreTempResolvedColumn(t) + } } + } - plan.resolveExpressions { - case t: TempResolvedColumn => UnresolvedAttribute(t.nameParts) - } + def trimTempResolvedColumn(input: Expression): Expression = input.transform { + case t: TempResolvedColumn => t.child + } + + def restoreTempResolvedColumn(t: TempResolvedColumn): Expression = { + CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 45e70bdcb6c..416e3a2b834 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -50,8 +50,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { val DATA_TYPE_MISMATCH_ERROR = TreeNodeTag[Boolean]("dataTypeMismatchError") - val DATA_TYPE_MISMATCH_ERROR_MESSAGE = TreeNodeTag[String]("dataTypeMismatchError") - protected def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } @@ -176,20 +174,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } } - val expressions = getAllExpressions(operator) - - expressions.foreach(_.foreachUp { - case e: Expression => - e.getTagValue(DATA_TYPE_MISMATCH_ERROR_MESSAGE) match { - case Some(message) => - e.failAnalysis(s"cannot resolve '${e.sql}' due to data type mismatch: $message" + - extraHintForAnsiTypeCoercionExpression(operator)) - case _ => - } - case _ => - }) - - expressions.foreach(_.foreachUp { + getAllExpressions(operator).foreach(_.foreachUp { case a: Attribute if !a.resolved => val missingCol = a.sql val candidates = operator.inputSet.toSeq.map(_.qualifiedName) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a6e952fd865..5c3f4b5f558 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1172,7 +1172,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { |FROM t |GROUP BY t.c, t.d |HAVING ${func}(c) > 0d""".stripMargin), - Seq(s"cannot resolve '$func(c)' due to data type mismatch"), + Seq(s"cannot resolve '$func(t.c)' due to data type mismatch"), false) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql index 2799b1a94d0..056b99e363d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -11,6 +11,9 @@ SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2; -- having condition contains grouping column SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2; +-- invalid having condition contains grouping column +SELECT count(k) FROM hav GROUP BY v HAVING v = array(1); + -- SPARK-11032: resolve having correctly SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out index fff470b3d81..e9e24562d1b 100644 --- a/sql/core/src/test/resources/sql-tests/results/having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -29,6 +29,15 @@ struct<count(k):bigint> 1 +-- !query +SELECT count(k) FROM hav GROUP BY v HAVING v = array(1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve '(hav.v = array(1))' due to data type mismatch: differing types in '(hav.v = array(1))' (int and array<int>).; line 1 pos 43 + + -- !query SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0) -- !query schema --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org