Repository: spark
Updated Branches:
  refs/heads/master 37e52f879 -> 70c5549ee


[SPARK-18141][SQL] Fix to quote column names in the predicate clause  of the 
JDBC RDD generated sql statement

## What changes were proposed in this pull request?

SQL query generated for the JDBC data source is not quoting columns in the 
predicate clause. When the source table has quoted column names,  spark jdbc 
read fails with column not found error incorrectly.

Error:
org.h2.jdbc.JdbcSQLException: Column "ID" not found;
Source SQL statement:
SELECT "Name","Id" FROM TEST."mixedCaseCols" WHERE (Id < 1)

This PR fixes by quoting column names in the generated  SQL for predicate 
clause  when filters are pushed down to the data source.

Source SQL statement after the fix:
SELECT "Name","Id" FROM TEST."mixedCaseCols" WHERE ("Id" < 1)

## How was this patch tested?

Added new test case to the JdbcSuite

Author: sureshthalamati <suresh.thalam...@gmail.com>

Closes #15662 from sureshthalamati/filter_quoted_cols-SPARK-18141.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/70c5549e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/70c5549e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/70c5549e

Branch: refs/heads/master
Commit: 70c5549ee9588228d18a7b405c977cf591e2efd4
Parents: 37e52f8
Author: sureshthalamati <suresh.thalam...@gmail.com>
Authored: Thu Dec 1 19:13:38 2016 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Thu Dec 1 19:13:38 2016 -0800

----------------------------------------------------------------------
 .../execution/datasources/jdbc/JDBCRDD.scala    | 45 ++++++------
 .../datasources/jdbc/JDBCRelation.scala         |  3 +-
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   | 73 +++++++++++++++-----
 3 files changed, 82 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/70c5549e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index a1e5dfd..37df283 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -27,7 +27,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.jdbc.JdbcDialects
+import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.CompletionIterator
@@ -105,37 +105,40 @@ object JDBCRDD extends Logging {
    * Turns a single Filter into a String representing a SQL expression.
    * Returns None for an unhandled filter.
    */
-  def compileFilter(f: Filter): Option[String] = {
+  def compileFilter(f: Filter, dialect: JdbcDialect): Option[String] = {
+    def quote(colName: String): String = dialect.quoteIdentifier(colName)
+
     Option(f match {
-      case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
+      case EqualTo(attr, value) => s"${quote(attr)} = ${compileValue(value)}"
       case EqualNullSafe(attr, value) =>
-        s"(NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " +
-          s"${compileValue(value)} IS NULL) OR ($attr IS NULL AND 
${compileValue(value)} IS NULL))"
-      case LessThan(attr, value) => s"$attr < ${compileValue(value)}"
-      case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}"
-      case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}"
-      case GreaterThanOrEqual(attr, value) => s"$attr >= 
${compileValue(value)}"
-      case IsNull(attr) => s"$attr IS NULL"
-      case IsNotNull(attr) => s"$attr IS NOT NULL"
-      case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'"
-      case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'"
-      case StringContains(attr, value) => s"${attr} LIKE '%${value}%'"
+        val col = quote(attr)
+        s"(NOT ($col != ${compileValue(value)} OR $col IS NULL OR " +
+          s"${compileValue(value)} IS NULL) OR ($col IS NULL AND 
${compileValue(value)} IS NULL))"
+      case LessThan(attr, value) => s"${quote(attr)} < ${compileValue(value)}"
+      case GreaterThan(attr, value) => s"${quote(attr)} > 
${compileValue(value)}"
+      case LessThanOrEqual(attr, value) => s"${quote(attr)} <= 
${compileValue(value)}"
+      case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= 
${compileValue(value)}"
+      case IsNull(attr) => s"${quote(attr)} IS NULL"
+      case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL"
+      case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'"
+      case StringEndsWith(attr, value) => s"${quote(attr)} LIKE '%${value}'"
+      case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'"
       case In(attr, value) if value.isEmpty =>
