Repository: spark
Updated Branches:
  refs/heads/master b29bc3f51 -> 145433f1a


[SPARK-14369] [SQL] Locality support for FileScanRDD

(This PR is a rebased version of PR #12153.)

## What changes were proposed in this pull request?

This PR adds preliminary locality support for `FileFormat` data sources by 
overriding `FileScanRDD.preferredLocations()`. The strategy can be divided into 
two parts:

1.  Block location lookup

    Unlike `HadoopRDD` or `NewHadoopRDD`, `FileScanRDD` doesn't have access to 
the underlying `InputFormat` or `InputSplit`, and thus can't rely on 
`InputSplit.getLocations()` to gather locality information. Instead, this PR 
queries block locations using `FileSystem.getBlockLocations()` after listing 
all `FileStatus`es in `HDFSFileCatalog` and convert all `FileStatus`es into 
`LocatedFileStatus`es.

    Note that although S3/S3A/S3N file systems don't provide valid locality 
information, their `getLocatedStatus()` implementations don't actually issue 
remote calls either. So there's no need to special case these file systems.

2.  Selecting preferred locations

    For each `FilePartition`, we pick up top 3 locations that containing the 
most data to be retrieved. This isn't necessarily the best algorithm out there. 
Further improvements may be brought up in follow-up PRs.

## How was this patch tested?

Tested by overriding default `FileSystem` implementation for `file:///` with a 
mocked one, which returns mocked block locations.

Author: Cheng Lian <l...@databricks.com>

Closes #12527 from liancheng/spark-14369-locality-rebased.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/145433f1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/145433f1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/145433f1

Branch: refs/heads/master
Commit: 145433f1aaf4a58f484f98c2f1d32abd8cc95b48
Parents: b29bc3f
Author: Cheng Lian <l...@databricks.com>
Authored: Thu Apr 21 21:48:09 2016 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Thu Apr 21 21:48:09 2016 -0700

----------------------------------------------------------------------
 .../sql/execution/datasources/FileScanRDD.scala |  24 ++++-
 .../datasources/FileSourceStrategy.scala        |  53 +++++++++-
 .../datasources/fileSourceInterfaces.scala      |  84 +++++++++++----
 .../datasources/FileSourceStrategySuite.scala   | 106 +++++++++++++++++--
 .../apache/spark/sql/test/SQLTestUtils.scala    |  21 ++++
 .../sql/sources/hadoopFsRelationSuites.scala    |  40 ++++++-
 6 files changed, 291 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/145433f1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 90694d9..60238bd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.datasources
 
+import scala.collection.mutable
+
 import org.apache.spark.{Partition => RDDPartition, TaskContext}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.rdd.{InputFileNameHolder, RDD}
@@ -33,7 +35,8 @@ case class PartitionedFile(
     partitionValues: InternalRow,
     filePath: String,
     start: Long,
-    length: Long) {
+    length: Long,
+    locations: Array[String] = Array.empty) {
   override def toString: String = {
     s"path: $filePath, range: $start-${start + length}, partition values: 
$partitionValues"
   }
@@ -131,4 +134,23 @@ class FileScanRDD(
   }
 
   override protected def getPartitions: Array[RDDPartition] = 
filePartitions.toArray
+
+  override protected def getPreferredLocations(split: RDDPartition): 
Seq[String] = {
+    val files = split.asInstanceOf[FilePartition].files
+
+    // Computes total number of bytes can be retrieved from each host.
+    val hostToNumBytes = mutable.HashMap.empty[String, Long]
+    files.foreach { file =>
+      file.locations.filter(_ != "localhost").foreach { host =>
+        hostToNumBytes(host) = hostToNumBytes.getOrElse(host, 0L) + file.length
+      }
+    }
+
+    // Takes the first 3 hosts with the most data to be retrieved
+    hostToNumBytes.toSeq.sortBy {
+      case (host, numBytes) => numBytes
+    }.reverse.take(3).map {
+      case (host, numBytes) => host
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/145433f1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 80a9156..ee48a7b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, 
Path}
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql._
@@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.{DataSourceScan, SparkPlan}
-import org.apache.spark.sql.sources._
 
 /**
  * A strategy for planning scans over collections of files that might be 
partitioned or bucketed
@@ -120,7 +119,10 @@ private[sql] object FileSourceStrategy extends Strategy 
with Logging {
           logInfo(s"Planning with ${bucketing.numBuckets} buckets")
           val bucketed =
             selectedPartitions.flatMap { p =>
-              p.files.map(f => PartitionedFile(p.values, 
f.getPath.toUri.toString, 0, f.getLen))
+              p.files.map { f =>
+                val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen)
+                PartitionedFile(p.values, f.getPath.toUri.toString, 0, 
f.getLen, hosts)
+              }
             }.groupBy { f =>
               BucketingUtils
                 .getBucketId(new Path(f.filePath).getName)
@@ -139,10 +141,12 @@ private[sql] object FileSourceStrategy extends Strategy 
with Logging {
 
           val splitFiles = selectedPartitions.flatMap { partition =>
             partition.files.flatMap { file =>
-              (0L to file.getLen by maxSplitBytes).map { offset =>
+              val blockLocations = getBlockLocations(file)
+              (0L until file.getLen by maxSplitBytes).map { offset =>
                 val remaining = file.getLen - offset
                 val size = if (remaining > maxSplitBytes) maxSplitBytes else 
remaining
-                PartitionedFile(partition.values, file.getPath.toUri.toString, 
offset, size)
+                val hosts = getBlockHosts(blockLocations, offset, size)
+                PartitionedFile(partition.values, file.getPath.toUri.toString, 
offset, size, hosts)
               }
             }
           }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
@@ -207,4 +211,43 @@ private[sql] object FileSourceStrategy extends Strategy 
with Logging {
 
     case _ => Nil
   }
+
+  private def getBlockLocations(file: FileStatus): Array[BlockLocation] = file 
match {
+    case f: LocatedFileStatus => f.getBlockLocations
+    case f => Array.empty[BlockLocation]
+  }
+
+  // Given locations of all blocks of a single file, `blockLocations`, and an 
`(offset, length)`
+  // pair that represents a segment of the same file, find out the block that 
contains the largest
+  // fraction the segment, and returns location hosts of that block. If no 
such block can be found,
+  // returns an empty array.
+  private def getBlockHosts(
+      blockLocations: Array[BlockLocation], offset: Long, length: Long): 
Array[String] = {
+    val candidates = blockLocations.map {
+      // The fragment starts from a position within this block
+      case b if b.getOffset <= offset && offset < b.getOffset + b.getLength =>
+        b.getHosts -> (b.getOffset + b.getLength - offset).min(length)
+
+      // The fragment ends at a position within this block
+      case b if offset <= b.getOffset && offset + length < b.getLength =>
+        b.getHosts -> (offset + length - b.getOffset).min(length)
+
+      // The fragment fully contains this block
+      case b if offset <= b.getOffset && b.getOffset + b.getLength <= offset + 
length =>
+        b.getHosts -> b.getLength
+
+      // The fragment doesn't intersect with this block
+      case b =>
+        b.getHosts -> 0L
+    }.filter { case (hosts, size) =>
+      size > 0L
+    }
+
+    if (candidates.isEmpty) {
+      Array.empty[String]
+    } else {
+      val (hosts, _) = candidates.maxBy { case (_, size) => size }
+      hosts
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/145433f1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
index d37a939..ed24bdd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
 import scala.util.Try
 
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.fs._
 import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
 import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
 
@@ -342,16 +342,31 @@ class HDFSFileCatalog(
     if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) {
       HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, 
sqlContext.sparkContext)
     } else {
-      val statuses = paths.flatMap { path =>
+      val statuses: Seq[FileStatus] = paths.flatMap { path =>
         val fs = path.getFileSystem(hadoopConf)
         logInfo(s"Listing $path on driver")
         // Dummy jobconf to get to the pathFilter defined in configuration
-        val jobConf = new JobConf(hadoopConf, this.getClass())
+        val jobConf = new JobConf(hadoopConf, this.getClass)
         val pathFilter = FileInputFormat.getInputPathFilter(jobConf)
-        if (pathFilter != null) {
-          Try(fs.listStatus(path, pathFilter)).getOrElse(Array.empty)
-        } else {
-          Try(fs.listStatus(path)).getOrElse(Array.empty)
+
+        val statuses = {
+          val stats = 
Try(fs.listStatus(path)).getOrElse(Array.empty[FileStatus])
+          if (pathFilter != null) stats.filter(f => 
pathFilter.accept(f.getPath)) else stats
+        }
+
+        statuses.map {
+          case f: LocatedFileStatus => f
+
+          // NOTE:
+          //
+          // - Although S3/S3A/S3N file system can be quite slow for remote 
file metadata
+          //   operations, calling `getFileBlockLocations` does no harm here 
since these file system
+          //   implementations don't actually issue RPC for this method.
+          //
+          // - Here we are calling `getFileBlockLocations` in a sequential 
manner, but it should a
+          //   a big deal since we always use to `listLeafFilesInParallel` 
when the number of paths
+          //   exceeds threshold.
+          case f => new LocatedFileStatus(f, fs.getFileBlockLocations(f, 0, 
f.getLen))
         }
       }.filterNot { status =>
         val name = status.getPath.getName
@@ -369,7 +384,7 @@ class HDFSFileCatalog(
     }
   }
 
-   def inferPartitioning(schema: Option[StructType]): PartitionSpec = {
+  def inferPartitioning(schema: Option[StructType]): PartitionSpec = {
     // We use leaf dirs containing data files to discover the schema.
     val leafDirs = leafDirToChildrenFiles.keys.toSeq
     schema match {
@@ -473,15 +488,15 @@ private[sql] object HadoopFsRelation extends Logging {
       // Dummy jobconf to get to the pathFilter defined in configuration
       val jobConf = new JobConf(fs.getConf, this.getClass())
       val pathFilter = FileInputFormat.getInputPathFilter(jobConf)
-      val statuses =
-        if (pathFilter != null) {
-          val (dirs, files) = fs.listStatus(status.getPath, 
pathFilter).partition(_.isDirectory)
-          files ++ dirs.flatMap(dir => listLeafFiles(fs, dir))
-        } else {
-          val (dirs, files) = 
fs.listStatus(status.getPath).partition(_.isDirectory)
-          files ++ dirs.flatMap(dir => listLeafFiles(fs, dir))
-        }
-      statuses.filterNot(status => shouldFilterOut(status.getPath.getName))
+      val statuses = {
+        val (dirs, files) = 
fs.listStatus(status.getPath).partition(_.isDirectory)
+        val stats = files ++ dirs.flatMap(dir => listLeafFiles(fs, dir))
+        if (pathFilter != null) stats.filter(f => 
pathFilter.accept(f.getPath)) else stats
+      }
+      statuses.filterNot(status => 
shouldFilterOut(status.getPath.getName)).map {
+        case f: LocatedFileStatus => f
+        case f => new LocatedFileStatus(f, fs.getFileBlockLocations(f, 0, 
f.getLen))
+      }
     }
   }
 
@@ -489,6 +504,12 @@ private[sql] object HadoopFsRelation extends Logging {
   // well with `SerializableWritable`.  So there seems to be no way to 
serialize a `FileStatus`.
   // Here we use `FakeFileStatus` to extract key components of a `FileStatus` 
to serialize it from
   // executor side and reconstruct it on driver side.
+  case class FakeBlockLocation(
+      names: Array[String],
+      hosts: Array[String],
+      offset: Long,
+      length: Long)
+
   case class FakeFileStatus(
       path: String,
       length: Long,
@@ -496,7 +517,8 @@ private[sql] object HadoopFsRelation extends Logging {
       blockReplication: Short,
       blockSize: Long,
       modificationTime: Long,
-      accessTime: Long)
+      accessTime: Long,
+      blockLocations: Array[FakeBlockLocation])
 
   def listLeafFilesInParallel(
       paths: Seq[Path],
@@ -511,6 +533,20 @@ private[sql] object HadoopFsRelation extends Logging {
       val fs = path.getFileSystem(serializableConfiguration.value)
       Try(listLeafFiles(fs, fs.getFileStatus(path))).getOrElse(Array.empty)
     }.map { status =>
+      val blockLocations = status match {
+        case f: LocatedFileStatus =>
+          f.getBlockLocations.map { loc =>
+            FakeBlockLocation(
+              loc.getNames,
+              loc.getHosts,
+              loc.getOffset,
+              loc.getLength)
+          }
+
+        case _ =>
+          Array.empty[FakeBlockLocation]
+      }
+
       FakeFileStatus(
         status.getPath.toString,
         status.getLen,
@@ -518,12 +554,18 @@ private[sql] object HadoopFsRelation extends Logging {
         status.getReplication,
         status.getBlockSize,
         status.getModificationTime,
-        status.getAccessTime)
+        status.getAccessTime,
+        blockLocations)
     }.collect()
 
     val hadoopFakeStatuses = fakeStatuses.map { f =>
-      new FileStatus(
-        f.length, f.isDir, f.blockReplication, f.blockSize, 
f.modificationTime, new Path(f.path))
+      val blockLocations = f.blockLocations.map { loc =>
+        new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length)
+      }
+      new LocatedFileStatus(
+        new FileStatus(
+          f.length, f.isDir, f.blockReplication, f.blockSize, 
f.modificationTime, new Path(f.path)),
+        blockLocations)
     }
     mutable.LinkedHashSet(hadoopFakeStatuses: _*)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/145433f1/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
index dac56d3..4699c48 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
@@ -18,8 +18,9 @@
 package org.apache.spark.sql.execution.datasources
 
 import java.io.File
+import java.util.concurrent.atomic.AtomicInteger
 
-import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.{BlockLocation, FileStatus, RawLocalFileSystem}
 import org.apache.hadoop.mapreduce.Job
 
 import org.apache.spark.sql._
@@ -267,6 +268,80 @@ class FileSourceStrategySuite extends QueryTest with 
SharedSQLContext with Predi
     }
   }
 
+  test("Locality support for FileScanRDD") {
+    val partition = FilePartition(0, Seq(
+      PartitionedFile(InternalRow.empty, "fakePath0", 0, 10, Array("host0", 
"host1")),
+      PartitionedFile(InternalRow.empty, "fakePath0", 10, 20, Array("host1", 
"host2")),
+      PartitionedFile(InternalRow.empty, "fakePath1", 0, 5, Array("host3")),
+      PartitionedFile(InternalRow.empty, "fakePath2", 0, 5, Array("host4"))
+    ))
+
+    val fakeRDD = new FileScanRDD(
+      sqlContext,
+      (file: PartitionedFile) => Iterator.empty,
+      Seq(partition)
+    )
+
+    assertResult(Set("host0", "host1", "host2")) {
+      fakeRDD.preferredLocations(partition).toSet
+    }
+  }
+
+  test("Locality support for FileScanRDD - one file per partition") {
+    withHadoopConf(
+      "fs.file.impl" -> classOf[LocalityTestFileSystem].getName,
+      "fs.file.impl.disable.cache" -> "true"
+    ) {
+      withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") {
+        val table =
+          createTable(files = Seq(
+            "file1" -> 10,
+            "file2" -> 10
+          ))
+
+        checkScan(table) { partitions =>
+          val Seq(p1, p2) = partitions
+          assert(p1.files.length == 1)
+          assert(p1.files.flatMap(_.locations).length == 1)
+          assert(p2.files.length == 1)
+          assert(p2.files.flatMap(_.locations).length == 1)
+
+          val fileScanRDD = getFileScanRDD(table)
+          assert(partitions.flatMap(fileScanRDD.preferredLocations).length == 
2)
+        }
+      }
+    }
+  }
+
+  test("Locality support for FileScanRDD - large file") {
+    withHadoopConf(
+      "fs.file.impl" -> classOf[LocalityTestFileSystem].getName,
+      "fs.file.impl.disable.cache" -> "true"
+    ) {
+      withSQLConf(
+        SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10",
+        SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "0"
+      ) {
+        val table =
+          createTable(files = Seq(
+            "file1" -> 15,
+            "file2" -> 5
+          ))
+
+        checkScan(table) { partitions =>
+          val Seq(p1, p2) = partitions
+          assert(p1.files.length == 1)
+          assert(p1.files.flatMap(_.locations).length == 1)
+          assert(p2.files.length == 2)
+          assert(p2.files.flatMap(_.locations).length == 2)
+
+          val fileScanRDD = getFileScanRDD(table)
+          assert(partitions.flatMap(fileScanRDD.preferredLocations).length == 
3)
+        }
+      }
+    }
+  }
+
   // Helpers for checking the arguments passed to the FileFormat.
 
   protected val checkPartitionSchema =
@@ -303,14 +378,7 @@ class FileSourceStrategySuite extends QueryTest with 
SharedSQLContext with Predi
 
   /** Plans the query and calls the provided validation function with the 
planned partitioning. */
   def checkScan(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = {
-    val fileScan = df.queryExecution.executedPlan.collect {
-      case scan: DataSourceScan if scan.rdd.isInstanceOf[FileScanRDD] =>
-        scan.rdd.asInstanceOf[FileScanRDD]
-    }.headOption.getOrElse {
-      fail(s"No FileScan in query\n${df.queryExecution}")
-    }
-
-    func(fileScan.filePartitions)
+    func(getFileScanRDD(df).filePartitions)
   }
 
   /**
@@ -348,6 +416,15 @@ class FileSourceStrategySuite extends QueryTest with 
SharedSQLContext with Predi
       df
     }
   }
+
+  def getFileScanRDD(df: DataFrame): FileScanRDD = {
+    df.queryExecution.executedPlan.collect {
+      case scan: DataSourceScan if scan.rdd.isInstanceOf[FileScanRDD] =>
+        scan.rdd.asInstanceOf[FileScanRDD]
+    }.headOption.getOrElse {
+      fail(s"No FileScan in query\n${df.queryExecution}")
+    }
+  }
 }
 
 /** Holds the last arguments passed to [[TestFileFormat]]. */
@@ -407,3 +484,14 @@ class TestFileFormat extends FileFormat {
     (file: PartitionedFile) => { Iterator.empty }
   }
 }
+
+
+class LocalityTestFileSystem extends RawLocalFileSystem {
+  private val invocations = new AtomicInteger(0)
+
+  override def getFileBlockLocations(
+      file: FileStatus, start: Long, len: Long): Array[BlockLocation] = {
+    val count = invocations.getAndAdd(1)
+    Array(new BlockLocation(Array(s"host$count:50010"), Array(s"host$count"), 
0, len))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/145433f1/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 7844d1b..f615019 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -92,6 +92,27 @@ private[sql] trait SQLTestUtils
   }
 
   /**
+   * Sets all Hadoop configurations specified in `pairs`, calls `f`, and then 
restore all Hadoop
+   * configurations.
+   */
+  protected def withHadoopConf(pairs: (String, String)*)(f: => Unit): Unit = {
+    val (keys, _) = pairs.unzip
+    val originalValues = keys.map(key => Option(hadoopConfiguration.get(key)))
+
+    try {
+      pairs.foreach { case (key, value) =>
+        hadoopConfiguration.set(key, value)
+      }
+      f
+    } finally {
+      keys.zip(originalValues).foreach {
+        case (key, Some(value)) => hadoopConfiguration.set(key, value)
+        case (key, None) => hadoopConfiguration.unset(key)
+      }
+    }
+  }
+
+  /**
    * Sets all SQL configurations specified in `pairs`, calls `f`, and then 
restore all SQL
    * configurations.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/145433f1/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index 368fe62..089cef6 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -28,7 +28,8 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter
 
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.sql._
-import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, 
LogicalRelation}
+import org.apache.spark.sql.execution.DataSourceScan
+import org.apache.spark.sql.execution.datasources.{FileScanRDD, 
HadoopFsRelation, LocalityTestFileSystem, LogicalRelation}
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SQLTestUtils
@@ -668,6 +669,43 @@ abstract class HadoopFsRelationTest extends QueryTest with 
SQLTestUtils with Tes
       df.write.format(dataSourceName).partitionBy("c", "d", 
"e").saveAsTable("t")
     }
   }
+
+  test("Locality support for FileScanRDD") {
+    withHadoopConf(
+      "fs.file.impl" -> classOf[LocalityTestFileSystem].getName,
+      "fs.file.impl.disable.cache" -> "true"
+    ) {
+      withTempPath { dir =>
+        val path = "file://" + dir.getCanonicalPath
+        val df1 = sqlContext.range(4)
+        
df1.coalesce(1).write.mode("overwrite").format(dataSourceName).save(path)
+        df1.coalesce(1).write.mode("append").format(dataSourceName).save(path)
+
+        def checkLocality(): Unit = {
+          val df2 = sqlContext.read
+            .format(dataSourceName)
+            .option("dataSchema", df1.schema.json)
+            .load(path)
+
+          val Some(fileScanRDD) = df2.queryExecution.executedPlan.collectFirst 
{
+            case scan: DataSourceScan if scan.rdd.isInstanceOf[FileScanRDD] =>
+              scan.rdd.asInstanceOf[FileScanRDD]
+          }
+
+          val partitions = fileScanRDD.partitions
+          val preferredLocations = 
partitions.flatMap(fileScanRDD.preferredLocations)
+
+          assert(preferredLocations.distinct.length == 2)
+        }
+
+        checkLocality()
+
+        withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "0") 
{
+          checkLocality()
+        }
+      }
+    }
+  }
 }
 
 // This class is used to test SPARK-8578. We should not use any custom output 
committer when


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to