viirya commented on a change in pull request #33639:
URL: https://github.com/apache/spark/pull/33639#discussion_r699878387
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
##########
@@ -585,8 +585,8 @@ private[sql] object ParquetSchemaConverter {
Types.buildMessage().named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME)
def checkFieldName(name: String): Unit = {
- // ,;{}()\n\t= and space are special characters in Parquet schema
- if (name.matches(".*[ ,;{}()\n\t=].*")) {
+ // ,;{}\n\t= and space are special characters in Parquet schema
+ if (name.matches(".*[ ,;{}\n\t=].*")) {
Review comment:
Not sure if `()` will create some issues at Parquet side?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
##########
@@ -127,4 +144,207 @@ object ParquetUtils {
file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE ||
file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
}
+
+ /**
+ * When the partial aggregates (Max/Min/Count) are pushed down to Parquet,
we don't need to
+ * createRowBaseReader to read data from Parquet and aggregate at Spark
layer. Instead we want
+ * to get the partial aggregates (Max/Min/Count) result using the statistics
information
+ * from Parquet footer file, and then construct an InternalRow from these
aggregate results.
+ *
+ * @return Aggregate results in the format of InternalRow
+ */
+ private[sql] def createAggInternalRowFromFooter(
+ footer: ParquetMetadata,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+ isCaseSensitive: Boolean): InternalRow = {
+ val (primitiveType, values) =
+ getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation,
isCaseSensitive)
+
+ val builder = Types.buildMessage()
+ primitiveType.foreach(t => builder.addField(t))
+ val parquetSchema = builder.named("root")
+
+ val schemaConverter = new ParquetToSparkSchemaConverter
+ val converter = new ParquetRowConverter(schemaConverter, parquetSchema,
aggSchema,
+ None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater)
+ val primitiveTypeName = primitiveType.map(_.getPrimitiveTypeName)
+ primitiveTypeName.zipWithIndex.foreach {
+ case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) =>
+ val v = values(i).asInstanceOf[Boolean]
+ converter.getConverter(i).asPrimitiveConverter().addBoolean(v)
+ case (PrimitiveType.PrimitiveTypeName.INT32, i) =>
+ val v = values(i).asInstanceOf[Integer]
+ converter.getConverter(i).asPrimitiveConverter().addInt(v)
+ case (PrimitiveType.PrimitiveTypeName.INT64, i) =>
+ val v = values(i).asInstanceOf[Long]
+ converter.getConverter(i).asPrimitiveConverter().addLong(v)
+ case (PrimitiveType.PrimitiveTypeName.FLOAT, i) =>
+ val v = values(i).asInstanceOf[Float]
+ converter.getConverter(i).asPrimitiveConverter().addFloat(v)
+ case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) =>
+ val v = values(i).asInstanceOf[Double]
+ converter.getConverter(i).asPrimitiveConverter().addDouble(v)
+ case (PrimitiveType.PrimitiveTypeName.BINARY, i) =>
+ val v = values(i).asInstanceOf[Binary]
+ converter.getConverter(i).asPrimitiveConverter().addBinary(v)
+ case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) =>
+ val v = values(i).asInstanceOf[Binary]
+ converter.getConverter(i).asPrimitiveConverter().addBinary(v)
+ case (_, i) =>
+ throw new SparkException("Unexpected parquet type name: " +
primitiveTypeName(i))
+ }
+ converter.currentRecord
+ }
+
+ /**
+ * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the
case of
+ * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need
buildColumnarReader
+ * to read data from Parquet and aggregate at Spark layer. Instead we want
+ * to get the aggregates (Max/Min/Count) result using the statistics
information
+ * from Parquet footer file, and then construct a ColumnarBatch from these
aggregate results.
+ *
+ * @return Aggregate results in the format of ColumnarBatch
+ */
+ private[sql] def createAggColumnarBatchFromFooter(
+ footer: ParquetMetadata,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ columnBatchSize: Int,
+ offHeap: Boolean,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+ isCaseSensitive: Boolean): ColumnarBatch = {
+ val row = createAggInternalRowFromFooter(
+ footer,
+ dataSchema,
+ partitionSchema,
+ aggregation,
+ aggSchema,
+ datetimeRebaseMode,
+ isCaseSensitive)
+ val converter = new RowToColumnConverter(aggSchema)
+ val columnVectors = if (offHeap) {
+ OffHeapColumnVector.allocateColumns(columnBatchSize, aggSchema)
+ } else {
+ OnHeapColumnVector.allocateColumns(columnBatchSize, aggSchema)
Review comment:
If we only need one row for aggregation results, maybe replace
`columnBatchSize` with 1?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
##########
@@ -127,4 +144,255 @@ object ParquetUtils {
file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE ||
file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
}
+
+ /**
+ * When the partial Aggregates (Max/Min/Count) are pushed down to parquet,
we don't need to
+ * createRowBaseReader to read data from parquet and aggregate at spark
layer. Instead we want
+ * to get the partial Aggregates (Max/Min/Count) result using the statistics
information
+ * from parquet footer file, and then construct an InternalRow from these
Aggregate results.
+ *
+ * @return Aggregate results in the format of InternalRow
+ */
+ private[sql] def createInternalRowFromAggResult(
+ footer: ParquetMetadata,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ datetimeRebaseModeInRead: String,
+ isCaseSensitive: Boolean): InternalRow = {
+ val (parquetTypes, values) =
+ getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation,
isCaseSensitive)
+ val mutableRow = new SpecificInternalRow(aggSchema.fields.map(x =>
x.dataType))
+ val footerFileMetaData = footer.getFileMetaData
+ val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode(
+ footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead)
+
+ parquetTypes.zipWithIndex.foreach {
+ case (PrimitiveType.PrimitiveTypeName.INT32, i) =>
+ aggSchema.fields(i).dataType match {
+ case ByteType =>
+ mutableRow.setByte(i, values(i).asInstanceOf[Integer].toByte)
+ case ShortType =>
+ mutableRow.setShort(i, values(i).asInstanceOf[Integer].toShort)
+ case IntegerType =>
+ mutableRow.setInt(i, values(i).asInstanceOf[Integer])
+ case DateType =>
+ val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead(
+ datetimeRebaseMode, "Parquet")
+ mutableRow.update(i,
dateRebaseFunc(values(i).asInstanceOf[Integer]))
+ case d: DecimalType =>
+ val decimal = Decimal(values(i).asInstanceOf[Integer].toLong,
d.precision, d.scale)
+ mutableRow.setDecimal(i, decimal, d.precision)
+ case _ => throw new SparkException("Unexpected type for INT32")
+ }
+ case (PrimitiveType.PrimitiveTypeName.INT64, i) =>
+ aggSchema.fields(i).dataType match {
+ case LongType =>
+ mutableRow.setLong(i, values(i).asInstanceOf[Long])
+ case d: DecimalType =>
+ val decimal = Decimal(values(i).asInstanceOf[Long], d.precision,
d.scale)
+ mutableRow.setDecimal(i, decimal, d.precision)
+ case _ => throw new SparkException("Unexpected type for INT64")
+ }
+ case (PrimitiveType.PrimitiveTypeName.FLOAT, i) =>
+ mutableRow.setFloat(i, values(i).asInstanceOf[Float])
+ case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) =>
+ mutableRow.setDouble(i, values(i).asInstanceOf[Double])
+ case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) =>
+ mutableRow.setBoolean(i, values(i).asInstanceOf[Boolean])
+ case (PrimitiveType.PrimitiveTypeName.BINARY, i) =>
+ val bytes = values(i).asInstanceOf[Binary].getBytes
+ aggSchema.fields(i).dataType match {
+ case StringType =>
+ mutableRow.update(i, UTF8String.fromBytes(bytes))
+ case BinaryType =>
+ mutableRow.update(i, bytes)
+ case d: DecimalType =>
+ val decimal =
+ Decimal(new BigDecimal(new BigInteger(bytes), d.scale),
d.precision, d.scale)
+ mutableRow.setDecimal(i, decimal, d.precision)
+ case _ => throw new SparkException("Unexpected type for Binary")
+ }
+ case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) =>
+ val bytes = values(i).asInstanceOf[Binary].getBytes
+ aggSchema.fields(i).dataType match {
+ case d: DecimalType =>
+ val decimal =
+ Decimal(new BigDecimal(new BigInteger(bytes), d.scale),
d.precision, d.scale)
+ mutableRow.setDecimal(i, decimal, d.precision)
+ case _ => throw new SparkException("Unexpected type for
FIXED_LEN_BYTE_ARRAY")
+ }
+ case _ =>
+ throw new SparkException("Unexpected parquet type name")
+ }
+ mutableRow
+ }
+
+ /**
+ * When the Aggregates (Max/Min/Count) are pushed down to parquet, in the
case of
+ * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need
buildColumnarReader
+ * to read data from parquet and aggregate at spark layer. Instead we want
+ * to get the Aggregates (Max/Min/Count) result using the statistics
information
+ * from parquet footer file, and then construct a ColumnarBatch from these
Aggregate results.
+ *
+ * @return Aggregate results in the format of ColumnarBatch
+ */
+ private[sql] def createColumnarBatchFromAggResult(
+ footer: ParquetMetadata,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ offHeap: Boolean,
+ datetimeRebaseModeInRead: String,
+ isCaseSensitive: Boolean): ColumnarBatch = {
+ val (parquetTypes, values) =
+ getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation,
isCaseSensitive)
+ val capacity = 4 * 1024
+ val footerFileMetaData = footer.getFileMetaData
+ val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode(
+ footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead)
+ val columnVectors = if (offHeap) {
+ OffHeapColumnVector.allocateColumns(capacity, aggSchema)
+ } else {
+ OnHeapColumnVector.allocateColumns(capacity, aggSchema)
+ }
+
+ parquetTypes.zipWithIndex.foreach {
+ case (PrimitiveType.PrimitiveTypeName.INT32, i) =>
+ aggSchema.fields(i).dataType match {
+ case ByteType =>
+ columnVectors(i).appendByte(values(i).asInstanceOf[Integer].toByte)
+ case ShortType =>
+
columnVectors(i).appendShort(values(i).asInstanceOf[Integer].toShort)
+ case IntegerType =>
+ columnVectors(i).appendInt(values(i).asInstanceOf[Integer])
+ case DateType =>
+ val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead(
+ datetimeRebaseMode, "Parquet")
+
columnVectors(i).appendInt(dateRebaseFunc(values(i).asInstanceOf[Integer]))
+ case _ => throw new SparkException("Unexpected type for INT32")
+ }
+ case (PrimitiveType.PrimitiveTypeName.INT64, i) =>
+ columnVectors(i).appendLong(values(i).asInstanceOf[Long])
+ case (PrimitiveType.PrimitiveTypeName.FLOAT, i) =>
+ columnVectors(i).appendFloat(values(i).asInstanceOf[Float])
+ case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) =>
+ columnVectors(i).appendDouble(values(i).asInstanceOf[Double])
+ case (PrimitiveType.PrimitiveTypeName.BINARY, i) =>
+ val bytes = values(i).asInstanceOf[Binary].getBytes
+ columnVectors(i).putByteArray(0, bytes, 0, bytes.length)
+ case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) =>
+ val bytes = values(i).asInstanceOf[Binary].getBytes
+ columnVectors(i).putByteArray(0, bytes, 0, bytes.length)
+ case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) =>
+ columnVectors(i).appendBoolean(values(i).asInstanceOf[Boolean])
+ case _ =>
+ throw new SparkException("Unexpected parquet type name")
+ }
+ new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
+ }
+
+ /**
+ * Calculate the pushed down Aggregates (Max/Min/Count) result using the
statistics
+ * information from parquet footer file.
+ *
+ * @return A tuple of `Array[PrimitiveType.PrimitiveTypeName]` and
Array[Any].
+ * The first element is the PrimitiveTypeName of the Aggregate
column,
+ * and the second element is the aggregated value.
+ */
+ private[sql] def getPushedDownAggResult(
+ footer: ParquetMetadata,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ isCaseSensitive: Boolean)
+ : (Array[PrimitiveType.PrimitiveTypeName], Array[Any]) = {
+ val footerFileMetaData = footer.getFileMetaData
+ val fields = footerFileMetaData.getSchema.getFields
+ val blocks = footer.getBlocks()
+ val typesBuilder = ArrayBuilder.make[PrimitiveType.PrimitiveTypeName]
+ val valuesBuilder = ArrayBuilder.make[Any]
+
+ aggregation.aggregateExpressions().foreach { agg =>
+ var value: Any = None
+ var rowCount = 0L
+ var isCount = false
+ var index = 0
+ blocks.forEach { block =>
+ val blockMetaData = block.getColumns()
+ agg match {
+ case max: Max =>
+ index =
dataSchema.fieldNames.toList.indexOf(max.column.fieldNames.head)
+ val currentMax = getCurrentBlockMaxOrMin(blockMetaData, index,
true)
+ if (currentMax != None &&
+ (value == None ||
currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) {
+ value = currentMax
+ }
+ case min: Min =>
+ index =
dataSchema.fieldNames.toList.indexOf(min.column.fieldNames.head)
+ val currentMin = getCurrentBlockMaxOrMin(blockMetaData, index,
false)
+ if (currentMin != None &&
+ (value == None ||
currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) {
+ value = currentMin
+ }
+ case count: Count =>
+ rowCount += block.getRowCount
+ var isPartitionCol = false;
+ if (partitionSchema.fields.map(PartitioningUtils.getColName(_,
isCaseSensitive))
+ .toSet.contains(count.column().fieldNames.head)) {
+ isPartitionCol = true
+ }
+ isCount = true
+ if(!isPartitionCol) {
+ index =
dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head)
+ // Count(*) includes the null values, but Count (colName)
doesn't.
+ rowCount -= getNumNulls(blockMetaData, index)
+ }
+ case _: CountStar =>
+ rowCount += block.getRowCount
+ isCount = true
+ case _ =>
+ }
+ }
+ if (isCount) {
+ valuesBuilder += rowCount
+ typesBuilder += PrimitiveType.PrimitiveTypeName.INT64
+ } else {
+ valuesBuilder += value
+ typesBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName
+ }
+ }
+ (typesBuilder.result(), valuesBuilder.result())
+ }
+
+ /**
+ * get the Max or Min value for ith column in the current block
+ *
+ * @return the Max or Min value
+ */
+ private def getCurrentBlockMaxOrMin(
+ columnChunkMetaData: util.List[ColumnChunkMetaData],
+ i: Int,
+ isMax: Boolean): Any = {
+ val statistics = columnChunkMetaData.get(i).getStatistics()
+ if (!statistics.hasNonNullValue) {
+ throw new UnsupportedOperationException("No min/max found for parquet
file, Set SQLConf" +
+ " PARQUET_AGGREGATE_PUSHDOWN_ENABLED to false and execute again")
+ } else {
+ if (isMax) statistics.genericGetMax() else statistics.genericGetMin()
+ }
+ }
+
+ private def getNumNulls(
+ columnChunkMetaData: util.List[ColumnChunkMetaData],
+ i: Int): Long = {
+ val statistics = columnChunkMetaData.get(i).getStatistics()
+ if (!statistics.isNumNullsSet()) {
Review comment:
You only run `getNumNulls` on non partition column, so do we still need
this check?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
##########
@@ -87,36 +110,73 @@ case class ParquetPartitionReaderFactory(
}
override def buildReader(file: PartitionedFile):
PartitionReader[InternalRow] = {
- val reader = if (enableVectorizedReader) {
- createVectorizedReader(file)
+ val fileReader = if (aggregation.isEmpty) {
+ val reader = if (enableVectorizedReader) {
+ createVectorizedReader(file)
+ } else {
+ createRowBaseReader(file)
+ }
+
+ new PartitionReader[InternalRow] {
+ override def next(): Boolean = reader.nextKeyValue()
+
+ override def get(): InternalRow =
reader.getCurrentValue.asInstanceOf[InternalRow]
+
+ override def close(): Unit = reader.close()
+ }
} else {
- createRowBaseReader(file)
- }
+ new PartitionReader[InternalRow] {
+ var hasNext = true
Review comment:
`private`?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
##########
@@ -87,36 +110,73 @@ case class ParquetPartitionReaderFactory(
}
override def buildReader(file: PartitionedFile):
PartitionReader[InternalRow] = {
- val reader = if (enableVectorizedReader) {
- createVectorizedReader(file)
+ val fileReader = if (aggregation.isEmpty) {
+ val reader = if (enableVectorizedReader) {
+ createVectorizedReader(file)
+ } else {
+ createRowBaseReader(file)
+ }
+
+ new PartitionReader[InternalRow] {
+ override def next(): Boolean = reader.nextKeyValue()
+
+ override def get(): InternalRow =
reader.getCurrentValue.asInstanceOf[InternalRow]
+
+ override def close(): Unit = reader.close()
+ }
} else {
- createRowBaseReader(file)
- }
+ new PartitionReader[InternalRow] {
+ var hasNext = true
- val fileReader = new PartitionReader[InternalRow] {
- override def next(): Boolean = reader.nextKeyValue()
+ override def next(): Boolean = hasNext
- override def get(): InternalRow =
reader.getCurrentValue.asInstanceOf[InternalRow]
+ override def get(): InternalRow = {
+ hasNext = false
+ val footer = getFooter(file)
+ ParquetUtils.createAggInternalRowFromFooter(footer, dataSchema,
partitionSchema,
+ aggregation.get, readDataSchema,
getDatetimeRebaseMode(footer.getFileMetaData),
+ isCaseSensitive)
+ }
- override def close(): Unit = reader.close()
+ override def close(): Unit = {}
+ }
}
new PartitionReaderWithPartitionValues(fileReader, readDataSchema,
partitionSchema, file.partitionValues)
}
override def buildColumnarReader(file: PartitionedFile):
PartitionReader[ColumnarBatch] = {
- val vectorizedReader = createVectorizedReader(file)
- vectorizedReader.enableReturningBatches()
+ val fileReader = if (aggregation.isEmpty) {
+ val vectorizedReader = createVectorizedReader(file)
+ vectorizedReader.enableReturningBatches()
- new PartitionReader[ColumnarBatch] {
- override def next(): Boolean = vectorizedReader.nextKeyValue()
+ new PartitionReader[ColumnarBatch] {
+ override def next(): Boolean = vectorizedReader.nextKeyValue()
- override def get(): ColumnarBatch =
- vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
+ override def get(): ColumnarBatch =
+ vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
- override def close(): Unit = vectorizedReader.close()
+ override def close(): Unit = vectorizedReader.close()
+ }
+ } else {
+ new PartitionReader[ColumnarBatch] {
+ var hasNext = true
Review comment:
`private`?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -80,8 +88,84 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: FieldReference): StructField = {
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ }
+
+ def isPartitionCol(col: FieldReference) = {
+ (readPartitionSchema().fields.map(PartitioningUtils
+ .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis))
+ .toSet.contains(col.fieldNames.head))
+ }
+
+ def checkMinMax(agg: AggregateFunc): Boolean = {
+ val (column, aggType) = agg match {
+ case max: Max => (max.column(), "max")
+ case min: Min => (min.column(), "min")
+ case _ => throw new IllegalArgumentException(s"Unexpected type of
AggregateFunc")
Review comment:
nit: not need s"".
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -80,8 +88,84 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: FieldReference): StructField = {
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ }
+
+ def isPartitionCol(col: FieldReference) = {
+ (readPartitionSchema().fields.map(PartitioningUtils
+ .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis))
+ .toSet.contains(col.fieldNames.head))
+ }
+
+ def checkMinMax(agg: AggregateFunc): Boolean = {
+ val (column, aggType) = agg match {
+ case max: Max => (max.column(), "max")
+ case min: Min => (min.column(), "min")
+ case _ => throw new IllegalArgumentException(s"Unexpected type of
AggregateFunc")
+ }
+
+ if (column.fieldNames.length != 1 || isPartitionCol(column)) {
+ return false
+ }
Review comment:
Add a comment for this case?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -80,8 +88,84 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: FieldReference): StructField = {
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ }
+
+ def isPartitionCol(col: FieldReference) = {
+ (readPartitionSchema().fields.map(PartitioningUtils
+ .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis))
+ .toSet.contains(col.fieldNames.head))
+ }
+
+ def checkMinMax(agg: AggregateFunc): Boolean = {
+ val (column, aggType) = agg match {
+ case max: Max => (max.column(), "max")
+ case min: Min => (min.column(), "min")
+ case _ => throw new IllegalArgumentException(s"Unexpected type of
AggregateFunc")
+ }
+
+ if (column.fieldNames.length != 1 || isPartitionCol(column)) {
+ return false
+ }
+ val structField = getStructFieldForCol(column)
+
+ structField.dataType match {
+ // not push down nested type
Review comment:
nested type? I think you mean complex type?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -80,8 +88,84 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: FieldReference): StructField = {
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ }
+
+ def isPartitionCol(col: FieldReference) = {
+ (readPartitionSchema().fields.map(PartitioningUtils
+ .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis))
+ .toSet.contains(col.fieldNames.head))
+ }
+
+ def checkMinMax(agg: AggregateFunc): Boolean = {
+ val (column, aggType) = agg match {
+ case max: Max => (max.column(), "max")
+ case min: Min => (min.column(), "min")
+ case _ => throw new IllegalArgumentException(s"Unexpected type of
AggregateFunc")
+ }
+
+ if (column.fieldNames.length != 1 || isPartitionCol(column)) {
+ return false
+ }
+ val structField = getStructFieldForCol(column)
+
+ structField.dataType match {
+ // not push down nested type
+ // not push down Timestamp because INT96 sort order is undefined,
+ // Parquet doesn't return statistics for INT96
+ case StructType(_) | ArrayType(_, _) | MapType(_, _, _) |
TimestampType =>
+ false
+ case _ =>
+ finalSchema = finalSchema.add(structField.copy(s"$aggType(" +
structField.name + ")"))
+ true
+ }
+ }
+
+ if (!sparkSession.sessionState.conf.parquetAggregatePushDown ||
+ // Parquet footer has max/min/count for columns
+ // e.g. SELECT COUNT(col1) FROM t
+ // but footer doesn't have max/min/count for a column if max/min/count
+ // are combined with filter or group by
+ // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
+ // SELECT COUNT(col1) FROM t GROUP BY col2
+ // Todo: 1. add support if groupby column is partition col
+ // 2. add support if filter col is partition col
Review comment:
It is better to create JIRAs and put the JIRA numbers here.
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
##########
@@ -858,6 +858,13 @@ object SQLConf {
.checkValue(threshold => threshold >= 0, "The threshold must not be
negative.")
.createWithDefault(10)
+ val PARQUET_AGGREGATE_PUSHDOWN_ENABLED =
buildConf("spark.sql.parquet.aggregatePushdown")
+ .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" +
+ " down to Parquet for optimization. ")
Review comment:
There are other limitations? E.g. not support complex type?
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
##########
@@ -858,6 +858,13 @@ object SQLConf {
.checkValue(threshold => threshold >= 0, "The threshold must not be
negative.")
.createWithDefault(10)
+ val PARQUET_AGGREGATE_PUSHDOWN_ENABLED =
buildConf("spark.sql.parquet.aggregatePushdown")
+ .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" +
+ " down to Parquet for optimization. ")
Review comment:
There are other limitations? E.g. not support complex type? Timestamp
type?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -80,8 +87,82 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: FieldReference): StructField = {
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ }
+
+ def isPartitionCol(col: FieldReference) = {
+ if (readPartitionSchema().fields.map(PartitioningUtils
+ .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis))
+ .toSet.contains(col.fieldNames.head)) {
+ true
+ } else {
+ false
+ }
+ }
+
+ if (!sparkSession.sessionState.conf.parquetAggregatePushDown ||
+ // parquet footer has max/min/count for columns
+ // e.g. SELECT COUNT(col1) FROM t
+ // but footer doesn't have max/min/count for a column if max/min/count
+ // are combined with filter or group by
+ // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
+ // SELECT COUNT(col1) FROM t GROUP BY col2
+ // Todo: 1. add support if groupby column is partition col
+ // 2. add support if filter col is partition col
+ aggregation.groupByColumns.nonEmpty || filters.length > 0) {
+ return false
+ }
+
+ aggregation.groupByColumns.foreach { col =>
Review comment:
It is weird to put it here as you don't use it for now.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -80,8 +88,84 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: FieldReference): StructField = {
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ }
+
+ def isPartitionCol(col: FieldReference) = {
+ (readPartitionSchema().fields.map(PartitioningUtils
+ .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis))
+ .toSet.contains(col.fieldNames.head))
+ }
+
+ def checkMinMax(agg: AggregateFunc): Boolean = {
Review comment:
`checkMinMax` is not accurate. Could you rename it? E.g.,
`isAllowedMinMaxAgg`?
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
##########
@@ -40,7 +42,8 @@ import org.apache.spark.util.Utils
/**
* A test suite that tests various Parquet queries.
*/
-abstract class ParquetQuerySuite extends QueryTest with ParquetTest with
SharedSparkSession {
+abstract class ParquetQuerySuite extends QueryTest with ParquetTest with
SharedSparkSession
+ with ExplainSuiteHelper{
Review comment:
This change can be reverted now?
##########
File path:
sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala
##########
@@ -206,7 +206,9 @@ class HiveParquetSourceSuite extends
ParquetPartitioningTest {
}
}
- test("Aggregation attribute names can't contain special chars \"
,;{}()\\n\\t=\"") {
+ // After pushing down aggregate to parquet, we can have something like
MAX(C) in column name
+ // ignore this test for now
Review comment:
Hmm, still not sure if this will be an issue or not.
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala
##########
@@ -0,0 +1,521 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
+import org.apache.spark.sql.functions.min
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+
+/**
+ * A test suite that tests Max/Min/Count push down.
+ */
+abstract class ParquetAggregatePushDownSuite
+ extends QueryTest
+ with ParquetTest
+ with SharedSparkSession
+ with ExplainSuiteHelper {
+ import testImplicits._
+
+ test("aggregate push down - nested column: Max(top level column) not push
down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val max = sql("SELECT Max(_1) FROM t")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - nested column: Count(top level column) push
down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val count = sql("SELECT Count(_1) FROM t")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [COUNT(_1)]"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+
+ test("aggregate push down - nested column: Max(nested column) not push
down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val max = sql("SELECT Max(_1._2[0]) FROM t")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - nested column: Count(nested column) not push
down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val count = sql("SELECT Count(_1._2[0]) FROM t")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+
+ test("aggregate push down - Max(partition Col): not push dow") {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").parquet(dir.getCanonicalPath)
+ withTempView("tmp") {
+
spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val max = sql("SELECT Max(p) FROM tmp")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ checkAnswer(max, Seq(Row(2)))
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - Count(partition Col): push down") {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").parquet(dir.getCanonicalPath)
+ withTempView("tmp") {
+
spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+ val enableVectorizedReader = Seq("false", "true")
+ for (testVectorizedReader <- enableVectorizedReader) {
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+ vectorizedReaderEnabledKey -> testVectorizedReader) {
+ val count = sql("SELECT COUNT(p) FROM tmp")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [COUNT(p)]"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - Filter alias over aggregate") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res
> 1")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_1), MAX(_1)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(7)))
+ }
+ }
+ }
+
+ test("aggregate push down - alias over aggregate") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as
minPlus2 FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_1)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-1, 0)))
+ }
+ }
+ }
+
+ test("aggregate push down - aggregate over alias not push down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val df = spark.table("t")
+ val query = df.select($"_1".as("col1")).agg(min($"col1"))
+ query.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []" // aggregate alias not pushed down
+ checkKeywordsExistsInExplain(query, expected_plan_fragment)
+ }
+ checkAnswer(query, Seq(Row(-2)))
+ }
+ }
+ }
+
+ test("aggregate push down - query with group by not push down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ // aggregate not pushed down if there is group by
+ val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-2), Row(0), Row(2), Row(3)))
+ }
+ }
+ }
+
+ test("aggregate push down - query with filter not push down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ // aggregate not pushed down if there is filter
+ val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(2)))
+ }
+ }
+ }
+
+ test("aggregate push down - push down only if all the aggregates can be
pushed down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ // not push down since sum can't be pushed down
+ val selectAgg = sql("SELECT min(_1), sum(_3) FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-2, 41)))
+ }
+ }
+ }
+
+ test("aggregate push down - MIN/MAX/COUNT") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1),
max(_1), max(_1)," +
+ " count(*), count(_1), count(_2), count(_3) FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_3), " +
+ "MAX(_3), " +
+ "MIN(_1), " +
+ "MAX(_1), " +
+ "COUNT(*), " +
+ "COUNT(_1), " +
+ "COUNT(_2), " +
+ "COUNT(_3)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+
+ checkAnswer(selectAgg, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6)))
+ }
+ }
+ }
+
+ test("aggregate push down - different data types") {
+ implicit class StringToDate(s: String) {
+ def date: Date = Date.valueOf(s)
+ }
+
+ implicit class StringToTs(s: String) {
+ def ts: Timestamp = Timestamp.valueOf(s)
+ }
+
+ val rows =
+ Seq(
+ Row(
+ "a string",
+ true,
+ 10.toByte,
+ "Spark SQL".getBytes,
+ 12.toShort,
+ 3,
+ Long.MaxValue,
+ 0.15.toFloat,
+ 0.75D,
+ Decimal("12.345678"),
+ ("2021-01-01").date,
+ ("2015-01-01 23:50:59.123").ts),
+ Row(
+ "test string",
+ false,
+ 1.toByte,
+ "Parquet".getBytes,
+ 2.toShort,
+ null,
+ Long.MinValue,
+ 0.25.toFloat,
+ 0.85D,
+ Decimal("1.2345678"),
+ ("2015-01-01").date,
+ ("2021-01-01 23:50:59.123").ts),
+ Row(
+ null,
+ true,
+ 10000.toByte,
+ "Spark ML".getBytes,
+ 222.toShort,
+ 113,
+ 11111111L,
+ 0.25.toFloat,
+ 0.75D,
+ Decimal("12345.678"),
+ ("2004-06-19").date,
+ ("1999-08-26 10:43:59.123").ts)
+ )
+
+ val schema = StructType(List(StructField("StringCol", StringType, true),
+ StructField("BooleanCol", BooleanType, false),
+ StructField("ByteCol", ByteType, false),
+ StructField("BinaryCol", BinaryType, false),
+ StructField("ShortCol", ShortType, false),
+ StructField("IntegerCol", IntegerType, true),
+ StructField("LongCol", LongType, false),
+ StructField("FloatCol", FloatType, false),
+ StructField("DoubleCol", DoubleType, false),
+ StructField("DecimalCol", DecimalType(25, 5), true),
+ StructField("DateCol", DateType, false),
+ StructField("TimestampCol", TimestampType, false)).toArray)
+
+ val rdd = sparkContext.parallelize(rows)
+ withTempPath { file =>
+ spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath)
+ withTempView("test") {
+
spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test")
+ val enableVectorizedReader = Seq("false", "true")
+ for (testVectorizedReader <- enableVectorizedReader) {
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+ vectorizedReaderEnabledKey -> testVectorizedReader) {
+
+ val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol),
min(ByteCol), " +
+ "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol),
min(FloatCol), " +
+ "min(DoubleCol), min(DecimalCol), min(DateCol),
min(TimestampCol) FROM test")
+
+ // INT96 (Timestamp) sort order is undefined, parquet doesn't
return stats for this type
+ // so aggregates are not pushed down
+ testMinWithTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(testMinWithTS,
expected_plan_fragment)
+ }
+
+ checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte,
"Parquet".getBytes,
+ 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D,
1.23457,
+ ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts)))
+
+ val testMinWithOutTS = sql("SELECT min(StringCol),
min(BooleanCol), min(ByteCol), " +
+ "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol),
min(FloatCol), " +
+ "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test")
+
+ testMinWithOutTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(StringCol), " +
+ "MIN(BooleanCol), " +
+ "MIN(ByteCol), " +
+ "MIN(BinaryCol), " +
+ "MIN(ShortCol), " +
+ "MIN(IntegerCol), " +
+ "MIN(LongCol), " +
+ "MIN(FloatCol), " +
+ "MIN(DoubleCol), " +
+ "MIN(DecimalCol), " +
+ "MIN(DateCol)]"
+ checkKeywordsExistsInExplain(testMinWithOutTS,
expected_plan_fragment)
+ }
+
+ checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte,
"Parquet".getBytes,
+ 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D,
1.23457,
+ ("2004-06-19").date)))
+
+ val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol),
max(ByteCol), " +
+ "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol),
max(FloatCol), " +
+ "max(DoubleCol), max(DecimalCol), max(DateCol),
max(TimestampCol) FROM test")
+
+ // INT96 (Timestamp) sort order is undefined, parquet doesn't
return stats for this type
+ // so aggregates are not pushed down
+ testMaxWithTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(testMaxWithTS,
expected_plan_fragment)
+ }
+
+ checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte,
+ "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L,
0.25.toFloat, 0.85D,
+ 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts)))
+
+ val testMaxWithoutTS = sql("SELECT max(StringCol),
max(BooleanCol), max(ByteCol), " +
+ "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol),
max(FloatCol), " +
+ "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test")
+
+ testMaxWithoutTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MAX(StringCol), " +
+ "MAX(BooleanCol), " +
+ "MAX(ByteCol), " +
+ "MAX(BinaryCol), " +
+ "MAX(ShortCol), " +
+ "MAX(IntegerCol), " +
+ "MAX(LongCol), " +
+ "MAX(FloatCol), " +
+ "MAX(DoubleCol), " +
+ "MAX(DecimalCol), " +
+ "MAX(DateCol)]"
+ checkKeywordsExistsInExplain(testMaxWithoutTS,
expected_plan_fragment)
+ }
+
+ checkAnswer(testMaxWithoutTS, Seq(Row("test string", true,
16.toByte,
+ "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L,
0.25.toFloat, 0.85D,
+ 12345.678, ("2021-01-01").date)))
+
+ val testCountStar = sql("SELECT count(*) FROM test")
+
+ val testCount = sql("SELECT count(StringCol), count(BooleanCol)," +
+ " count(ByteCol), count(BinaryCol), count(ShortCol),
count(IntegerCol)," +
+ " count(LongCol), count(FloatCol), count(DoubleCol)," +
+ " count(DecimalCol), count(DateCol), count(TimestampCol) FROM
test")
+
+ testCount.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [" +
+ "COUNT(StringCol), " +
+ "COUNT(BooleanCol), " +
+ "COUNT(ByteCol), " +
+ "COUNT(BinaryCol), " +
+ "COUNT(ShortCol), " +
+ "COUNT(IntegerCol), " +
+ "COUNT(LongCol), " +
+ "COUNT(FloatCol), " +
+ "COUNT(DoubleCol), " +
+ "COUNT(DecimalCol), " +
+ "COUNT(DateCol), " +
+ "COUNT(TimestampCol)]"
+ checkKeywordsExistsInExplain(testCount, expected_plan_fragment)
+ }
+
+ checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3,
3)))
+ }
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - column name case sensitivity") {
+ val enableVectorizedReader = Seq("false", "true")
+ for (testVectorizedReader <- enableVectorizedReader) {
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+ vectorizedReaderEnabledKey -> testVectorizedReader) {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").parquet(dir.getCanonicalPath)
+ withTempView("tmp") {
+
spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+ val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MAX(id), MIN(id)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(9, 0)))
+ }
+ }
+ }
+ }
+ }
+}
+
+class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
+
+ override protected def sparkConf: SparkConf =
+ super
+ .sparkConf
+ .set(SQLConf.USE_V1_SOURCE_LIST, "parquet")
+}
+
+class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
+
+ // TODO: enable Parquet V2 write path after file source V2 writers are
workable.
Review comment:
??
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]