huaxingao commented on a change in pull request #33639:
URL: https://github.com/apache/spark/pull/33639#discussion_r725682334



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
##########
@@ -80,43 +84,121 @@ case class ParquetPartitionReaderFactory(
   private val datetimeRebaseModeInRead = 
parquetOptions.datetimeRebaseModeInRead
   private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead
 
+  private def getFooter(file: PartitionedFile): ParquetMetadata = {
+    val conf = broadcastedConf.value.value
+    val filePath = new Path(new URI(file.filePath))
+
+    if (aggregation.isEmpty) {
+      ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS)
+    } else {
+      // For aggregate push down, we will get max/min/count from footer 
statistics.
+      // We want to read the footer for the whole file instead of reading 
multiple
+      // footers for every split of the file. Basically if the start (the 
beginning of)
+      // the offset in PartitionedFile is 0, we will read the footer. 
Otherwise, it means
+      // that we have already read footer for that file, so we will skip 
reading again.
+      if (file.start != 0) return null
+      ParquetFooterReader.readFooter(conf, filePath, NO_FILTER)
+    }
+  }
+
+  private def getDatetimeRebaseMode(
+      footerFileMetaData: FileMetaData): LegacyBehaviorPolicy.Value = {
+    DataSourceUtils.datetimeRebaseMode(
+      footerFileMetaData.getKeyValueMetaData.get,
+      datetimeRebaseModeInRead)
+  }
+
   override def supportColumnarReads(partition: InputPartition): Boolean = {
     sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled &&
       resultSchema.length <= sqlConf.wholeStageMaxNumFields &&
       resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
   }
 
   override def buildReader(file: PartitionedFile): 
