Repository: spark
Updated Branches:
  refs/heads/master 515708d5f -> 8d9495a8f


[SPARK-25207][SQL] Case-insensitve field resolution for filter pushdown when 
reading Parquet

## What changes were proposed in this pull request?

Currently, filter pushdown will not work if Parquet schema and Hive metastore 
schema are in different letter cases even spark.sql.caseSensitive is false.

Like the below case:
```scala
spark.sparkContext.hadoopConfiguration.setInt("parquet.block.size", 8 * 1024 * 
1024)
spark.range(1, 40 * 1024 * 1024, 1, 
1).sortWithinPartitions("id").write.parquet("/tmp/t")
sql("CREATE TABLE t (ID LONG) USING parquet LOCATION '/tmp/t'")
sql("select * from t where id < 100L").write.csv("/tmp/id")
```

Although filter "ID < 100L" is generated by Spark, it fails to pushdown into 
parquet actually, Spark still does the full table scan when reading.
This PR provides a case-insensitive field resolution to make it work.

Before - "ID < 100L" fail to pushedown:
<img width="273" alt="screen shot 2018-08-23 at 10 08 26 pm" 
src="https://user-images.githubusercontent.com/2989575/44530558-40ef8b00-a721-11e8-8abc-7f97671590d3.png";>
After - "ID < 100L" pushedown sucessfully:
<img width="267" alt="screen shot 2018-08-23 at 10 08 40 pm" 
src="https://user-images.githubusercontent.com/2989575/44530567-44831200-a721-11e8-8634-e9f664b33d39.png";>

## How was this patch tested?

Added UTs.

Closes #22197 from yucai/SPARK-25207.

Authored-by: yucai <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8d9495a8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8d9495a8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8d9495a8

Branch: refs/heads/master
Commit: 8d9495a8f1e64dbc42c3741f9bcbd4893ce3f0e9
Parents: 515708d
Author: yucai <[email protected]>
Authored: Fri Aug 31 19:24:09 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Fri Aug 31 19:24:09 2018 +0800

----------------------------------------------------------------------
 .../datasources/parquet/ParquetFileFormat.scala |   3 +-
 .../datasources/parquet/ParquetFilters.scala    |  90 ++++++++++-----
 .../parquet/ParquetFilterSuite.scala            | 115 ++++++++++++++++++-
 3 files changed, 179 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8d9495a8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index d7eb143..ea4f159 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -347,6 +347,7 @@ class ParquetFileFormat
     val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal
     val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith
     val pushDownInFilterThreshold = 
sqlConf.parquetFilterPushDownInFilterThreshold
+    val isCaseSensitive = sqlConf.caseSensitiveAnalysis
 
     (file: PartitionedFile) => {
       assert(file.partitionValues.numFields == partitionSchema.size)
@@ -372,7 +373,7 @@ class ParquetFileFormat
       val pushed = if (enableParquetFilterPushDown) {
         val parquetSchema = footerFileMetaData.getSchema
         val parquetFilters = new ParquetFilters(pushDownDate, 
pushDownTimestamp, pushDownDecimal,
-          pushDownStringStartWith, pushDownInFilterThreshold)
+          pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive)
         filters
           // Collects all converted Parquet filter predicates. Notice that not 
all predicates can be
           // converted (`ParquetFilters.createFilter` returns an `Option`). 
That's why a `flatMap`

http://git-wip-us.apache.org/repos/asf/spark/blob/8d9495a8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index 58b4a76..0c286de 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
 import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, 
Long => JLong}
 import java.math.{BigDecimal => JBigDecimal}
 import java.sql.{Date, Timestamp}
+import java.util.Locale
 
 import scala.collection.JavaConverters.asScalaBufferConverter
 
@@ -31,7 +32,7 @@ import org.apache.parquet.schema.OriginalType._
 import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
 import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
 
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate
 import org.apache.spark.sql.sources
 import org.apache.spark.unsafe.types.UTF8String
