This is an automated email from the ASF dual-hosted git repository. meng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 20a3ef7 [SPARK-27534][SQL] Do not load `content` column in binary data source if it is not selected 20a3ef7 is described below commit 20a3ef7259490e0c9f6348f13db1e99da5f0df83 Author: Xiangrui Meng <m...@databricks.com> AuthorDate: Sun Apr 28 07:57:03 2019 -0700 [SPARK-27534][SQL] Do not load `content` column in binary data source if it is not selected ## What changes were proposed in this pull request? A follow-up task from SPARK-25348. To save I/O cost, Spark shouldn't attempt to read the file if users didn't request the `content` column. For example: ``` spark.read.format("binaryFile").load(path).filter($"length" < 1000000).count() ``` ## How was this patch tested? Unit test added. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #24473 from WeichenXu123/SPARK-27534. Lead-authored-by: Xiangrui Meng <m...@databricks.com> Co-authored-by: WeichenXu <weichen...@databricks.com> Signed-off-by: Xiangrui Meng <m...@databricks.com> --- .../datasources/binaryfile/BinaryFileFormat.scala | 74 +++++++++------------- .../binaryfile/BinaryFileFormatSuite.scala | 63 ++++++++++++++++-- 2 files changed, 89 insertions(+), 48 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala index 8617ae3..db93268 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala @@ -26,12 +26,10 @@ import org.apache.hadoop.mapreduce.Job import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} -import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter, GreaterThan, - GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or} +import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration @@ -80,7 +78,7 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister { false } - override def shortName(): String = "binaryFile" + override def shortName(): String = BINARY_FILE override protected def buildReader( sparkSession: SparkSession, @@ -90,54 +88,43 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + require(dataSchema.sameType(schema), + s""" + |Binary file data source expects dataSchema: $schema, + |but got: $dataSchema. + """.stripMargin) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val binaryFileSourceOptions = new BinaryFileSourceOptions(options) - val pathGlobPattern = binaryFileSourceOptions.pathGlobFilter - val filterFuncs = filters.map(filter => createFilterFunction(filter)) file: PartitionedFile => { - val path = file.filePath - val fsPath = new Path(path) - + val path = new Path(file.filePath) // TODO: Improve performance here: each file will recompile the glob pattern here. - if (pathGlobPattern.forall(new GlobFilter(_).accept(fsPath))) { - val fs = fsPath.getFileSystem(broadcastedHadoopConf.value.value) - val fileStatus = fs.getFileStatus(fsPath) - val length = fileStatus.getLen - val modificationTime = fileStatus.getModificationTime - - if (filterFuncs.forall(_.apply(fileStatus))) { - val stream = fs.open(fsPath) - val content = try { - ByteStreams.toByteArray(stream) - } finally { - Closeables.close(stream, true) - } - - val fullOutput = dataSchema.map { f => - AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() - } - val requiredOutput = fullOutput.filter { a => - requiredSchema.fieldNames.contains(a.name) + if (pathGlobPattern.forall(new GlobFilter(_).accept(path))) { + val fs = path.getFileSystem(broadcastedHadoopConf.value.value) + val status = fs.getFileStatus(path) + if (filterFuncs.forall(_.apply(status))) { + val writer = new UnsafeRowWriter(requiredSchema.length) + writer.resetRowWriter() + requiredSchema.fieldNames.zipWithIndex.foreach { + case (PATH, i) => writer.write(i, UTF8String.fromString(status.getPath.toString)) + case (LENGTH, i) => writer.write(i, status.getLen) + case (MODIFICATION_TIME, i) => + writer.write(i, DateTimeUtils.fromMillis(status.getModificationTime)) + case (CONTENT, i) => + val stream = fs.open(status.getPath) + try { + writer.write(i, ByteStreams.toByteArray(stream)) + } finally { + Closeables.close(stream, true) + } + case (other, _) => + throw new RuntimeException(s"Unsupported field name: ${other}") } - - // TODO: Add column pruning - // currently it still read the file content even if content column is not required. - val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput) - - val internalRow = InternalRow( - UTF8String.fromString(path), - DateTimeUtils.fromMillis(modificationTime), - length, - content - ) - - Iterator(requiredColumns(internalRow)) + Iterator.single(writer.getRow) } else { Iterator.empty } @@ -154,6 +141,7 @@ object BinaryFileFormat { private[binaryfile] val MODIFICATION_TIME = "modificationTime" private[binaryfile] val LENGTH = "length" private[binaryfile] val CONTENT = "content" + private[binaryfile] val BINARY_FILE = "binaryFile" /** * Schema for the binary file data source. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala index 3254a7f..fb83c3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.binaryfile -import java.io.File +import java.io.{File, IOException} import java.nio.file.{Files, StandardOpenOption} import java.sql.Timestamp @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path} import org.mockito.Mockito.{mock, when} import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.functions.col import org.apache.spark.sql.sources._ @@ -101,7 +102,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest } def testBinaryFileDataSource(pathGlobFilter: String): Unit = { - val dfReader = spark.read.format("binaryFile") + val dfReader = spark.read.format(BINARY_FILE) if (pathGlobFilter != null) { dfReader.option("pathGlobFilter", pathGlobFilter) } @@ -124,7 +125,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest for (fileStatus <- fs.listStatus(dirPath)) { if (globFilter == null || globFilter.accept(fileStatus.getPath)) { - val fpath = fileStatus.getPath.toString.replace("file:/", "file:///") + val fpath = fileStatus.getPath.toString val flen = fileStatus.getLen val modificationTime = new Timestamp(fileStatus.getModificationTime) @@ -157,11 +158,11 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest } test("binary file data source do not support write operation") { - val df = spark.read.format("binaryFile").load(testDir) + val df = spark.read.format(BINARY_FILE).load(testDir) withTempDir { tmpDir => val thrown = intercept[UnsupportedOperationException] { df.write - .format("binaryFile") + .format(BINARY_FILE) .save(tmpDir + "/test_save") } assert(thrown.getMessage.contains("Write is not supported for binary file data source")) @@ -286,4 +287,56 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest EqualTo(MODIFICATION_TIME, file1Status.getModificationTime) ), true) } + + test("column pruning") { + def getRequiredSchema(fieldNames: String*): StructType = { + StructType(fieldNames.map { + case f if schema.fieldNames.contains(f) => schema(f) + case other => StructField(other, NullType) + }) + } + def read(file: File, requiredSchema: StructType): Row = { + val format = new BinaryFileFormat + val reader = format.buildReaderWithPartitionValues( + sparkSession = spark, + dataSchema = schema, + partitionSchema = StructType(Nil), + requiredSchema = requiredSchema, + filters = Seq.empty, + options = Map.empty, + hadoopConf = spark.sessionState.newHadoopConf() + ) + val partitionedFile = mock(classOf[PartitionedFile]) + when(partitionedFile.filePath).thenReturn(file.getPath) + val encoder = RowEncoder(requiredSchema).resolveAndBind() + encoder.fromRow(reader(partitionedFile).next()) + } + val file = new File(Utils.createTempDir(), "data") + val content = "123".getBytes + Files.write(file.toPath, content, StandardOpenOption.CREATE, StandardOpenOption.WRITE) + + read(file, getRequiredSchema(MODIFICATION_TIME, CONTENT, LENGTH, PATH)) match { + case Row(t, c, len, p) => + assert(t === new Timestamp(file.lastModified())) + assert(c === content) + assert(len === content.length) + assert(p.asInstanceOf[String].endsWith(file.getAbsolutePath)) + } + file.setReadable(false) + withClue("cannot read content") { + intercept[IOException] { + read(file, getRequiredSchema(CONTENT)) + } + } + assert(read(file, getRequiredSchema(LENGTH)) === Row(content.length), + "Get length should not read content.") + intercept[RuntimeException] { + read(file, getRequiredSchema(LENGTH, "other")) + } + + val df = spark.read.format(BINARY_FILE).load(file.getPath) + assert(df.count() === 1, "Count should not read content.") + assert(df.select("LENGTH").first().getLong(0) === content.length, + "column pruning should be case insensitive") + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org