This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3feddec3d9c [SPARK-40862][SQL] Support non-aggregated subqueries in RewriteCorrelatedScalarSubquery 3feddec3d9c is described below commit 3feddec3d9c0b2bd44610b20c9448445a6d761d3 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Fri Oct 28 12:25:28 2022 +0800 [SPARK-40862][SQL] Support non-aggregated subqueries in RewriteCorrelatedScalarSubquery ### What changes were proposed in this pull request? This PR updates the `splitSubquery` in `RewriteCorrelatedScalarSubquery` to support non-aggregated one-row subquery. In CheckAnalysis, we allow three types of correlated scalar subquery patterns: 1. SubqueryAlias/Project + Aggregate 2. SubqueryAlias/Project + Filter + Aggregate 3. SubqueryAlias/Project + LogicalPlan (maxRows <= 1) https://github.com/apache/spark/blob/748fa2792e488a6b923b32e2898d9bb6e16fb4ca/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala#L851-L856 We should support the thrid case in `splitSubquery` to avoid `Unexpected operator` exceptions. ### Why are the changes needed? To fix an issue with correlated subquery rewrite. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New unit tests. Closes #38336 from allisonwang-db/spark-40862-split-subquery. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/optimizer/subquery.scala | 142 +++++++++++---------- .../scala/org/apache/spark/sql/SubquerySuite.scala | 17 +++ 2 files changed, 95 insertions(+), 64 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 6665d885554..3c995573d53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -509,19 +509,21 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe /** * Split the plan for a scalar subquery into the parts above the innermost query block * (first part of returned value), the HAVING clause of the innermost query block - * (optional second part) and the parts below the HAVING CLAUSE (third part). + * (optional second part) and the Aggregate below the HAVING CLAUSE (optional third part). + * When the third part is empty, it means the subquery is a non-aggregated single-row subquery. */ - private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = { + private def splitSubquery( + plan: LogicalPlan): (Seq[LogicalPlan], Option[Filter], Option[Aggregate]) = { val topPart = ArrayBuffer.empty[LogicalPlan] var bottomPart: LogicalPlan = plan while (true) { bottomPart match { case havingPart @ Filter(_, aggPart: Aggregate) => - return (topPart.toSeq, Option(havingPart), aggPart) + return (topPart.toSeq, Option(havingPart), Some(aggPart)) case aggPart: Aggregate => // No HAVING clause - return (topPart.toSeq, None, aggPart) + return (topPart.toSeq, None, Some(aggPart)) case p @ Project(_, child) => topPart += p @@ -531,6 +533,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe topPart += s bottomPart = child + case p: LogicalPlan if p.maxRows.exists(_ <= 1) => + // Non-aggregated one row subquery. + return (topPart.toSeq, None, None) + case Filter(_, op) => throw QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op, " below filter") @@ -561,72 +567,80 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe val origOutput = query.output.head val resultWithZeroTups = evalSubqueryOnZeroTups(query) + lazy val planWithoutCountBug = Project( + currentChild.output :+ origOutput, + Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) + if (resultWithZeroTups.isEmpty) { // CASE 1: Subquery guaranteed not to have the COUNT bug - Project( - currentChild.output :+ origOutput, - Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) + planWithoutCountBug } else { - // Subquery might have the COUNT bug. Add appropriate corrections. val (topPart, havingNode, aggNode) = splitSubquery(query) - - // The next two cases add a leading column to the outer join input to make it - // possible to distinguish between the case when no tuples join and the case - // when the tuple that joins contains null values. - // The leading column always has the value TRUE. - val alwaysTrueExprId = NamedExpression.newExprId - val alwaysTrueExpr = Alias(Literal.TrueLiteral, - ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId) - val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME, - BooleanType)(exprId = alwaysTrueExprId) - - val aggValRef = query.output.head - - if (havingNode.isEmpty) { - // CASE 2: Subquery with no HAVING clause - val subqueryResultExpr = - Alias(If(IsNull(alwaysTrueRef), - resultWithZeroTups.get, - aggValRef), origOutput.name)() - subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute)) - Project( - currentChild.output :+ subqueryResultExpr, - Join(currentChild, - Project(query.output :+ alwaysTrueExpr, query), - LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) - + if (aggNode.isEmpty) { + // SPARK-40862: When the aggregate node is empty, it means the subquery produces + // at most one row and it is not subject to the COUNT bug. + planWithoutCountBug } else { - // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. - // Need to modify any operators below the join to pass through all columns - // referenced in the HAVING clause. - var subqueryRoot: UnaryNode = aggNode - val havingInputs: Seq[NamedExpression] = aggNode.output - - topPart.reverse.foreach { - case Project(projList, _) => - subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) - case s @ SubqueryAlias(alias, _) => - subqueryRoot = SubqueryAlias(alias, subqueryRoot) - case op => throw QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op) + // Subquery might have the COUNT bug. Add appropriate corrections. + val aggregate = aggNode.get + + // The next two cases add a leading column to the outer join input to make it + // possible to distinguish between the case when no tuples join and the case + // when the tuple that joins contains null values. + // The leading column always has the value TRUE. + val alwaysTrueExprId = NamedExpression.newExprId + val alwaysTrueExpr = Alias(Literal.TrueLiteral, + ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId) + val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME, + BooleanType)(exprId = alwaysTrueExprId) + + val aggValRef = query.output.head + + if (havingNode.isEmpty) { + // CASE 2: Subquery with no HAVING clause + val subqueryResultExpr = + Alias(If(IsNull(alwaysTrueRef), + resultWithZeroTups.get, + aggValRef), origOutput.name)() + subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute)) + Project( + currentChild.output :+ subqueryResultExpr, + Join(currentChild, + Project(query.output :+ alwaysTrueExpr, query), + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) + + } else { + // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. + // Need to modify any operators below the join to pass through all columns + // referenced in the HAVING clause. + var subqueryRoot: UnaryNode = aggregate + val havingInputs: Seq[NamedExpression] = aggregate.output + + topPart.reverse.foreach { + case Project(projList, _) => + subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) + case s@SubqueryAlias(alias, _) => + subqueryRoot = SubqueryAlias(alias, subqueryRoot) + case op => throw QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op) + } + + // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups + // WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>) + // ELSE (aggregate value) END AS (original column name) + val caseExpr = Alias(CaseWhen(Seq( + (IsNull(alwaysTrueRef), resultWithZeroTups.get), + (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), + aggValRef), + origOutput.name)() + + subqueryAttrMapping += ((origOutput, caseExpr.toAttribute)) + + Project( + currentChild.output :+ caseExpr, + Join(currentChild, + Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) } - - // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups - // WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>) - // ELSE (aggregate value) END AS (original column name) - val caseExpr = Alias(CaseWhen(Seq( - (IsNull(alwaysTrueRef), resultWithZeroTups.get), - (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), - aggValRef), - origOutput.name)() - - subqueryAttrMapping += ((origOutput, caseExpr.toAttribute)) - - Project( - currentChild.output :+ caseExpr, - Join(currentChild, - Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), - LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) - } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 4b586356367..7b67648d475 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2491,4 +2491,21 @@ class SubquerySuite extends QueryTest Row("a")) } } + + test("SPARK-40862: correlated one-row subquery with non-deterministic expressions") { + import org.apache.spark.sql.functions.udf + withTempView("t1") { + sql("CREATE TEMP VIEW t1 AS SELECT ARRAY('a', 'b') a") + val func = udf(() => "a") + spark.udf.register("func", func.asNondeterministic()) + checkAnswer(sql( + """ + |SELECT ( + | SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] || str AS sorted + | FROM (SELECT MAP('a', 1, 'b', 2) rank, func() AS str) + |) FROM t1 + |""".stripMargin), + Row("aa")) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org