This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.2 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push: new a5ecf2a [SPARK-36352][SQL] Spark should check result plan's output schema name a5ecf2a is described below commit a5ecf2a490727fec97790b149f59bdc498b445be Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Mon Aug 9 16:47:56 2021 +0800 [SPARK-36352][SQL] Spark should check result plan's output schema name ### What changes were proposed in this pull request? Spark should check result plan's output schema name ### Why are the changes needed? In current code, some optimizer rule may change plan's output schema, since in the code we always use semantic equal to check output, but it may change the plan's output schema. For example, for SchemaPruning, if we have a plan ``` Project[a, B] |--Scan[A, b, c] ``` the origin output schema is `a, B`, after SchemaPruning. it become ``` Project[A, b] |--Scan[A, b] ``` It change the plan's schema. when we use CTAS, the schema is same as query plan's output. Then since we change the schema, it not consistent with origin SQL. So we need to check final result plan's schema with origin plan's schema ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existed UT Closes #33583 from AngersZhuuuu/SPARK-36352. Authored-by: Angerszhuuuu <angers....@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit e051a540a10cdda42dc86a6195c0357aea8900e4) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 6 ++++-- .../spark/sql/catalyst/optimizer/Optimizer.scala | 22 +++++++++------------- .../spark/sql/catalyst/rules/RuleExecutor.scala | 6 +++--- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../org/apache/spark/sql/util/SchemaUtils.scala | 11 +++++++++++ .../sql/catalyst/trees/RuleExecutorSuite.scala | 8 ++++++-- .../sql/execution/adaptive/AQEOptimizer.scala | 12 ++++++++---- .../execution/datasources/DataSourceStrategy.scala | 2 +- .../sql/execution/datasources/SchemaPruning.scala | 10 ++++++---- .../datasources/v2/V2ScanRelationPushDown.scala | 3 ++- .../execution/datasources/SchemaPruningSuite.scala | 12 ++++++++++++ 11 files changed, 63 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 963b42b..b6228d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -174,8 +174,10 @@ class Analyzer(override val catalogManager: CatalogManager) private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog - override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { - !Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan) + override protected def isPlanIntegral( + previousPlan: LogicalPlan, + currentPlan: LogicalPlan): Boolean = { + !Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) } override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 369fb51..40b4c01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils._ import org.apache.spark.util.Utils /** @@ -46,10 +47,14 @@ abstract class Optimizer(catalogManager: CatalogManager) // - is still resolved // - only host special expressions in supported operators // - has globally-unique attribute IDs - override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { - !Utils.isTesting || (plan.resolved && - plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && - LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)) + // - optimized plan have same schema with previous plan. + override protected def isPlanIntegral( + previousPlan: LogicalPlan, + currentPlan: LogicalPlan): Boolean = { + !Utils.isTesting || (currentPlan.resolved && + currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && + LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) && + DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema)) } override protected val excludedOnceBatches: Set[String] = @@ -515,15 +520,6 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { * Remove no-op operators from the query plan that do not make any modifications. */ object RemoveNoopOperators extends Rule[LogicalPlan] { - def restoreOriginalOutputNames( - projectList: Seq[NamedExpression], - originalNames: Seq[String]): Seq[NamedExpression] = { - projectList.zip(originalNames).map { - case (attr: Attribute, name) => attr.withName(name) - case (alias: Alias, name) => alias.withName(name) - case (other, _) => other - } - } def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsAnyPattern(PROJECT, WINDOW), ruleId) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 17d7794..759eba6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -156,7 +156,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { * `Optimizer`, so we can catch rules that return invalid plans. The check function returns * `false` if the given plan doesn't pass the structural integrity check. */ - protected def isPlanIntegral(plan: TreeType): Boolean = true + protected def isPlanIntegral(previousPlan: TreeType, currentPlan: TreeType): Boolean = true /** * Util method for checking whether a plan remains the same if re-optimized. @@ -192,7 +192,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val beforeMetrics = RuleExecutor.getCurrentMetrics() // Run the structural integrity checker against the initial input - if (!isPlanIntegral(plan)) { + if (!isPlanIntegral(plan, plan)) { throw QueryExecutionErrors.structuralIntegrityOfInputPlanIsBrokenInClassError( this.getClass.getName.stripSuffix("$")) } @@ -224,7 +224,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, effective)) // Run the structural integrity checker against the plan after each rule. - if (effective && !isPlanIntegral(result)) { + if (effective && !isPlanIntegral(plan, result)) { throw QueryExecutionErrors.structuralIntegrityIsBrokenAfterApplyingRuleError( rule.ruleName, batch.name) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 585045d..ef1aeec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -292,7 +292,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index da105af..63c1f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression} import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} @@ -273,6 +274,16 @@ private[spark] object SchemaUtils { field._1 } + def restoreOriginalOutputNames( + projectList: Seq[NamedExpression], + originalNames: Seq[String]): Seq[NamedExpression] = { + projectList.zip(originalNames).map { + case (attr: Attribute, name) => attr.withName(name) + case (alias: Alias, name) => alias.withName(name) + case (other, _) => other + } + } + /** * @param str The string to be escaped. * @return The escaped string. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index 25352e2..b14686b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -73,7 +73,9 @@ class RuleExecutorSuite extends SparkFunSuite { test("structural integrity checker - verify initial input") { object WithSIChecker extends RuleExecutor[Expression] { - override protected def isPlanIntegral(expr: Expression): Boolean = expr match { + override protected def isPlanIntegral( + previousPlan: Expression, + currentPlan: Expression): Boolean = currentPlan match { case IntegerLiteral(_) => true case _ => false } @@ -91,7 +93,9 @@ class RuleExecutorSuite extends SparkFunSuite { test("structural integrity checker - verify rule execution result") { object WithSICheckerForPositiveLiteral extends RuleExecutor[Expression] { - override protected def isPlanIntegral(expr: Expression): Boolean = expr match { + override protected def isPlanIntegral( + previousPlan: Expression, + currentPlan: Expression): Boolean = currentPlan match { case IntegerLiteral(i) if i > 0 => true case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala index 0767039..f8cba90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils /** @@ -64,9 +65,12 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] { } } - override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { - !Utils.isTesting || (plan.resolved && - plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && - LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)) + override protected def isPlanIntegral( + previousPlan: LogicalPlan, + currentPlan: LogicalPlan): Boolean = { + !Utils.isTesting || (currentPlan.resolved && + currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && + LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) && + DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 81ecb2c..11d23f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -464,7 +464,7 @@ object DataSourceStrategy */ protected[sql] def normalizeExprs( exprs: Seq[Expression], - attributes: Seq[AttributeReference]): Seq[Expression] = { + attributes: Seq[Attribute]): Seq[Expression] = { exprs.map { e => e transform { case a: AttributeReference => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index a197445..4f331c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.util.SchemaUtils._ /** * Prunes unnecessary physical columns given a [[PhysicalOperation]] over a data source relation. @@ -82,8 +83,8 @@ object SchemaPruning extends Rule[LogicalPlan] { val prunedRelation = leafNodeBuilder(prunedDataSchema) val projectionOverSchema = ProjectionOverSchema(prunedDataSchema) - Some(buildNewProjection(normalizedProjects, normalizedFilters, prunedRelation, - projectionOverSchema)) + Some(buildNewProjection(projects, normalizedProjects, normalizedFilters, + prunedRelation, projectionOverSchema)) } else { None } @@ -125,6 +126,7 @@ object SchemaPruning extends Rule[LogicalPlan] { */ private def buildNewProjection( projects: Seq[NamedExpression], + normalizedProjects: Seq[NamedExpression], filters: Seq[Expression], leafNode: LeafNode, projectionOverSchema: ProjectionOverSchema): Project = { @@ -143,7 +145,7 @@ object SchemaPruning extends Rule[LogicalPlan] { // Construct the new projections of our Project by // rewriting the original projections - val newProjects = projects.map(_.transformDown { + val newProjects = normalizedProjects.map(_.transformDown { case projectionOverSchema(expr) => expr }).map { case expr: NamedExpression => expr } @@ -151,7 +153,7 @@ object SchemaPruning extends Rule[LogicalPlan] { logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}") } - Project(newProjects, projectionChild) + Project(restoreOriginalOutputNames(newProjects, projects.map(_.name)), projectionChild) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index d05519b..ab5a0fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownA import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ @@ -207,7 +208,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val newProjects = normalizedProjects .map(projectionFunc) .asInstanceOf[Seq[NamedExpression]] - Project(newProjects, withFilter) + Project(restoreOriginalOutputNames(newProjects, project.map(_.name)), withFilter) } else { withFilter } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index ac5c289..395ee6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -870,4 +870,16 @@ abstract class SchemaPruningSuite checkAnswer(query, Row(1) :: Row(2) :: Nil) } } + + test("SPARK-36352: Spark should check result plan's output schema name") { + withMixedCaseData { + val query = sql("select cOL1, cOl2.B from mixedcase") + assert(query.queryExecution.executedPlan.schema.catalogString == + "struct<cOL1:string,B:int>") + checkAnswer(query.orderBy("id"), + Row("r0c1", 1) :: + Row("r1c1", 2) :: + Nil) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org