Repository: spark
Updated Branches:
  refs/heads/master 2d96d35dc -> 6ce1b675e


http://git-wip-us.apache.org/repos/asf/spark/blob/6ce1b675/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
index 9ee3d62..569a9c1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
@@ -172,15 +172,24 @@ private[hive] trait HiveClient {
    * Returns the partitions for the given table that match the supplied 
partition spec.
    * If no partition spec is specified, all partitions are returned.
    */
-  def getPartitions(
+  final def getPartitions(
       db: String,
       table: String,
+      partialSpec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = {
+    getPartitions(getTable(db, table), partialSpec)
+  }
+
+  /**
+   * Returns the partitions for the given table that match the supplied 
partition spec.
+   * If no partition spec is specified, all partitions are returned.
+   */
+  def getPartitions(
+      catalogTable: CatalogTable,
       partialSpec: Option[TablePartitionSpec] = None): 
Seq[CatalogTablePartition]
 
   /** Returns partitions filtered by predicates for the given table. */
   def getPartitionsByFilter(
-      db: String,
-      table: String,
+      catalogTable: CatalogTable,
       predicates: Seq[Expression]): Seq[CatalogTablePartition]
 
   /** Loads a static partition into an existing table. */

http://git-wip-us.apache.org/repos/asf/spark/blob/6ce1b675/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index 5c8f7ff..e745a8c 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -37,6 +37,7 @@ import org.apache.hadoop.security.UserGroupInformation
 
 import org.apache.spark.{SparkConf, SparkException}
 import org.apache.spark.internal.Logging
+import org.apache.spark.metrics.source.HiveCatalogMetrics
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, 
NoSuchPartitionException}
@@ -525,22 +526,24 @@ private[hive] class HiveClientImpl(
    * If no partition spec is specified, all partitions are returned.
    */
   override def getPartitions(
-      db: String,
-      table: String,
+      table: CatalogTable,
       spec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = 
withHiveState {
-    val hiveTable = toHiveTable(getTable(db, table))
-    spec match {
+    val hiveTable = toHiveTable(table)
+    val parts = spec match {
       case None => shim.getAllPartitions(client, 
hiveTable).map(fromHivePartition)
       case Some(s) => client.getPartitions(hiveTable, 
s.asJava).asScala.map(fromHivePartition)
     }
+    HiveCatalogMetrics.incrementFetchedPartitions(parts.length)
+    parts
   }
 
   override def getPartitionsByFilter(
-      db: String,
-      table: String,
+      table: CatalogTable,
       predicates: Seq[Expression]): Seq[CatalogTablePartition] = withHiveState 
{
-    val hiveTable = toHiveTable(getTable(db, table))
-    shim.getPartitionsByFilter(client, hiveTable, 
predicates).map(fromHivePartition)
+    val hiveTable = toHiveTable(table)
+    val parts = shim.getPartitionsByFilter(client, hiveTable, 
predicates).map(fromHivePartition)
+    HiveCatalogMetrics.incrementFetchedPartitions(parts.length)
+    parts
   }
 
   override def listTables(dbName: String): Seq[String] = withHiveState {

http://git-wip-us.apache.org/repos/asf/spark/blob/6ce1b675/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
index e94f49e..1af3280 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
@@ -313,7 +313,17 @@ private[orc] object OrcRelation extends HiveInspectors {
 
   def setRequiredColumns(
       conf: Configuration, physicalSchema: StructType, requestedSchema: 
StructType): Unit = {
-    val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): 
Integer)
+    val caseInsensitiveFieldMap: Map[String, Int] = physicalSchema.fieldNames
+      .zipWithIndex
+      .map(f => (f._1.toLowerCase, f._2))
+      .toMap
+    val ids = requestedSchema.map { a =>
+      val exactMatch: Option[Int] = physicalSchema.getFieldIndex(a.name)
+      val res = exactMatch.getOrElse(
+        caseInsensitiveFieldMap.getOrElse(a.name,
+          throw new IllegalArgumentException(s"""Field "$a.name" does not 
exist.""")))
+      res: Integer
+    }
     val (sortedIDs, sortedNames) = 
ids.zip(requestedSchema.fieldNames).sorted.unzip
     HiveShim.appendReadColumns(conf, sortedIDs, sortedNames)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/6ce1b675/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala
index 96e9054..f65e74d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala
@@ -17,10 +17,14 @@
 
 package org.apache.spark.sql.hive
 
+import java.io.File
+
+import org.apache.spark.metrics.source.HiveCatalogMetrics
 import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.sql.QueryTest
 
-class HiveDataFrameSuite extends QueryTest with TestHiveSingleton {
+class HiveDataFrameSuite extends QueryTest with TestHiveSingleton with 
SQLTestUtils {
   test("table name with schema") {
     // regression test for SPARK-11778
     spark.sql("create schema usrdb")
@@ -34,4 +38,107 @@ class HiveDataFrameSuite extends QueryTest with 
TestHiveSingleton {
     val hiveClient = 
spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
     assert(hiveClient.getConf("hive.in.test", "") == "true")
   }
+
+  private def setupPartitionedTable(tableName: String, dir: File): Unit = {
+    spark.range(5).selectExpr("id", "id as partCol1", "id as partCol2").write
+      .partitionBy("partCol1", "partCol2")
+      .mode("overwrite")
+      .parquet(dir.getAbsolutePath)
+
+    spark.sql(s"""
+      |create external table $tableName (id long)
+      |partitioned by (partCol1 int, partCol2 int)
+      |stored as parquet
+      |location "${dir.getAbsolutePath}"""".stripMargin)
+    spark.sql(s"msck repair table $tableName")
+  }
+
+  test("partitioned pruned table reports only selected files") {
+    assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) 
== "true")
+    withTable("test") {
+      withTempDir { dir =>
+        setupPartitionedTable("test", dir)
+        val df = spark.sql("select * from test")
+        assert(df.count() == 5)
+        assert(df.inputFiles.length == 5)  // unpruned
+
+        val df2 = spark.sql("select * from test where partCol1 = 3 or partCol2 
= 4")
+        assert(df2.count() == 2)
+        assert(df2.inputFiles.length == 2)  // pruned, so we have less files
+
+        val df3 = spark.sql("select * from test where PARTCOL1 = 3 or partcol2 
= 4")
+        assert(df3.count() == 2)
+        assert(df3.inputFiles.length == 2)
+
+        val df4 = spark.sql("select * from test where partCol1 = 999")
+        assert(df4.count() == 0)
+        assert(df4.inputFiles.length == 0)
+      }
+    }
+  }
+
+  test("lazy partition pruning reads only necessary partition data") {
+    withSQLConf("spark.sql.hive.filesourcePartitionPruning" -> "true") {
+      withTable("test") {
+        withTempDir { dir =>
+          setupPartitionedTable("test", dir)
+          HiveCatalogMetrics.reset()
+          spark.sql("select * from test where partCol1 = 999").count()
+          assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0)
+          assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0)
+
+          HiveCatalogMetrics.reset()
+          spark.sql("select * from test where partCol1 < 2").count()
+          assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 2)
+          assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 2)
+
+          HiveCatalogMetrics.reset()
+          spark.sql("select * from test where partCol1 < 3").count()
+          assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 3)
+          assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 3)
+
+          // should read all
+          HiveCatalogMetrics.reset()
+          spark.sql("select * from test").count()
+          assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5)
+          assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5)
+
+          // read all should be cached
+          HiveCatalogMetrics.reset()
+          spark.sql("select * from test").count()
+          assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0)
+          assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0)
+        }
+      }
+    }
+  }
+
+  test("all partitions read and cached when filesource partition pruning is 
off") {
+    withSQLConf("spark.sql.hive.filesourcePartitionPruning" -> "false") {
+      withTable("test") {
+        withTempDir { dir =>
+          setupPartitionedTable("test", dir)
+
+          // We actually query the partitions from hive each time the table is 
resolved in this
+          // mode. This is kind of terrible, but is needed to preserve the 
legacy behavior
+          // of doing plan cache validation based on the entire partition set.
+          HiveCatalogMetrics.reset()
+          spark.sql("select * from test where partCol1 = 999").count()
+          // 5 from table resolution, another 5 from ListingFileCatalog
+          assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 10)
+          assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5)
+
+          HiveCatalogMetrics.reset()
+          spark.sql("select * from test where partCol1 < 2").count()
+          assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5)
+          assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0)
+
+          HiveCatalogMetrics.reset()
+          spark.sql("select * from test").count()
+          assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5)
+          assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0)
+        }
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6ce1b675/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
index 3414f5e..7af81a3 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
@@ -59,4 +59,45 @@ class HiveMetadataCacheSuite extends QueryTest with 
SQLTestUtils with TestHiveSi
       }
     }
   }
