This is an automated email from the ASF dual-hosted git repository. meng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9793d9e [SPARK-27473][SQL] Support filter push down for status fields in binary file data source 9793d9e is described below commit 9793d9ec22ff7d9778554e4fa3f03ef4f93d473d Author: WeichenXu <weichen...@databricks.com> AuthorDate: Sun Apr 21 12:45:59 2019 -0700 [SPARK-27473][SQL] Support filter push down for status fields in binary file data source ## What changes were proposed in this pull request? Support 4 kinds of filters: - LessThan - LessThanOrEqual - GreatThan - GreatThanOrEqual Support filters applied on 2 columns: - modificationTime - length Note: In order to support datasource filter push-down, I flatten schema to be: ``` val schema = StructType( StructField("path", StringType, false) :: StructField("modificationTime", TimestampType, false) :: StructField("length", LongType, false) :: StructField("content", BinaryType, true) :: Nil) ``` ## How was this patch tested? To be added. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #24387 from WeichenXu123/binary_ds_filter. Lead-authored-by: WeichenXu <weichen...@databricks.com> Co-authored-by: Xiangrui Meng <m...@databricks.com> Signed-off-by: Xiangrui Meng <m...@databricks.com> --- .../datasources/binaryfile/BinaryFileFormat.scala | 134 ++++++++++----- .../binaryfile/BinaryFileFormatSuite.scala | 188 ++++++++++++++++++--- 2 files changed, 256 insertions(+), 66 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala index ad9292a..8617ae3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.binaryfile +import java.sql.Timestamp + import com.google.common.io.{ByteStreams, Closeables} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, GlobFilter, Path} @@ -28,7 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} -import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter, GreaterThan, + GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration @@ -55,10 +58,12 @@ import org.apache.spark.util.SerializableConfiguration */ class BinaryFileFormat extends FileFormat with DataSourceRegister { + import BinaryFileFormat._ + override def inferSchema( sparkSession: SparkSession, options: Map[String, String], - files: Seq[FileStatus]): Option[StructType] = Some(BinaryFileFormat.schema) + files: Seq[FileStatus]): Option[StructType] = Some(schema) override def prepareWrite( sparkSession: SparkSession, @@ -84,7 +89,7 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister { requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String], - hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -93,46 +98,49 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister { val pathGlobPattern = binaryFileSourceOptions.pathGlobFilter - (file: PartitionedFile) => { + val filterFuncs = filters.map(filter => createFilterFunction(filter)) + + file: PartitionedFile => { val path = file.filePath val fsPath = new Path(path) // TODO: Improve performance here: each file will recompile the glob pattern here. - val globFilter = pathGlobPattern.map(new GlobFilter(_)) - if (!globFilter.isDefined || globFilter.get.accept(fsPath)) { + if (pathGlobPattern.forall(new GlobFilter(_).accept(fsPath))) { val fs = fsPath.getFileSystem(broadcastedHadoopConf.value.value) val fileStatus = fs.getFileStatus(fsPath) - val length = fileStatus.getLen() - val modificationTime = fileStatus.getModificationTime() - val stream = fs.open(fsPath) - - val content = try { - ByteStreams.toByteArray(stream) - } finally { - Closeables.close(stream, true) - } - - val fullOutput = dataSchema.map { f => - AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() - } - val requiredOutput = fullOutput.filter { a => - requiredSchema.fieldNames.contains(a.name) - } - - // TODO: Add column pruning - // currently it still read the file content even if content column is not required. - val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput) - - val internalRow = InternalRow( - content, - InternalRow( + val length = fileStatus.getLen + val modificationTime = fileStatus.getModificationTime + + if (filterFuncs.forall(_.apply(fileStatus))) { + val stream = fs.open(fsPath) + val content = try { + ByteStreams.toByteArray(stream) + } finally { + Closeables.close(stream, true) + } + + val fullOutput = dataSchema.map { f => + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() + } + val requiredOutput = fullOutput.filter { a => + requiredSchema.fieldNames.contains(a.name) + } + + // TODO: Add column pruning + // currently it still read the file content even if content column is not required. + val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput) + + val internalRow = InternalRow( UTF8String.fromString(path), DateTimeUtils.fromMillis(modificationTime), - length + length, + content ) - ) - Iterator(requiredColumns(internalRow)) + Iterator(requiredColumns(internalRow)) + } else { + Iterator.empty + } } else { Iterator.empty } @@ -142,26 +150,62 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister { object BinaryFileFormat { - private val fileStatusSchema = StructType( - StructField("path", StringType, false) :: - StructField("modificationTime", TimestampType, false) :: - StructField("length", LongType, false) :: Nil) + private[binaryfile] val PATH = "path" + private[binaryfile] val MODIFICATION_TIME = "modificationTime" + private[binaryfile] val LENGTH = "length" + private[binaryfile] val CONTENT = "content" /** * Schema for the binary file data source. * * Schema: + * - path (StringType): The path of the file. + * - modificationTime (TimestampType): The modification time of the file. + * In some Hadoop FileSystem implementation, this might be unavailable and fallback to some + * default value. + * - length (LongType): The length of the file in bytes. * - content (BinaryType): The content of the file. - * - status (StructType): The status of the file. - * - path (StringType): The path of the file. - * - modificationTime (TimestampType): The modification time of the file. - * In some Hadoop FileSystem implementation, this might be unavailable and fallback to some - * default value. - * - length (LongType): The length of the file in bytes. */ val schema = StructType( - StructField("content", BinaryType, true) :: - StructField("status", fileStatusSchema, false) :: Nil) + StructField(PATH, StringType, false) :: + StructField(MODIFICATION_TIME, TimestampType, false) :: + StructField(LENGTH, LongType, false) :: + StructField(CONTENT, BinaryType, true) :: Nil) + + private[binaryfile] def createFilterFunction(filter: Filter): FileStatus => Boolean = { + filter match { + case And(left, right) => + s => createFilterFunction(left)(s) && createFilterFunction(right)(s) + case Or(left, right) => + s => createFilterFunction(left)(s) || createFilterFunction(right)(s) + case Not(child) => + s => !createFilterFunction(child)(s) + + case LessThan(LENGTH, value: Long) => + _.getLen < value + case LessThanOrEqual(LENGTH, value: Long) => + _.getLen <= value + case GreaterThan(LENGTH, value: Long) => + _.getLen > value + case GreaterThanOrEqual(LENGTH, value: Long) => + _.getLen >= value + case EqualTo(LENGTH, value: Long) => + _.getLen == value + + case LessThan(MODIFICATION_TIME, value: Timestamp) => + _.getModificationTime < value.getTime + case LessThanOrEqual(MODIFICATION_TIME, value: Timestamp) => + _.getModificationTime <= value.getTime + case GreaterThan(MODIFICATION_TIME, value: Timestamp) => + _.getModificationTime > value.getTime + case GreaterThanOrEqual(MODIFICATION_TIME, value: Timestamp) => + _.getModificationTime >= value.getTime + case EqualTo(MODIFICATION_TIME, value: Timestamp) => + _.getModificationTime == value.getTime + + case _ => (_ => true) + } + } } class BinaryFileSourceOptions( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala index 090f417..3254a7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala @@ -24,14 +24,19 @@ import java.sql.Timestamp import scala.collection.JavaConverters._ import com.google.common.io.{ByteStreams, Closeables} -import org.apache.hadoop.fs.{FileSystem, GlobFilter, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path} +import org.mockito.Mockito.{mock, when} import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.functions.col +import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + import BinaryFileFormat._ private var testDir: String = _ @@ -39,6 +44,8 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest private var fs: FileSystem = _ + private var file1Status: FileStatus = _ + override def beforeAll(): Unit = { super.beforeAll() @@ -51,44 +58,64 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest val year2015Dir = new File(testDir, "year=2015") year2015Dir.mkdir() + val file1 = new File(year2014Dir, "data.txt") Files.write( - new File(year2014Dir, "data.txt").toPath, + file1.toPath, Seq("2014-test").asJava, StandardOpenOption.CREATE, StandardOpenOption.WRITE ) + file1Status = fs.getFileStatus(new Path(file1.getPath)) + + val file2 = new File(year2014Dir, "data2.bin") Files.write( - new File(year2014Dir, "data2.bin").toPath, + file2.toPath, "2014-test-bin".getBytes, StandardOpenOption.CREATE, StandardOpenOption.WRITE ) + val file3 = new File(year2015Dir, "bool.csv") Files.write( - new File(year2015Dir, "bool.csv").toPath, + file3.toPath, Seq("bool", "True", "False", "true").asJava, StandardOpenOption.CREATE, StandardOpenOption.WRITE ) + + val file4 = new File(year2015Dir, "data.bin") Files.write( - new File(year2015Dir, "data.txt").toPath, + file4.toPath, "2015-test".getBytes, StandardOpenOption.CREATE, StandardOpenOption.WRITE ) } + test("BinaryFileFormat methods") { + val format = new BinaryFileFormat + assert(format.shortName() === "binaryFile") + assert(format.isSplitable(spark, Map.empty, new Path("any")) === false) + assert(format.inferSchema(spark, Map.empty, Seq.empty) === Some(BinaryFileFormat.schema)) + assert(BinaryFileFormat.schema === StructType(Seq( + StructField("path", StringType, false), + StructField("modificationTime", TimestampType, false), + StructField("length", LongType, false), + StructField("content", BinaryType, true)))) + } + def testBinaryFileDataSource(pathGlobFilter: String): Unit = { - val resultDF = spark.read.format("binaryFile") - .option("pathGlobFilter", pathGlobFilter) - .load(testDir) - .select( - col("status.path"), - col("status.modificationTime"), - col("status.length"), - col("content"), + val dfReader = spark.read.format("binaryFile") + if (pathGlobFilter != null) { + dfReader.option("pathGlobFilter", pathGlobFilter) + } + val resultDF = dfReader.load(testDir).select( + col(PATH), + col(MODIFICATION_TIME), + col(LENGTH), + col(CONTENT), col("year") // this is a partition column ) val expectedRowSet = new collection.mutable.HashSet[Row]() - val globFilter = new GlobFilter(pathGlobFilter) + val globFilter = if (pathGlobFilter == null) null else new GlobFilter(pathGlobFilter) for (partitionDirStatus <- fs.listStatus(fsTestDir)) { val dirPath = partitionDirStatus.getPath @@ -96,7 +123,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest val year = partitionName.toInt // partition column "year" value which is `Int` type for (fileStatus <- fs.listStatus(dirPath)) { - if (globFilter.accept(fileStatus.getPath)) { + if (globFilter == null || globFilter.accept(fileStatus.getPath)) { val fpath = fileStatus.getPath.toString.replace("file:/", "file:///") val flen = fileStatus.getLen val modificationTime = new Timestamp(fileStatus.getModificationTime) @@ -121,14 +148,15 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest } test("binary file data source test") { - testBinaryFileDataSource(pathGlobFilter = "*.*") - testBinaryFileDataSource(pathGlobFilter = "*.bin") - testBinaryFileDataSource(pathGlobFilter = "*.txt") - testBinaryFileDataSource(pathGlobFilter = "*.{txt,csv}") - testBinaryFileDataSource(pathGlobFilter = "*.json") + testBinaryFileDataSource(null) + testBinaryFileDataSource("*.*") + testBinaryFileDataSource("*.bin") + testBinaryFileDataSource("*.txt") + testBinaryFileDataSource("*.{txt,csv}") + testBinaryFileDataSource("*.json") } - test ("binary file data source do not support write operation") { + test("binary file data source do not support write operation") { val df = spark.read.format("binaryFile").load(testDir) withTempDir { tmpDir => val thrown = intercept[UnsupportedOperationException] { @@ -140,4 +168,122 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest } } + def mockFileStatus(length: Long, modificationTime: Long): FileStatus = { + val status = mock(classOf[FileStatus]) + when(status.getLen).thenReturn(length) + when(status.getModificationTime).thenReturn(modificationTime) + when(status.toString).thenReturn( + s"FileStatus($LENGTH=$length, $MODIFICATION_TIME=$modificationTime)") + status + } + + def testCreateFilterFunction( + filters: Seq[Filter], + testCases: Seq[(FileStatus, Boolean)]): Unit = { + val funcs = filters.map(BinaryFileFormat.createFilterFunction) + testCases.foreach { case (status, expected) => + assert(funcs.forall(f => f(status)) === expected, + s"$filters applied to $status should be $expected.") + } + } + + test("createFilterFunction") { + // test filter applied on `length` column + val l1 = mockFileStatus(1L, 0L) + val l2 = mockFileStatus(2L, 0L) + val l3 = mockFileStatus(3L, 0L) + testCreateFilterFunction( + Seq(LessThan(LENGTH, 2L)), + Seq((l1, true), (l2, false), (l3, false))) + testCreateFilterFunction( + Seq(LessThanOrEqual(LENGTH, 2L)), + Seq((l1, true), (l2, true), (l3, false))) + testCreateFilterFunction( + Seq(GreaterThan(LENGTH, 2L)), + Seq((l1, false), (l2, false), (l3, true))) + testCreateFilterFunction( + Seq(GreaterThanOrEqual(LENGTH, 2L)), + Seq((l1, false), (l2, true), (l3, true))) + testCreateFilterFunction( + Seq(EqualTo(LENGTH, 2L)), + Seq((l1, false), (l2, true), (l3, false))) + testCreateFilterFunction( + Seq(Not(EqualTo(LENGTH, 2L))), + Seq((l1, true), (l2, false), (l3, true))) + testCreateFilterFunction( + Seq(And(GreaterThan(LENGTH, 1L), LessThan(LENGTH, 3L))), + Seq((l1, false), (l2, true), (l3, false))) + testCreateFilterFunction( + Seq(Or(LessThanOrEqual(LENGTH, 1L), GreaterThanOrEqual(LENGTH, 3L))), + Seq((l1, true), (l2, false), (l3, true))) + + // test filter applied on `modificationTime` column + val t1 = mockFileStatus(0L, 1L) + val t2 = mockFileStatus(0L, 2L) + val t3 = mockFileStatus(0L, 3L) + testCreateFilterFunction( + Seq(LessThan(MODIFICATION_TIME, new Timestamp(2L))), + Seq((t1, true), (t2, false), (t3, false))) + testCreateFilterFunction( + Seq(LessThanOrEqual(MODIFICATION_TIME, new Timestamp(2L))), + Seq((t1, true), (t2, true), (t3, false))) + testCreateFilterFunction( + Seq(GreaterThan(MODIFICATION_TIME, new Timestamp(2L))), + Seq((t1, false), (t2, false), (t3, true))) + testCreateFilterFunction( + Seq(GreaterThanOrEqual(MODIFICATION_TIME, new Timestamp(2L))), + Seq((t1, false), (t2, true), (t3, true))) + testCreateFilterFunction( + Seq(EqualTo(MODIFICATION_TIME, new Timestamp(2L))), + Seq((t1, false), (t2, true), (t3, false))) + testCreateFilterFunction( + Seq(Not(EqualTo(MODIFICATION_TIME, new Timestamp(2L)))), + Seq((t1, true), (t2, false), (t3, true))) + testCreateFilterFunction( + Seq(And(GreaterThan(MODIFICATION_TIME, new Timestamp(1L)), + LessThan(MODIFICATION_TIME, new Timestamp(3L)))), + Seq((t1, false), (t2, true), (t3, false))) + testCreateFilterFunction( + Seq(Or(LessThanOrEqual(MODIFICATION_TIME, new Timestamp(1L)), + GreaterThanOrEqual(MODIFICATION_TIME, new Timestamp(3L)))), + Seq((t1, true), (t2, false), (t3, true))) + + // test filters applied on both columns + testCreateFilterFunction( + Seq(And(GreaterThan(LENGTH, 2L), LessThan(MODIFICATION_TIME, new Timestamp(2L)))), + Seq((l1, false), (l2, false), (l3, true), (t1, false), (t2, false), (t3, false))) + + // test nested filters + testCreateFilterFunction( + // NOT (length > 2 OR modificationTime < 2) + Seq(Not(Or(GreaterThan(LENGTH, 2L), LessThan(MODIFICATION_TIME, new Timestamp(2L))))), + Seq((l1, false), (l2, false), (l3, false), (t1, false), (t2, true), (t3, true))) + } + + test("buildReader") { + def testBuildReader(fileStatus: FileStatus, filters: Seq[Filter], expected: Boolean): Unit = { + val format = new BinaryFileFormat + val reader = format.buildReaderWithPartitionValues( + sparkSession = spark, + dataSchema = schema, + partitionSchema = StructType(Nil), + requiredSchema = schema, + filters = filters, + options = Map.empty, + hadoopConf = spark.sessionState.newHadoopConf()) + val partitionedFile = mock(classOf[PartitionedFile]) + when(partitionedFile.filePath).thenReturn(fileStatus.getPath.toString) + assert(reader(partitionedFile).nonEmpty === expected, + s"Filters $filters applied to $fileStatus should be $expected.") + } + testBuildReader(file1Status, Seq.empty, true) + testBuildReader(file1Status, Seq(LessThan(LENGTH, file1Status.getLen)), false) + testBuildReader(file1Status, Seq( + LessThan(MODIFICATION_TIME, new Timestamp(file1Status.getModificationTime)) + ), false) + testBuildReader(file1Status, Seq( + EqualTo(LENGTH, file1Status.getLen), + EqualTo(MODIFICATION_TIME, file1Status.getModificationTime) + ), true) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org