This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d2cd98bdd32 [SPARK-46029][SQL] Escape the single quote, `_` and `%` 
for DS V2 pushdown
d2cd98bdd32 is described below

commit d2cd98bdd32446b4106e66eb099efd8fb47acf40
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Wed Nov 29 01:37:35 2023 +0100

    [SPARK-46029][SQL] Escape the single quote, `_` and `%` for DS V2 pushdown
    
    ### What changes were proposed in this pull request?
    Spark supports push down `startsWith`, `endWith` and `contains` to JDBC 
database with DS V2 pushdown.
    But the `V2ExpressionSQLBuilder` didn't escape the single quote, `_` and 
`%`, it can cause unexpected result.
    
    ### Why are the changes needed?
    Escape the single quote, `_` and `%` for DS V2 pushdown
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    
    ### How was this patch tested?
    Exists test cases.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    'No'.
    
    Closes #43801 from beliefer/SPARK-38432_followup3.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/connector/util/V2ExpressionSQLBuilder.java |  35 ++++++-
 .../org/apache/spark/sql/jdbc/H2Dialect.scala      |   8 ++
 .../datasources/v2/V2PredicateSuite.scala          |   6 +-
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 113 ++++++++++++++++++++-
 4 files changed, 151 insertions(+), 11 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index 5c2523943dd..506b2c8782e 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -48,6 +48,35 @@ import org.apache.spark.sql.types.DataType;
  */
 public class V2ExpressionSQLBuilder {
 
+  /**
+   * Escape the special chars for like pattern.
+   *
+   * Note: This method adopts the escape representation within Spark and is 
not bound to any JDBC
+   * dialect. JDBC dialect should overwrite this API if the underlying 
database have more special
+   * chars other than _ and %.
+   */
+  protected String escapeSpecialCharsForLikePattern(String str) {
+    StringBuilder builder = new StringBuilder();
+
+    for (char c : str.toCharArray()) {
+      switch (c) {
+        case '_':
+          builder.append("\\_");
+          break;
+        case '%':
+          builder.append("\\%");
+          break;
+        case '\'':
+          builder.append("\\\'");
+          break;
+        default:
+          builder.append(c);
+      }
+    }
+
+    return builder.toString();
+  }
+
   public String build(Expression expr) {
     if (expr instanceof Literal) {
       return visitLiteral((Literal<?>) expr);
@@ -238,21 +267,21 @@ public class V2ExpressionSQLBuilder {
     // Remove quotes at the beginning and end.
     // e.g. converts "'str'" to "str".
     String value = r.substring(1, r.length() - 1);
-    return l + " LIKE '" + value + "%'";
+    return l + " LIKE '" + escapeSpecialCharsForLikePattern(value) + "%' 
ESCAPE '\\'";
   }
 
   protected String visitEndsWith(String l, String r) {
     // Remove quotes at the beginning and end.
     // e.g. converts "'str'" to "str".
     String value = r.substring(1, r.length() - 1);
-    return l + " LIKE '%" + value + "'";
+    return l + " LIKE '%" + escapeSpecialCharsForLikePattern(value) + "' 
ESCAPE '\\'";
   }
 
   protected String visitContains(String l, String r) {
     // Remove quotes at the beginning and end.
     // e.g. converts "'str'" to "str".
     String value = r.substring(1, r.length() - 1);
-    return l + " LIKE '%" + value + "%'";
+    return l + " LIKE '%" + escapeSpecialCharsForLikePattern(value) + "%' 
ESCAPE '\\'";
   }
 
   private String inputToSQL(Expression input) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 9bed6a6f873..a42fe989b15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -242,6 +242,14 @@ private[sql] object H2Dialect extends JdbcDialect {
   }
 
   class H2SQLBuilder extends JDBCSQLBuilder {
+    override def escapeSpecialCharsForLikePattern(str: String): String = {
+      str.map {
+        case '_' => "\\_"
+        case '%' => "\\%"
+        case c => c.toString
+      }.mkString
+    }
+
     override def visitAggregateFunction(
         funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
       if (isDistinct && 
distinctUnsupportedAggregateFunctions.contains(funcName)) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
index d2e04eab05c..92892a58399 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
@@ -315,7 +315,7 @@ class V2PredicateSuite extends SparkFunSuite {
       Array[Expression](ref("a"), literal))
     assert(predicate1.equals(predicate2))
     assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
-    assert(predicate1.describe.equals("a LIKE 'str%'"))
+    assert(predicate1.describe.equals(raw"a LIKE 'str%' ESCAPE '\'"))
 
     val v1Filter = StringStartsWith("a", "str")
     assert(v1Filter.toV2.equals(predicate1))
@@ -332,7 +332,7 @@ class V2PredicateSuite extends SparkFunSuite {
       Array[Expression](ref("a"), literal))
     assert(predicate1.equals(predicate2))
     assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
-    assert(predicate1.describe.equals("a LIKE '%str'"))
+    assert(predicate1.describe.equals(raw"a LIKE '%str' ESCAPE '\'"))
 
     val v1Filter = StringEndsWith("a", "str")
     assert(v1Filter.toV2.equals(predicate1))
@@ -349,7 +349,7 @@ class V2PredicateSuite extends SparkFunSuite {
       Array[Expression](ref("a"), literal))
     assert(predicate1.equals(predicate2))
     assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
-    assert(predicate1.describe.equals("a LIKE '%str%'"))
+    assert(predicate1.describe.equals(raw"a LIKE '%str%' ESCAPE '\'"))
 
     val v1Filter = StringContains("a", "str")
     assert(v1Filter.toV2.equals(predicate1))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index bcb366bbdda..a81501127a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -204,6 +204,19 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       conn.prepareStatement("INSERT INTO \"test\".\"datetime\" VALUES " +
         "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate()
 
+      conn.prepareStatement(
+        "CREATE TABLE \"test\".\"address\" (email TEXT(32) NOT 
NULL)").executeUpdate()
+      conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " +
+        "('abc_...@gmail.com')").executeUpdate()
+      conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " +
+        "('abc%...@gmail.com')").executeUpdate()
+      conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " +
+        "('abc%_...@gmail.com')").executeUpdate()
+      conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " +
+        "('abc_%...@gmail.com')").executeUpdate()
+      conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " +
+        "('abc_''%d...@gmail.com')").executeUpdate()
+
       conn.prepareStatement("CREATE TABLE \"test\".\"binary1\" (name 
TEXT(32),b BINARY(20))")
         .executeUpdate()
       val stmt = conn.prepareStatement("INSERT INTO \"test\".\"binary1\" 
VALUES (?, ?)")
@@ -1118,7 +1131,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
 
     val df3 = spark.table("h2.test.employee").filter($"name".startsWith("a"))
     checkFiltersRemoved(df3)
-    checkPushedInfo(df3, "PushedFilters: [NAME IS NOT NULL, NAME LIKE 'a%']")
+    checkPushedInfo(df3, raw"PushedFilters: [NAME IS NOT NULL, NAME LIKE 'a%' 
ESCAPE '\']")
     checkAnswer(df3, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 
12000, 1200, false)))
 
     val df4 = spark.table("h2.test.employee").filter($"is_manager")
@@ -1262,6 +1275,94 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     checkAnswer(df17, Seq(Row(6, "jen", 12000, 1200, true)))
   }
 
