sunchao commented on a change in pull request #34298:
URL: https://github.com/apache/spark/pull/34298#discussion_r737026407



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,106 @@ object OrcUtils extends Logging {
       case _ => false
     }
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we 
don't need to read data
+   * from ORC and aggregate at Spark layer. Instead we want to get the partial 
aggregates
+   * (Max/Min/Count) result using the statistics information from ORC file 
footer, and then
+   * construct an InternalRow from these aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  def createAggInternalRowFromFooter(
+      reader: Reader,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      isCaseSensitive: Boolean): InternalRow = {
+    require(aggregation.groupByColumns.length == 0,
+      s"aggregate $aggregation with group-by column shouldn't be pushed down")
+    val columnsStatistics = OrcFooterReader.readStatistics(reader)
+
+    // Get column statistics with column name.
+    def getColumnStatistics(columnName: String): ColumnStatistics = {
+      val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+      columnsStatistics.get(columnIndex).getStatistics
+    }
+
+    // Get Min/Max statistics and store as ORC `WritableComparable` format.
+    def getMinMaxFromColumnStatistics(
+        statistics: ColumnStatistics,
+        dataType: DataType,
+        isMax: Boolean): WritableComparable[_] = {
+      statistics match {
+        case s: BooleanColumnStatistics =>
+          val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+          new BooleanWritable(value)
+        case s: IntegerColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case ByteType => new ByteWritable(value.toByte)
+            case ShortType => new ShortWritable(value.toShort)
+            case IntegerType => new IntWritable(value.toInt)
+            case LongType => new LongWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType " +
+              "for IntegerColumnStatistics")
+          }
+        case s: DoubleColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case FloatType => new FloatWritable(value.toFloat)
+            case DoubleType => new DoubleWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+                "for DoubleColumnStatistics")
+          }
+        case s: DateColumnStatistics =>
+          new DateWritable(
+            if (isMax) s.getMaximumDayOfEpoch.toInt else 
s.getMinimumDayOfEpoch.toInt)
+        case _ => throw new IllegalArgumentException(
+          s"getMaxFromColumnStatistics should not take 
${statistics.getClass.getName}: " +
+            s"$statistics as the ORC column statistics")
+      }
+    }
+
+    val aggORCValues: Seq[WritableComparable[_]] =
+      aggregation.aggregateExpressions.zipWithIndex.map {
+        case (max: Max, index) =>
+          val columnName = max.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema(index).dataType
+          getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
+        case (min: Min, index) =>
+          val columnName = min.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema.apply(index).dataType
+          getMinMaxFromColumnStatistics(statistics, dataType, isMax = false)
+        case (count: Count, _) =>
+          val columnName = count.column.fieldNames.head
+          val isPartitionColumn = partitionSchema.fields
+            .map(PartitioningUtils.getColName(_, isCaseSensitive))
+            .contains(columnName)
+          // NOTE: Count(columnName) doesn't include null values.
+          // org.apache.orc.ColumnStatistics.getNumberOfValues() returns 
number of non-null values
+          // for ColumnStatistics of individual column. In addition to this, 
ORC also stores number
+          // of all values (null and non-null) separately.
+          val nonNullRowsCount = if (isPartitionColumn) {
+            columnsStatistics.getStatistics.getNumberOfValues

Review comment:
       I see, thanks!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to