sunchao commented on a change in pull request #34298:
URL: https://github.com/apache/spark/pull/34298#discussion_r736000999
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,106 @@ object OrcUtils extends Logging {
case _ => false
}
}
+
+ /**
+ * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we
don't need to read data
+ * from ORC and aggregate at Spark layer. Instead we want to get the partial
aggregates
+ * (Max/Min/Count) result using the statistics information from ORC file
footer, and then
+ * construct an InternalRow from these aggregate results.
+ *
+ * @return Aggregate results in the format of InternalRow
+ */
+ def createAggInternalRowFromFooter(
+ reader: Reader,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ isCaseSensitive: Boolean): InternalRow = {
+ require(aggregation.groupByColumns.length == 0,
+ s"aggregate $aggregation with group-by column shouldn't be pushed down")
+ val columnsStatistics = OrcFooterReader.readStatistics(reader)
+
+ // Get column statistics with column name.
+ def getColumnStatistics(columnName: String): ColumnStatistics = {
+ val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+ columnsStatistics.get(columnIndex).getStatistics
+ }
+
+ // Get Min/Max statistics and store as ORC `WritableComparable` format.
+ def getMinMaxFromColumnStatistics(
+ statistics: ColumnStatistics,
+ dataType: DataType,
+ isMax: Boolean): WritableComparable[_] = {
+ statistics match {
+ case s: BooleanColumnStatistics =>
+ val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+ new BooleanWritable(value)
+ case s: IntegerColumnStatistics =>
+ val value = if (isMax) s.getMaximum else s.getMinimum
+ dataType match {
+ case ByteType => new ByteWritable(value.toByte)
+ case ShortType => new ShortWritable(value.toShort)
+ case IntegerType => new IntWritable(value.toInt)
+ case LongType => new LongWritable(value)
+ case _ => throw new IllegalArgumentException(
+ s"getMaxFromColumnStatistics should not take type $dataType " +
+ "for IntegerColumnStatistics")
+ }
+ case s: DoubleColumnStatistics =>
+ val value = if (isMax) s.getMaximum else s.getMinimum
+ dataType match {
+ case FloatType => new FloatWritable(value.toFloat)
+ case DoubleType => new DoubleWritable(value)
+ case _ => throw new IllegalArgumentException(
+ s"getMaxFromColumnStatistics should not take type $dataType" +
Review comment:
nit: add a space at the end
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
##########
@@ -175,24 +175,26 @@ case class ParquetPartitionReaderFactory(
} else {
new PartitionReader[ColumnarBatch] {
private var hasNext = true
- private val row: ColumnarBatch = {
+ private val batch: ColumnarBatch = {
val footer = getFooter(file)
if (footer != null && footer.getBlocks.size > 0) {
- ParquetUtils.createAggColumnarBatchFromFooter(footer,
file.filePath, dataSchema,
- partitionSchema, aggregation.get, readDataSchema,
enableOffHeapColumnVector,
+ val row = ParquetUtils.createAggInternalRowFromFooter(footer,
file.filePath,
+ dataSchema, partitionSchema, aggregation.get, readDataSchema,
getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
+ AggregatePushDownUtils.convertAggregatesRowToBatch(
+ row, readDataSchema, enableOffHeapColumnVector)
Review comment:
in case we are using off-heap memory, we might want to check
`taskContext.isDefined` since otherwise the task completion listener may not be
triggered to free up the memory?
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
##########
@@ -960,6 +960,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val ORC_AGGREGATE_PUSHDOWN_ENABLED =
buildConf("spark.sql.orc.aggregatePushdown")
+ .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" +
+ " down to ORC for optimization. MAX/MIN for complex types can't be
pushed down")
Review comment:
nit: does it mean `COUNT` for complex types can be pushed down? maybe
make it more explicit.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,106 @@ object OrcUtils extends Logging {
case _ => false
}
}
+
+ /**
+ * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we
don't need to read data
+ * from ORC and aggregate at Spark layer. Instead we want to get the partial
aggregates
+ * (Max/Min/Count) result using the statistics information from ORC file
footer, and then
+ * construct an InternalRow from these aggregate results.
+ *
+ * @return Aggregate results in the format of InternalRow
+ */
+ def createAggInternalRowFromFooter(
+ reader: Reader,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ isCaseSensitive: Boolean): InternalRow = {
+ require(aggregation.groupByColumns.length == 0,
+ s"aggregate $aggregation with group-by column shouldn't be pushed down")
+ val columnsStatistics = OrcFooterReader.readStatistics(reader)
+
+ // Get column statistics with column name.
+ def getColumnStatistics(columnName: String): ColumnStatistics = {
+ val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+ columnsStatistics.get(columnIndex).getStatistics
+ }
+
+ // Get Min/Max statistics and store as ORC `WritableComparable` format.
+ def getMinMaxFromColumnStatistics(
+ statistics: ColumnStatistics,
+ dataType: DataType,
+ isMax: Boolean): WritableComparable[_] = {
+ statistics match {
+ case s: BooleanColumnStatistics =>
+ val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+ new BooleanWritable(value)
+ case s: IntegerColumnStatistics =>
+ val value = if (isMax) s.getMaximum else s.getMinimum
Review comment:
what if the column has 0 values, will min/max still be defined?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,106 @@ object OrcUtils extends Logging {
case _ => false
}
}
+
+ /**
+ * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we
don't need to read data
+ * from ORC and aggregate at Spark layer. Instead we want to get the partial
aggregates
+ * (Max/Min/Count) result using the statistics information from ORC file
footer, and then
+ * construct an InternalRow from these aggregate results.
+ *
+ * @return Aggregate results in the format of InternalRow
+ */
+ def createAggInternalRowFromFooter(
+ reader: Reader,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ isCaseSensitive: Boolean): InternalRow = {
+ require(aggregation.groupByColumns.length == 0,
+ s"aggregate $aggregation with group-by column shouldn't be pushed down")
+ val columnsStatistics = OrcFooterReader.readStatistics(reader)
Review comment:
hmm does a ORC file always have stats?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,106 @@ object OrcUtils extends Logging {
case _ => false
}
}
+
+ /**
+ * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we
don't need to read data
+ * from ORC and aggregate at Spark layer. Instead we want to get the partial
aggregates
+ * (Max/Min/Count) result using the statistics information from ORC file
footer, and then
+ * construct an InternalRow from these aggregate results.
+ *
+ * @return Aggregate results in the format of InternalRow
+ */
+ def createAggInternalRowFromFooter(
+ reader: Reader,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ isCaseSensitive: Boolean): InternalRow = {
+ require(aggregation.groupByColumns.length == 0,
+ s"aggregate $aggregation with group-by column shouldn't be pushed down")
+ val columnsStatistics = OrcFooterReader.readStatistics(reader)
+
+ // Get column statistics with column name.
+ def getColumnStatistics(columnName: String): ColumnStatistics = {
+ val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+ columnsStatistics.get(columnIndex).getStatistics
+ }
+
+ // Get Min/Max statistics and store as ORC `WritableComparable` format.
+ def getMinMaxFromColumnStatistics(
+ statistics: ColumnStatistics,
+ dataType: DataType,
+ isMax: Boolean): WritableComparable[_] = {
+ statistics match {
+ case s: BooleanColumnStatistics =>
+ val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+ new BooleanWritable(value)
+ case s: IntegerColumnStatistics =>
+ val value = if (isMax) s.getMaximum else s.getMinimum
+ dataType match {
+ case ByteType => new ByteWritable(value.toByte)
+ case ShortType => new ShortWritable(value.toShort)
+ case IntegerType => new IntWritable(value.toInt)
+ case LongType => new LongWritable(value)
+ case _ => throw new IllegalArgumentException(
+ s"getMaxFromColumnStatistics should not take type $dataType " +
+ "for IntegerColumnStatistics")
+ }
+ case s: DoubleColumnStatistics =>
+ val value = if (isMax) s.getMaximum else s.getMinimum
+ dataType match {
+ case FloatType => new FloatWritable(value.toFloat)
+ case DoubleType => new DoubleWritable(value)
+ case _ => throw new IllegalArgumentException(
+ s"getMaxFromColumnStatistics should not take type $dataType" +
+ "for DoubleColumnStatistics")
+ }
+ case s: DateColumnStatistics =>
+ new DateWritable(
+ if (isMax) s.getMaximumDayOfEpoch.toInt else
s.getMinimumDayOfEpoch.toInt)
+ case _ => throw new IllegalArgumentException(
+ s"getMaxFromColumnStatistics should not take
${statistics.getClass.getName}: " +
+ s"$statistics as the ORC column statistics")
+ }
+ }
+
+ val aggORCValues: Seq[WritableComparable[_]] =
+ aggregation.aggregateExpressions.zipWithIndex.map {
+ case (max: Max, index) =>
+ val columnName = max.column.fieldNames.head
+ val statistics = getColumnStatistics(columnName)
+ val dataType = aggSchema(index).dataType
+ getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
+ case (min: Min, index) =>
+ val columnName = min.column.fieldNames.head
+ val statistics = getColumnStatistics(columnName)
+ val dataType = aggSchema.apply(index).dataType
+ getMinMaxFromColumnStatistics(statistics, dataType, isMax = false)
+ case (count: Count, _) =>
+ val columnName = count.column.fieldNames.head
+ val isPartitionColumn = partitionSchema.fields
+ .map(PartitioningUtils.getColName(_, isCaseSensitive))
+ .contains(columnName)
+ // NOTE: Count(columnName) doesn't include null values.
+ // org.apache.orc.ColumnStatistics.getNumberOfValues() returns
number of non-null values
+ // for ColumnStatistics of individual column. In addition to this,
ORC also stores number
+ // of all values (null and non-null) separately.
+ val nonNullRowsCount = if (isPartitionColumn) {
+ columnsStatistics.getStatistics.getNumberOfValues
Review comment:
hm why we can include both null and non-null values when the column is a
partition column?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]