-        s"CASE WHEN ${attr} IS NULL THEN NULL ELSE FALSE END"
-      case In(attr, value) => s"$attr IN (${compileValue(value)})"
-      case Not(f) => compileFilter(f).map(p => s"(NOT ($p))").getOrElse(null)
+        s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END"
+      case In(attr, value) => s"${quote(attr)} IN (${compileValue(value)})"
+      case Not(f) => compileFilter(f, dialect).map(p => s"(NOT 
($p))").getOrElse(null)
       case Or(f1, f2) =>
         // We can't compile Or filter unless both sub-filters are compiled 
successfully.
         // It applies too for the following And filter.
         // If we can make sure compileFilter supports all filters, we can 
remove this check.
-        val or = Seq(f1, f2).flatMap(compileFilter(_))
+        val or = Seq(f1, f2).flatMap(compileFilter(_, dialect))
         if (or.size == 2) {
           or.map(p => s"($p)").mkString(" OR ")
         } else {
           null
         }
       case And(f1, f2) =>
-        val and = Seq(f1, f2).flatMap(compileFilter(_))
+        val and = Seq(f1, f2).flatMap(compileFilter(_, dialect))
         if (and.size == 2) {
           and.map(p => s"($p)").mkString(" AND ")
         } else {
@@ -214,7 +217,9 @@ private[jdbc] class JDBCRDD(
    * `filters`, but as a WHERE clause suitable for injection into a SQL query.
    */
   private val filterWhereClause: String =
-    filters.flatMap(JDBCRDD.compileFilter).map(p => s"($p)").mkString(" AND ")
+    filters
+      .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url)))
+      .map(p => s"($p)").mkString(" AND ")
 
   /**
    * A WHERE clause representing both `filters`, if any, and the current 
partition.

http://git-wip-us.apache.org/repos/asf/spark/blob/70c5549e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 30caa73a..5ca1c75 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.Partition
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, 
SQLContext}
+import org.apache.spark.sql.jdbc.JdbcDialects
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.StructType
 
@@ -113,7 +114,7 @@ private[sql] case class JDBCRelation(
 
   // Check if JDBCRDD.compileFilter can accept input filters
   override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
-    filters.filter(JDBCRDD.compileFilter(_).isEmpty)
+    filters.filter(JDBCRDD.compileFilter(_, 
JdbcDialects.get(jdbcOptions.url)).isEmpty)
   }
 
   override def buildScan(requiredColumns: Array[String], filters: 
Array[Filter]): RDD[Row] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/70c5549e/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 218ccf9..aa1ab14 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -202,6 +202,21 @@ class JDBCSuite extends SparkFunSuite
          |partitionColumn '"Dept"', lowerBound '1', upperBound '4', 
numPartitions '4')
       """.stripMargin.replaceAll("\n", " "))
 
+    conn.prepareStatement(
+      """create table test."mixedCaseCols" ("Name" TEXT(32), "Id" INTEGER NOT 
NULL)""")
+      .executeUpdate()
+    conn.prepareStatement("""insert into test."mixedCaseCols" values ('fred', 
1)""").executeUpdate()
+    conn.prepareStatement("""insert into test."mixedCaseCols" values ('mary', 
2)""").executeUpdate()
+    conn.prepareStatement("""insert into test."mixedCaseCols" values (null, 
3)""").executeUpdate()
+    conn.commit()
+
+    sql(
+      s"""
+         |CREATE TEMPORARY TABLE mixedCaseCols
+         |USING org.apache.spark.sql.jdbc
+         |OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 
'testUser', password 'testPass')
+      """.stripMargin.replaceAll("\n", " "))
+
     // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
   }
 
@@ -632,30 +647,32 @@ class JDBCSuite extends SparkFunSuite
 
   test("compile filters") {
     val compileFilter = PrivateMethod[Option[String]]('compileFilter)
-    def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate 
compileFilter(f) getOrElse("")
-    assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3")
-    assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "(NOT (col1 = 
'abc'))")
+    def doCompileFilter(f: Filter): String =
+      JDBCRDD invokePrivate compileFilter(f, JdbcDialects.get("jdbc:")) 
getOrElse("")
+    assert(doCompileFilter(EqualTo("col0", 3)) === """"col0" = 3""")
+    assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === """(NOT ("col1" = 
'abc'))""")
     assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def")))
-      === "(col0 = 0) AND (col1 = 'def')")
+      === """("col0" = 0) AND ("col1" = 'def')""")
     assert(doCompileFilter(Or(EqualTo("col0", 2), EqualTo("col1", "ghi")))
-      === "(col0 = 2) OR (col1 = 'ghi')")
-    assert(doCompileFilter(LessThan("col0", 5)) === "col0 < 5")
+      === """("col0" = 2) OR ("col1" = 'ghi')""")
+    assert(doCompileFilter(LessThan("col0", 5)) === """"col0" < 5""")
     assert(doCompileFilter(LessThan("col3",
-      Timestamp.valueOf("1995-11-21 00:00:00.0"))) === "col3 < '1995-11-21 
00:00:00.0'")
-    assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) === 
"col4 < '1983-08-04'")
-    assert(doCompileFilter(LessThanOrEqual("col0", 5)) === "col0 <= 5")
-    assert(doCompileFilter(GreaterThan("col0", 3)) === "col0 > 3")
-    assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === "col0 >= 3")
-    assert(doCompileFilter(In("col1", Array("jkl"))) === "col1 IN ('jkl')")
+      Timestamp.valueOf("1995-11-21 00:00:00.0"))) === """"col3" < '1995-11-21 
00:00:00.0'""")
+    assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04")))
+      === """"col4" < '1983-08-04'""")
+    assert(doCompileFilter(LessThanOrEqual("col0", 5)) === """"col0" <= 5""")
+    assert(doCompileFilter(GreaterThan("col0", 3)) === """"col0" > 3""")
+    assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === """"col0" >= 
3""")
+    assert(doCompileFilter(In("col1", Array("jkl"))) === """"col1" IN 
('jkl')""")
     assert(doCompileFilter(In("col1", Array.empty)) ===
-      "CASE WHEN col1 IS NULL THEN NULL ELSE FALSE END")
+      """CASE WHEN "col1" IS NULL THEN NULL ELSE FALSE END""")
     assert(doCompileFilter(Not(In("col1", Array("mno", "pqr"))))
-      === "(NOT (col1 IN ('mno', 'pqr')))")
-    assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL")
-    assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL")
+      === """(NOT ("col1" IN ('mno', 'pqr')))""")
+    assert(doCompileFilter(IsNull("col1")) === """"col1" IS NULL""")
+    assert(doCompileFilter(IsNotNull("col1")) === """"col1" IS NOT NULL""")
     assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", 
"def")))
-      === "((NOT (col0 != 'abc' OR col0 IS NULL OR 'abc' IS NULL) "
-        + "OR (col0 IS NULL AND 'abc' IS NULL))) AND (col1 = 'def')")
+      === """((NOT ("col0" != 'abc' OR "col0" IS NULL OR 'abc' IS NULL) """
+        + """OR ("col0" IS NULL AND 'abc' IS NULL))) AND ("col1" = 'def')""")
   }
 
   test("Dialect unregister") {
@@ -853,4 +870,24 @@ class JDBCSuite extends SparkFunSuite
     val schema = JdbcUtils.schemaString(df.schema, 
"jdbc:mysql://localhost:3306/temp")
     assert(schema.contains("`order` TEXT"))
   }
+
+  test("SPARK-18141: Predicates on quoted column names in the jdbc data 
source") {
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Id < 1").collect().size == 0)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Id <= 1").collect().size == 
1)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Id > 1").collect().size == 2)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Id >= 1").collect().size == 
3)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Id = 1").collect().size == 1)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Id != 2").collect().size == 
2)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Id <=> 2").collect().size == 
1)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Name LIKE 
'fr%'").collect().size == 1)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Name LIKE 
'%ed'").collect().size == 1)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Name LIKE 
'%re%'").collect().size == 1)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Name IS 
NULL").collect().size == 1)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Name IS NOT 
NULL").collect().size == 2)
+    assert(sql("SELECT * FROM 
mixedCaseCols").filter($"Name".isin()).collect().size == 0)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Name IN ('mary', 
'fred')").collect().size == 2)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Name NOT IN 
('fred')").collect().size == 1)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Id = 1 OR Name = 
'mary'").collect().size == 2)
+    assert(sql("SELECT * FROM mixedCaseCols WHERE Name = 'mary' AND Id = 
2").collect().size == 1)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to