This is an automated email from the ASF dual-hosted git repository.
meng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 20a3ef7 [SPARK-27534][SQL] Do not load `content` column in binary
data source if it is not selected
20a3ef7 is described below
commit 20a3ef7259490e0c9f6348f13db1e99da5f0df83
Author: Xiangrui Meng <[email protected]>
AuthorDate: Sun Apr 28 07:57:03 2019 -0700
[SPARK-27534][SQL] Do not load `content` column in binary data source if it
is not selected
## What changes were proposed in this pull request?
A follow-up task from SPARK-25348. To save I/O cost, Spark shouldn't
attempt to read the file if users didn't request the `content` column. For
example:
```
spark.read.format("binaryFile").load(path).filter($"length" <
1000000).count()
```
## How was this patch tested?
Unit test added.
Please review http://spark.apache.org/contributing.html before opening a
pull request.
Closes #24473 from WeichenXu123/SPARK-27534.
Lead-authored-by: Xiangrui Meng <[email protected]>
Co-authored-by: WeichenXu <[email protected]>
Signed-off-by: Xiangrui Meng <[email protected]>
---
.../datasources/binaryfile/BinaryFileFormat.scala | 74 +++++++++-------------
.../binaryfile/BinaryFileFormatSuite.scala | 63 ++++++++++++++++--
2 files changed, 89 insertions(+), 48 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
index 8617ae3..db93268 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
@@ -26,12 +26,10 @@ import org.apache.hadoop.mapreduce.Job
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.datasources.{FileFormat,
OutputWriterFactory, PartitionedFile}
-import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter,
GreaterThan,
- GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
+import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter,
GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.SerializableConfiguration
@@ -80,7 +78,7 @@ class BinaryFileFormat extends FileFormat with
DataSourceRegister {
false
}
- override def shortName(): String = "binaryFile"
+ override def shortName(): String = BINARY_FILE
override protected def buildReader(
sparkSession: SparkSession,
@@ -90,54 +88,43 @@ class BinaryFileFormat extends FileFormat with
DataSourceRegister {
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
+ require(dataSchema.sameType(schema),
+ s"""
+ |Binary file data source expects dataSchema: $schema,
+ |but got: $dataSchema.
+ """.stripMargin)
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new
SerializableConfiguration(hadoopConf))
-
val binaryFileSourceOptions = new BinaryFileSourceOptions(options)
-
val pathGlobPattern = binaryFileSourceOptions.pathGlobFilter
-
val filterFuncs = filters.map(filter => createFilterFunction(filter))
file: PartitionedFile => {
- val path = file.filePath
- val fsPath = new Path(path)
-
+ val path = new Path(file.filePath)
// TODO: Improve performance here: each file will recompile the glob
pattern here.
- if (pathGlobPattern.forall(new GlobFilter(_).accept(fsPath))) {
- val fs = fsPath.getFileSystem(broadcastedHadoopConf.value.value)
- val fileStatus = fs.getFileStatus(fsPath)
- val length = fileStatus.getLen
- val modificationTime = fileStatus.getModificationTime
-
- if (filterFuncs.forall(_.apply(fileStatus))) {
- val stream = fs.open(fsPath)
- val content = try {
- ByteStreams.toByteArray(stream)
- } finally {
- Closeables.close(stream, true)
- }
-
- val fullOutput = dataSchema.map { f =>
- AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
- }
- val requiredOutput = fullOutput.filter { a =>
- requiredSchema.fieldNames.contains(a.name)
+ if (pathGlobPattern.forall(new GlobFilter(_).accept(path))) {
+ val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
+ val status = fs.getFileStatus(path)
+ if (filterFuncs.forall(_.apply(status))) {
+ val writer = new UnsafeRowWriter(requiredSchema.length)
+ writer.resetRowWriter()
+ requiredSchema.fieldNames.zipWithIndex.foreach {
+ case (PATH, i) => writer.write(i,
UTF8String.fromString(status.getPath.toString))
+ case (LENGTH, i) => writer.write(i, status.getLen)
+ case (MODIFICATION_TIME, i) =>
+ writer.write(i,
DateTimeUtils.fromMillis(status.getModificationTime))
+ case (CONTENT, i) =>
+ val stream = fs.open(status.getPath)
+ try {
+ writer.write(i, ByteStreams.toByteArray(stream))
+ } finally {
+ Closeables.close(stream, true)
+ }
+ case (other, _) =>
+ throw new RuntimeException(s"Unsupported field name: ${other}")
}
-
- // TODO: Add column pruning
- // currently it still read the file content even if content column
is not required.
- val requiredColumns =
GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
-
- val internalRow = InternalRow(
- UTF8String.fromString(path),
- DateTimeUtils.fromMillis(modificationTime),
- length,
- content
- )
-
- Iterator(requiredColumns(internalRow))
+ Iterator.single(writer.getRow)
} else {
Iterator.empty
}
@@ -154,6 +141,7 @@ object BinaryFileFormat {
private[binaryfile] val MODIFICATION_TIME = "modificationTime"
private[binaryfile] val LENGTH = "length"
private[binaryfile] val CONTENT = "content"
+ private[binaryfile] val BINARY_FILE = "binaryFile"
/**
* Schema for the binary file data source.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
index 3254a7f..fb83c3c 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.binaryfile
-import java.io.File
+import java.io.{File, IOException}
import java.nio.file.{Files, StandardOpenOption}
import java.sql.Timestamp
@@ -28,6 +28,7 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem,
GlobFilter, Path}
import org.mockito.Mockito.{mock, when}
import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.sources._
@@ -101,7 +102,7 @@ class BinaryFileFormatSuite extends QueryTest with
SharedSQLContext with SQLTest
}
def testBinaryFileDataSource(pathGlobFilter: String): Unit = {
- val dfReader = spark.read.format("binaryFile")
+ val dfReader = spark.read.format(BINARY_FILE)
if (pathGlobFilter != null) {
dfReader.option("pathGlobFilter", pathGlobFilter)
}
@@ -124,7 +125,7 @@ class BinaryFileFormatSuite extends QueryTest with
SharedSQLContext with SQLTest
for (fileStatus <- fs.listStatus(dirPath)) {
if (globFilter == null || globFilter.accept(fileStatus.getPath)) {
- val fpath = fileStatus.getPath.toString.replace("file:/", "file:///")
+ val fpath = fileStatus.getPath.toString
val flen = fileStatus.getLen
val modificationTime = new Timestamp(fileStatus.getModificationTime)
@@ -157,11 +158,11 @@ class BinaryFileFormatSuite extends QueryTest with
SharedSQLContext with SQLTest
}
test("binary file data source do not support write operation") {
- val df = spark.read.format("binaryFile").load(testDir)
+ val df = spark.read.format(BINARY_FILE).load(testDir)
withTempDir { tmpDir =>
val thrown = intercept[UnsupportedOperationException] {
df.write
- .format("binaryFile")
+ .format(BINARY_FILE)
.save(tmpDir + "/test_save")
}
assert(thrown.getMessage.contains("Write is not supported for binary
file data source"))
@@ -286,4 +287,56 @@ class BinaryFileFormatSuite extends QueryTest with
SharedSQLContext with SQLTest
EqualTo(MODIFICATION_TIME, file1Status.getModificationTime)
), true)
}
+
+ test("column pruning") {
+ def getRequiredSchema(fieldNames: String*): StructType = {
+ StructType(fieldNames.map {
+ case f if schema.fieldNames.contains(f) => schema(f)
+ case other => StructField(other, NullType)
+ })
+ }
+ def read(file: File, requiredSchema: StructType): Row = {
+ val format = new BinaryFileFormat
+ val reader = format.buildReaderWithPartitionValues(
+ sparkSession = spark,
+ dataSchema = schema,
+ partitionSchema = StructType(Nil),
+ requiredSchema = requiredSchema,
+ filters = Seq.empty,
+ options = Map.empty,
+ hadoopConf = spark.sessionState.newHadoopConf()
+ )
+ val partitionedFile = mock(classOf[PartitionedFile])
+ when(partitionedFile.filePath).thenReturn(file.getPath)
+ val encoder = RowEncoder(requiredSchema).resolveAndBind()
+ encoder.fromRow(reader(partitionedFile).next())
+ }
+ val file = new File(Utils.createTempDir(), "data")
+ val content = "123".getBytes
+ Files.write(file.toPath, content, StandardOpenOption.CREATE,
StandardOpenOption.WRITE)
+
+ read(file, getRequiredSchema(MODIFICATION_TIME, CONTENT, LENGTH, PATH))
match {
+ case Row(t, c, len, p) =>
+ assert(t === new Timestamp(file.lastModified()))
+ assert(c === content)
+ assert(len === content.length)
+ assert(p.asInstanceOf[String].endsWith(file.getAbsolutePath))
+ }
+ file.setReadable(false)
+ withClue("cannot read content") {
+ intercept[IOException] {
+ read(file, getRequiredSchema(CONTENT))
+ }
+ }
+ assert(read(file, getRequiredSchema(LENGTH)) === Row(content.length),
+ "Get length should not read content.")
+ intercept[RuntimeException] {
+ read(file, getRequiredSchema(LENGTH, "other"))
+ }
+
+ val df = spark.read.format(BINARY_FILE).load(file.getPath)
+ assert(df.count() === 1, "Count should not read content.")
+ assert(df.select("LENGTH").first().getLong(0) === content.length,
+ "column pruning should be case insensitive")
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]