This is an automated email from the ASF dual-hosted git repository. viirya pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 128168d [SPARK-36645][SQL] Aggregate (Min/Max/Count) push down for Parquet 128168d is described below commit 128168d8c4019a1e10a9f1be734868524f6a09f0 Author: Huaxin Gao <huaxin_...@apple.com> AuthorDate: Sun Oct 10 22:20:09 2021 -0700 [SPARK-36645][SQL] Aggregate (Min/Max/Count) push down for Parquet ### What changes were proposed in this pull request? Push down Min/Max/Count to Parquet with the following restrictions: - nested types such as Array, Map or Struct will not be pushed down - Timestamp not pushed down because INT96 sort order is undefined, Parquet doesn't return statistics for INT96 - If the aggregate column is on partition column, only Count will be pushed, Min or Max will not be pushed down because Parquet doesn't return max/min for partition column. - If somehow the file doesn't have stats for the aggregate columns, Spark will throw Exception. - Currently, if filter/GROUP BY is involved, Min/Max/Count will not be pushed down, but the restriction will be lifted if the filter or GROUP BY is on partition column (https://issues.apache.org/jira/browse/SPARK-36646 and https://issues.apache.org/jira/browse/SPARK-36647) ### Why are the changes needed? Since parquet has the statistics information for min, max and count, we want to take advantage of this info and push down Min/Max/Count to parquet layer for better performance. ### Does this PR introduce _any_ user-facing change? Yes, `SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED` was added. If sets to true, we will push down Min/Max/Count to Parquet. ### How was this patch tested? new test suites Closes #33639 from huaxingao/parquet_agg. Authored-by: Huaxin Gao <huaxin_...@apple.com> Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com> --- .../org/apache/spark/sql/internal/SQLConf.scala | 10 + .../org/apache/spark/sql/types/StructType.scala | 2 +- .../datasources/parquet/ParquetUtils.scala | 227 +++++++++ .../execution/datasources/v2/FileScanBuilder.scala | 2 +- .../v2/parquet/ParquetPartitionReaderFactory.scala | 123 ++++- .../datasources/v2/parquet/ParquetScan.scala | 37 +- .../v2/parquet/ParquetScanBuilder.scala | 96 +++- .../scala/org/apache/spark/sql/FileScanSuite.scala | 2 +- .../parquet/ParquetAggregatePushDownSuite.scala | 518 +++++++++++++++++++++ 9 files changed, 984 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6443dfd..98aad1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -853,6 +853,14 @@ object SQLConf { .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") .createWithDefault(10) + val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown") + .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" + + " down to Parquet for optimization. MAX/MIN/COUNT for complex types and timestamp" + + " can't be pushed down") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " + "values will be written in Apache Parquet's fixed-length byte array format, which other " + @@ -3660,6 +3668,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownInFilterThreshold: Int = getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index c9862cb..50b197f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -115,7 +115,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru def names: Array[String] = fieldNames private lazy val fieldNamesSet: Set[String] = fieldNames.toSet - private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private[sql] lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap override def equals(that: Any): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index b91d75c..1093f9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -16,11 +16,28 @@ */ package org.apache.spark.sql.execution.datasources.parquet +import java.util + +import scala.collection.mutable +import scala.language.existentials + import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.ParquetFileWriter +import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata} +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{PrimitiveType, Types} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} object ParquetUtils { def inferSchema( @@ -127,4 +144,214 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to + * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want + * to get the partial aggregates (Max/Min/Count) result using the statistics information + * from Parquet footer file, and then construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + private[sql] def createAggInternalRowFromFooter( + footer: ParquetMetadata, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + isCaseSensitive: Boolean): InternalRow = { + val (primitiveTypes, values) = getPushedDownAggResult( + footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive) + + val builder = Types.buildMessage + primitiveTypes.foreach(t => builder.addField(t)) + val parquetSchema = builder.named("root") + + val schemaConverter = new ParquetToSparkSchemaConverter + val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema, + None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater) + val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName) + primitiveTypeNames.zipWithIndex.foreach { + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + val v = values(i).asInstanceOf[Boolean] + converter.getConverter(i).asPrimitiveConverter.addBoolean(v) + case (PrimitiveType.PrimitiveTypeName.INT32, i) => + val v = values(i).asInstanceOf[Integer] + converter.getConverter(i).asPrimitiveConverter.addInt(v) + case (PrimitiveType.PrimitiveTypeName.INT64, i) => + val v = values(i).asInstanceOf[Long] + converter.getConverter(i).asPrimitiveConverter.addLong(v) + case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => + val v = values(i).asInstanceOf[Float] + converter.getConverter(i).asPrimitiveConverter.addFloat(v) + case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => + val v = values(i).asInstanceOf[Double] + converter.getConverter(i).asPrimitiveConverter.addDouble(v) + case (PrimitiveType.PrimitiveTypeName.BINARY, i) => + val v = values(i).asInstanceOf[Binary] + converter.getConverter(i).asPrimitiveConverter.addBinary(v) + case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => + val v = values(i).asInstanceOf[Binary] + converter.getConverter(i).asPrimitiveConverter.addBinary(v) + case (_, i) => + throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i)) + } + converter.currentRecord + } + + /** + * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the case of + * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader + * to read data from Parquet and aggregate at Spark layer. Instead we want + * to get the aggregates (Max/Min/Count) result using the statistics information + * from Parquet footer file, and then construct a ColumnarBatch from these aggregate results. + * + * @return Aggregate results in the format of ColumnarBatch + */ + private[sql] def createAggColumnarBatchFromFooter( + footer: ParquetMetadata, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + offHeap: Boolean, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + isCaseSensitive: Boolean): ColumnarBatch = { + val row = createAggInternalRowFromFooter( + footer, + filePath, + dataSchema, + partitionSchema, + aggregation, + aggSchema, + datetimeRebaseMode, + isCaseSensitive) + val converter = new RowToColumnConverter(aggSchema) + val columnVectors = if (offHeap) { + OffHeapColumnVector.allocateColumns(1, aggSchema) + } else { + OnHeapColumnVector.allocateColumns(1, aggSchema) + } + converter.convert(row, columnVectors.toArray) + new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) + } + + /** + * Calculate the pushed down aggregates (Max/Min/Count) result using the statistics + * information from Parquet footer file. + * + * @return A tuple of `Array[PrimitiveType]` and Array[Any]. + * The first element is the Parquet PrimitiveType of the aggregate column, + * and the second element is the aggregated value. + */ + private[sql] def getPushedDownAggResult( + footer: ParquetMetadata, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + isCaseSensitive: Boolean) + : (Array[PrimitiveType], Array[Any]) = { + val footerFileMetaData = footer.getFileMetaData + val fields = footerFileMetaData.getSchema.getFields + val blocks = footer.getBlocks + val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType] + val valuesBuilder = mutable.ArrayBuilder.make[Any] + + assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down") + aggregation.aggregateExpressions.foreach { agg => + var value: Any = None + var rowCount = 0L + var isCount = false + var index = 0 + var schemaName = "" + blocks.forEach { block => + val blockMetaData = block.getColumns + agg match { + case max: Max => + val colName = max.column.fieldNames.head + index = dataSchema.fieldNames.toList.indexOf(colName) + schemaName = "max(" + colName + ")" + val currentMax = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, true) + if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) { + value = currentMax + } + case min: Min => + val colName = min.column.fieldNames.head + index = dataSchema.fieldNames.toList.indexOf(colName) + schemaName = "min(" + colName + ")" + val currentMin = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, false) + if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) { + value = currentMin + } + case count: Count => + schemaName = "count(" + count.column.fieldNames.head + ")" + rowCount += block.getRowCount + var isPartitionCol = false + if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)) + .toSet.contains(count.column.fieldNames.head)) { + isPartitionCol = true + } + isCount = true + if (!isPartitionCol) { + index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head) + // Count(*) includes the null values, but Count(colName) doesn't. + rowCount -= getNumNulls(filePath, blockMetaData, index) + } + case _: CountStar => + schemaName = "count(*)" + rowCount += block.getRowCount + isCount = true + case _ => + } + } + if (isCount) { + valuesBuilder += rowCount + primitiveTypeBuilder += Types.required(PrimitiveTypeName.INT64).named(schemaName); + } else { + valuesBuilder += value + val field = fields.get(index) + primitiveTypeBuilder += Types.required(field.asPrimitiveType.getPrimitiveTypeName) + .as(field.getLogicalTypeAnnotation) + .length(field.asPrimitiveType.getTypeLength) + .named(schemaName) + } + } + (primitiveTypeBuilder.result, valuesBuilder.result) + } + + /** + * Get the Max or Min value for ith column in the current block + * + * @return the Max or Min value + */ + private def getCurrentBlockMaxOrMin( + filePath: String, + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int, + isMax: Boolean): Any = { + val statistics = columnChunkMetaData.get(i).getStatistics + if (!statistics.hasNonNullValue) { + throw new UnsupportedOperationException(s"No min/max found for Parquet file $filePath. " + + s"Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") + } else { + if (isMax) statistics.genericGetMax else statistics.genericGetMin + } + } + + private def getNumNulls( + filePath: String, + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int): Long = { + val statistics = columnChunkMetaData.get(i).getStatistics + if (!statistics.isNumNullsSet) { + throw new UnsupportedOperationException(s"Number of nulls not set for Parquet file" + + s" $filePath. Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute" + + s" again") + } + statistics.getNumNulls; + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 309f045..2dc4137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -96,6 +96,6 @@ abstract class FileScanBuilder( private def createRequiredNameSet(): Set[String] = requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet - private val partitionNameSet: Set[String] = + val partitionNameSet: Set[String] = partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 058669b..111018b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -25,14 +25,16 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS +import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS} import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} +import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} import org.apache.spark.sql.execution.datasources.parquet._ @@ -53,6 +55,7 @@ import org.apache.spark.util.SerializableConfiguration * @param readDataSchema Required schema of Parquet files. * @param partitionSchema Schema of partitions. * @param filters Filters to be pushed down in the batch scan. + * @param aggregation Aggregation to be pushed down in the batch scan. * @param parquetOptions The options of Parquet datasource that are set for the read. */ case class ParquetPartitionReaderFactory( @@ -62,6 +65,7 @@ case class ParquetPartitionReaderFactory( readDataSchema: StructType, partitionSchema: StructType, filters: Array[Filter], + aggregation: Option[Aggregation], parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields) @@ -80,6 +84,30 @@ case class ParquetPartitionReaderFactory( private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + private def getFooter(file: PartitionedFile): ParquetMetadata = { + val conf = broadcastedConf.value.value + val filePath = new Path(new URI(file.filePath)) + + if (aggregation.isEmpty) { + ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS) + } else { + // For aggregate push down, we will get max/min/count from footer statistics. + // We want to read the footer for the whole file instead of reading multiple + // footers for every split of the file. Basically if the start (the beginning of) + // the offset in PartitionedFile is 0, we will read the footer. Otherwise, it means + // that we have already read footer for that file, so we will skip reading again. + if (file.start != 0) return null + ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) + } + } + + private def getDatetimeRebaseMode( + footerFileMetaData: FileMetaData): LegacyBehaviorPolicy.Value = { + DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + } + override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled && resultSchema.length <= sqlConf.wholeStageMaxNumFields && @@ -87,18 +115,44 @@ case class ParquetPartitionReaderFactory( } override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { - val reader = if (enableVectorizedReader) { - createVectorizedReader(file) - } else { - createRowBaseReader(file) - } + val fileReader = if (aggregation.isEmpty) { + val reader = if (enableVectorizedReader) { + createVectorizedReader(file) + } else { + createRowBaseReader(file) + } + + new PartitionReader[InternalRow] { + override def next(): Boolean = reader.nextKeyValue() - val fileReader = new PartitionReader[InternalRow] { - override def next(): Boolean = reader.nextKeyValue() + override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] - override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] + override def close(): Unit = reader.close() + } + } else { + new PartitionReader[InternalRow] { + private var hasNext = true + private lazy val row: InternalRow = { + val footer = getFooter(file) + if (footer != null && footer.getBlocks.size > 0) { + ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, dataSchema, + partitionSchema, aggregation.get, readDataSchema, + getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) + } else { + null + } + } + override def next(): Boolean = { + hasNext && row != null + } - override def close(): Unit = reader.close() + override def get(): InternalRow = { + hasNext = false + row + } + + override def close(): Unit = {} + } } new PartitionReaderWithPartitionValues(fileReader, readDataSchema, @@ -106,17 +160,45 @@ case class ParquetPartitionReaderFactory( } override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val vectorizedReader = createVectorizedReader(file) - vectorizedReader.enableReturningBatches() + val fileReader = if (aggregation.isEmpty) { + val vectorizedReader = createVectorizedReader(file) + vectorizedReader.enableReturningBatches() + + new PartitionReader[ColumnarBatch] { + override def next(): Boolean = vectorizedReader.nextKeyValue() - new PartitionReader[ColumnarBatch] { - override def next(): Boolean = vectorizedReader.nextKeyValue() + override def get(): ColumnarBatch = + vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] - override def get(): ColumnarBatch = - vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] + override def close(): Unit = vectorizedReader.close() + } + } else { + new PartitionReader[ColumnarBatch] { + private var hasNext = true + private val row: ColumnarBatch = { + val footer = getFooter(file) + if (footer != null && footer.getBlocks.size > 0) { + ParquetUtils.createAggColumnarBatchFromFooter(footer, file.filePath, dataSchema, + partitionSchema, aggregation.get, readDataSchema, enableOffHeapColumnVector, + getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) + } else { + null + } + } + + override def next(): Boolean = { + hasNext && row != null + } + + override def get(): ColumnarBatch = { + hasNext = false + row + } - override def close(): Unit = vectorizedReader.close() + override def close(): Unit = {} + } } + fileReader } private def buildReaderBase[T]( @@ -131,11 +213,8 @@ case class ParquetPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) - lazy val footerFileMetaData = - ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS).getFileMetaData - val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( - footerFileMetaData.getKeyValueMetaData.get, - datetimeRebaseModeInRead) + lazy val footerFileMetaData = getFooter(file).getFileMetaData + val datetimeRebaseMode = getDatetimeRebaseMode(footerFileMetaData) // Try to push down filters when filter push-down is enabled. val pushed = if (enableParquetFilterPushDown) { val parquetSchema = footerFileMetaData.getSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index e277e33..42dc287 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -24,6 +24,7 @@ import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} @@ -43,10 +44,17 @@ case class ParquetScan( readPartitionSchema: StructType, pushedFilters: Array[Filter], options: CaseInsensitiveStringMap, + pushedAggregate: Option[Aggregation] = None, partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true + override def readSchema(): StructType = { + // If aggregate is pushed down, schema has already been pruned in `ParquetScanBuilder` + // and no need to call super.readSchema() + if (pushedAggregate.nonEmpty) readDataSchema else super.readSchema() + } + override def createReaderFactory(): PartitionReaderFactory = { val readDataSchemaAsJson = readDataSchema.json hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) @@ -86,23 +94,46 @@ case class ParquetScan( readDataSchema, readPartitionSchema, pushedFilters, + pushedAggregate, new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf)) } override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => + val pushedDownAggEqual = if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) { + equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get) + } else { + pushedAggregate.isEmpty && p.pushedAggregate.isEmpty + } super.equals(p) && dataSchema == p.dataSchema && options == p.options && - equivalentFilters(pushedFilters, p.pushedFilters) + equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual case _ => false } override def hashCode(): Int = getClass.hashCode() + lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { + (seqToString(pushedAggregate.get.aggregateExpressions), + seqToString(pushedAggregate.get.groupByColumns)) + } else { + ("[]", "[]") + } + override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggregation: " + pushedAggregationsStr + + ", PushedGroupBy: " + pushedGroupByStr } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + Map("PushedAggregation" -> pushedAggregationsStr) ++ + Map("PushedGroupBy" -> pushedGroupByStr) + } + + private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { + a.aggregateExpressions.sortBy(_.hashCode()) + .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && + a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 9a0e4b4..c579867 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetScanBuilder( @@ -35,7 +37,8 @@ case class ParquetScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownAggregates{ lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -70,6 +73,10 @@ case class ParquetScanBuilder( } } + private var finalSchema = new StructType() + + private var pushedAggregations = Option.empty[Aggregation] + override protected val supportsNestedSchemaPruning: Boolean = true override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = dataFilters @@ -79,8 +86,87 @@ case class ParquetScanBuilder( // All filters that can be converted to Parquet are pushed down. override def pushedFilters(): Array[Filter] = pushedParquetFilters + override def pushAggregation(aggregation: Aggregation): Boolean = { + + def getStructFieldForCol(col: NamedReference): StructField = { + schema.nameToField(col.fieldNames.head) + } + + def isPartitionCol(col: NamedReference) = { + partitionNameSet.contains(col.fieldNames.head) + } + + def processMinOrMax(agg: AggregateFunc): Boolean = { + val (column, aggType) = agg match { + case max: Max => (max.column, "max") + case min: Min => (min.column, "min") + case _ => + throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") + } + + if (isPartitionCol(column)) { + // don't push down partition column, footer doesn't have max/min for partition column + return false + } + val structField = getStructFieldForCol(column) + + structField.dataType match { + // not push down complex type + // not push down Timestamp because INT96 sort order is undefined, + // Parquet doesn't return statistics for INT96 + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => + false + case _ => + finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) + true + } + } + + if (!sparkSession.sessionState.conf.parquetAggregatePushDown || + aggregation.groupByColumns.nonEmpty || dataFilters.length > 0) { + // Parquet footer has max/min/count for columns + // e.g. SELECT COUNT(col1) FROM t + // but footer doesn't have max/min/count for a column if max/min/count + // are combined with filter or group by + // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 + // SELECT COUNT(col1) FROM t GROUP BY col2 + // Todo: 1. add support if groupby column is partition col + // (https://issues.apache.org/jira/browse/SPARK-36646) + // 2. add support if filter col is partition col + // (https://issues.apache.org/jira/browse/SPARK-36647) + return false + } + + aggregation.groupByColumns.foreach { col => + if (col.fieldNames.length != 1) return false + finalSchema = finalSchema.add(getStructFieldForCol(col)) + } + + aggregation.aggregateExpressions.foreach { + case max: Max => + if (!processMinOrMax(max)) return false + case min: Min => + if (!processMinOrMax(min)) return false + case count: Count => + if (count.column.fieldNames.length != 1 || count.isDistinct) return false + finalSchema = + finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType)) + case _: CountStar => + finalSchema = finalSchema.add(StructField("count(*)", LongType)) + case _ => + return false + } + this.pushedAggregations = Some(aggregation) + true + } + override def build(): Scan = { - ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options, partitionFilters, dataFilters) + // the `finalSchema` is either pruned in pushAggregation (if aggregates are + // pushed down), or pruned in readDataSchema() (in regular column pruning). These + // two are mutual exclusive. + if (pushedAggregations.isEmpty) finalSchema = readDataSchema() + ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, + readPartitionSchema(), pushedParquetFilters, options, pushedAggregations, + partitionFilters, dataFilters) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index d0877db..604a892 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -354,7 +354,7 @@ class FileScanSuite extends FileScanSuiteBase { val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( ("ParquetScan", (s, fi, ds, rds, rps, f, o, pf, df) => - ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df), + ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, None, pf, df), Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala new file mode 100644 index 0000000..c795bd9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala @@ -0,0 +1,518 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.sql.{Date, Timestamp} + +import org.apache.spark.SparkConf +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.functions.min +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +/** + * A test suite that tests Max/Min/Count push down. + */ +abstract class ParquetAggregatePushDownSuite + extends QueryTest + with ParquetTest + with SharedSparkSession + with ExplainSuiteHelper { + import testImplicits._ + + test("aggregate push down - nested column: Max(top level column) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val max = sql("SELECT Max(_1) FROM t") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + } + } + } + + test("aggregate push down - nested column: Count(top level column) push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val count = sql("SELECT Count(_1) FROM t") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [COUNT(_1)]" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + + test("aggregate push down - nested column: Max(nested column) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val max = sql("SELECT Max(_1._2[0]) FROM t") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + } + } + } + + test("aggregate push down - nested column: Count(nested column) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val count = sql("SELECT Count(_1._2[0]) FROM t") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + + test("aggregate push down - Max(partition Col): not push dow") { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val max = sql("SELECT Max(p) FROM tmp") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + checkAnswer(max, Seq(Row(2))) + } + } + } + } + + test("aggregate push down - Count(partition Col): push down") { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + val count = sql("SELECT COUNT(p) FROM tmp") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [COUNT(p)]" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + } + } + + test("aggregate push down - Filter alias over aggregate") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res > 1") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_1), MAX(_1)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(7))) + } + } + } + + test("aggregate push down - alias over aggregate") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as minPlus2 FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_1)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-1, 0))) + } + } + } + + test("aggregate push down - aggregate over alias not push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val df = spark.table("t") + val query = df.select($"_1".as("col1")).agg(min($"col1")) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" // aggregate alias not pushed down + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + checkAnswer(query, Seq(Row(-2))) + } + } + } + + test("aggregate push down - query with group by not push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + // aggregate not pushed down if there is group by + val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-2), Row(0), Row(2), Row(3))) + } + } + } + + test("aggregate push down - query with filter not push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + // aggregate not pushed down if there is filter + val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(2))) + } + } + } + + test("aggregate push down - push down only if all the aggregates can be pushed down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + // not push down since sum can't be pushed down + val selectAgg = sql("SELECT min(_1), sum(_3) FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-2, 41))) + } + } + } + + test("aggregate push down - MIN/MAX/COUNT") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + + " count(*), count(_1), count(_2), count(_3) FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_3), " + + "MAX(_3), " + + "MIN(_1), " + + "MAX(_1), " + + "COUNT(*), " + + "COUNT(_1), " + + "COUNT(_2), " + + "COUNT(_3)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + + checkAnswer(selectAgg, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6))) + } + } + } + + test("aggregate push down - different data types") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + implicit class StringToTs(s: String) { + def ts: Timestamp = Timestamp.valueOf(s) + } + + val rows = + Seq( + Row( + "a string", + true, + 10.toByte, + "Spark SQL".getBytes, + 12.toShort, + 3, + Long.MaxValue, + 0.15.toFloat, + 0.75D, + Decimal("12.345678"), + ("2021-01-01").date, + ("2015-01-01 23:50:59.123").ts), + Row( + "test string", + false, + 1.toByte, + "Parquet".getBytes, + 2.toShort, + null, + Long.MinValue, + 0.25.toFloat, + 0.85D, + Decimal("1.2345678"), + ("2015-01-01").date, + ("2021-01-01 23:50:59.123").ts), + Row( + null, + true, + 10000.toByte, + "Spark ML".getBytes, + 222.toShort, + 113, + 11111111L, + 0.25.toFloat, + 0.75D, + Decimal("12345.678"), + ("2004-06-19").date, + ("1999-08-26 10:43:59.123").ts) + ) + + val schema = StructType(List(StructField("StringCol", StringType, true), + StructField("BooleanCol", BooleanType, false), + StructField("ByteCol", ByteType, false), + StructField("BinaryCol", BinaryType, false), + StructField("ShortCol", ShortType, false), + StructField("IntegerCol", IntegerType, true), + StructField("LongCol", LongType, false), + StructField("FloatCol", FloatType, false), + StructField("DoubleCol", DoubleType, false), + StructField("DecimalCol", DecimalType(25, 5), true), + StructField("DateCol", DateType, false), + StructField("TimestampCol", TimestampType, false)).toArray) + + val rdd = sparkContext.parallelize(rows) + withTempPath { file => + spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath) + withTempView("test") { + spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test") + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + + val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + testMinWithTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMinWithTS, expected_plan_fragment) + } + + checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts))) + + val testMinWithOutTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test") + + testMinWithOutTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(StringCol), " + + "MIN(BooleanCol), " + + "MIN(ByteCol), " + + "MIN(BinaryCol), " + + "MIN(ShortCol), " + + "MIN(IntegerCol), " + + "MIN(LongCol), " + + "MIN(FloatCol), " + + "MIN(DoubleCol), " + + "MIN(DecimalCol), " + + "MIN(DateCol)]" + checkKeywordsExistsInExplain(testMinWithOutTS, expected_plan_fragment) + } + + checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + ("2004-06-19").date))) + + val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + testMaxWithTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMaxWithTS, expected_plan_fragment) + } + + checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) + + val testMaxWithoutTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test") + + testMaxWithoutTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MAX(StringCol), " + + "MAX(BooleanCol), " + + "MAX(ByteCol), " + + "MAX(BinaryCol), " + + "MAX(ShortCol), " + + "MAX(IntegerCol), " + + "MAX(LongCol), " + + "MAX(FloatCol), " + + "MAX(DoubleCol), " + + "MAX(DecimalCol), " + + "MAX(DateCol)]" + checkKeywordsExistsInExplain(testMaxWithoutTS, expected_plan_fragment) + } + + checkAnswer(testMaxWithoutTS, Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date))) + + val testCount = sql("SELECT count(StringCol), count(BooleanCol)," + + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + + " count(LongCol), count(FloatCol), count(DoubleCol)," + + " count(DecimalCol), count(DateCol), count(TimestampCol) FROM test") + + testCount.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [" + + "COUNT(StringCol), " + + "COUNT(BooleanCol), " + + "COUNT(ByteCol), " + + "COUNT(BinaryCol), " + + "COUNT(ShortCol), " + + "COUNT(IntegerCol), " + + "COUNT(LongCol), " + + "COUNT(FloatCol), " + + "COUNT(DoubleCol), " + + "COUNT(DecimalCol), " + + "COUNT(DateCol), " + + "COUNT(TimestampCol)]" + checkKeywordsExistsInExplain(testCount, expected_plan_fragment) + } + + checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))) + } + } + } + } + } + + test("aggregate push down - column name case sensitivity") { + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MAX(id), MIN(id)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(9, 0))) + } + } + } + } + } +} + +class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "parquet") +} + +class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "") +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org