c21 commented on a change in pull request #34298: URL: https://github.com/apache/spark/pull/34298#discussion_r734787209
########## File path: sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnsStatistics.java ########## @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.orc.ColumnStatistics; + +import java.util.ArrayList; +import java.util.List; + +/** + * Columns statistics interface wrapping ORC {@link ColumnStatistics}s. + * + * Because ORC {@link ColumnStatistics}s are stored as an flatten array in ORC file footer, Review comment: Yes, added comment. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala ########## @@ -377,4 +381,124 @@ object OrcUtils extends Logging { case _ => false } } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data + * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates + * (Max/Min/Count) result using the statistics information from ORC file footer, and then + * construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + def createAggInternalRowFromFooter( + reader: Reader, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + isCaseSensitive: Boolean): InternalRow = { + require(aggregation.groupByColumns.length == 0, + s"aggregate $aggregation with group-by column shouldn't be pushed down") + val columnsStatistics = OrcFooterReader.readStatistics(reader) + + // Get column statistics with column name. + def getColumnStatistics(columnName: String): ColumnStatistics = { + val columnIndex = dataSchema.fieldNames.indexOf(columnName) + columnsStatistics.get(columnIndex).getStatistics + } + + // Get Min/Max statistics and store as ORC `WritableComparable` format. + def getMinMaxFromColumnStatistics( + statistics: ColumnStatistics, + dataType: DataType, + isMax: Boolean): WritableComparable[_] = { + statistics match { + case s: BooleanColumnStatistics => + val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0) + new BooleanWritable(value) + case s: IntegerColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case ByteType => new ByteWritable(value.toByte) + case ShortType => new ShortWritable(value.toShort) + case IntegerType => new IntWritable(value.toInt) + case LongType => new LongWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for IntegerColumnStatistics") + } + case s: DoubleColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case FloatType => new FloatWritable(value.toFloat) + case DoubleType => new DoubleWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for DoubleColumnStatistics") + } + case s: DecimalColumnStatistics => + new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum) + case s: StringColumnStatistics => + new Text(if (isMax) s.getMaximum else s.getMinimum) + case s: DateColumnStatistics => + new DateWritable( + if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " + + s"$statistics as the ORC column statistics") + } + } + + val aggORCValues: Seq[WritableComparable[_]] = + aggregation.aggregateExpressions.zipWithIndex.map { + case (max: Max, index) => + val columnName = max.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema(index).dataType + val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) + value + case (min: Min, index) => + val columnName = min.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema.apply(index).dataType + val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) + value + case (count: Count, _) => + val columnName = count.column.fieldNames.head + val isPartitionColumn = partitionSchema.fields + .map(PartitioningUtils.getColName(_, isCaseSensitive)) + .contains(columnName) + // NOTE: Count(columnName) doesn't include null values. + // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values + // for ColumnStatistics of individual column. In addition to this, ORC also returns + // number of non-null and null values for its top-level + // ColumnStatistics.getNumberOfValues(). + val nonNullRowsCount = if (isPartitionColumn) { + val topLevelStatistics = columnsStatistics.getStatistics + if (topLevelStatistics.hasNull) { + throw new SparkException(s"Illegal ORC top-level column statistics with NULL " + + s"values: $topLevelStatistics. Aggregate expression: $count") + } + topLevelStatistics.getNumberOfValues + } else { + getColumnStatistics(columnName).getNumberOfValues + } + new LongWritable(nonNullRowsCount) + case (_: CountStar, _) => + // Count(*) includes both null and non-null values. + val topLevelStatistics = columnsStatistics.getStatistics + if (topLevelStatistics.hasNull) { + throw new SparkException(s"Illegal ORC top-level column statistics with NULL " + Review comment: @sunchao - yes same as above, this error message is quite confusing. Removed. ########## File path: sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnsStatistics.java ########## @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.orc.ColumnStatistics; + +import java.util.ArrayList; +import java.util.List; + +/** + * Columns statistics interface wrapping ORC {@link ColumnStatistics}s. + * + * Because ORC {@link ColumnStatistics}s are stored as an flatten array in ORC file footer, + * this class is used to covert ORC {@link ColumnStatistics}s from array to nested tree structure, + * according to data types. This is used for aggregate push down in ORC. + */ +public class OrcColumnsStatistics { Review comment: @sunchao - updated with new name. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala ########## @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + +/** + * Utility class for aggregate push down to Parquet and ORC. + */ +object AggregatePushDownUtils { + + /** + * Get the data schema for aggregate to be pushed down. + */ + def getSchemaForPushedAggregation( + aggregation: Aggregation, + schema: StructType, + partitionNameSet: Set[String], + dataFilters: Seq[Expression], + isAllowedTypeForMinMaxAggregate: DataType => Boolean, + sparkSession: SparkSession): Option[StructType] = { Review comment: @sunchao - sorry, removed. ########## File path: sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java ########## @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.orc.ColumnStatistics; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.spark.sql.types.*; + +import java.util.Arrays; +import java.util.LinkedList; +import java.util.Queue; + +/** + * `OrcFooterReader` is a util class which encapsulates the helper + * methods of reading ORC file footer. + */ +public class OrcFooterReader { + + /** + * Read the columns statistics from ORC file footer. + * + * @param orcReader the reader to read ORC file footer. + * @return Statistics for all columns in the file. + */ + public static OrcColumnsStatistics readStatistics(Reader orcReader) { + TypeDescription orcSchema = orcReader.getSchema(); + ColumnStatistics[] orcStatistics = orcReader.getStatistics(); + StructType dataType = OrcUtils.toCatalystSchema(orcSchema); Review comment: @sunchao - updated. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala ########## @@ -377,4 +381,124 @@ object OrcUtils extends Logging { case _ => false } } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data + * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates + * (Max/Min/Count) result using the statistics information from ORC file footer, and then + * construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + def createAggInternalRowFromFooter( + reader: Reader, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + isCaseSensitive: Boolean): InternalRow = { + require(aggregation.groupByColumns.length == 0, + s"aggregate $aggregation with group-by column shouldn't be pushed down") + val columnsStatistics = OrcFooterReader.readStatistics(reader) + + // Get column statistics with column name. + def getColumnStatistics(columnName: String): ColumnStatistics = { + val columnIndex = dataSchema.fieldNames.indexOf(columnName) + columnsStatistics.get(columnIndex).getStatistics + } + + // Get Min/Max statistics and store as ORC `WritableComparable` format. + def getMinMaxFromColumnStatistics( + statistics: ColumnStatistics, + dataType: DataType, + isMax: Boolean): WritableComparable[_] = { + statistics match { + case s: BooleanColumnStatistics => + val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0) + new BooleanWritable(value) + case s: IntegerColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case ByteType => new ByteWritable(value.toByte) + case ShortType => new ShortWritable(value.toShort) + case IntegerType => new IntWritable(value.toInt) + case LongType => new LongWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for IntegerColumnStatistics") + } + case s: DoubleColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case FloatType => new FloatWritable(value.toFloat) + case DoubleType => new DoubleWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for DoubleColumnStatistics") + } + case s: DecimalColumnStatistics => + new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum) + case s: StringColumnStatistics => Review comment: @sunchao - you are right. We should not push down string for ORC, and I just checked the [string type MIN/MAX are truncated by default if exceeding 1024 characters](https://github.com/apache/orc/blob/main/java/core/src/java/org/apache/orc/impl/ColumnStatisticsImpl.java#L719-L723). Removed, thanks. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala ########## @@ -87,84 +86,45 @@ case class ParquetScanBuilder( override def pushedFilters(): Array[Filter] = pushedParquetFilters override def pushAggregation(aggregation: Aggregation): Boolean = { - - def getStructFieldForCol(col: NamedReference): StructField = { - schema.nameToField(col.fieldNames.head) - } - - def isPartitionCol(col: NamedReference) = { - partitionNameSet.contains(col.fieldNames.head) + if (!sparkSession.sessionState.conf.parquetAggregatePushDown) { + return false } - def processMinOrMax(agg: AggregateFunc): Boolean = { - val (column, aggType) = agg match { - case max: Max => (max.column, "max") - case min: Min => (min.column, "min") - case _ => - throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") - } - - if (isPartitionCol(column)) { - // don't push down partition column, footer doesn't have max/min for partition column - return false - } - val structField = getStructFieldForCol(column) - - structField.dataType match { - // not push down complex type - // not push down Timestamp because INT96 sort order is undefined, - // Parquet doesn't return statistics for INT96 - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => + def isAllowedTypeForMinMaxAggregate(dataType: DataType): Boolean = { + dataType match { + // Not push down complex type. + // Not push down Timestamp because INT96 sort order is undefined, + // Parquet doesn't return statistics for INT96. + // Not push down Binary type as Parquet can truncate the statistics. + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType | BinaryType => Review comment: @sunchao - updated. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala ########## @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + +/** + * Utility class for aggregate push down to Parquet and ORC. + */ +object AggregatePushDownUtils { + + /** + * Get the data schema for aggregate to be pushed down. + */ + def getSchemaForPushedAggregation( + aggregation: Aggregation, + schema: StructType, + partitionNameSet: Set[String], Review comment: @sunchao - updated. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala ########## @@ -377,4 +381,124 @@ object OrcUtils extends Logging { case _ => false } } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data + * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates + * (Max/Min/Count) result using the statistics information from ORC file footer, and then + * construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + def createAggInternalRowFromFooter( + reader: Reader, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + isCaseSensitive: Boolean): InternalRow = { + require(aggregation.groupByColumns.length == 0, + s"aggregate $aggregation with group-by column shouldn't be pushed down") + val columnsStatistics = OrcFooterReader.readStatistics(reader) + + // Get column statistics with column name. + def getColumnStatistics(columnName: String): ColumnStatistics = { + val columnIndex = dataSchema.fieldNames.indexOf(columnName) + columnsStatistics.get(columnIndex).getStatistics + } + + // Get Min/Max statistics and store as ORC `WritableComparable` format. + def getMinMaxFromColumnStatistics( + statistics: ColumnStatistics, + dataType: DataType, + isMax: Boolean): WritableComparable[_] = { + statistics match { + case s: BooleanColumnStatistics => + val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0) + new BooleanWritable(value) + case s: IntegerColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case ByteType => new ByteWritable(value.toByte) + case ShortType => new ShortWritable(value.toShort) + case IntegerType => new IntWritable(value.toInt) + case LongType => new LongWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + Review comment: @sunchao - sorry, added. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala ########## @@ -377,4 +381,124 @@ object OrcUtils extends Logging { case _ => false } } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data + * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates + * (Max/Min/Count) result using the statistics information from ORC file footer, and then + * construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + def createAggInternalRowFromFooter( + reader: Reader, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + isCaseSensitive: Boolean): InternalRow = { + require(aggregation.groupByColumns.length == 0, + s"aggregate $aggregation with group-by column shouldn't be pushed down") + val columnsStatistics = OrcFooterReader.readStatistics(reader) + + // Get column statistics with column name. + def getColumnStatistics(columnName: String): ColumnStatistics = { + val columnIndex = dataSchema.fieldNames.indexOf(columnName) + columnsStatistics.get(columnIndex).getStatistics + } + + // Get Min/Max statistics and store as ORC `WritableComparable` format. + def getMinMaxFromColumnStatistics( + statistics: ColumnStatistics, + dataType: DataType, + isMax: Boolean): WritableComparable[_] = { + statistics match { + case s: BooleanColumnStatistics => + val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0) + new BooleanWritable(value) + case s: IntegerColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case ByteType => new ByteWritable(value.toByte) + case ShortType => new ShortWritable(value.toShort) + case IntegerType => new IntWritable(value.toInt) + case LongType => new LongWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for IntegerColumnStatistics") + } + case s: DoubleColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case FloatType => new FloatWritable(value.toFloat) + case DoubleType => new DoubleWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for DoubleColumnStatistics") + } + case s: DecimalColumnStatistics => + new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum) + case s: StringColumnStatistics => + new Text(if (isMax) s.getMaximum else s.getMinimum) + case s: DateColumnStatistics => + new DateWritable( + if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " + + s"$statistics as the ORC column statistics") + } + } + + val aggORCValues: Seq[WritableComparable[_]] = + aggregation.aggregateExpressions.zipWithIndex.map { + case (max: Max, index) => + val columnName = max.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema(index).dataType + val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) + value Review comment: @sunchao - removed. ########## File path: sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala ########## @@ -123,36 +126,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - Count(partition Col): push down") { Review comment: @huaxingao - added back. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala ########## @@ -377,4 +381,124 @@ object OrcUtils extends Logging { case _ => false } } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data + * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates + * (Max/Min/Count) result using the statistics information from ORC file footer, and then + * construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + def createAggInternalRowFromFooter( + reader: Reader, + filePath: String, Review comment: @sunchao - sorry, removed. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala ########## @@ -377,4 +381,124 @@ object OrcUtils extends Logging { case _ => false } } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data + * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates + * (Max/Min/Count) result using the statistics information from ORC file footer, and then + * construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + def createAggInternalRowFromFooter( + reader: Reader, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + isCaseSensitive: Boolean): InternalRow = { + require(aggregation.groupByColumns.length == 0, + s"aggregate $aggregation with group-by column shouldn't be pushed down") + val columnsStatistics = OrcFooterReader.readStatistics(reader) + + // Get column statistics with column name. + def getColumnStatistics(columnName: String): ColumnStatistics = { + val columnIndex = dataSchema.fieldNames.indexOf(columnName) + columnsStatistics.get(columnIndex).getStatistics + } + + // Get Min/Max statistics and store as ORC `WritableComparable` format. + def getMinMaxFromColumnStatistics( + statistics: ColumnStatistics, + dataType: DataType, + isMax: Boolean): WritableComparable[_] = { + statistics match { + case s: BooleanColumnStatistics => + val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0) + new BooleanWritable(value) + case s: IntegerColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case ByteType => new ByteWritable(value.toByte) + case ShortType => new ShortWritable(value.toShort) + case IntegerType => new IntWritable(value.toInt) + case LongType => new LongWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for IntegerColumnStatistics") + } + case s: DoubleColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case FloatType => new FloatWritable(value.toFloat) + case DoubleType => new DoubleWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for DoubleColumnStatistics") + } + case s: DecimalColumnStatistics => + new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum) + case s: StringColumnStatistics => + new Text(if (isMax) s.getMaximum else s.getMinimum) + case s: DateColumnStatistics => + new DateWritable( + if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " + + s"$statistics as the ORC column statistics") + } + } + + val aggORCValues: Seq[WritableComparable[_]] = + aggregation.aggregateExpressions.zipWithIndex.map { + case (max: Max, index) => + val columnName = max.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema(index).dataType + val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) + value + case (min: Min, index) => + val columnName = min.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema.apply(index).dataType + val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) + value + case (count: Count, _) => + val columnName = count.column.fieldNames.head + val isPartitionColumn = partitionSchema.fields + .map(PartitioningUtils.getColName(_, isCaseSensitive)) + .contains(columnName) + // NOTE: Count(columnName) doesn't include null values. + // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values Review comment: @sunchao - rephrased a bit. ORC stores number of non-null values in each column's statistics. In addition to that, ORC also stores number of all values (null and non-null) separately in statistics (it's stored separately from any column's statistics). ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala ########## @@ -37,35 +38,65 @@ case class OrcScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, + pushedAggregate: Option[Aggregation] = None, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { - override def isSplitable(path: Path): Boolean = true + override def isSplitable(path: Path): Boolean = { + // If aggregate is pushed down, only the file footer will be read once, + // so file should be not split across multiple tasks. + pushedAggregate.isEmpty Review comment: @sunchao - agreed, that's why I diverge from Parquet code path for this. We should make sure the file only being processed by only 1 task. Splitting the file across multiple tasks is weird and useless. I can make a change on Parquet side later after this PR is merged. > Also maybe we should change how we measure file weight when combining tasks for aggregate pushdown, since we can combine multiple large files into a single task as computing stats is much cheaper. Yes I thought this as well. It's not trivial though as we need to come up with another heuristics to decide how do we combine files when aggregate is pushed down. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala ########## @@ -377,4 +381,124 @@ object OrcUtils extends Logging { case _ => false } } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data + * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates + * (Max/Min/Count) result using the statistics information from ORC file footer, and then + * construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + def createAggInternalRowFromFooter( + reader: Reader, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + isCaseSensitive: Boolean): InternalRow = { + require(aggregation.groupByColumns.length == 0, + s"aggregate $aggregation with group-by column shouldn't be pushed down") + val columnsStatistics = OrcFooterReader.readStatistics(reader) + + // Get column statistics with column name. + def getColumnStatistics(columnName: String): ColumnStatistics = { + val columnIndex = dataSchema.fieldNames.indexOf(columnName) + columnsStatistics.get(columnIndex).getStatistics + } + + // Get Min/Max statistics and store as ORC `WritableComparable` format. + def getMinMaxFromColumnStatistics( + statistics: ColumnStatistics, + dataType: DataType, + isMax: Boolean): WritableComparable[_] = { + statistics match { + case s: BooleanColumnStatistics => + val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0) + new BooleanWritable(value) + case s: IntegerColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case ByteType => new ByteWritable(value.toByte) + case ShortType => new ShortWritable(value.toShort) + case IntegerType => new IntWritable(value.toInt) + case LongType => new LongWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for IntegerColumnStatistics") + } + case s: DoubleColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case FloatType => new FloatWritable(value.toFloat) + case DoubleType => new DoubleWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for DoubleColumnStatistics") + } + case s: DecimalColumnStatistics => + new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum) + case s: StringColumnStatistics => + new Text(if (isMax) s.getMaximum else s.getMinimum) + case s: DateColumnStatistics => + new DateWritable( + if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " + + s"$statistics as the ORC column statistics") + } + } + + val aggORCValues: Seq[WritableComparable[_]] = + aggregation.aggregateExpressions.zipWithIndex.map { + case (max: Max, index) => + val columnName = max.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema(index).dataType + val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) + value + case (min: Min, index) => + val columnName = min.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema.apply(index).dataType + val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) + value + case (count: Count, _) => + val columnName = count.column.fieldNames.head + val isPartitionColumn = partitionSchema.fields + .map(PartitioningUtils.getColName(_, isCaseSensitive)) + .contains(columnName) + // NOTE: Count(columnName) doesn't include null values. + // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values + // for ColumnStatistics of individual column. In addition to this, ORC also returns + // number of non-null and null values for its top-level + // ColumnStatistics.getNumberOfValues(). + val nonNullRowsCount = if (isPartitionColumn) { + val topLevelStatistics = columnsStatistics.getStatistics + if (topLevelStatistics.hasNull) { + throw new SparkException(s"Illegal ORC top-level column statistics with NULL " + Review comment: @sunchao - here it means the ORC file is invalid. Actually we don't need this check, as ORC guarantees this and this error message is quite confusing. Removed. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala ########## @@ -58,4 +72,35 @@ case class OrcScanBuilder( Array.empty[Filter] } } + + override def pushAggregation(aggregation: Aggregation): Boolean = { + if (!sparkSession.sessionState.conf.orcAggregatePushDown) { + return false + } + + def isAllowedTypeForMinMaxAggregate(dataType: DataType): Boolean = { + dataType match { + // Not push down complex and Timestamp type. + // Not push down Binary type as ORC does not write min/max statistics for it. + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType | BinaryType => Review comment: @sunchao - yes, updated. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