PartitionReader[InternalRow] = {
-    val reader = if (enableVectorizedReader) {
-      createVectorizedReader(file)
-    } else {
-      createRowBaseReader(file)
-    }
+    val fileReader = if (aggregation.isEmpty) {
+      val reader = if (enableVectorizedReader) {
+        createVectorizedReader(file)
+      } else {
+        createRowBaseReader(file)
+      }
+
+      new PartitionReader[InternalRow] {
+        override def next(): Boolean = reader.nextKeyValue()
 
-    val fileReader = new PartitionReader[InternalRow] {
-      override def next(): Boolean = reader.nextKeyValue()
+        override def get(): InternalRow = 
reader.getCurrentValue.asInstanceOf[InternalRow]
 
-      override def get(): InternalRow = 
reader.getCurrentValue.asInstanceOf[InternalRow]
+        override def close(): Unit = reader.close()
+      }
+    } else {
+      new PartitionReader[InternalRow] {
+        private var hasNext = true
+        private lazy val row: InternalRow = {
+          val footer = getFooter(file)
+          if (footer != null && footer.getBlocks.size > 0) {
+            ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, 
dataSchema,
+              partitionSchema, aggregation.get, readDataSchema,
+              getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
+          } else {
+            null
+          }

Review comment:
       Yes, other tasks still run. Ideally, other tasks no need to run, but I 
am not sure how to do that. 

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
##########
@@ -43,10 +44,14 @@ case class ParquetScan(
     readPartitionSchema: StructType,
     pushedFilters: Array[Filter],
     options: CaseInsensitiveStringMap,
+    pushedAggregate: Option[Aggregation] = None,
     partitionFilters: Seq[Expression] = Seq.empty,
     dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
   override def isSplitable(path: Path): Boolean = true
 
+  override def readSchema(): StructType =

Review comment:
       Done

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -79,8 +86,89 @@ case class ParquetScanBuilder(
   // All filters that can be converted to Parquet are pushed down.
   override def pushedFilters(): Array[Filter] = pushedParquetFilters
 
+  override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+    def getStructFieldForCol(col: NamedReference): StructField = {
+      schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+    }

Review comment:
       Right, changed. 

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -79,8 +86,89 @@ case class ParquetScanBuilder(
   // All filters that can be converted to Parquet are pushed down.
   override def pushedFilters(): Array[Filter] = pushedParquetFilters
 
+  override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+    def getStructFieldForCol(col: NamedReference): StructField = {
+      schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+    }
+
+    def isPartitionCol(col: NamedReference) = {
+      (readPartitionSchema.fields.map(PartitioningUtils

Review comment:
       Changed by using `partitionSchema` in `FileScanBuilder`

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
##########
@@ -127,4 +144,213 @@ object ParquetUtils {
     file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE ||
       file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, 
we don't need to
+   * createRowBaseReader to read data from Parquet and aggregate at Spark 
layer. Instead we want
+   * to get the partial aggregates (Max/Min/Count) result using the statistics 
information
+   * from Parquet footer file, and then construct an InternalRow from these 
aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  private[sql] def createAggInternalRowFromFooter(
+      footer: ParquetMetadata,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+      isCaseSensitive: Boolean): InternalRow = {
+    val (primitiveTypes, values) = getPushedDownAggResult(
+      footer, filePath, dataSchema, partitionSchema, aggregation, 
isCaseSensitive)
+
+    val builder = Types.buildMessage
+    primitiveTypes.foreach(t => builder.addField(t))
+    val parquetSchema = builder.named("root")
+
+    val schemaConverter = new ParquetToSparkSchemaConverter
+    val converter = new ParquetRowConverter(schemaConverter, parquetSchema, 
aggSchema,
+      None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater)
+    val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName)
+    primitiveTypeNames.zipWithIndex.foreach {
+      case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) =>
+        val v = values(i).asInstanceOf[Boolean]
+        converter.getConverter(i).asPrimitiveConverter.addBoolean(v)
+      case (PrimitiveType.PrimitiveTypeName.INT32, i) =>
+        val v = values(i).asInstanceOf[Integer]
+        converter.getConverter(i).asPrimitiveConverter.addInt(v)
+      case (PrimitiveType.PrimitiveTypeName.INT64, i) =>
+        val v = values(i).asInstanceOf[Long]
+        converter.getConverter(i).asPrimitiveConverter.addLong(v)
+      case (PrimitiveType.PrimitiveTypeName.FLOAT, i) =>
+        val v = values(i).asInstanceOf[Float]
+        converter.getConverter(i).asPrimitiveConverter.addFloat(v)
+      case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) =>
+        val v = values(i).asInstanceOf[Double]
+        converter.getConverter(i).asPrimitiveConverter.addDouble(v)
+      case (PrimitiveType.PrimitiveTypeName.BINARY, i) =>
+        val v = values(i).asInstanceOf[Binary]
+        converter.getConverter(i).asPrimitiveConverter.addBinary(v)
+      case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) =>
+        val v = values(i).asInstanceOf[Binary]
+        converter.getConverter(i).asPrimitiveConverter.addBinary(v)
+      case (_, i) =>
+        throw new SparkException("Unexpected parquet type name: " + 
primitiveTypeNames(i))
+    }
+    converter.currentRecord
+  }
+
+  /**
+   * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the 
case of
+   * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need 
buildColumnarReader
+   * to read data from Parquet and aggregate at Spark layer. Instead we want
+   * to get the aggregates (Max/Min/Count) result using the statistics 
information
+   * from Parquet footer file, and then construct a ColumnarBatch from these 
aggregate results.
+   *
+   * @return Aggregate results in the format of ColumnarBatch
+   */
+  private[sql] def createAggColumnarBatchFromFooter(
+      footer: ParquetMetadata,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      offHeap: Boolean,
+      datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+      isCaseSensitive: Boolean): ColumnarBatch = {
+    val row = createAggInternalRowFromFooter(
+      footer,
+      filePath,
+      dataSchema,
+      partitionSchema,
+      aggregation,
+      aggSchema,
+      datetimeRebaseMode,
+      isCaseSensitive)
+    val converter = new RowToColumnConverter(aggSchema)
+    val columnVectors = if (offHeap) {
+      OffHeapColumnVector.allocateColumns(1, aggSchema)
+    } else {
+      OnHeapColumnVector.allocateColumns(1, aggSchema)
+    }
+    converter.convert(row, columnVectors.toArray)
+    new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
+  }
+
+  /**
+   * Calculate the pushed down aggregates (Max/Min/Count) result using the 
statistics
+   * information from Parquet footer file.
+   *
+   * @return A tuple of `Array[PrimitiveType]` and Array[Any].
+   *         The first element is the Parquet PrimitiveType of the aggregate 
column,
+   *         and the second element is the aggregated value.
+   */
+  private[sql] def getPushedDownAggResult(
+      footer: ParquetMetadata,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      isCaseSensitive: Boolean)
+  : (Array[PrimitiveType], Array[Any]) = {
+    val footerFileMetaData = footer.getFileMetaData
+    val fields = footerFileMetaData.getSchema.getFields
+    val blocks = footer.getBlocks
+    val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType]
+    val valuesBuilder = mutable.ArrayBuilder.make[Any]
+
+    aggregation.aggregateExpressions.foreach { agg =>
+      var value: Any = None
+      var rowCount = 0L
+      var isCount = false
+      var index = 0
+      var schemaName = ""
+      blocks.forEach { block =>
+        val blockMetaData = block.getColumns
+        agg match {
+          case max: Max =>
+            val colName = max.column.fieldNames.head
+            index = dataSchema.fieldNames.toList.indexOf(colName)
+            schemaName = "max(" + colName + ")"
+            val currentMax = getCurrentBlockMaxOrMin(filePath, blockMetaData, 
index, true)
+            if (value == None || 
currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) {
+              value = currentMax
+            }
+          case min: Min =>
+            val colName = min.column.fieldNames.head
+            index = dataSchema.fieldNames.toList.indexOf(colName)
+            schemaName = "min(" + colName + ")"
+            val currentMin = getCurrentBlockMaxOrMin(filePath, blockMetaData, 
index, false)
+            if (value == None || 
currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) {
+              value = currentMin
+            }
+          case count: Count =>
+            schemaName = "count(" + count.column.fieldNames.head + ")"
+            rowCount += block.getRowCount
+            var isPartitionCol = false
+            if (partitionSchema.fields.map(PartitioningUtils.getColName(_, 
isCaseSensitive))

Review comment:
       for count(column), I actually get the `getRowCount` from 
`BlockMetaData`, and then ` getNumNulls` from footer, and subtract the num of 
Null from the row count. If the column is partition column, I don't think it 
can be null, so I simply get `getRowCount`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to