This is an automated email from the ASF dual-hosted git repository.

meng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9793d9e  [SPARK-27473][SQL] Support filter push down for status fields 
in binary file data source
9793d9e is described below

commit 9793d9ec22ff7d9778554e4fa3f03ef4f93d473d
Author: WeichenXu <weichen...@databricks.com>
AuthorDate: Sun Apr 21 12:45:59 2019 -0700

    [SPARK-27473][SQL] Support filter push down for status fields in binary 
file data source
    
    ## What changes were proposed in this pull request?
    
    Support 4 kinds of filters:
    - LessThan
    - LessThanOrEqual
    - GreatThan
    - GreatThanOrEqual
    
    Support filters applied on 2 columns:
    - modificationTime
    - length
    
    Note:
    In order to support datasource filter push-down, I flatten schema to be:
    ```
    val schema = StructType(
        StructField("path", StringType, false) ::
        StructField("modificationTime", TimestampType, false) ::
        StructField("length", LongType, false) ::
        StructField("content", BinaryType, true) :: Nil)
    ```
    
    ## How was this patch tested?
    
    To be added.
    
    Please review http://spark.apache.org/contributing.html before opening a 
pull request.
    
    Closes #24387 from WeichenXu123/binary_ds_filter.
    
    Lead-authored-by: WeichenXu <weichen...@databricks.com>
    Co-authored-by: Xiangrui Meng <m...@databricks.com>
    Signed-off-by: Xiangrui Meng <m...@databricks.com>
---
 .../datasources/binaryfile/BinaryFileFormat.scala  | 134 ++++++++++-----
 .../binaryfile/BinaryFileFormatSuite.scala         | 188 ++++++++++++++++++---
 2 files changed, 256 insertions(+), 66 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
index ad9292a..8617ae3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.datasources.binaryfile
 