+  test("SPARK-38432: escape the single quote, _ and % for DS V2 pushdown") {
+    val df1 = 
spark.table("h2.test.address").filter($"email".startsWith("abc_"))
+    checkFiltersRemoved(df1)
+    checkPushedInfo(df1, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 
'abc\_%' ESCAPE '\']")
+    checkAnswer(df1,
+      Seq(Row("abc_%...@gmail.com"), Row("abc_'%d...@gmail.com"), 
Row("abc_...@gmail.com")))
+
+    val df2 = 
spark.table("h2.test.address").filter($"email".startsWith("abc%"))
+    checkFiltersRemoved(df2)
+    checkPushedInfo(df2, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 
'abc\%%' ESCAPE '\']")
+    checkAnswer(df2, Seq(Row("abc%_...@gmail.com"), Row("abc%...@gmail.com")))
+
+    val df3 = 
spark.table("h2.test.address").filter($"email".startsWith("abc%_"))
+    checkFiltersRemoved(df3)
+    checkPushedInfo(df3, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 
'abc\%\_%' ESCAPE '\']")
+    checkAnswer(df3, Seq(Row("abc%_...@gmail.com")))
+
+    val df4 = 
spark.table("h2.test.address").filter($"email".startsWith("abc_%"))
+    checkFiltersRemoved(df4)
+    checkPushedInfo(df4, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 
'abc\_\%%' ESCAPE '\']")
+    checkAnswer(df4, Seq(Row("abc_%...@gmail.com")))
+
+    val df5 = 
spark.table("h2.test.address").filter($"email".startsWith("abc_'%"))
+    checkFiltersRemoved(df5)
+    checkPushedInfo(df5,
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_\'\%%' ESCAPE 
'\']")
+    checkAnswer(df5, Seq(Row("abc_'%d...@gmail.com")))
+
+    val df6 = 
spark.table("h2.test.address").filter($"email".endsWith("_...@gmail.com"))
+    checkFiltersRemoved(df6)
+    checkPushedInfo(df6,
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_...@gmail.com' 
ESCAPE '\']")
+    checkAnswer(df6, Seq(Row("abc%_...@gmail.com"), Row("abc_...@gmail.com")))
+
+    val df7 = 
spark.table("h2.test.address").filter($"email".endsWith("%d...@gmail.com"))
+    checkFiltersRemoved(df7)
+    checkPushedInfo(df7,
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\%d...@gmail.com' 
ESCAPE '\']")
+    checkAnswer(df7,
+      Seq(Row("abc%...@gmail.com"), Row("abc_%...@gmail.com"), 
Row("abc_'%d...@gmail.com")))
+
+    val df8 = 
spark.table("h2.test.address").filter($"email".endsWith("%_...@gmail.com"))
+    checkFiltersRemoved(df8)
+    checkPushedInfo(df8,
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\%\_...@gmail.com' 
ESCAPE '\']")
+    checkAnswer(df8, Seq(Row("abc%_...@gmail.com")))
+
+    val df9 = 
spark.table("h2.test.address").filter($"email".endsWith("_%...@gmail.com"))
+    checkFiltersRemoved(df9)
+    checkPushedInfo(df9,
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_\%d...@gmail.com' 
ESCAPE '\']")
+    checkAnswer(df9, Seq(Row("abc_%...@gmail.com")))
+
+    val df10 = 
spark.table("h2.test.address").filter($"email".endsWith("_'%d...@gmail.com"))
+    checkFiltersRemoved(df10)
+    checkPushedInfo(df10,
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 
'%\_\'\%d...@gmail.com' ESCAPE '\']")
+    checkAnswer(df10, Seq(Row("abc_'%d...@gmail.com")))
+
+    val df11 = spark.table("h2.test.address").filter($"email".contains("c_d"))
+    checkFiltersRemoved(df11)
+    checkPushedInfo(df11, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 
'%c\_d%' ESCAPE '\']")
+    checkAnswer(df11, Seq(Row("abc_...@gmail.com")))
+
+    val df12 = spark.table("h2.test.address").filter($"email".contains("c%d"))
+    checkFiltersRemoved(df12)
+    checkPushedInfo(df12, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 
'%c\%d%' ESCAPE '\']")
+    checkAnswer(df12, Seq(Row("abc%...@gmail.com")))
+
+    val df13 = spark.table("h2.test.address").filter($"email".contains("c%_d"))
+    checkFiltersRemoved(df13)
+    checkPushedInfo(df13,
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\%\_d%' ESCAPE 
'\']")
+    checkAnswer(df13, Seq(Row("abc%_...@gmail.com")))
+
+    val df14 = spark.table("h2.test.address").filter($"email".contains("c_%d"))
+    checkFiltersRemoved(df14)
+    checkPushedInfo(df14,
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_\%d%' ESCAPE 
'\']")
+    checkAnswer(df14, Seq(Row("abc_%...@gmail.com")))
+
+    val df15 = 
spark.table("h2.test.address").filter($"email".contains("c_'%d"))
+    checkFiltersRemoved(df15)
+    checkPushedInfo(df15,
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_\'\%d%' ESCAPE 
'\']")
+    checkAnswer(df15, Seq(Row("abc_'%d...@gmail.com")))
+  }
+
   test("scan with filter push-down with ansi mode") {
     Seq(false, true).foreach { ansiMode =>
       withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
@@ -1347,10 +1448,11 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
         checkFiltersRemoved(df6, ansiMode)
         val expectedPlanFragment6 = if (ansiMode) {
           "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL, " +
-            "CAST(BONUS AS string) LIKE '%30%', CAST(DEPT AS byte) > 1, " +
+            raw"CAST(BONUS AS string) LIKE '%30%' ESCAPE '\', CAST(DEPT AS 
byte) > 1, " +
             "CAST(DEPT AS short) > 1, CAST(BONUS AS decimal(20,2)) > 1200.00]"
         } else {
-          "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL, CAST(BONUS AS 
string) LIKE '%30%']"
+          "PushedFilters: [BONUS IS NOT NULL, " +
+            raw"DEPT IS NOT NULL, CAST(BONUS AS string) LIKE '%30%' ESCAPE 
'\']"
         }
         checkPushedInfo(df6, expectedPlanFragment6)
         checkAnswer(df6, Seq(Row(2, "david", 10000, 1300, true)))
@@ -1612,8 +1714,9 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
 
   test("show tables") {
     checkAnswer(sql("SHOW TABLES IN h2.test"),
-      Seq(Row("test", "people", false), Row("test", "empty_table", false),
-        Row("test", "employee", false), Row("test", "item", false), 
Row("test", "dept", false),
+      Seq(Row("test", "address", false), Row("test", "people", false),
+        Row("test", "empty_table", false), Row("test", "employee", false),
+        Row("test", "item", false), Row("test", "dept", false),
         Row("test", "person", false), Row("test", "view1", false), Row("test", 
"view2", false),
         Row("test", "datetime", false), Row("test", "binary1", false)))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to