This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 01adf10405fb [SPARK-50256][SQL] Add lightweight validation to check if 
a logical plan becomes unresolved after every optimizer rule
01adf10405fb is described below

commit 01adf10405fb2d2a528f39fec362da93dd6de55e
Author: Kelvin Jiang <[email protected]>
AuthorDate: Thu Nov 7 15:58:16 2024 -0800

    [SPARK-50256][SQL] Add lightweight validation to check if a logical plan 
becomes unresolved after every optimizer rule
    
    ### What changes were proposed in this pull request?
    
    This PR adds a new "lightweight" plan change validation, that will be run 
after every optimizer rule and enabled by default. Note that this is an 
extension to existing validation logic, which is currently enabled for tests 
but disabled by default in production. This new validation will be enabled by 
default but will be skipped if regular validation is ever enabled (since 
regular validation is a superset of this).
    
    Right now, the lightweight validation only consists of checking if the plan 
becomes unresolved (which is a cheap O(1) lookup).
    
    ### Why are the changes needed?
    
    If a query fails somewhere in optimization or physical planning due to an 
unresolved reference, it is likely due to a bug somewhere in optimization. 
Adding this validation helps us know where exactly in optimization that the 
plan became invalid.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, this should only fail queries that would crash somewhere else during 
query compilation. It effectively brings the query failure closer to the cause.
    
    ### How was this patch tested?
    
    Added a UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48787 from kelvinjian-db/SPARK-50256-lightweight-validation.
    
    Authored-by: Kelvin Jiang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  8 ++-
 .../sql/catalyst/plans/logical/LogicalPlan.scala   | 15 ++++-
 .../spark/sql/catalyst/rules/RuleExecutor.scala    | 26 ++++++--
 .../org/apache/spark/sql/internal/SQLConf.scala    | 10 ++-
 .../sql/catalyst/trees/RuleExecutorSuite.scala     | 74 +++++++++++++++++++++-
 .../sql/execution/adaptive/AQEOptimizer.scala      |  8 ++-
 6 files changed, 131 insertions(+), 10 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 76a0e90f5eb2..9a2aa82c25d5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -52,7 +52,13 @@ abstract class Optimizer(catalogManager: CatalogManager)
   override protected def validatePlanChanges(
       previousPlan: LogicalPlan,
       currentPlan: LogicalPlan): Option[String] = {
-    LogicalPlanIntegrity.validateOptimizedPlan(previousPlan, currentPlan)
+    LogicalPlanIntegrity.validateOptimizedPlan(previousPlan, currentPlan, 
lightweight = false)
+  }
+
+  override protected def validatePlanChangesLightweight(
+      previousPlan: LogicalPlan,
+      currentPlan: LogicalPlan): Option[String] = {
+    LogicalPlanIntegrity.validateOptimizedPlan(previousPlan, currentPlan, 
lightweight = true)
   }
 
   override protected val excludedOnceBatches: Set[String] =
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 556319920544..c236f7cf08e8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -433,10 +433,23 @@ object LogicalPlanIntegrity {
    * - has globally-unique attribute IDs
    * - has the same result schema as the previous plan
    * - has no dangling attribute references
+   * If `lightweight` is true, we only run the first check above.
    */
   def validateOptimizedPlan(
       previousPlan: LogicalPlan,
-      currentPlan: LogicalPlan): Option[String] = {
+      currentPlan: LogicalPlan,
+      lightweight: Boolean): Option[String] = {
+    // Lightweight validation logic. If `lightweight` is true, we only run 
this validation.
+    if (lightweight) {
+      val validation = if (previousPlan.resolved && !currentPlan.resolved) {
+        Some("The plan was previously resolved and now became unresolved.")
+      } else {
+        None
+      }
+      return validation
+    }
+
+    // Full validation logic.
     var validation = if (!currentPlan.resolved) {
       Some("The plan becomes unresolved: " + currentPlan.treeString + "\nThe 
previous plan: " +
         previousPlan.treeString)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index c8b3f224a312..935233d5c85d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -165,6 +165,15 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] 
extends Logging {
       previousPlan: TreeType,
       currentPlan: TreeType): Option[String] = None
 
+  /**
+   * Defines a validate function that validates the plan changes after the 
execution of each rule,
+   * to make sure these rules make valid changes to the plan. Since this is 
enabled by default,
+   * this should only consist of very lightweight checks.
+   */
+  protected def validatePlanChangesLightweight(
+      previousPlan: TreeType,
+      currentPlan: TreeType): Option[String] = None
+
   /**
    * Util method for checking whether a plan remains the same if re-optimized.
    */
@@ -198,9 +207,10 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] 
extends Logging {
     val tracker: Option[QueryPlanningTracker] = QueryPlanningTracker.get
     val beforeMetrics = RuleExecutor.getCurrentMetrics()
 
-    val enableValidation = SQLConf.get.getConf(SQLConf.PLAN_CHANGE_VALIDATION)
+    val fullValidation = SQLConf.get.getConf(SQLConf.PLAN_CHANGE_VALIDATION)
+    lazy val lightweightValidation = 
SQLConf.get.getConf(SQLConf.LIGHTWEIGHT_PLAN_CHANGE_VALIDATION)
     // Validate the initial input.
-    if (Utils.isTesting || enableValidation) {
+    if (fullValidation) {
       validatePlanChanges(plan, plan) match {
         case Some(msg) =>
           val ruleExecutorName = this.getClass.getName.stripSuffix("$")
@@ -218,7 +228,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] 
extends Logging {
       var lastPlan = curPlan
       var continue = true
 
-      // Run until fix point (or the max number of iterations as specified in 
the strategy.
+      // Run until fix point or the max number of iterations as specified in 
the strategy.
       while (continue) {
         curPlan = batch.rules.foldLeft(curPlan) {
           case (plan, rule) =>
@@ -232,8 +242,14 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] 
extends Logging {
               queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, 
runTime)
               planChangeLogger.logRule(rule.ruleName, plan, result)
               // Run the plan changes validation after each rule.
-              if (Utils.isTesting || enableValidation) {
-                validatePlanChanges(plan, result) match {
+              if (fullValidation || lightweightValidation) {
+                // Only run the lightweight version of validation if full 
validation is disabled.
+                val validationResult = if (fullValidation) {
+                  validatePlanChanges(plan, result)
+                } else {
+                  validatePlanChangesLightweight(plan, result)
+                }
+                validationResult match {
                   case Some(msg) =>
                     throw new SparkException(
                       errorClass = "PLAN_VALIDATION_FAILED_RULE_IN_BATCH",
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 82e58b360488..d17ab656fe6b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -326,7 +326,15 @@ object SQLConf {
       "catalyst rules, to make sure every rule returns a valid plan")
     .version("3.4.0")
     .booleanConf
-    .createWithDefault(false)
+    .createWithDefault(Utils.isTesting)
+
+  val LIGHTWEIGHT_PLAN_CHANGE_VALIDATION = 
buildConf("spark.sql.lightweightPlanChangeValidation")
+    .internal()
+    .doc(s"Similar to ${PLAN_CHANGE_VALIDATION.key}, this validates plan 
changes and runs after " +
+      s"every rule, however it is enabled by default and so it should be 
lightweight.")
+    .version("4.0.0")
+    .booleanConf
+    .createWithDefault(true)
 
   val ALLOW_NAMED_FUNCTION_ARGUMENTS = 
buildConf("spark.sql.allowNamedFunctionArguments")
     .doc("If true, Spark will turn on support for named parameters for all 
functions that has" +
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
index df2a4db6bb15..ae1941021044 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
@@ -18,10 +18,12 @@
 package org.apache.spark.sql.catalyst.trees
 
 import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.catalyst.SQLConfHelper
 import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, 
Literal}
 import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.internal.SQLConf
 
-class RuleExecutorSuite extends SparkFunSuite {
+class RuleExecutorSuite extends SparkFunSuite with SQLConfHelper {
   object DecrementLiterals extends Rule[Expression] {
     def apply(e: Expression): Expression = e transform {
       case IntegerLiteral(i) if i > 0 => Literal(i - 1)
@@ -113,6 +115,76 @@ class RuleExecutorSuite extends SparkFunSuite {
     assert(e.getMessage.contains("not positive integer"))
   }
 
+  private object OptimizerWithLightweightValidation extends 
RuleExecutor[Expression] {
+    override protected def validatePlanChanges(
+        previousPlan: Expression,
+        currentPlan: Expression): Option[String] = {
+      (previousPlan, currentPlan) match {
+        case (IntegerLiteral(i), IntegerLiteral(j)) if i == j => None
+        case _ => Some("value changed")
+      }
+    }
+    override protected def validatePlanChangesLightweight(
+        previousPlan: Expression,
+        currentPlan: Expression): Option[String] = previousPlan match {
+      case IntegerLiteral(i) if i < 0 => None
+      case _ => Some("input is non-negative")
+    }
+    override val batches: Seq[Batch] = Batch("once", FixedPoint(1), 
DecrementLiterals) :: Nil
+  }
+
+  test("lightweight optimizer validation disabled") {
+    withSQLConf(SQLConf.LIGHTWEIGHT_PLAN_CHANGE_VALIDATION.key -> "false") {
+      // Test when full plan validation is both enabled and disabled.
+      Seq("true", "false").foreach { fullValidation =>
+        withSQLConf(SQLConf.PLAN_CHANGE_VALIDATION.key -> fullValidation) {
+          // Input passes validation
+          assert(OptimizerWithLightweightValidation.execute(Literal(0)) === 
Literal(0))
+
+          // Input does not pass validation
+          if (fullValidation == "false") {
+            // no validation runs
+            assert(OptimizerWithLightweightValidation.execute(Literal(1)) === 
Literal(0))
+          } else {
+            // full validation runs, taking the place of lightweight validation
+            val e = intercept[SparkException] {
+              OptimizerWithLightweightValidation.execute(Literal(1))
+            }
+            val ruleName = DecrementLiterals.ruleName
+            assert(e.getMessage.contains(s"Rule $ruleName in batch once 
generated an invalid plan"))
+            assert(e.getMessage.contains("value changed"))
+          }
+        }
+      }
+    }
+  }
+
+  test("lightweight optimizer validation enabled") {
+    withSQLConf(SQLConf.LIGHTWEIGHT_PLAN_CHANGE_VALIDATION.key -> "true") {
+      // Test when full plan validation is both enabled and disabled.
+      Seq("true", "false").foreach { fullValidation =>
+        withSQLConf(SQLConf.PLAN_CHANGE_VALIDATION.key -> fullValidation) {
+          // Input passes validation
+          assert(OptimizerWithLightweightValidation.execute(Literal(0)) === 
Literal(0))
+
+          // Input does not pass validation
+          val e = intercept[SparkException] {
+            OptimizerWithLightweightValidation.execute(Literal(1))
+          }
+          val ruleName = DecrementLiterals.ruleName
+          assert(e.getMessage.contains(s"Rule $ruleName in batch once 
generated an invalid plan"))
+          if (fullValidation == "false") {
+            // only lightweight validation runs
+            assert(e.getMessage.contains("input is non-negative"))
+          } else {
+            // full validation runs, taking the place of lightweight validation
+            assert(e.getMessage.contains("value changed"))
+          }
+        }
+      }
+    }
+  }
+
   test("SPARK-27243: dumpTimeSpent when no rule has run") {
     RuleExecutor.resetMetrics()
     // This should not throw an exception
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
index 014d23f2f410..0f1743eeaacf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
@@ -74,6 +74,12 @@ class AQEOptimizer(conf: SQLConf, 
extendedRuntimeOptimizerRules: Seq[Rule[Logica
   override protected def validatePlanChanges(
       previousPlan: LogicalPlan,
       currentPlan: LogicalPlan): Option[String] = {
-    LogicalPlanIntegrity.validateOptimizedPlan(previousPlan, currentPlan)
+    LogicalPlanIntegrity.validateOptimizedPlan(previousPlan, currentPlan, 
lightweight = false)
+  }
+
+  override protected def validatePlanChangesLightweight(
+      previousPlan: LogicalPlan,
+      currentPlan: LogicalPlan): Option[String] = {
+    LogicalPlanIntegrity.validateOptimizedPlan(previousPlan, currentPlan, 
lightweight = true)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to