This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 35ecb32e479 [SPARK-46029][SQL] Escape the single quote, `_` and `%` for DS V2 pushdown 35ecb32e479 is described below commit 35ecb32e479a33a1454709d133c48295d6774f3b Author: Jiaan Geng <belie...@163.com> AuthorDate: Wed Nov 29 01:37:35 2023 +0100 [SPARK-46029][SQL] Escape the single quote, `_` and `%` for DS V2 pushdown ### What changes were proposed in this pull request? Spark supports push down `startsWith`, `endWith` and `contains` to JDBC database with DS V2 pushdown. But the `V2ExpressionSQLBuilder` didn't escape the single quote, `_` and `%`, it can cause unexpected result. ### Why are the changes needed? Escape the single quote, `_` and `%` for DS V2 pushdown ### Does this PR introduce _any_ user-facing change? 'No'. ### How was this patch tested? Exists test cases. ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #43801 from beliefer/SPARK-38432_followup3. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit d2cd98bdd32446b4106e66eb099efd8fb47acf40) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/connector/util/V2ExpressionSQLBuilder.java | 35 ++++++- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 8 ++ .../datasources/v2/V2PredicateSuite.scala | 6 +- .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 113 ++++++++++++++++++++- 4 files changed, 151 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 9ca0fe4787f..dcb3c706946 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -48,6 +48,35 @@ import org.apache.spark.sql.types.DataType; */ public class V2ExpressionSQLBuilder { + /** + * Escape the special chars for like pattern. + * + * Note: This method adopts the escape representation within Spark and is not bound to any JDBC + * dialect. JDBC dialect should overwrite this API if the underlying database have more special + * chars other than _ and %. + */ + protected String escapeSpecialCharsForLikePattern(String str) { + StringBuilder builder = new StringBuilder(); + + for (char c : str.toCharArray()) { + switch (c) { + case '_': + builder.append("\\_"); + break; + case '%': + builder.append("\\%"); + break; + case '\'': + builder.append("\\\'"); + break; + default: + builder.append(c); + } + } + + return builder.toString(); + } + public String build(Expression expr) { if (expr instanceof Literal) { return visitLiteral((Literal<?>) expr); @@ -247,21 +276,21 @@ public class V2ExpressionSQLBuilder { // Remove quotes at the beginning and end. // e.g. converts "'str'" to "str". String value = r.substring(1, r.length() - 1); - return l + " LIKE '" + value + "%'"; + return l + " LIKE '" + escapeSpecialCharsForLikePattern(value) + "%' ESCAPE '\\'"; } protected String visitEndsWith(String l, String r) { // Remove quotes at the beginning and end. // e.g. converts "'str'" to "str". String value = r.substring(1, r.length() - 1); - return l + " LIKE '%" + value + "'"; + return l + " LIKE '%" + escapeSpecialCharsForLikePattern(value) + "' ESCAPE '\\'"; } protected String visitContains(String l, String r) { // Remove quotes at the beginning and end. // e.g. converts "'str'" to "str". String value = r.substring(1, r.length() - 1); - return l + " LIKE '%" + value + "%'"; + return l + " LIKE '%" + escapeSpecialCharsForLikePattern(value) + "%' ESCAPE '\\'"; } private String inputToSQL(Expression input) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index c246b50f4e1..8471a49153f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -240,6 +240,14 @@ private[sql] object H2Dialect extends JdbcDialect { } class H2SQLBuilder extends JDBCSQLBuilder { + override def escapeSpecialCharsForLikePattern(str: String): String = { + str.map { + case '_' => "\\_" + case '%' => "\\%" + case c => c.toString + }.mkString + } + override def visitAggregateFunction( funcName: String, isDistinct: Boolean, inputs: Array[String]): String = if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala index a5fee51dc91..4a8a231cc54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -315,7 +315,7 @@ class V2PredicateSuite extends SparkFunSuite { Array[Expression](ref("a"), literal)) assert(predicate1.equals(predicate2)) assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) - assert(predicate1.describe.equals("a LIKE 'str%'")) + assert(predicate1.describe.equals(raw"a LIKE 'str%' ESCAPE '\'")) val v1Filter = StringStartsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) @@ -332,7 +332,7 @@ class V2PredicateSuite extends SparkFunSuite { Array[Expression](ref("a"), literal)) assert(predicate1.equals(predicate2)) assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) - assert(predicate1.describe.equals("a LIKE '%str'")) + assert(predicate1.describe.equals(raw"a LIKE '%str' ESCAPE '\'")) val v1Filter = StringEndsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) @@ -349,7 +349,7 @@ class V2PredicateSuite extends SparkFunSuite { Array[Expression](ref("a"), literal)) assert(predicate1.equals(predicate2)) assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) - assert(predicate1.describe.equals("a LIKE '%str%'")) + assert(predicate1.describe.equals(raw"a LIKE '%str%' ESCAPE '\'")) val v1Filter = StringContains("a", "str") assert(v1Filter.toV2.equals(predicate1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index ae0cfe17b11..51a15881088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -185,6 +185,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel conn.prepareStatement("INSERT INTO \"test\".\"datetime\" VALUES " + "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() + conn.prepareStatement( + "CREATE TABLE \"test\".\"address\" (email TEXT(32) NOT NULL)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " + + "('abc_...@gmail.com')").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " + + "('abc%...@gmail.com')").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " + + "('abc%_...@gmail.com')").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " + + "('abc_%...@gmail.com')").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " + + "('abc_''%d...@gmail.com')").executeUpdate() + conn.prepareStatement("CREATE TABLE \"test\".\"binary1\" (name TEXT(32),b BINARY(20))") .executeUpdate() val stmt = conn.prepareStatement("INSERT INTO \"test\".\"binary1\" VALUES (?, ?)") @@ -1096,7 +1109,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df3 = spark.table("h2.test.employee").filter($"name".startsWith("a")) checkFiltersRemoved(df3) - checkPushedInfo(df3, "PushedFilters: [NAME IS NOT NULL, NAME LIKE 'a%']") + checkPushedInfo(df3, raw"PushedFilters: [NAME IS NOT NULL, NAME LIKE 'a%' ESCAPE '\']") checkAnswer(df3, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 12000, 1200, false))) val df4 = spark.table("h2.test.employee").filter($"is_manager") @@ -1240,6 +1253,94 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df17, Seq(Row(6, "jen", 12000, 1200, true))) } + test("SPARK-38432: escape the single quote, _ and % for DS V2 pushdown") { + val df1 = spark.table("h2.test.address").filter($"email".startsWith("abc_")) + checkFiltersRemoved(df1) + checkPushedInfo(df1, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_%' ESCAPE '\']") + checkAnswer(df1, + Seq(Row("abc_%...@gmail.com"), Row("abc_'%d...@gmail.com"), Row("abc_...@gmail.com"))) + + val df2 = spark.table("h2.test.address").filter($"email".startsWith("abc%")) + checkFiltersRemoved(df2) + checkPushedInfo(df2, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\%%' ESCAPE '\']") + checkAnswer(df2, Seq(Row("abc%_...@gmail.com"), Row("abc%...@gmail.com"))) + + val df3 = spark.table("h2.test.address").filter($"email".startsWith("abc%_")) + checkFiltersRemoved(df3) + checkPushedInfo(df3, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\%\_%' ESCAPE '\']") + checkAnswer(df3, Seq(Row("abc%_...@gmail.com"))) + + val df4 = spark.table("h2.test.address").filter($"email".startsWith("abc_%")) + checkFiltersRemoved(df4) + checkPushedInfo(df4, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_\%%' ESCAPE '\']") + checkAnswer(df4, Seq(Row("abc_%...@gmail.com"))) + + val df5 = spark.table("h2.test.address").filter($"email".startsWith("abc_'%")) + checkFiltersRemoved(df5) + checkPushedInfo(df5, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_\'\%%' ESCAPE '\']") + checkAnswer(df5, Seq(Row("abc_'%d...@gmail.com"))) + + val df6 = spark.table("h2.test.address").filter($"email".endsWith("_...@gmail.com")) + checkFiltersRemoved(df6) + checkPushedInfo(df6, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_...@gmail.com' ESCAPE '\']") + checkAnswer(df6, Seq(Row("abc%_...@gmail.com"), Row("abc_...@gmail.com"))) + + val df7 = spark.table("h2.test.address").filter($"email".endsWith("%d...@gmail.com")) + checkFiltersRemoved(df7) + checkPushedInfo(df7, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\%d...@gmail.com' ESCAPE '\']") + checkAnswer(df7, + Seq(Row("abc%...@gmail.com"), Row("abc_%...@gmail.com"), Row("abc_'%d...@gmail.com"))) + + val df8 = spark.table("h2.test.address").filter($"email".endsWith("%_...@gmail.com")) + checkFiltersRemoved(df8) + checkPushedInfo(df8, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\%\_...@gmail.com' ESCAPE '\']") + checkAnswer(df8, Seq(Row("abc%_...@gmail.com"))) + + val df9 = spark.table("h2.test.address").filter($"email".endsWith("_%...@gmail.com")) + checkFiltersRemoved(df9) + checkPushedInfo(df9, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_\%d...@gmail.com' ESCAPE '\']") + checkAnswer(df9, Seq(Row("abc_%...@gmail.com"))) + + val df10 = spark.table("h2.test.address").filter($"email".endsWith("_'%d...@gmail.com")) + checkFiltersRemoved(df10) + checkPushedInfo(df10, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_\'\%d...@gmail.com' ESCAPE '\']") + checkAnswer(df10, Seq(Row("abc_'%d...@gmail.com"))) + + val df11 = spark.table("h2.test.address").filter($"email".contains("c_d")) + checkFiltersRemoved(df11) + checkPushedInfo(df11, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_d%' ESCAPE '\']") + checkAnswer(df11, Seq(Row("abc_...@gmail.com"))) + + val df12 = spark.table("h2.test.address").filter($"email".contains("c%d")) + checkFiltersRemoved(df12) + checkPushedInfo(df12, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\%d%' ESCAPE '\']") + checkAnswer(df12, Seq(Row("abc%...@gmail.com"))) + + val df13 = spark.table("h2.test.address").filter($"email".contains("c%_d")) + checkFiltersRemoved(df13) + checkPushedInfo(df13, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\%\_d%' ESCAPE '\']") + checkAnswer(df13, Seq(Row("abc%_...@gmail.com"))) + + val df14 = spark.table("h2.test.address").filter($"email".contains("c_%d")) + checkFiltersRemoved(df14) + checkPushedInfo(df14, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_\%d%' ESCAPE '\']") + checkAnswer(df14, Seq(Row("abc_%...@gmail.com"))) + + val df15 = spark.table("h2.test.address").filter($"email".contains("c_'%d")) + checkFiltersRemoved(df15) + checkPushedInfo(df15, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_\'\%d%' ESCAPE '\']") + checkAnswer(df15, Seq(Row("abc_'%d...@gmail.com"))) + } + test("scan with filter push-down with ansi mode") { Seq(false, true).foreach { ansiMode => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { @@ -1325,10 +1426,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df6, ansiMode) val expectedPlanFragment6 = if (ansiMode) { "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL, " + - "CAST(BONUS AS string) LIKE '%30%', CAST(DEPT AS byte) > 1, " + + raw"CAST(BONUS AS string) LIKE '%30%' ESCAPE '\', CAST(DEPT AS byte) > 1, " + "CAST(DEPT AS short) > 1, CAST(BONUS AS decimal(20,2)) > 1200.00]" } else { - "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL, CAST(BONUS AS string) LIKE '%30%']" + "PushedFilters: [BONUS IS NOT NULL, " + + raw"DEPT IS NOT NULL, CAST(BONUS AS string) LIKE '%30%' ESCAPE '\']" } checkPushedInfo(df6, expectedPlanFragment6) checkAnswer(df6, Seq(Row(2, "david", 10000, 1300, true))) @@ -1538,8 +1640,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("show tables") { checkAnswer(sql("SHOW TABLES IN h2.test"), - Seq(Row("test", "people", false), Row("test", "empty_table", false), - Row("test", "employee", false), Row("test", "item", false), Row("test", "dept", false), + Seq(Row("test", "address", false), Row("test", "people", false), + Row("test", "empty_table", false), Row("test", "employee", false), + Row("test", "item", false), Row("test", "dept", false), Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false), Row("test", "datetime", false), Row("test", "binary1", false))) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org