This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 172d68e3be1b [SPARK-55647][SQL] Improve `ConstantPropagation` for 
collated `AttributeReference`s
172d68e3be1b is described below

commit 172d68e3be1b95fe47ed7b75f1de697006d2b304
Author: ilicmarkodb <[email protected]>
AuthorDate: Tue Mar 3 23:32:41 2026 +0800

    [SPARK-55647][SQL] Improve `ConstantPropagation` for collated 
`AttributeReference`s
    
    ### What changes were proposed in this pull request?
    The previous change (https://github.com/apache/spark/pull/54435) completely 
blocked `ConstantPropagation` for non-binary-stable types, leading to potential 
performance implications. In this PR, I propose improving the rule to replace 
collated `AttributeReferences` when it is safe.
    
    ### Why are the changes needed?
    Perf improvement.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    New tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #54515 from ilicmarkodb/improve_ConstantPropagation.
    
    Authored-by: ilicmarkodb <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/optimizer/expressions.scala |  49 +++--
 .../spark/sql/collation/CollationSuite.scala       | 230 ++++++++++++++++++++-
 2 files changed, 260 insertions(+), 19 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index e406c51e7f8a..53a5e0f7eccf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -218,28 +218,51 @@ object ConstantPropagation extends Rule[LogicalPlan] {
   // substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable 
then the enclosing
   // NOT prevents us to do the substitution as NOT flips the context 
(`nullIsFalse`) of what a
   // null result of the enclosed expression means.
-  //
-  // Also, we shouldn't replace attributes with non-binary-stable data types, 
since this can lead
-  // to incorrect results. For example:
-  // `CREATE TABLE t (c STRING COLLATE UTF8_LCASE);`
-  // `INSERT INTO t VALUES ('HELLO'), ('hello');`
-  // `SELECT * FROM t WHERE c = 'hello' AND c = 'HELLO' COLLATE UNICODE;`
-  // If we replace `c` with `'hello'`, we get `'hello' = 'HELLO' COLLATE 
UNICODE` for the right
-  // condition, which is false, while the original `c = 'HELLO' COLLATE 
UNICODE` is true for
-  // 'HELLO' and false for 'hello'.
   private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
-    (!ar.nullable || nullIsFalse) && isBinaryStable(ar.dataType)
+    !ar.nullable || nullIsFalse
 
   private def replaceConstants(
       condition: Expression,
       equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): 
Expression = {
     val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, 
_)) => attr -> lit })
     val predicates = equalityPredicates.values.map(_._2).toSet
-    condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
-      case b: BinaryComparison if !predicates.contains(b) => b transform {
-        case a: AttributeReference => constantsMap.getOrElse(a, a)
+    def replaceInComparison(b: BinaryComparison): Expression = {
+      lazy val collationSafeReplacement = isSameCollationAttrRefComparison(b)
+      b transform {
+        case a: AttributeReference
+          if isBinaryStable(a.dataType) || collationSafeReplacement =>
+          constantsMap.getOrElse(a, a)
       }
     }
+    condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
+      case b: BinaryComparison if !predicates.contains(b) => 
replaceInComparison(b)
+    }
+  }
+
+  /**
+   * Binary-stable `AttributeReference`s can always be replaced safely. 
Non-binary-stable
+   * `AttributeReference`s (i.e., those with a non-`UTF8_BINARY` `StringType`) 
are only replaced
+   * when both sides of the comparison are `AttributeReference`s (or 
`CollationKey`-wrapped
+   * `AttributeReference`s) with the same `StringType`, preventing 
substitution inside
+   * expressions that change the effective collation (e.g., a `Cast`). For 
example, given a
+   * column `c STRING COLLATE UTF8_LCASE`:
+   *
+   *   `c = 'hello' AND c = 'HELLO' COLLATE UNICODE`
+   *
+   * `c` is added to `constantsMap`. In the right-hand comparison, `c` is 
wrapped with a
+   * `Cast` to `UNICODE`, so we don't have an `AttributeReference` vs. 
`AttributeReference`
+   * comparison and `c` is not replaced inside the `Cast`, preserving 
correctness.
+   */
+  private def isSameCollationAttrRefComparison(b: BinaryComparison): Boolean = 
{
+    (b.left, b.right) match {
+      case (AttributeReference(_, st1: StringType, _, _),
+      AttributeReference(_, st2: StringType, _, _)) =>
+        st1 == st2
+      case (CollationKey(AttributeReference(_, st1: StringType, _, _)),
+      CollationKey(AttributeReference(_, st2: StringType, _, _))) =>
+        st1 == st2
+      case _ => false
+    }
   }
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
index 734c1166c403..b446b29f7c68 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
@@ -2249,16 +2249,234 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
-  test("ConstantPropagation does not replace attributes with non-binary-stable 
collation") {
-    val tableName = "t1"
-    withTable(tableName) {
-      sql(s"CREATE TABLE $tableName (c STRING COLLATE UTF8_LCASE)")
-      sql(s"INSERT INTO $tableName VALUES ('hello'), ('HELLO')")
+  test("ConstantPropagation: does not replace attributes with 
non-binary-stable collation") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+
+      checkAnswer(
+        sql("SELECT * FROM t1 WHERE 'hello' = c AND c = 'HELLO' COLLATE 
UNICODE"),
+        Row("HELLO")
+      )
+    }
+  }
+
+  test("ConstantPropagation: does not replace non-binary-stable attributes 
with EqualNullSafe") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), (NULL)")
+
+      checkAnswer(
+        sql("SELECT * FROM t1 WHERE 'hello' <=> c AND c <=> 'HELLO' COLLATE 
UNICODE"),
+        Row("HELLO")
+      )
+    }
+  }
+
+  test("ConstantPropagation: replaces binary-stable attributes with 
contradicting predicates") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (c STRING)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('world')")
+
+      checkAnswer(
+        sql("SELECT * FROM t1 WHERE c = 'hello' AND c = 'world'"),
+        Seq.empty
+      )
+    }
+  }
+
+  test("ConstantPropagation: replaces binary-stable attributes across 
collation cast") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (c STRING)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+
+      checkAnswer(
+        sql("SELECT * FROM t1 WHERE c = 'hello' AND c COLLATE UTF8_LCASE = 
'HELLO'"),
+        Row("hello")
+      )
+    }
+  }
+
+  test("ConstantPropagation: does not replace non-binary-stable " +
+    "attributes with explicit CAST collation") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+
+      checkAnswer(
+        sql("""SELECT * FROM t1 WHERE c = 'hello'
+              |AND CAST(c AS STRING COLLATE UNICODE) = 'HELLO'""".stripMargin),
+        Row("HELLO")
+      )
+    }
+  }
+
+  test("ConstantPropagation: replaces non-binary-stable attributes in 
same-collation comparison") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (col1 STRING COLLATE UTF8_LCASE, col2 STRING 
COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello', 'hello'), ('HELLO', 'hello'), 
('hello', 'world')")
+
+      checkAnswer(
+        sql("SELECT * FROM t1 WHERE col1 = 'hello' AND col1 = col2"),
+        Seq(Row("hello", "hello"), Row("HELLO", "hello"))
+      )
+    }
+  }
+
+  test("ConstantPropagation: does not replace non-binary-stable attribute " +
+    "in different-collation column comparison") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (col1 STRING COLLATE UTF8_LCASE, col2 STRING 
COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello', 'hello'), ('HELLO', 'HELLO')")
 
       checkAnswer(
-        sql(s"SELECT * FROM $tableName WHERE c = 'hello' AND c = 'HELLO' 
COLLATE UNICODE"),
+        sql("SELECT * FROM t1 WHERE col1 = 'hello' AND col1 COLLATE UNICODE = 
col2"),
+        Seq(Row("HELLO", "HELLO"), Row("hello", "hello"))
+      )
+    }
+  }
+
+  test("ConstantPropagation: attribute is not propagated from inside NOT") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+
+      checkAnswer(
+        sql("SELECT * FROM t1 WHERE NOT(c = 'world') AND c = 'HELLO' COLLATE 
UNICODE"),
         Row("HELLO")
       )
     }
   }
+
+  test("ConstantPropagation: non-binary-stable attribute is not replaced 
inside NOT") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+
+      checkAnswer(
+        sql("SELECT * FROM t1 WHERE 'HELLO' = c AND NOT(c = 'HELLO' COLLATE 
UNICODE)"),
+        Row("hello")
+      )
+    }
+  }
+
+  test("ConstantPropagation: predicates do not propagate across OR branches") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+
+      checkAnswer(
+        sql("""SELECT * FROM t1 WHERE (c = 'hello' AND c = 'HELLO' COLLATE 
UNICODE)
+              |OR c = 'world'""".stripMargin),
+        Seq(Row("HELLO"), Row("world"))
+      )
+    }
+  }
+
+  test("ConstantPropagation: non-binary-stable join matches 
case-insensitively") {
+    withTable("t1", "t2") {
+      sql("CREATE TABLE t1 (a STRING COLLATE UTF8_LCASE)")
+      sql("CREATE TABLE t2 (b STRING COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+      sql("INSERT INTO t2 VALUES ('hello')")
+
+      checkAnswer(
+        sql("SELECT t1.a FROM t1 JOIN t2 ON t1.a = t2.b WHERE t1.a = 'hello'"),
+        Seq(Row("hello"), Row("HELLO"))
+      )
+    }
+  }
+
+  test("ConstantPropagation: does not replace non-binary-stable attribute " +
+    "in cross-collation join condition") {
+    withTable("t1", "t2") {
+      sql("CREATE TABLE t1 (a STRING COLLATE UTF8_LCASE)")
+      sql("CREATE TABLE t2 (b STRING COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+      sql("INSERT INTO t2 VALUES ('hello')")
+
+      checkAnswer(
+        sql("""SELECT t1.a FROM t1 JOIN t2 ON t1.a = t2.b COLLATE UNICODE
+              |WHERE t1.a = 'hello'""".stripMargin),
+        Row("hello")
+      )
+    }
+  }
+
+  test("ConstantPropagation: does not replace non-binary-stable attribute " +
+    "in cross-collation join filter") {
+    withTable("t1", "t2") {
+      sql("CREATE TABLE t1 (a STRING COLLATE UTF8_LCASE)")
+      sql("CREATE TABLE t2 (b STRING COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+      sql("INSERT INTO t2 VALUES ('hello')")
+
+      checkAnswer(
+        sql("""SELECT t1.a FROM t1 JOIN t2 ON t1.a = t2.b
+              |WHERE t1.a = 'hello' AND t1.a = 'HELLO' COLLATE 
UNICODE""".stripMargin),
+        Row("HELLO")
+      )
+    }
+  }
+
+  test("ConstantPropagation: binary-stable join correctly replaces 
attributes") {
+    withTable("t1", "t2") {
+      sql("CREATE TABLE t1 (a STRING)")
+      sql("CREATE TABLE t2 (b STRING)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+      sql("INSERT INTO t2 VALUES ('hello')")
+
+      checkAnswer(
+        sql("SELECT t1.a FROM t1 JOIN t2 ON t1.a = t2.b WHERE t1.a = 'hello'"),
+        Row("hello")
+      )
+    }
+  }
+
+  test("ConstantPropagation: binary-stable join replaces attributes " +
+    "across collation cast in join condition") {
+    withTable("t1", "t2") {
+      sql("CREATE TABLE t1 (a STRING)")
+      sql("CREATE TABLE t2 (b STRING)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+      sql("INSERT INTO t2 VALUES ('hello')")
+
+      checkAnswer(
+        sql("""SELECT t1.a FROM t1 JOIN t2 ON t1.a = t2.b COLLATE UTF8_LCASE
+              |WHERE t1.a = 'hello'""".stripMargin),
+        Row("hello")
+      )
+    }
+  }
+
+  test("ConstantPropagation: does not replace non-binary-stable attribute " +
+    "in null-safe join filter") {
+    withTable("t1", "t2") {
+      sql("CREATE TABLE t1 (a STRING COLLATE UTF8_LCASE)")
+      sql("CREATE TABLE t2 (b STRING COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), (NULL)")
+      sql("INSERT INTO t2 VALUES ('hello'), (NULL)")
+
+      checkAnswer(
+        sql("""SELECT t1.a FROM t1 JOIN t2 ON t1.a <=> t2.b
+              |WHERE t1.a = 'hello' AND t1.a = 'HELLO' COLLATE 
UNICODE""".stripMargin),
+        Row("HELLO")
+      )
+    }
+  }
+
+  test("ConstantPropagation: non-binary-stable null-safe join condition " +
+    "matches case-insensitively") {
+    withTable("t1", "t2") {
+      sql("CREATE TABLE t1 (a STRING COLLATE UTF8_LCASE)")
+      sql("CREATE TABLE t2 (b STRING COLLATE UTF8_LCASE)")
+      sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+      sql("INSERT INTO t2 VALUES ('hello')")
+
+      checkAnswer(
+        sql("SELECT t1.a FROM t1 JOIN t2 ON t1.a <=> t2.b WHERE t1.a = 
'hello'"),
+        Seq(Row("hello"), Row("HELLO"))
+      )
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to