@@ -44,7 +45,18 @@ private[parquet] class ParquetFilters(
     pushDownTimestamp: Boolean,
     pushDownDecimal: Boolean,
     pushDownStartWith: Boolean,
-    pushDownInFilterThreshold: Int) {
+    pushDownInFilterThreshold: Int,
+    caseSensitive: Boolean) {
+
+  /**
+   * Holds a single field information stored in the underlying parquet file.
+   *
+   * @param fieldName field name in parquet file
+   * @param fieldType field type related info in parquet file
+   */
+  private case class ParquetField(
+      fieldName: String,
+      fieldType: ParquetSchemaType)
 
   private case class ParquetSchemaType(
       originalType: OriginalType,
@@ -350,25 +362,38 @@ private[parquet] class ParquetFilters(
   }
 
   /**
-   * Returns a map from name of the column to the data type, if predicate push 
down applies.
+   * Returns a map, which contains parquet field name and data type, if 
predicate push down applies.
    */
-  private def getFieldMap(dataType: MessageType): Map[String, 
ParquetSchemaType] = dataType match {
-    case m: MessageType =>
-      // Here we don't flatten the fields in the nested schema but just look 
up through
-      // root fields. Currently, accessing to nested fields does not push down 
filters
-      // and it does not support to create filters for them.
-      m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { 
f =>
-        f.getName -> ParquetSchemaType(
-          f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, 
f.getDecimalMetadata)
-      }.toMap
-    case _ => Map.empty[String, ParquetSchemaType]
+  private def getFieldMap(dataType: MessageType): Map[String, ParquetField] = {
+    // Here we don't flatten the fields in the nested schema but just look up 
through
+    // root fields. Currently, accessing to nested fields does not push down 
filters
+    // and it does not support to create filters for them.
+    val primitiveFields =
+      
dataType.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { 
f =>
+        f.getName -> ParquetField(f.getName,
+          ParquetSchemaType(f.getOriginalType,
+            f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata))
+      }
+    if (caseSensitive) {
+      primitiveFields.toMap
+    } else {
+      // Don't consider ambiguity here, i.e. more than one field is matched in 
case insensitive
+      // mode, just skip pushdown for these fields, they will trigger 
Exception when reading,
+      // See: SPARK-25132.
+      val dedupPrimitiveFields =
+        primitiveFields
+          .groupBy(_._1.toLowerCase(Locale.ROOT))
+          .filter(_._2.size == 1)
+          .mapValues(_.head._2)
+      CaseInsensitiveMap(dedupPrimitiveFields)
+    }
   }
 
   /**
    * Converts data sources filters to Parquet filter predicates.
    */
   def createFilter(schema: MessageType, predicate: sources.Filter): 
Option[FilterPredicate] = {
-    val nameToType = getFieldMap(schema)
+    val nameToParquetField = getFieldMap(schema)
 
     // Decimal type must make sure that filter value's scale matched the file.
     // If doesn't matched, which would cause data corruption.
@@ -381,7 +406,7 @@ private[parquet] class ParquetFilters(
     // Parquet's type in the given file should be matched to the value's type
     // in the pushed filter in order to push down the filter to Parquet.
     def valueCanMakeFilterOn(name: String, value: Any): Boolean = {
-      value == null || (nameToType(name) match {
+      value == null || (nameToParquetField(name).fieldType match {
         case ParquetBooleanType => value.isInstanceOf[JBoolean]
         case ParquetByteType | ParquetShortType | ParquetIntegerType => 
value.isInstanceOf[Number]
         case ParquetLongType => value.isInstanceOf[JLong]
@@ -408,7 +433,7 @@ private[parquet] class ParquetFilters(
     // filters for the column having dots in the names. Thus, we do not push 
down such filters.
     // See SPARK-20364.
     def canMakeFilterOn(name: String, value: Any): Boolean = {
-      nameToType.contains(name) && !name.contains(".") && 
valueCanMakeFilterOn(name, value)
+      nameToParquetField.contains(name) && !name.contains(".") && 
valueCanMakeFilterOn(name, value)
     }
 
     // NOTE:
@@ -428,29 +453,39 @@ private[parquet] class ParquetFilters(
 
     predicate match {
       case sources.IsNull(name) if canMakeFilterOn(name, null) =>
-        makeEq.lift(nameToType(name)).map(_(name, null))
+        makeEq.lift(nameToParquetField(name).fieldType)
+          .map(_(nameToParquetField(name).fieldName, null))
       case sources.IsNotNull(name) if canMakeFilterOn(name, null) =>
-        makeNotEq.lift(nameToType(name)).map(_(name, null))
+        makeNotEq.lift(nameToParquetField(name).fieldType)
+          .map(_(nameToParquetField(name).fieldName, null))
 
       case sources.EqualTo(name, value) if canMakeFilterOn(name, value) =>
-        makeEq.lift(nameToType(name)).map(_(name, value))
+        makeEq.lift(nameToParquetField(name).fieldType)
+          .map(_(nameToParquetField(name).fieldName, value))
       case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, 
value) =>
-        makeNotEq.lift(nameToType(name)).map(_(name, value))
+        makeNotEq.lift(nameToParquetField(name).fieldType)
+          .map(_(nameToParquetField(name).fieldName, value))
 
       case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) 
=>
-        makeEq.lift(nameToType(name)).map(_(name, value))
+        makeEq.lift(nameToParquetField(name).fieldType)
+          .map(_(nameToParquetField(name).fieldName, value))
       case sources.Not(sources.EqualNullSafe(name, value)) if 
canMakeFilterOn(name, value) =>
-        makeNotEq.lift(nameToType(name)).map(_(name, value))
+        makeNotEq.lift(nameToParquetField(name).fieldType)
+          .map(_(nameToParquetField(name).fieldName, value))
 
       case sources.LessThan(name, value) if canMakeFilterOn(name, value) =>
-        makeLt.lift(nameToType(name)).map(_(name, value))
+        makeLt.lift(nameToParquetField(name).fieldType)
+          .map(_(nameToParquetField(name).fieldName, value))
       case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, 
value) =>
-        makeLtEq.lift(nameToType(name)).map(_(name, value))
+        makeLtEq.lift(nameToParquetField(name).fieldType)
+          .map(_(nameToParquetField(name).fieldName, value))
 
       case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) =>
-        makeGt.lift(nameToType(name)).map(_(name, value))
+        makeGt.lift(nameToParquetField(name).fieldType)
+          .map(_(nameToParquetField(name).fieldName, value))
       case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, 
value) =>
-        makeGtEq.lift(nameToType(name)).map(_(name, value))
+        makeGtEq.lift(nameToParquetField(name).fieldType)
+          .map(_(nameToParquetField(name).fieldName, value))
 
       case sources.And(lhs, rhs) =>
         // At here, it is not safe to just convert one side if we do not 
understand the
@@ -477,7 +512,8 @@ private[parquet] class ParquetFilters(
       case sources.In(name, values) if canMakeFilterOn(name, values.head)
         && values.distinct.length <= pushDownInFilterThreshold =>
         values.distinct.flatMap { v =>
-          makeEq.lift(nameToType(name)).map(_(name, v))
+          makeEq.lift(nameToParquetField(name).fieldType)
+            .map(_(nameToParquetField(name).fieldName, v))
         }.reduceLeftOption(FilterApi.or)
 
       case sources.StringStartsWith(name, prefix)

http://git-wip-us.apache.org/repos/asf/spark/blob/8d9495a8/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index be4f498..7ebb750 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -25,6 +25,7 @@ import org.apache.parquet.filter2.predicate.{FilterApi, 
FilterPredicate, Operato
 import org.apache.parquet.filter2.predicate.FilterApi._
 import org.apache.parquet.filter2.predicate.Operators.{Column => _, _}
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
@@ -60,7 +61,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest 
with SharedSQLContex
   private lazy val parquetFilters =
     new ParquetFilters(conf.parquetFilterPushDownDate, 
conf.parquetFilterPushDownTimestamp,
       conf.parquetFilterPushDownDecimal, 
conf.parquetFilterPushDownStringStartWith,
-      conf.parquetFilterPushDownInFilterThreshold)
+      conf.parquetFilterPushDownInFilterThreshold, conf.caseSensitiveAnalysis)
 
   override def beforeEach(): Unit = {
     super.beforeEach()
@@ -1021,6 +1022,118 @@ class ParquetFilterSuite extends QueryTest with 
ParquetTest with SharedSQLContex
       }
     }
   }
+
+  test("SPARK-25207: Case-insensitive field resolution for pushdown when 
reading parquet") {
+    def createParquetFilter(caseSensitive: Boolean): ParquetFilters = {
+      new ParquetFilters(conf.parquetFilterPushDownDate, 
conf.parquetFilterPushDownTimestamp,
+        conf.parquetFilterPushDownDecimal, 
conf.parquetFilterPushDownStringStartWith,
+        conf.parquetFilterPushDownInFilterThreshold, caseSensitive)
+    }
+    val caseSensitiveParquetFilters = createParquetFilter(caseSensitive = true)
+    val caseInsensitiveParquetFilters = createParquetFilter(caseSensitive = 
false)
+
+    def testCaseInsensitiveResolution(
+        schema: StructType,
+        expected: FilterPredicate,
+        filter: sources.Filter): Unit = {
+      val parquetSchema = new 
SparkToParquetSchemaConverter(conf).convert(schema)
+
+      assertResult(Some(expected)) {
+        caseInsensitiveParquetFilters.createFilter(parquetSchema, filter)
+      }
+      assertResult(None) {
+        caseSensitiveParquetFilters.createFilter(parquetSchema, filter)
+      }
+    }
+
+    val schema = StructType(Seq(StructField("cint", IntegerType)))
+
+    testCaseInsensitiveResolution(
+      schema, FilterApi.eq(intColumn("cint"), null.asInstanceOf[Integer]), 
sources.IsNull("CINT"))
+
+    testCaseInsensitiveResolution(
+      schema,
+      FilterApi.notEq(intColumn("cint"), null.asInstanceOf[Integer]),
+      sources.IsNotNull("CINT"))
+
+    testCaseInsensitiveResolution(
+      schema, FilterApi.eq(intColumn("cint"), 1000: Integer), 
sources.EqualTo("CINT", 1000))
+
+    testCaseInsensitiveResolution(
+      schema,
+      FilterApi.notEq(intColumn("cint"), 1000: Integer),
+      sources.Not(sources.EqualTo("CINT", 1000)))
+
+    testCaseInsensitiveResolution(
+      schema, FilterApi.eq(intColumn("cint"), 1000: Integer), 
sources.EqualNullSafe("CINT", 1000))
+
+    testCaseInsensitiveResolution(
+      schema,
+      FilterApi.notEq(intColumn("cint"), 1000: Integer),
+      sources.Not(sources.EqualNullSafe("CINT", 1000)))
+
+    testCaseInsensitiveResolution(
+      schema,
+      FilterApi.lt(intColumn("cint"), 1000: Integer), sources.LessThan("CINT", 
1000))
+
+    testCaseInsensitiveResolution(
+      schema,
+      FilterApi.ltEq(intColumn("cint"), 1000: Integer),
+      sources.LessThanOrEqual("CINT", 1000))
+
+    testCaseInsensitiveResolution(
+      schema, FilterApi.gt(intColumn("cint"), 1000: Integer), 
sources.GreaterThan("CINT", 1000))
+
+    testCaseInsensitiveResolution(
+      schema,
+      FilterApi.gtEq(intColumn("cint"), 1000: Integer),
+      sources.GreaterThanOrEqual("CINT", 1000))
+
+    testCaseInsensitiveResolution(
+      schema,
+      FilterApi.or(
+        FilterApi.eq(intColumn("cint"), 10: Integer),
+        FilterApi.eq(intColumn("cint"), 20: Integer)),
+      sources.In("CINT", Array(10, 20)))
+
+    val dupFieldSchema = StructType(
+      Seq(StructField("cint", IntegerType), StructField("cINT", IntegerType)))
+    val dupParquetSchema = new 
SparkToParquetSchemaConverter(conf).convert(dupFieldSchema)
+    assertResult(None) {
+      caseInsensitiveParquetFilters.createFilter(
+        dupParquetSchema, sources.EqualTo("CINT", 1000))
+    }
+  }
+
+  test("SPARK-25207: exception when duplicate fields in case-insensitive 
mode") {
+    withTempPath { dir =>
+      val count = 10
+      val tableName = "spark_25207"
+      val tableDir = dir.getAbsoluteFile + "/table"
+      withTable(tableName) {
+        withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+          spark.range(count).selectExpr("id as A", "id as B", "id as b")
+            .write.mode("overwrite").parquet(tableDir)
+        }
+        sql(
+          s"""
+             |CREATE TABLE $tableName (A LONG, B LONG) USING PARQUET LOCATION 
'$tableDir'
+           """.stripMargin)
+
+        withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+          val e = intercept[SparkException] {
+            sql(s"select a from $tableName where b > 0").collect()
+          }
+          assert(e.getCause.isInstanceOf[RuntimeException] && 
e.getCause.getMessage.contains(
+            """Found duplicate field(s) "B": [B, b] in case-insensitive 
mode"""))
+        }
+
+        withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+          checkAnswer(sql(s"select A from $tableName where B > 0"), (1 until 
count).map(Row(_)))
+        }
+      }
+    }
+  }
 }
 
 class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to