huaxingao commented on a change in pull request #33639:
URL: https://github.com/apache/spark/pull/33639#discussion_r700573972
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
##########
@@ -127,4 +144,207 @@ object ParquetUtils {
file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE ||
file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
}
+
+ /**
+ * When the partial aggregates (Max/Min/Count) are pushed down to Parquet,
we don't need to
+ * createRowBaseReader to read data from Parquet and aggregate at Spark
layer. Instead we want
+ * to get the partial aggregates (Max/Min/Count) result using the statistics
information
+ * from Parquet footer file, and then construct an InternalRow from these
aggregate results.
+ *
+ * @return Aggregate results in the format of InternalRow
+ */
+ private[sql] def createAggInternalRowFromFooter(
+ footer: ParquetMetadata,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+ isCaseSensitive: Boolean): InternalRow = {
+ val (primitiveType, values) =
+ getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation,
isCaseSensitive)
+
+ val builder = Types.buildMessage()
+ primitiveType.foreach(t => builder.addField(t))
+ val parquetSchema = builder.named("root")
+
+ val schemaConverter = new ParquetToSparkSchemaConverter
+ val converter = new ParquetRowConverter(schemaConverter, parquetSchema,
aggSchema,
+ None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater)
+ val primitiveTypeName = primitiveType.map(_.getPrimitiveTypeName)
+ primitiveTypeName.zipWithIndex.foreach {
+ case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) =>
+ val v = values(i).asInstanceOf[Boolean]
+ converter.getConverter(i).asPrimitiveConverter().addBoolean(v)
+ case (PrimitiveType.PrimitiveTypeName.INT32, i) =>
+ val v = values(i).asInstanceOf[Integer]
+ converter.getConverter(i).asPrimitiveConverter().addInt(v)
+ case (PrimitiveType.PrimitiveTypeName.INT64, i) =>
+ val v = values(i).asInstanceOf[Long]
+ converter.getConverter(i).asPrimitiveConverter().addLong(v)
+ case (PrimitiveType.PrimitiveTypeName.FLOAT, i) =>
+ val v = values(i).asInstanceOf[Float]
+ converter.getConverter(i).asPrimitiveConverter().addFloat(v)
+ case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) =>
+ val v = values(i).asInstanceOf[Double]
+ converter.getConverter(i).asPrimitiveConverter().addDouble(v)
+ case (PrimitiveType.PrimitiveTypeName.BINARY, i) =>
+ val v = values(i).asInstanceOf[Binary]
+ converter.getConverter(i).asPrimitiveConverter().addBinary(v)
+ case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) =>
+ val v = values(i).asInstanceOf[Binary]
+ converter.getConverter(i).asPrimitiveConverter().addBinary(v)
+ case (_, i) =>
+ throw new SparkException("Unexpected parquet type name: " +
primitiveTypeName(i))
+ }
+ converter.currentRecord
+ }
+
+ /**
+ * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the
case of
+ * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need
buildColumnarReader
+ * to read data from Parquet and aggregate at Spark layer. Instead we want
+ * to get the aggregates (Max/Min/Count) result using the statistics
information
+ * from Parquet footer file, and then construct a ColumnarBatch from these
aggregate results.
+ *
+ * @return Aggregate results in the format of ColumnarBatch
+ */
+ private[sql] def createAggColumnarBatchFromFooter(
+ footer: ParquetMetadata,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ columnBatchSize: Int,
+ offHeap: Boolean,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+ isCaseSensitive: Boolean): ColumnarBatch = {
+ val row = createAggInternalRowFromFooter(
+ footer,
+ dataSchema,
+ partitionSchema,
+ aggregation,
+ aggSchema,
+ datetimeRebaseMode,
+ isCaseSensitive)
+ val converter = new RowToColumnConverter(aggSchema)
+ val columnVectors = if (offHeap) {
+ OffHeapColumnVector.allocateColumns(columnBatchSize, aggSchema)
+ } else {
+ OnHeapColumnVector.allocateColumns(columnBatchSize, aggSchema)
Review comment:
Changed
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
##########
@@ -87,36 +110,73 @@ case class ParquetPartitionReaderFactory(
}
override def buildReader(file: PartitionedFile):
PartitionReader[InternalRow] = {
- val reader = if (enableVectorizedReader) {
- createVectorizedReader(file)
+ val fileReader = if (aggregation.isEmpty) {
+ val reader = if (enableVectorizedReader) {
+ createVectorizedReader(file)
+ } else {
+ createRowBaseReader(file)
+ }
+
+ new PartitionReader[InternalRow] {
+ override def next(): Boolean = reader.nextKeyValue()
+
+ override def get(): InternalRow =
reader.getCurrentValue.asInstanceOf[InternalRow]
+
+ override def close(): Unit = reader.close()
+ }
} else {
- createRowBaseReader(file)
- }
+ new PartitionReader[InternalRow] {
+ var hasNext = true
Review comment:
added
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
##########
@@ -87,36 +110,73 @@ case class ParquetPartitionReaderFactory(
}
override def buildReader(file: PartitionedFile):
PartitionReader[InternalRow] = {
- val reader = if (enableVectorizedReader) {
- createVectorizedReader(file)
+ val fileReader = if (aggregation.isEmpty) {
+ val reader = if (enableVectorizedReader) {
+ createVectorizedReader(file)
+ } else {
+ createRowBaseReader(file)
+ }
+
+ new PartitionReader[InternalRow] {
+ override def next(): Boolean = reader.nextKeyValue()
+
+ override def get(): InternalRow =
reader.getCurrentValue.asInstanceOf[InternalRow]
+
+ override def close(): Unit = reader.close()
+ }
} else {
- createRowBaseReader(file)
- }
+ new PartitionReader[InternalRow] {
+ var hasNext = true
- val fileReader = new PartitionReader[InternalRow] {
- override def next(): Boolean = reader.nextKeyValue()
+ override def next(): Boolean = hasNext
- override def get(): InternalRow =
reader.getCurrentValue.asInstanceOf[InternalRow]
+ override def get(): InternalRow = {
+ hasNext = false
+ val footer = getFooter(file)
+ ParquetUtils.createAggInternalRowFromFooter(footer, dataSchema,
partitionSchema,
+ aggregation.get, readDataSchema,
getDatetimeRebaseMode(footer.getFileMetaData),
+ isCaseSensitive)
+ }
- override def close(): Unit = reader.close()
+ override def close(): Unit = {}
+ }
}
new PartitionReaderWithPartitionValues(fileReader, readDataSchema,
partitionSchema, file.partitionValues)
}
override def buildColumnarReader(file: PartitionedFile):
PartitionReader[ColumnarBatch] = {
- val vectorizedReader = createVectorizedReader(file)
- vectorizedReader.enableReturningBatches()
+ val fileReader = if (aggregation.isEmpty) {
+ val vectorizedReader = createVectorizedReader(file)
+ vectorizedReader.enableReturningBatches()
- new PartitionReader[ColumnarBatch] {
- override def next(): Boolean = vectorizedReader.nextKeyValue()
+ new PartitionReader[ColumnarBatch] {
+ override def next(): Boolean = vectorizedReader.nextKeyValue()
- override def get(): ColumnarBatch =
- vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
+ override def get(): ColumnarBatch =
+ vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
- override def close(): Unit = vectorizedReader.close()
+ override def close(): Unit = vectorizedReader.close()
+ }
+ } else {
+ new PartitionReader[ColumnarBatch] {
+ var hasNext = true
Review comment:
added
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -80,8 +88,84 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: FieldReference): StructField = {
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ }
+
+ def isPartitionCol(col: FieldReference) = {
+ (readPartitionSchema().fields.map(PartitioningUtils
+ .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis))
+ .toSet.contains(col.fieldNames.head))
+ }
+
+ def checkMinMax(agg: AggregateFunc): Boolean = {
+ val (column, aggType) = agg match {
+ case max: Max => (max.column(), "max")
+ case min: Min => (min.column(), "min")
+ case _ => throw new IllegalArgumentException(s"Unexpected type of
AggregateFunc")
Review comment:
removed
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -80,8 +88,84 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: FieldReference): StructField = {
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ }
+
+ def isPartitionCol(col: FieldReference) = {
+ (readPartitionSchema().fields.map(PartitioningUtils
+ .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis))
+ .toSet.contains(col.fieldNames.head))
+ }
+
+ def checkMinMax(agg: AggregateFunc): Boolean = {
+ val (column, aggType) = agg match {
+ case max: Max => (max.column(), "max")
+ case min: Min => (min.column(), "min")
+ case _ => throw new IllegalArgumentException(s"Unexpected type of
AggregateFunc")
+ }
+
+ if (column.fieldNames.length != 1 || isPartitionCol(column)) {
+ return false
+ }
Review comment:
added comment
removed `column.fieldNames.length != 1` because nested column is already
blocked
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -80,8 +88,84 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: FieldReference): StructField = {
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ }
+
+ def isPartitionCol(col: FieldReference) = {
+ (readPartitionSchema().fields.map(PartitioningUtils
+ .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis))
+ .toSet.contains(col.fieldNames.head))
+ }
+
+ def checkMinMax(agg: AggregateFunc): Boolean = {
+ val (column, aggType) = agg match {
+ case max: Max => (max.column(), "max")
+ case min: Min => (min.column(), "min")
+ case _ => throw new IllegalArgumentException(s"Unexpected type of
AggregateFunc")
+ }
+
+ if (column.fieldNames.length != 1 || isPartitionCol(column)) {
+ return false
+ }
+ val structField = getStructFieldForCol(column)
+
+ structField.dataType match {
+ // not push down nested type
Review comment:
Fixed
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -80,8 +88,84 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: FieldReference): StructField = {
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ }
+
+ def isPartitionCol(col: FieldReference) = {
+ (readPartitionSchema().fields.map(PartitioningUtils
+ .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis))
+ .toSet.contains(col.fieldNames.head))
+ }
+
+ def checkMinMax(agg: AggregateFunc): Boolean = {
+ val (column, aggType) = agg match {
+ case max: Max => (max.column(), "max")
+ case min: Min => (min.column(), "min")
+ case _ => throw new IllegalArgumentException(s"Unexpected type of
AggregateFunc")
+ }
+
+ if (column.fieldNames.length != 1 || isPartitionCol(column)) {
+ return false
+ }
+ val structField = getStructFieldForCol(column)
+
+ structField.dataType match {
+ // not push down nested type
+ // not push down Timestamp because INT96 sort order is undefined,
+ // Parquet doesn't return statistics for INT96
+ case StructType(_) | ArrayType(_, _) | MapType(_, _, _) |
TimestampType =>
+ false
+ case _ =>
+ finalSchema = finalSchema.add(structField.copy(s"$aggType(" +
structField.name + ")"))
+ true
+ }
+ }
+
+ if (!sparkSession.sessionState.conf.parquetAggregatePushDown ||
+ // Parquet footer has max/min/count for columns
+ // e.g. SELECT COUNT(col1) FROM t
+ // but footer doesn't have max/min/count for a column if max/min/count
+ // are combined with filter or group by
+ // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
+ // SELECT COUNT(col1) FROM t GROUP BY col2
+ // Todo: 1. add support if groupby column is partition col
+ // 2. add support if filter col is partition col
Review comment:
Done
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
##########
@@ -858,6 +858,13 @@ object SQLConf {
.checkValue(threshold => threshold >= 0, "The threshold must not be
negative.")
.createWithDefault(10)
+ val PARQUET_AGGREGATE_PUSHDOWN_ENABLED =
buildConf("spark.sql.parquet.aggregatePushdown")
+ .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" +
+ " down to Parquet for optimization. ")
Review comment:
fixed
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -80,8 +88,84 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: FieldReference): StructField = {
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ }
+
+ def isPartitionCol(col: FieldReference) = {
+ (readPartitionSchema().fields.map(PartitioningUtils
+ .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis))
+ .toSet.contains(col.fieldNames.head))
+ }
+
+ def checkMinMax(agg: AggregateFunc): Boolean = {
Review comment:
changed to `processMinOrMax`
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala
##########
@@ -0,0 +1,521 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
+import org.apache.spark.sql.functions.min
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+
+/**
+ * A test suite that tests Max/Min/Count push down.
+ */
+abstract class ParquetAggregatePushDownSuite
+ extends QueryTest
+ with ParquetTest
+ with SharedSparkSession
+ with ExplainSuiteHelper {
+ import testImplicits._
+
+ test("aggregate push down - nested column: Max(top level column) not push
down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val max = sql("SELECT Max(_1) FROM t")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - nested column: Count(top level column) push
down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val count = sql("SELECT Count(_1) FROM t")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [COUNT(_1)]"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+
+ test("aggregate push down - nested column: Max(nested column) not push
down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val max = sql("SELECT Max(_1._2[0]) FROM t")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - nested column: Count(nested column) not push
down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val count = sql("SELECT Count(_1._2[0]) FROM t")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+
+ test("aggregate push down - Max(partition Col): not push dow") {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").parquet(dir.getCanonicalPath)
+ withTempView("tmp") {
+
spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val max = sql("SELECT Max(p) FROM tmp")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ checkAnswer(max, Seq(Row(2)))
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - Count(partition Col): push down") {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").parquet(dir.getCanonicalPath)
+ withTempView("tmp") {
+
spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+ val enableVectorizedReader = Seq("false", "true")
+ for (testVectorizedReader <- enableVectorizedReader) {
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+ vectorizedReaderEnabledKey -> testVectorizedReader) {
+ val count = sql("SELECT COUNT(p) FROM tmp")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [COUNT(p)]"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - Filter alias over aggregate") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res
> 1")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_1), MAX(_1)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(7)))
+ }
+ }
+ }
+
+ test("aggregate push down - alias over aggregate") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as
minPlus2 FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_1)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-1, 0)))
+ }
+ }
+ }
+
+ test("aggregate push down - aggregate over alias not push down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val df = spark.table("t")
+ val query = df.select($"_1".as("col1")).agg(min($"col1"))
+ query.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []" // aggregate alias not pushed down
+ checkKeywordsExistsInExplain(query, expected_plan_fragment)
+ }
+ checkAnswer(query, Seq(Row(-2)))
+ }
+ }
+ }
+
+ test("aggregate push down - query with group by not push down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ // aggregate not pushed down if there is group by
+ val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-2), Row(0), Row(2), Row(3)))
+ }
+ }
+ }
+
+ test("aggregate push down - query with filter not push down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ // aggregate not pushed down if there is filter
+ val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(2)))
+ }
+ }
+ }
+
+ test("aggregate push down - push down only if all the aggregates can be
pushed down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ // not push down since sum can't be pushed down
+ val selectAgg = sql("SELECT min(_1), sum(_3) FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-2, 41)))
+ }
+ }
+ }
+
+ test("aggregate push down - MIN/MAX/COUNT") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1),
max(_1), max(_1)," +
+ " count(*), count(_1), count(_2), count(_3) FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_3), " +
+ "MAX(_3), " +
+ "MIN(_1), " +
+ "MAX(_1), " +
+ "COUNT(*), " +
+ "COUNT(_1), " +
+ "COUNT(_2), " +
+ "COUNT(_3)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+
+ checkAnswer(selectAgg, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6)))
+ }
+ }
+ }
+
+ test("aggregate push down - different data types") {
+ implicit class StringToDate(s: String) {
+ def date: Date = Date.valueOf(s)
+ }
+
+ implicit class StringToTs(s: String) {
+ def ts: Timestamp = Timestamp.valueOf(s)
+ }
+
+ val rows =
+ Seq(
+ Row(
+ "a string",
+ true,
+ 10.toByte,
+ "Spark SQL".getBytes,
+ 12.toShort,
+ 3,
+ Long.MaxValue,
+ 0.15.toFloat,
+ 0.75D,
+ Decimal("12.345678"),
+ ("2021-01-01").date,
+ ("2015-01-01 23:50:59.123").ts),
+ Row(
+ "test string",
+ false,
+ 1.toByte,
+ "Parquet".getBytes,
+ 2.toShort,
+ null,
+ Long.MinValue,
+ 0.25.toFloat,
+ 0.85D,
+ Decimal("1.2345678"),
+ ("2015-01-01").date,
+ ("2021-01-01 23:50:59.123").ts),
+ Row(
+ null,
+ true,
+ 10000.toByte,
+ "Spark ML".getBytes,
+ 222.toShort,
+ 113,
+ 11111111L,
+ 0.25.toFloat,
+ 0.75D,
+ Decimal("12345.678"),
+ ("2004-06-19").date,
+ ("1999-08-26 10:43:59.123").ts)
+ )
+
+ val schema = StructType(List(StructField("StringCol", StringType, true),
+ StructField("BooleanCol", BooleanType, false),
+ StructField("ByteCol", ByteType, false),
+ StructField("BinaryCol", BinaryType, false),
+ StructField("ShortCol", ShortType, false),
+ StructField("IntegerCol", IntegerType, true),
+ StructField("LongCol", LongType, false),
+ StructField("FloatCol", FloatType, false),
+ StructField("DoubleCol", DoubleType, false),
+ StructField("DecimalCol", DecimalType(25, 5), true),
+ StructField("DateCol", DateType, false),
+ StructField("TimestampCol", TimestampType, false)).toArray)
+
+ val rdd = sparkContext.parallelize(rows)
+ withTempPath { file =>
+ spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath)
+ withTempView("test") {
+
spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test")
+ val enableVectorizedReader = Seq("false", "true")
+ for (testVectorizedReader <- enableVectorizedReader) {
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+ vectorizedReaderEnabledKey -> testVectorizedReader) {
+
+ val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol),
min(ByteCol), " +
+ "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol),
min(FloatCol), " +
+ "min(DoubleCol), min(DecimalCol), min(DateCol),
min(TimestampCol) FROM test")
+
+ // INT96 (Timestamp) sort order is undefined, parquet doesn't
return stats for this type
+ // so aggregates are not pushed down
+ testMinWithTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(testMinWithTS,
expected_plan_fragment)
+ }
+
+ checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte,
"Parquet".getBytes,
+ 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D,
1.23457,
+ ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts)))
+
+ val testMinWithOutTS = sql("SELECT min(StringCol),
min(BooleanCol), min(ByteCol), " +
+ "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol),
min(FloatCol), " +
+ "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test")
+
+ testMinWithOutTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(StringCol), " +
+ "MIN(BooleanCol), " +
+ "MIN(ByteCol), " +
+ "MIN(BinaryCol), " +
+ "MIN(ShortCol), " +
+ "MIN(IntegerCol), " +
+ "MIN(LongCol), " +
+ "MIN(FloatCol), " +
+ "MIN(DoubleCol), " +
+ "MIN(DecimalCol), " +
+ "MIN(DateCol)]"
+ checkKeywordsExistsInExplain(testMinWithOutTS,
expected_plan_fragment)
+ }
+
+ checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte,
"Parquet".getBytes,
+ 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D,
1.23457,
+ ("2004-06-19").date)))
+
+ val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol),
max(ByteCol), " +
+ "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol),
max(FloatCol), " +
+ "max(DoubleCol), max(DecimalCol), max(DateCol),
max(TimestampCol) FROM test")
+
+ // INT96 (Timestamp) sort order is undefined, parquet doesn't
return stats for this type
+ // so aggregates are not pushed down
+ testMaxWithTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(testMaxWithTS,
expected_plan_fragment)
+ }
+
+ checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte,
+ "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L,
0.25.toFloat, 0.85D,
+ 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts)))
+
+ val testMaxWithoutTS = sql("SELECT max(StringCol),
max(BooleanCol), max(ByteCol), " +
+ "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol),
max(FloatCol), " +
+ "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test")
+
+ testMaxWithoutTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MAX(StringCol), " +
+ "MAX(BooleanCol), " +
+ "MAX(ByteCol), " +
+ "MAX(BinaryCol), " +
+ "MAX(ShortCol), " +
+ "MAX(IntegerCol), " +
+ "MAX(LongCol), " +
+ "MAX(FloatCol), " +
+ "MAX(DoubleCol), " +
+ "MAX(DecimalCol), " +
+ "MAX(DateCol)]"
+ checkKeywordsExistsInExplain(testMaxWithoutTS,
expected_plan_fragment)
+ }
+
+ checkAnswer(testMaxWithoutTS, Seq(Row("test string", true,
16.toByte,
+ "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L,
0.25.toFloat, 0.85D,
+ 12345.678, ("2021-01-01").date)))
+
+ val testCountStar = sql("SELECT count(*) FROM test")
+
+ val testCount = sql("SELECT count(StringCol), count(BooleanCol)," +
+ " count(ByteCol), count(BinaryCol), count(ShortCol),
count(IntegerCol)," +
+ " count(LongCol), count(FloatCol), count(DoubleCol)," +
+ " count(DecimalCol), count(DateCol), count(TimestampCol) FROM
test")
+
+ testCount.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [" +
+ "COUNT(StringCol), " +
+ "COUNT(BooleanCol), " +
+ "COUNT(ByteCol), " +
+ "COUNT(BinaryCol), " +
+ "COUNT(ShortCol), " +
+ "COUNT(IntegerCol), " +
+ "COUNT(LongCol), " +
+ "COUNT(FloatCol), " +
+ "COUNT(DoubleCol), " +
+ "COUNT(DecimalCol), " +
+ "COUNT(DateCol), " +
+ "COUNT(TimestampCol)]"
+ checkKeywordsExistsInExplain(testCount, expected_plan_fragment)
+ }
+
+ checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3,
3)))
+ }
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - column name case sensitivity") {
+ val enableVectorizedReader = Seq("false", "true")
+ for (testVectorizedReader <- enableVectorizedReader) {
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+ vectorizedReaderEnabledKey -> testVectorizedReader) {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").parquet(dir.getCanonicalPath)
+ withTempView("tmp") {
+
spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+ val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MAX(id), MIN(id)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(9, 0)))
+ }
+ }
+ }
+ }
+ }
+}
+
+class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
+
+ override protected def sparkConf: SparkConf =
+ super
+ .sparkConf
+ .set(SQLConf.USE_V1_SOURCE_LIST, "parquet")
+}
+
+class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
+
+ // TODO: enable Parquet V2 write path after file source V2 writers are
workable.
Review comment:
removed
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]