+
+  def testCaching(pruningEnabled: Boolean): Unit = {
+    test(s"partitioned table is cached when partition pruning is 
$pruningEnabled") {
+      withSQLConf("spark.sql.hive.filesourcePartitionPruning" -> 
pruningEnabled.toString) {
+        withTable("test") {
+          withTempDir { dir =>
+            spark.range(5).selectExpr("id", "id as f1", "id as f2").write
+              .partitionBy("f1", "f2")
+              .mode("overwrite")
+              .parquet(dir.getAbsolutePath)
+
+            spark.sql(s"""
+              |create external table test (id long)
+              |partitioned by (f1 int, f2 int)
+              |stored as parquet
+              |location "${dir.getAbsolutePath}"""".stripMargin)
+            spark.sql("msck repair table test")
+
+            val df = spark.sql("select * from test")
+            assert(sql("select * from test").count() == 5)
+
+            // Delete a file, then assert that we tried to read it. This means 
the table was cached.
+            val p = new Path(spark.table("test").inputFiles.head)
+            
assert(p.getFileSystem(hiveContext.sessionState.newHadoopConf()).delete(p, 
true))
+            val e = intercept[SparkException] {
+              sql("select * from test").count()
+            }
+            assert(e.getMessage.contains("FileNotFoundException"))
+
+            // Test refreshing the cache.
+            spark.catalog.refreshTable("test")
+            assert(sql("select * from test").count() == 4)
+          }
+        }
+      }
+    }
+  }
+
+  for (pruningEnabled <- Seq(true, false)) {
+    testCaching(pruningEnabled)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6ce1b675/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index c158bf1..9a10957 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -295,12 +295,12 @@ class VersionsSuite extends SparkFunSuite with Logging {
     }
 
     test(s"$version: getPartitions(catalogTable)") {
-      assert(2 == client.getPartitions("default", "src_part").size)
+      assert(2 == client.getPartitions(client.getTable("default", 
"src_part")).size)
     }
 
     test(s"$version: getPartitionsByFilter") {
       // Only one partition [1, 1] for key2 == 1
-      val result = client.getPartitionsByFilter("default", "src_part",
+      val result = client.getPartitionsByFilter(client.getTable("default", 
"src_part"),
         Seq(EqualTo(AttributeReference("key2", IntegerType)(), Literal(1))))
 
       // Hive 0.12 doesn't support getPartitionsByFilter, it ignores the 
filter condition.

http://git-wip-us.apache.org/repos/asf/spark/blob/6ce1b675/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
index b2ee49c..ecb5972 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
@@ -474,6 +474,28 @@ class OrcQuerySuite extends QueryTest with 
BeforeAndAfterAll with OrcTest {
     }
   }
 
+  test("converted ORC table supports resolving mixed case field") {
+    withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") {
+      withTable("dummy_orc") {
+        withTempPath { dir =>
+          val df = spark.range(5).selectExpr("id", "id as valueField", "id as 
partitionValue")
+          df.write
+            .partitionBy("partitionValue")
+            .mode("overwrite")
+            .orc(dir.getAbsolutePath)
+
+          spark.sql(s"""
+            |create external table dummy_orc (id long, valueField long)
+            |partitioned by (partitionValue int)
+            |stored as orc
+            |location "${dir.getAbsolutePath}"""".stripMargin)
+          spark.sql(s"msck repair table dummy_orc")
+          checkAnswer(spark.sql("select * from dummy_orc"), df)
+        }
+      }
+    }
+  }
+
   test("SPARK-14962 Produce correct results on array type with isnotnull") {
     withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
       val data = (0 until 10).map(i => Tuple1(Array(i)))

http://git-wip-us.apache.org/repos/asf/spark/blob/6ce1b675/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index 2f6d9fb..9fc62a3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -175,7 +175,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest 
{
     (1 to 10).map(i => Tuple1(Seq(new Integer(i), null))).toDF("a")
       .createOrReplaceTempView("jt_array")
 
-    setConf(HiveUtils.CONVERT_METASTORE_PARQUET, true)
+    assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) 
== "true")
   }
 
   override def afterAll(): Unit = {
@@ -187,7 +187,6 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest 
{
       "jt",
       "jt_array",
        "test_parquet")
-    setConf(HiveUtils.CONVERT_METASTORE_PARQUET, false)
   }
 
   test(s"conversion is working") {
@@ -586,6 +585,23 @@ class ParquetMetastoreSuite extends 
ParquetPartitioningTest {
         checkAnswer(
           sql("SELECT * FROM test_added_partitions"),
           Seq(("foo", 0), ("bar", 0), ("baz", 1)).toDF("a", "b"))
+
+        // Check it with pruning predicates
+        checkAnswer(
+          sql("SELECT * FROM test_added_partitions where b = 0"),
+          Seq(("foo", 0), ("bar", 0)).toDF("a", "b"))
+        checkAnswer(
+          sql("SELECT * FROM test_added_partitions where b = 1"),
+          Seq(("baz", 1)).toDF("a", "b"))
+        checkAnswer(
+          sql("SELECT * FROM test_added_partitions where b = 2"),
+          Seq[(String, Int)]().toDF("a", "b"))
+
+        // Also verify the inputFiles implementation
+        assert(sql("select * from test_added_partitions").inputFiles.length == 
2)
+        assert(sql("select * from test_added_partitions where b = 
0").inputFiles.length == 1)
+        assert(sql("select * from test_added_partitions where b = 
1").inputFiles.length == 1)
+        assert(sql("select * from test_added_partitions where b = 
2").inputFiles.length == 0)
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to