This is an automated email from the ASF dual-hosted git repository.

meng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 20a3ef7  [SPARK-27534][SQL] Do not load `content` column in binary 
data source if it is not selected
20a3ef7 is described below

commit 20a3ef7259490e0c9f6348f13db1e99da5f0df83
Author: Xiangrui Meng <m...@databricks.com>
AuthorDate: Sun Apr 28 07:57:03 2019 -0700

    [SPARK-27534][SQL] Do not load `content` column in binary data source if it 
is not selected
    
    ## What changes were proposed in this pull request?
    
    A follow-up task from SPARK-25348. To save I/O cost, Spark shouldn't 
attempt to read the file if users didn't request the `content` column. For 
example:
    ```
    spark.read.format("binaryFile").load(path).filter($"length" < 
1000000).count()
    ```
    
    ## How was this patch tested?
    
    Unit test added.
    
    Please review http://spark.apache.org/contributing.html before opening a 
pull request.
    
    Closes #24473 from WeichenXu123/SPARK-27534.
    
    Lead-authored-by: Xiangrui Meng <m...@databricks.com>
    Co-authored-by: WeichenXu <weichen...@databricks.com>
    Signed-off-by: Xiangrui Meng <m...@databricks.com>
---
 .../datasources/binaryfile/BinaryFileFormat.scala  | 74 +++++++++-------------
 .../binaryfile/BinaryFileFormatSuite.scala         | 63 ++++++++++++++++--
 2 files changed, 89 insertions(+), 48 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
index 8617ae3..db93268 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
@@ -26,12 +26,10 @@ import org.apache.hadoop.mapreduce.Job
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
 import org.apache.spark.sql.execution.datasources.{FileFormat, 
OutputWriterFactory, PartitionedFile}
-import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter, 
GreaterThan,
-  GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
+import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter, 
GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.SerializableConfiguration
@@ -80,7 +78,7 @@ class BinaryFileFormat extends FileFormat with 
DataSourceRegister {
     false
   }
 