+import java.sql.Timestamp
+
 import com.google.common.io.{ByteStreams, Closeables}
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileStatus, GlobFilter, Path}
@@ -28,7 +30,8 @@ import 
org.apache.spark.sql.catalyst.expressions.AttributeReference
 import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
 import org.apache.spark.sql.execution.datasources.{FileFormat, 
OutputWriterFactory, PartitionedFile}
-import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
+import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter, 
GreaterThan,
+  GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.SerializableConfiguration
@@ -55,10 +58,12 @@ import org.apache.spark.util.SerializableConfiguration
  */
 class BinaryFileFormat extends FileFormat with DataSourceRegister {
 
+  import BinaryFileFormat._
+
   override def inferSchema(
       sparkSession: SparkSession,
       options: Map[String, String],
-      files: Seq[FileStatus]): Option[StructType] = 
Some(BinaryFileFormat.schema)
+      files: Seq[FileStatus]): Option[StructType] = Some(schema)
 
   override def prepareWrite(
       sparkSession: SparkSession,
@@ -84,7 +89,7 @@ class BinaryFileFormat extends FileFormat with 
DataSourceRegister {
       requiredSchema: StructType,
       filters: Seq[Filter],
       options: Map[String, String],
-      hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = 
{
+      hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
 
     val broadcastedHadoopConf =
       sparkSession.sparkContext.broadcast(new 
SerializableConfiguration(hadoopConf))
@@ -93,46 +98,49 @@ class BinaryFileFormat extends FileFormat with 
DataSourceRegister {
 
     val pathGlobPattern = binaryFileSourceOptions.pathGlobFilter
 
-    (file: PartitionedFile) => {
+    val filterFuncs = filters.map(filter => createFilterFunction(filter))
+
+    file: PartitionedFile => {
       val path = file.filePath
       val fsPath = new Path(path)
 
       // TODO: Improve performance here: each file will recompile the glob 
pattern here.
-      val globFilter = pathGlobPattern.map(new GlobFilter(_))
-      if (!globFilter.isDefined || globFilter.get.accept(fsPath)) {
+      if (pathGlobPattern.forall(new GlobFilter(_).accept(fsPath))) {
         val fs = fsPath.getFileSystem(broadcastedHadoopConf.value.value)
         val fileStatus = fs.getFileStatus(fsPath)
-        val length = fileStatus.getLen()
-        val modificationTime = fileStatus.getModificationTime()
-        val stream = fs.open(fsPath)
-
-        val content = try {
-          ByteStreams.toByteArray(stream)
-        } finally {
-          Closeables.close(stream, true)
-        }
-
-        val fullOutput = dataSchema.map { f =>
-          AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
-        }
-        val requiredOutput = fullOutput.filter { a =>
-          requiredSchema.fieldNames.contains(a.name)
-        }
-
-        // TODO: Add column pruning
-        // currently it still read the file content even if content column is 
not required.
-        val requiredColumns = 
GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
-
-        val internalRow = InternalRow(
-          content,
-          InternalRow(
+        val length = fileStatus.getLen
+        val modificationTime = fileStatus.getModificationTime
+
+        if (filterFuncs.forall(_.apply(fileStatus))) {
+          val stream = fs.open(fsPath)
+          val content = try {
+            ByteStreams.toByteArray(stream)
+          } finally {
+            Closeables.close(stream, true)
+          }
+
+          val fullOutput = dataSchema.map { f =>
+            AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
+          }
+          val requiredOutput = fullOutput.filter { a =>
+            requiredSchema.fieldNames.contains(a.name)
+          }
+
+          // TODO: Add column pruning
+          // currently it still read the file content even if content column 
is not required.
+          val requiredColumns = 
GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
+
+          val internalRow = InternalRow(
             UTF8String.fromString(path),
             DateTimeUtils.fromMillis(modificationTime),
-            length
+            length,
+            content
           )
-        )
 
-        Iterator(requiredColumns(internalRow))
+          Iterator(requiredColumns(internalRow))
+        } else {
+          Iterator.empty
+        }
       } else {
         Iterator.empty
       }
@@ -142,26 +150,62 @@ class BinaryFileFormat extends FileFormat with 
DataSourceRegister {
 
 object BinaryFileFormat {
 
-  private val fileStatusSchema = StructType(
-    StructField("path", StringType, false) ::
-      StructField("modificationTime", TimestampType, false) ::
-      StructField("length", LongType, false) :: Nil)
+  private[binaryfile] val PATH = "path"
+  private[binaryfile] val MODIFICATION_TIME = "modificationTime"
+  private[binaryfile] val LENGTH = "length"
+  private[binaryfile] val CONTENT = "content"
 
   /**
    * Schema for the binary file data source.
    *
    * Schema:
+   *  - path (StringType): The path of the file.
+   *  - modificationTime (TimestampType): The modification time of the file.
+   *    In some Hadoop FileSystem implementation, this might be unavailable 
and fallback to some
+   *    default value.
+   *  - length (LongType): The length of the file in bytes.
    *  - content (BinaryType): The content of the file.
-   *  - status (StructType): The status of the file.
-   *    - path (StringType): The path of the file.
-   *    - modificationTime (TimestampType): The modification time of the file.
-   *      In some Hadoop FileSystem implementation, this might be unavailable 
and fallback to some
-   *      default value.
-   *    - length (LongType): The length of the file in bytes.
    */
   val schema = StructType(
-    StructField("content", BinaryType, true) ::
-      StructField("status", fileStatusSchema, false) :: Nil)
+    StructField(PATH, StringType, false) ::
+    StructField(MODIFICATION_TIME, TimestampType, false) ::
+    StructField(LENGTH, LongType, false) ::
+    StructField(CONTENT, BinaryType, true) :: Nil)
+
+  private[binaryfile] def createFilterFunction(filter: Filter): FileStatus => 
Boolean = {
+    filter match {
+      case And(left, right) =>
+        s => createFilterFunction(left)(s) && createFilterFunction(right)(s)
+      case Or(left, right) =>
+        s => createFilterFunction(left)(s) || createFilterFunction(right)(s)
+      case Not(child) =>
+        s => !createFilterFunction(child)(s)
+
+      case LessThan(LENGTH, value: Long) =>
+        _.getLen < value
+      case LessThanOrEqual(LENGTH, value: Long) =>
+        _.getLen <= value
+      case GreaterThan(LENGTH, value: Long) =>
+        _.getLen > value
+      case GreaterThanOrEqual(LENGTH, value: Long) =>
+        _.getLen >= value
+      case EqualTo(LENGTH, value: Long) =>
+        _.getLen == value
+
+      case LessThan(MODIFICATION_TIME, value: Timestamp) =>
+        _.getModificationTime < value.getTime
+      case LessThanOrEqual(MODIFICATION_TIME, value: Timestamp) =>
+        _.getModificationTime <= value.getTime
+      case GreaterThan(MODIFICATION_TIME, value: Timestamp) =>
+        _.getModificationTime > value.getTime
+      case GreaterThanOrEqual(MODIFICATION_TIME, value: Timestamp) =>
+        _.getModificationTime >= value.getTime
+      case EqualTo(MODIFICATION_TIME, value: Timestamp) =>
+        _.getModificationTime == value.getTime
+
+      case _ => (_ => true)
+    }
+  }
 }
 
 class BinaryFileSourceOptions(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
index 090f417..3254a7f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
@@ -24,14 +24,19 @@ import java.sql.Timestamp
 import scala.collection.JavaConverters._
 
 import com.google.common.io.{ByteStreams, Closeables}
-import org.apache.hadoop.fs.{FileSystem, GlobFilter, Path}
+import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path}
+import org.mockito.Mockito.{mock, when}
 
 import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.execution.datasources.PartitionedFile
 import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.sources._
 import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
+import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
 class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with 
SQLTestUtils {
+  import BinaryFileFormat._
 
   private var testDir: String = _
 
@@ -39,6 +44,8 @@ class BinaryFileFormatSuite extends QueryTest with 
SharedSQLContext with SQLTest
 
   private var fs: FileSystem = _
 
+  private var file1Status: FileStatus = _
+
   override def beforeAll(): Unit = {
     super.beforeAll()
 
@@ -51,44 +58,64 @@ class BinaryFileFormatSuite extends QueryTest with 
SharedSQLContext with SQLTest
     val year2015Dir = new File(testDir, "year=2015")
     year2015Dir.mkdir()
 
+    val file1 = new File(year2014Dir, "data.txt")
     Files.write(
-      new File(year2014Dir, "data.txt").toPath,
+      file1.toPath,
       Seq("2014-test").asJava,
       StandardOpenOption.CREATE, StandardOpenOption.WRITE
     )
+    file1Status = fs.getFileStatus(new Path(file1.getPath))
+
+    val file2 = new File(year2014Dir, "data2.bin")
     Files.write(
-      new File(year2014Dir, "data2.bin").toPath,
+      file2.toPath,
       "2014-test-bin".getBytes,
       StandardOpenOption.CREATE, StandardOpenOption.WRITE
     )
 
+    val file3 = new File(year2015Dir, "bool.csv")
     Files.write(
-      new File(year2015Dir, "bool.csv").toPath,
+      file3.toPath,
       Seq("bool", "True", "False", "true").asJava,
       StandardOpenOption.CREATE, StandardOpenOption.WRITE
     )
+
+    val file4 = new File(year2015Dir, "data.bin")
     Files.write(
-      new File(year2015Dir, "data.txt").toPath,
+      file4.toPath,
       "2015-test".getBytes,
       StandardOpenOption.CREATE, StandardOpenOption.WRITE
     )
   }
 
+  test("BinaryFileFormat methods") {
+    val format = new BinaryFileFormat
+    assert(format.shortName() === "binaryFile")
+    assert(format.isSplitable(spark, Map.empty, new Path("any")) === false)
+    assert(format.inferSchema(spark, Map.empty, Seq.empty) === 
Some(BinaryFileFormat.schema))
+    assert(BinaryFileFormat.schema === StructType(Seq(
+      StructField("path", StringType, false),
+      StructField("modificationTime", TimestampType, false),
+      StructField("length", LongType, false),
+      StructField("content", BinaryType, true))))
+  }
+
   def testBinaryFileDataSource(pathGlobFilter: String): Unit = {
-    val resultDF = spark.read.format("binaryFile")
-      .option("pathGlobFilter", pathGlobFilter)
-      .load(testDir)
-      .select(
-        col("status.path"),
-        col("status.modificationTime"),
-        col("status.length"),
-        col("content"),
+    val dfReader = spark.read.format("binaryFile")
+    if (pathGlobFilter != null) {
+      dfReader.option("pathGlobFilter", pathGlobFilter)
+    }
+    val resultDF = dfReader.load(testDir).select(
+        col(PATH),
+        col(MODIFICATION_TIME),
+        col(LENGTH),
+        col(CONTENT),
         col("year") // this is a partition column
       )
 
     val expectedRowSet = new collection.mutable.HashSet[Row]()
 
-    val globFilter = new GlobFilter(pathGlobFilter)
+    val globFilter = if (pathGlobFilter == null) null else new 
GlobFilter(pathGlobFilter)
     for (partitionDirStatus <- fs.listStatus(fsTestDir)) {
       val dirPath = partitionDirStatus.getPath
 
@@ -96,7 +123,7 @@ class BinaryFileFormatSuite extends QueryTest with 
SharedSQLContext with SQLTest
       val year = partitionName.toInt // partition column "year" value which is 
`Int` type
 
       for (fileStatus <- fs.listStatus(dirPath)) {
-        if (globFilter.accept(fileStatus.getPath)) {
+        if (globFilter == null || globFilter.accept(fileStatus.getPath)) {
           val fpath = fileStatus.getPath.toString.replace("file:/", "file:///")
           val flen = fileStatus.getLen
           val modificationTime = new Timestamp(fileStatus.getModificationTime)
@@ -121,14 +148,15 @@ class BinaryFileFormatSuite extends QueryTest with 
SharedSQLContext with SQLTest
   }
 
   test("binary file data source test") {
-    testBinaryFileDataSource(pathGlobFilter = "*.*")
-    testBinaryFileDataSource(pathGlobFilter = "*.bin")
-    testBinaryFileDataSource(pathGlobFilter = "*.txt")
-    testBinaryFileDataSource(pathGlobFilter = "*.{txt,csv}")
-    testBinaryFileDataSource(pathGlobFilter = "*.json")
+    testBinaryFileDataSource(null)
+    testBinaryFileDataSource("*.*")
+    testBinaryFileDataSource("*.bin")
+    testBinaryFileDataSource("*.txt")
+    testBinaryFileDataSource("*.{txt,csv}")
+    testBinaryFileDataSource("*.json")
   }
 
-  test ("binary file data source do not support write operation") {
+  test("binary file data source do not support write operation") {
     val df = spark.read.format("binaryFile").load(testDir)
     withTempDir { tmpDir =>
       val thrown = intercept[UnsupportedOperationException] {
@@ -140,4 +168,122 @@ class BinaryFileFormatSuite extends QueryTest with 
SharedSQLContext with SQLTest
     }
   }
 
+  def mockFileStatus(length: Long, modificationTime: Long): FileStatus = {
+    val status = mock(classOf[FileStatus])
+    when(status.getLen).thenReturn(length)
+    when(status.getModificationTime).thenReturn(modificationTime)
+    when(status.toString).thenReturn(
+      s"FileStatus($LENGTH=$length, $MODIFICATION_TIME=$modificationTime)")
+    status
+  }
+
+  def testCreateFilterFunction(
+      filters: Seq[Filter],
+      testCases: Seq[(FileStatus, Boolean)]): Unit = {
+    val funcs = filters.map(BinaryFileFormat.createFilterFunction)
+    testCases.foreach { case (status, expected) =>
+      assert(funcs.forall(f => f(status)) === expected,
+        s"$filters applied to $status should be $expected.")
+    }
+  }
+
+  test("createFilterFunction") {
+    // test filter applied on `length` column
+    val l1 = mockFileStatus(1L, 0L)
+    val l2 = mockFileStatus(2L, 0L)
+    val l3 = mockFileStatus(3L, 0L)
+    testCreateFilterFunction(
+      Seq(LessThan(LENGTH, 2L)),
+      Seq((l1, true), (l2, false), (l3, false)))
+    testCreateFilterFunction(
+      Seq(LessThanOrEqual(LENGTH, 2L)),
+      Seq((l1, true), (l2, true), (l3, false)))
+    testCreateFilterFunction(
+      Seq(GreaterThan(LENGTH, 2L)),
+      Seq((l1, false), (l2, false), (l3, true)))
+    testCreateFilterFunction(
+      Seq(GreaterThanOrEqual(LENGTH, 2L)),
+      Seq((l1, false), (l2, true), (l3, true)))
+    testCreateFilterFunction(
+      Seq(EqualTo(LENGTH, 2L)),
+      Seq((l1, false), (l2, true), (l3, false)))
+    testCreateFilterFunction(
+      Seq(Not(EqualTo(LENGTH, 2L))),
+      Seq((l1, true), (l2, false), (l3, true)))
+    testCreateFilterFunction(
+      Seq(And(GreaterThan(LENGTH, 1L), LessThan(LENGTH, 3L))),
+      Seq((l1, false), (l2, true), (l3, false)))
+    testCreateFilterFunction(
+      Seq(Or(LessThanOrEqual(LENGTH, 1L), GreaterThanOrEqual(LENGTH, 3L))),
+      Seq((l1, true), (l2, false), (l3, true)))
+
+    // test filter applied on `modificationTime` column
+    val t1 = mockFileStatus(0L, 1L)
+    val t2 = mockFileStatus(0L, 2L)
+    val t3 = mockFileStatus(0L, 3L)
+    testCreateFilterFunction(
+      Seq(LessThan(MODIFICATION_TIME, new Timestamp(2L))),
+      Seq((t1, true), (t2, false), (t3, false)))
+    testCreateFilterFunction(
+      Seq(LessThanOrEqual(MODIFICATION_TIME, new Timestamp(2L))),
+      Seq((t1, true), (t2, true), (t3, false)))
+    testCreateFilterFunction(
+      Seq(GreaterThan(MODIFICATION_TIME, new Timestamp(2L))),
+      Seq((t1, false), (t2, false), (t3, true)))
+    testCreateFilterFunction(
+      Seq(GreaterThanOrEqual(MODIFICATION_TIME, new Timestamp(2L))),
+      Seq((t1, false), (t2, true), (t3, true)))
+    testCreateFilterFunction(
+      Seq(EqualTo(MODIFICATION_TIME, new Timestamp(2L))),
+      Seq((t1, false), (t2, true), (t3, false)))
+    testCreateFilterFunction(
+      Seq(Not(EqualTo(MODIFICATION_TIME, new Timestamp(2L)))),
+      Seq((t1, true), (t2, false), (t3, true)))
+    testCreateFilterFunction(
+      Seq(And(GreaterThan(MODIFICATION_TIME, new Timestamp(1L)),
+        LessThan(MODIFICATION_TIME, new Timestamp(3L)))),
+      Seq((t1, false), (t2, true), (t3, false)))
+    testCreateFilterFunction(
+      Seq(Or(LessThanOrEqual(MODIFICATION_TIME, new Timestamp(1L)),
+        GreaterThanOrEqual(MODIFICATION_TIME, new Timestamp(3L)))),
+      Seq((t1, true), (t2, false), (t3, true)))
+
+    // test filters applied on both columns
+    testCreateFilterFunction(
+      Seq(And(GreaterThan(LENGTH, 2L), LessThan(MODIFICATION_TIME, new 
Timestamp(2L)))),
+      Seq((l1, false), (l2, false), (l3, true), (t1, false), (t2, false), (t3, 
false)))
+
+    // test nested filters
+    testCreateFilterFunction(
+      // NOT (length > 2 OR modificationTime < 2)
+      Seq(Not(Or(GreaterThan(LENGTH, 2L), LessThan(MODIFICATION_TIME, new 
Timestamp(2L))))),
+      Seq((l1, false), (l2, false), (l3, false), (t1, false), (t2, true), (t3, 
true)))
+  }
+
+  test("buildReader") {
+    def testBuildReader(fileStatus: FileStatus, filters: Seq[Filter], 
expected: Boolean): Unit = {
+      val format = new BinaryFileFormat
+      val reader = format.buildReaderWithPartitionValues(
+        sparkSession = spark,
+        dataSchema = schema,
+        partitionSchema = StructType(Nil),
+        requiredSchema = schema,
+        filters = filters,
+        options = Map.empty,
+        hadoopConf = spark.sessionState.newHadoopConf())
+      val partitionedFile = mock(classOf[PartitionedFile])
+      when(partitionedFile.filePath).thenReturn(fileStatus.getPath.toString)
+      assert(reader(partitionedFile).nonEmpty === expected,
+        s"Filters $filters applied to $fileStatus should be $expected.")
+    }
+    testBuildReader(file1Status, Seq.empty, true)
+    testBuildReader(file1Status, Seq(LessThan(LENGTH, file1Status.getLen)), 
false)
+    testBuildReader(file1Status, Seq(
+      LessThan(MODIFICATION_TIME, new 
Timestamp(file1Status.getModificationTime))
+    ), false)
+    testBuildReader(file1Status, Seq(
+      EqualTo(LENGTH, file1Status.getLen),
+      EqualTo(MODIFICATION_TIME, file1Status.getModificationTime)
+    ), true)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to