This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 61e25e1cdcbb [SPARK-47071][SQL] Inline With expression if it contains special expression 61e25e1cdcbb is described below commit 61e25e1cdcbb867fca264fa444d30b20e27c5a00 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Fri Feb 16 00:28:29 2024 -0800 [SPARK-47071][SQL] Inline With expression if it contains special expression ### What changes were proposed in this pull request? This is a bug fix for the With expression. If the common expression contains special expression like aggregate expresson, we cannot pull it out and put it in Project. We have to inline it. ### Why are the changes needed? bug fix. ### Does this PR introduce _any_ user-facing change? a failed can run after this fix, but this bug is not released yet. ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #45134 from cloud-fan/with. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../catalyst/optimizer/RewriteWithExpression.scala | 23 ++++++++++++++++------ .../sql-compatibility-functions.sql.out | 7 +++++++ .../inputs/sql-compatibility-functions.sql | 3 +++ .../results/sql-compatibility-functions.sql.out | 8 ++++++++ 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index cf2c77069a19..342c7ad09574 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PlanHelper, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION} @@ -41,10 +41,15 @@ object RewriteWithExpression extends Rule[LogicalPlan] { rewriteWithExprAndInputPlans(expr, inputPlans) } newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq) - if (p.output == newPlan.output) { - newPlan - } else { + // Since we add extra Projects with extra columns to pre-evaluate the common expressions, + // the current operator may have extra columns if it inherits the output columns from its + // child, and we need to project away the extra columns to keep the plan schema unchanged. + assert(p.output.length <= newPlan.output.length) + if (p.output.length < newPlan.output.length) { + assert(p.outputSet.subsetOf(newPlan.outputSet)) Project(p.output, newPlan) + } else { + newPlan } } } @@ -85,8 +90,14 @@ object RewriteWithExpression extends Rule[LogicalPlan] { refToExpr(id) = child } else { val alias = Alias(child, s"_common_expr_$index")() - childProjections(childProjectionIndex) += alias - refToExpr(id) = alias.toAttribute + val fakeProj = Project(Seq(alias), inputPlans(childProjectionIndex)) + if (PlanHelper.specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) { + // We have to inline the common expression if it cannot be put in a Project. + refToExpr(id) = child + } else { + childProjections(childProjectionIndex) += alias + refToExpr(id) = alias.toAttribute + } } } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-compatibility-functions.sql.out index b713f4d50917..f80290c5ab34 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-compatibility-functions.sql.out @@ -116,3 +116,10 @@ Aggregate [nvl(st#x.col1, value)], [nvl(st#x.col1, value) AS nvl(st.col1, value) +- Project [cast(id#x as int) AS id#x, cast(st#x as struct<col1:string,col2:string>) AS st#x] +- SubqueryAlias T +- LocalRelation [id#x, st#x] + + +-- !query +SELECT nullif(SUM(id), 0) from range(5) +-- !query analysis +Aggregate [nullif(sum(id#xL), 0) AS nullif(sum(id), 0)#xL] ++- Range (0, 5, step=1, splits=None) diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql index 1ae49c8bfc76..6c840154c618 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql @@ -22,3 +22,6 @@ SELECT string(1, 2); -- SPARK-21555: RuntimeReplaceable used in group by CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st); SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value"); + +-- aggregate function inside NULLIF +SELECT nullif(SUM(id), 0) from range(5); diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 1d3257fdaae3..0dd8c738d212 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -126,3 +126,11 @@ SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, struct<nvl(st.col1, value):string,FROM:bigint> -- !query output gamma 1 + + +-- !query +SELECT nullif(SUM(id), 0) from range(5) +-- !query schema +struct<nullif(sum(id), 0):bigint> +-- !query output +10 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org