-  override def shortName(): String = "binaryFile"
+  override def shortName(): String = BINARY_FILE
 
   override protected def buildReader(
       sparkSession: SparkSession,
@@ -90,54 +88,43 @@ class BinaryFileFormat extends FileFormat with 
DataSourceRegister {
       filters: Seq[Filter],
       options: Map[String, String],
       hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
+    require(dataSchema.sameType(schema),
+      s"""
+         |Binary file data source expects dataSchema: $schema,
+         |but got: $dataSchema.
+        """.stripMargin)
 
     val broadcastedHadoopConf =
       sparkSession.sparkContext.broadcast(new 
SerializableConfiguration(hadoopConf))
-
     val binaryFileSourceOptions = new BinaryFileSourceOptions(options)
-
     val pathGlobPattern = binaryFileSourceOptions.pathGlobFilter
-
     val filterFuncs = filters.map(filter => createFilterFunction(filter))
 
     file: PartitionedFile => {
-      val path = file.filePath
-      val fsPath = new Path(path)
-
+      val path = new Path(file.filePath)
       // TODO: Improve performance here: each file will recompile the glob 
pattern here.
-      if (pathGlobPattern.forall(new GlobFilter(_).accept(fsPath))) {
-        val fs = fsPath.getFileSystem(broadcastedHadoopConf.value.value)
-        val fileStatus = fs.getFileStatus(fsPath)
-        val length = fileStatus.getLen
-        val modificationTime = fileStatus.getModificationTime
-
-        if (filterFuncs.forall(_.apply(fileStatus))) {
-          val stream = fs.open(fsPath)
-          val content = try {
-            ByteStreams.toByteArray(stream)
-          } finally {
-            Closeables.close(stream, true)
-          }
-
-          val fullOutput = dataSchema.map { f =>
-            AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
-          }
-          val requiredOutput = fullOutput.filter { a =>
-            requiredSchema.fieldNames.contains(a.name)
+      if (pathGlobPattern.forall(new GlobFilter(_).accept(path))) {
+        val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
+        val status = fs.getFileStatus(path)
+        if (filterFuncs.forall(_.apply(status))) {
+          val writer = new UnsafeRowWriter(requiredSchema.length)
+          writer.resetRowWriter()
+          requiredSchema.fieldNames.zipWithIndex.foreach {
+            case (PATH, i) => writer.write(i, 
UTF8String.fromString(status.getPath.toString))
+            case (LENGTH, i) => writer.write(i, status.getLen)
+            case (MODIFICATION_TIME, i) =>
+              writer.write(i, 
DateTimeUtils.fromMillis(status.getModificationTime))
+            case (CONTENT, i) =>
+              val stream = fs.open(status.getPath)
+              try {
+                writer.write(i, ByteStreams.toByteArray(stream))
+              } finally {
+                Closeables.close(stream, true)
+              }
+            case (other, _) =>
+              throw new RuntimeException(s"Unsupported field name: ${other}")
           }
-
-          // TODO: Add column pruning
-          // currently it still read the file content even if content column 
is not required.
-          val requiredColumns = 
GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
-
-          val internalRow = InternalRow(
-            UTF8String.fromString(path),
-            DateTimeUtils.fromMillis(modificationTime),
-            length,
-            content
-          )
-
-          Iterator(requiredColumns(internalRow))
+          Iterator.single(writer.getRow)
         } else {
           Iterator.empty
         }
@@ -154,6 +141,7 @@ object BinaryFileFormat {
   private[binaryfile] val MODIFICATION_TIME = "modificationTime"
   private[binaryfile] val LENGTH = "length"
   private[binaryfile] val CONTENT = "content"
+  private[binaryfile] val BINARY_FILE = "binaryFile"
 
   /**
    * Schema for the binary file data source.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
index 3254a7f..fb83c3c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution.datasources.binaryfile
 
-import java.io.File
+import java.io.{File, IOException}
 import java.nio.file.{Files, StandardOpenOption}
 import java.sql.Timestamp
 
@@ -28,6 +28,7 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, 
GlobFilter, Path}
 import org.mockito.Mockito.{mock, when}
 
 import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.execution.datasources.PartitionedFile
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.sources._
@@ -101,7 +102,7 @@ class BinaryFileFormatSuite extends QueryTest with 
SharedSQLContext with SQLTest
   }
 
   def testBinaryFileDataSource(pathGlobFilter: String): Unit = {
-    val dfReader = spark.read.format("binaryFile")
+    val dfReader = spark.read.format(BINARY_FILE)
     if (pathGlobFilter != null) {
       dfReader.option("pathGlobFilter", pathGlobFilter)
     }
@@ -124,7 +125,7 @@ class BinaryFileFormatSuite extends QueryTest with 
SharedSQLContext with SQLTest
 
       for (fileStatus <- fs.listStatus(dirPath)) {
         if (globFilter == null || globFilter.accept(fileStatus.getPath)) {
-          val fpath = fileStatus.getPath.toString.replace("file:/", "file:///")
+          val fpath = fileStatus.getPath.toString
           val flen = fileStatus.getLen
           val modificationTime = new Timestamp(fileStatus.getModificationTime)
 
@@ -157,11 +158,11 @@ class BinaryFileFormatSuite extends QueryTest with 
SharedSQLContext with SQLTest
   }
 
   test("binary file data source do not support write operation") {
-    val df = spark.read.format("binaryFile").load(testDir)
+    val df = spark.read.format(BINARY_FILE).load(testDir)
     withTempDir { tmpDir =>
       val thrown = intercept[UnsupportedOperationException] {
         df.write
-          .format("binaryFile")
+          .format(BINARY_FILE)
           .save(tmpDir + "/test_save")
       }
       assert(thrown.getMessage.contains("Write is not supported for binary 
file data source"))
@@ -286,4 +287,56 @@ class BinaryFileFormatSuite extends QueryTest with 
SharedSQLContext with SQLTest
       EqualTo(MODIFICATION_TIME, file1Status.getModificationTime)
     ), true)
   }
+
+  test("column pruning") {
+    def getRequiredSchema(fieldNames: String*): StructType = {
+      StructType(fieldNames.map {
+        case f if schema.fieldNames.contains(f) => schema(f)
+        case other => StructField(other, NullType)
+      })
+    }
+    def read(file: File, requiredSchema: StructType): Row = {
+      val format = new BinaryFileFormat
+      val reader = format.buildReaderWithPartitionValues(
+        sparkSession = spark,
+        dataSchema = schema,
+        partitionSchema = StructType(Nil),
+        requiredSchema = requiredSchema,
+        filters = Seq.empty,
+        options = Map.empty,
+        hadoopConf = spark.sessionState.newHadoopConf()
+      )
+      val partitionedFile = mock(classOf[PartitionedFile])
+      when(partitionedFile.filePath).thenReturn(file.getPath)
+      val encoder = RowEncoder(requiredSchema).resolveAndBind()
+      encoder.fromRow(reader(partitionedFile).next())
+    }
+    val file = new File(Utils.createTempDir(), "data")
+    val content = "123".getBytes
+    Files.write(file.toPath, content, StandardOpenOption.CREATE, 
StandardOpenOption.WRITE)
+
+    read(file, getRequiredSchema(MODIFICATION_TIME, CONTENT, LENGTH, PATH)) 
match {
+      case Row(t, c, len, p) =>
+        assert(t === new Timestamp(file.lastModified()))
+        assert(c === content)
+        assert(len === content.length)
+        assert(p.asInstanceOf[String].endsWith(file.getAbsolutePath))
+    }
+    file.setReadable(false)
+    withClue("cannot read content") {
+      intercept[IOException] {
+        read(file, getRequiredSchema(CONTENT))
+      }
+    }
+    assert(read(file, getRequiredSchema(LENGTH)) === Row(content.length),
+      "Get length should not read content.")
+    intercept[RuntimeException] {
+      read(file, getRequiredSchema(LENGTH, "other"))
+    }
+
+    val df = spark.read.format(BINARY_FILE).load(file.getPath)
+    assert(df.count() === 1, "Count should not read content.")
+    assert(df.select("LENGTH").first().getLong(0) === content.length,
+      "column pruning should be case insensitive")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to