This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new a5dc3ef83b [spark] support to push down min/max aggregation (#5270)
a5dc3ef83b is described below

commit a5dc3ef83b01f6276360f18e842dd9c0d2749804
Author: Yann Byron <[email protected]>
AuthorDate: Fri Mar 21 13:23:41 2025 +0800

    [spark] support to push down min/max aggregation (#5270)
---
 .../apache/paimon/stats/SimpleStatsEvolution.java  |  24 ++++
 .../org/apache/paimon/table/source/DataSplit.java  |  44 ++++++
 .../org/apache/paimon/table/source/SplitTest.java  | 110 ++++++++++++++-
 .../apache/paimon/spark/PaimonScanBuilder.scala    |  40 +++---
 .../paimon/spark/aggregate/AggFuncEvaluator.scala  |  96 +++++++++++++
 .../spark/aggregate/AggregatePushDownUtils.scala   | 124 +++++++++++++++++
 .../paimon/spark/aggregate/LocalAggregator.scala   |  93 +++++--------
 .../paimon/spark/sql/PushDownAggregatesTest.scala  | 150 +++++++++++++++++----
 8 files changed, 572 insertions(+), 109 deletions(-)

diff --git 
a/paimon-core/src/main/java/org/apache/paimon/stats/SimpleStatsEvolution.java 
b/paimon-core/src/main/java/org/apache/paimon/stats/SimpleStatsEvolution.java
index fb029eccdb..b1c7cfebee 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/stats/SimpleStatsEvolution.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/stats/SimpleStatsEvolution.java
@@ -64,6 +64,30 @@ public class SimpleStatsEvolution {
         this.emptyNullCounts = new GenericArray(new Object[fieldNames.size()]);
     }
 
+    public InternalRow evolution(InternalRow row, @Nullable List<String> 
denseFields) {
+        InternalRow result = row;
+
+        if (denseFields != null && denseFields.isEmpty()) {
+            result = emptyValues;
+        } else if (denseFields != null) {
+            int[] denseIndexMapping =
+                    indexMappings.computeIfAbsent(
+                            denseFields,
+                            k -> 
fieldNames.stream().mapToInt(denseFields::indexOf).toArray());
+            result = ProjectedRow.from(denseIndexMapping).replaceRow(result);
+        }
+
+        if (indexMapping != null) {
+            result = ProjectedRow.from(indexMapping).replaceRow(result);
+        }
+
+        if (castFieldGetters != null) {
+            result = CastedRow.from(castFieldGetters).replaceRow(result);
+        }
+
+        return result;
+    }
+
     public Result evolution(
             SimpleStats stats, @Nullable Long rowCount, @Nullable List<String> 
denseFields) {
         InternalRow minValues = stats.minValues();
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java 
b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
index 39f9269f41..5e39d3a71b 100644
--- a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
+++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
@@ -19,6 +19,7 @@
 package org.apache.paimon.table.source;
 
 import org.apache.paimon.data.BinaryRow;
+import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.io.DataFileMeta;
 import org.apache.paimon.io.DataFileMeta08Serializer;
 import org.apache.paimon.io.DataFileMeta09Serializer;
@@ -28,7 +29,12 @@ import org.apache.paimon.io.DataInputView;
 import org.apache.paimon.io.DataInputViewStreamWrapper;
 import org.apache.paimon.io.DataOutputView;
 import org.apache.paimon.io.DataOutputViewStreamWrapper;
+import org.apache.paimon.predicate.CompareUtils;
+import org.apache.paimon.stats.SimpleStatsEvolution;
+import org.apache.paimon.stats.SimpleStatsEvolutions;
+import org.apache.paimon.types.DataField;
 import org.apache.paimon.utils.FunctionWithIOException;
+import org.apache.paimon.utils.InternalRowUtils;
 import org.apache.paimon.utils.SerializationUtils;
 
 import javax.annotation.Nullable;
@@ -141,6 +147,44 @@ public class DataSplit implements Split {
         return partialMergedRowCount();
     }
 
+    public Object minValue(int fieldIndex, DataField dataField, 
SimpleStatsEvolutions evolutions) {
+        Object minValue = null;
+        for (DataFileMeta dataFile : dataFiles) {
+            SimpleStatsEvolution evolution = 
evolutions.getOrCreate(dataFile.schemaId());
+            InternalRow minValues =
+                    evolution.evolution(
+                            dataFile.valueStats().minValues(), 
dataFile.valueStatsCols());
+            Object other = InternalRowUtils.get(minValues, fieldIndex, 
dataField.type());
+            if (minValue == null) {
+                minValue = other;
+            } else if (other != null) {
+                if (CompareUtils.compareLiteral(dataField.type(), minValue, 
other) > 0) {
+                    minValue = other;
+                }
+            }
+        }
+        return minValue;
+    }
+
+    public Object maxValue(int fieldIndex, DataField dataField, 
SimpleStatsEvolutions evolutions) {
+        Object maxValue = null;
+        for (DataFileMeta dataFile : dataFiles) {
+            SimpleStatsEvolution evolution = 
evolutions.getOrCreate(dataFile.schemaId());
+            InternalRow maxValues =
+                    evolution.evolution(
+                            dataFile.valueStats().maxValues(), 
dataFile.valueStatsCols());
+            Object other = InternalRowUtils.get(maxValues, fieldIndex, 
dataField.type());
+            if (maxValue == null) {
+                maxValue = other;
+            } else if (other != null) {
+                if (CompareUtils.compareLiteral(dataField.type(), maxValue, 
other) < 0) {
+                    maxValue = other;
+                }
+            }
+        }
+        return maxValue;
+    }
+
     /**
      * Obtain merged row count as much as possible. There are two scenarios 
where accurate row count
      * can be calculated:
diff --git 
a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java 
b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java
index a088f40dab..a87a645711 100644
--- a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java
+++ b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java
@@ -28,18 +28,30 @@ import org.apache.paimon.io.DataInputDeserializer;
 import org.apache.paimon.io.DataOutputViewStreamWrapper;
 import org.apache.paimon.manifest.FileSource;
 import org.apache.paimon.stats.SimpleStats;
+import org.apache.paimon.stats.SimpleStatsEvolutions;
+import org.apache.paimon.types.BigIntType;
+import org.apache.paimon.types.DataField;
+import org.apache.paimon.types.DoubleType;
+import org.apache.paimon.types.FloatType;
+import org.apache.paimon.types.IntType;
+import org.apache.paimon.types.SmallIntType;
+import org.apache.paimon.types.TimestampType;
 import org.apache.paimon.utils.IOUtils;
 import org.apache.paimon.utils.InstantiationUtil;
 
 import org.junit.jupiter.api.Test;
 
+import javax.annotation.Nullable;
+
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.time.LocalDateTime;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.ThreadLocalRandom;
 
 import static org.apache.paimon.data.BinaryArray.fromLongArray;
@@ -84,6 +96,70 @@ public class SplitTest {
         assertThat(split.mergedRowCount()).isEqualTo(5700L);
     }
 
+    @Test
+    public void testSplitMinMaxValue() {
+        Map<Long, List<DataField>> schemas = new HashMap<>();
+
+        Timestamp minTs = 
Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-01T00:00:00"));
+        Timestamp maxTs1 = 
Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-03-01T00:00:00"));
+        Timestamp maxTs2 = 
Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-03-12T00:00:00"));
+        BinaryRow min1 = newBinaryRow(new Object[] {10, 123L, 888.0D, minTs});
+        BinaryRow max1 = newBinaryRow(new Object[] {99, 456L, 999.0D, maxTs1});
+        SimpleStats valueStats1 = new SimpleStats(min1, max1, 
fromLongArray(new Long[] {0L}));
+
+        BinaryRow min2 = newBinaryRow(new Object[] {5, 0L, 777.0D, minTs});
+        BinaryRow max2 = newBinaryRow(new Object[] {90, 789L, 899.0D, maxTs2});
+        SimpleStats valueStats2 = new SimpleStats(min2, max2, 
fromLongArray(new Long[] {0L}));
+
+        // test the common case.
+        DataFileMeta d1 = newDataFile(100, valueStats1, null);
+        DataFileMeta d2 = newDataFile(100, valueStats2, null);
+        DataSplit split1 = newDataSplit(true, Arrays.asList(d1, d2), null);
+
+        DataField intField = new DataField(0, "c_int", new IntType());
+        DataField longField = new DataField(1, "c_long", new BigIntType());
+        DataField doubleField = new DataField(2, "c_double", new DoubleType());
+        DataField tsField = new DataField(3, "c_ts", new TimestampType());
+        schemas.put(1L, Arrays.asList(intField, longField, doubleField, 
tsField));
+
+        SimpleStatsEvolutions evolutions = new 
SimpleStatsEvolutions(schemas::get, 1);
+        assertThat(split1.minValue(0, intField, evolutions)).isEqualTo(5);
+        assertThat(split1.maxValue(0, intField, evolutions)).isEqualTo(99);
+        assertThat(split1.minValue(1, longField, evolutions)).isEqualTo(0L);
+        assertThat(split1.maxValue(1, longField, evolutions)).isEqualTo(789L);
+        assertThat(split1.minValue(2, doubleField, 
evolutions)).isEqualTo(777D);
+        assertThat(split1.maxValue(2, doubleField, 
evolutions)).isEqualTo(999D);
+        assertThat(split1.minValue(3, tsField, evolutions)).isEqualTo(minTs);
+        assertThat(split1.maxValue(3, tsField, evolutions)).isEqualTo(maxTs2);
+
+        // test the case which provide non-null valueStatsCol and there are 
different between file
+        // schema and table schema.
+        BinaryRow min3 = newBinaryRow(new Object[] {10, 123L, minTs});
+        BinaryRow max3 = newBinaryRow(new Object[] {99, 456L, maxTs1});
+        SimpleStats valueStats3 = new SimpleStats(min3, max3, 
fromLongArray(new Long[] {0L}));
+        BinaryRow min4 = newBinaryRow(new Object[] {5, 0L, minTs});
+        BinaryRow max4 = newBinaryRow(new Object[] {90, 789L, maxTs2});
+        SimpleStats valueStats4 = new SimpleStats(min4, max4, 
fromLongArray(new Long[] {0L}));
+        List<String> valueStatsCols2 = Arrays.asList("c_int", "c_long", 
"c_ts");
+        DataFileMeta d3 = newDataFile(100, valueStats3, valueStatsCols2);
+        DataFileMeta d4 = newDataFile(100, valueStats4, valueStatsCols2);
+        DataSplit split2 = newDataSplit(true, Arrays.asList(d3, d4), null);
+
+        DataField smallField = new DataField(4, "c_small", new SmallIntType());
+        DataField floatField = new DataField(5, "c_float", new FloatType());
+        schemas.put(2L, Arrays.asList(intField, smallField, tsField, 
floatField));
+
+        evolutions = new SimpleStatsEvolutions(schemas::get, 2);
+        assertThat(split2.minValue(0, intField, evolutions)).isEqualTo(5);
+        assertThat(split2.maxValue(0, intField, evolutions)).isEqualTo(99);
+        assertThat(split2.minValue(1, smallField, evolutions)).isEqualTo(null);
+        assertThat(split2.maxValue(1, smallField, evolutions)).isEqualTo(null);
+        assertThat(split2.minValue(2, tsField, evolutions)).isEqualTo(minTs);
+        assertThat(split2.maxValue(2, tsField, evolutions)).isEqualTo(maxTs2);
+        assertThat(split2.minValue(3, floatField, evolutions)).isEqualTo(null);
+        assertThat(split2.maxValue(3, floatField, evolutions)).isEqualTo(null);
+    }
+
     @Test
     public void testSerializer() throws IOException {
         DataFileTestDataGenerator gen = 
DataFileTestDataGenerator.builder().build();
@@ -436,18 +512,23 @@ public class SplitTest {
     }
 
     private DataFileMeta newDataFile(long rowCount) {
+        return newDataFile(rowCount, null, null);
+    }
+
+    private DataFileMeta newDataFile(
+            long rowCount, SimpleStats rowStats, @Nullable List<String> 
valueStatsCols) {
         return DataFileMeta.forAppend(
                 "my_data_file.parquet",
                 1024 * 1024,
                 rowCount,
-                null,
+                rowStats,
                 0L,
-                rowCount,
+                rowCount - 1,
                 1,
                 Collections.emptyList(),
                 null,
                 null,
-                null,
+                valueStatsCols,
                 null);
     }
 
@@ -467,4 +548,27 @@ public class SplitTest {
         }
         return builder.build();
     }
+
+    private BinaryRow newBinaryRow(Object[] objs) {
+        BinaryRow row = new BinaryRow(objs.length);
+        BinaryRowWriter writer = new BinaryRowWriter(row);
+        writer.reset();
+        for (int i = 0; i < objs.length; i++) {
+            if (objs[i] instanceof Integer) {
+                writer.writeInt(i, (Integer) objs[i]);
+            } else if (objs[i] instanceof Long) {
+                writer.writeLong(i, (Long) objs[i]);
+            } else if (objs[i] instanceof Float) {
+                writer.writeFloat(i, (Float) objs[i]);
+            } else if (objs[i] instanceof Double) {
+                writer.writeDouble(i, (Double) objs[i]);
+            } else if (objs[i] instanceof Timestamp) {
+                writer.writeTimestamp(i, (Timestamp) objs[i], 5);
+            } else {
+                throw new UnsupportedOperationException("It's not supported.");
+            }
+        }
+        writer.complete();
+        return row;
+    }
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
index c9afa07021..5fe1737c0d 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
@@ -19,8 +19,8 @@
 package org.apache.paimon.spark
 
 import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, 
PredicateBuilder}
-import org.apache.paimon.spark.aggregate.LocalAggregator
-import org.apache.paimon.table.Table
+import org.apache.paimon.spark.aggregate.{AggregatePushDownUtils, 
LocalAggregator}
+import org.apache.paimon.table.{FileStoreTable, Table}
 import org.apache.paimon.table.source.DataSplit
 
 import org.apache.spark.sql.PaimonUtils
@@ -101,13 +101,12 @@ class PaimonScanBuilder(table: Table)
       return true
     }
 
-    // Only support when there is no post scan predicates.
-    if (hasPostScanPredicates) {
+    if (!table.isInstanceOf[FileStoreTable]) {
       return false
     }
 
-    val aggregator = new LocalAggregator(table)
-    if (!aggregator.pushAggregation(aggregation)) {
+    // Only support when there is no post scan predicates.
+    if (hasPostScanPredicates) {
       return false
     }
 
@@ -116,19 +115,26 @@ class PaimonScanBuilder(table: Table)
       val pushedPartitionPredicate = 
PredicateBuilder.and(pushedPaimonPredicates.toList.asJava)
       readBuilder.withFilter(pushedPartitionPredicate)
     }
-    val dataSplits =
+    val dataSplits = if 
(AggregatePushDownUtils.hasMinMaxAggregation(aggregation)) {
+      
readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
+    } else {
       
readBuilder.dropStats().newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
-    if (!dataSplits.forall(_.mergedRowCountAvailable())) {
-      return false
     }
-    dataSplits.foreach(aggregator.update)
-    localScan = Some(
-      PaimonLocalScan(
-        aggregator.result(),
-        aggregator.resultSchema(),
-        table,
-        pushedPaimonPredicates))
-    true
+    if (AggregatePushDownUtils.canPushdownAggregation(table, aggregation, 
dataSplits.toSeq)) {
+      val aggregator = new LocalAggregator(table.asInstanceOf[FileStoreTable])
+      aggregator.initialize(aggregation)
+      dataSplits.foreach(aggregator.update)
+      localScan = Some(
+        PaimonLocalScan(
+          aggregator.result(),
+          aggregator.resultSchema(),
+          table,
+          pushedPaimonPredicates)
+      )
+      true
+    } else {
+      false
+    }
   }
 
   override def build(): Scan = {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggFuncEvaluator.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggFuncEvaluator.scala
new file mode 100644
index 0000000000..fcb64e3064
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggFuncEvaluator.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.aggregate
+
+import org.apache.paimon.data.BinaryString
+import org.apache.paimon.predicate.CompareUtils
+import org.apache.paimon.spark.SparkTypeUtils
+import org.apache.paimon.stats.SimpleStatsEvolutions
+import org.apache.paimon.table.source.DataSplit
+import org.apache.paimon.types.DataField
+
+import org.apache.spark.sql.types.{DataType, LongType}
+import org.apache.spark.unsafe.types.UTF8String
+
+trait AggFuncEvaluator[T] {
+  def update(dataSplit: DataSplit): Unit
+
+  def result(): T
+
+  def resultType: DataType
+
+  def prettyName: String
+}
+
+class CountStarEvaluator extends AggFuncEvaluator[Long] {
+  private var _result: Long = 0L
+
+  override def update(dataSplit: DataSplit): Unit = {
+    _result += dataSplit.mergedRowCount()
+  }
+
+  val a: Int = 1;
+  override def result(): Long = _result
+
+  override def resultType: DataType = LongType
+
+  override def prettyName: String = "count_star"
+}
+
+case class MinEvaluator(idx: Int, dataField: DataField, evolutions: 
SimpleStatsEvolutions)
+  extends AggFuncEvaluator[Any] {
+  private var _result: Any = _
+
+  override def update(dataSplit: DataSplit): Unit = {
+    val other = dataSplit.minValue(idx, dataField, evolutions)
+    if (_result == null || CompareUtils.compareLiteral(dataField.`type`(), 
_result, other) > 0) {
+      _result = other;
+    }
+  }
+
+  override def result(): Any = _result match {
+    case s: BinaryString => UTF8String.fromString(s.toString)
+    case a => a
+  }
+
+  override def resultType: DataType = 
SparkTypeUtils.fromPaimonType(dataField.`type`())
+
+  override def prettyName: String = "min"
+}
+
+case class MaxEvaluator(idx: Int, dataField: DataField, evolutions: 
SimpleStatsEvolutions)
+  extends AggFuncEvaluator[Any] {
+  private var _result: Any = _
+
+  override def update(dataSplit: DataSplit): Unit = {
+    val other = dataSplit.maxValue(idx, dataField, evolutions)
+    if (_result == null || CompareUtils.compareLiteral(dataField.`type`(), 
_result, other) < 0) {
+      _result = other
+    }
+  }
+
+  override def result(): Any = _result match {
+    case s: BinaryString => UTF8String.fromString(s.toString)
+    case a => a
+  }
+
+  override def resultType: DataType = 
SparkTypeUtils.fromPaimonType(dataField.`type`())
+
+  override def prettyName: String = "max"
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala
new file mode 100644
index 0000000000..c6abec1acd
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.aggregate
+
+import org.apache.paimon.table.Table
+import org.apache.paimon.table.source.DataSplit
+import org.apache.paimon.types._
+
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, 
Aggregation, CountStar, Max, Min}
+import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+object AggregatePushDownUtils {
+
+  def canPushdownAggregation(
+      table: Table,
+      aggregation: Aggregation,
+      dataSplits: Seq[DataSplit]): Boolean = {
+
+    var hasMinMax = false
+    val minmaxColumns = mutable.HashSet.empty[String]
+    var hasCount = false
+
+    def getDataFieldForCol(colName: String): DataField = {
+      table.rowType.getField(colName)
+    }
+
+    def isPartitionCol(colName: String) = {
+      table.partitionKeys.contains(colName)
+    }
+
+    def processMinOrMax(agg: AggregateFunc): Boolean = {
+      val columnName = agg match {
+        case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined =>
+          V2ColumnUtils.extractV2Column(max.column).get
+        case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined =>
+          V2ColumnUtils.extractV2Column(min.column).get
+        case _ => return false
+      }
+
+      val dataField = getDataFieldForCol(columnName)
+
+      dataField.`type`() match {
+        // not push down complex type
+        // not push down Timestamp because INT96 sort order is undefined,
+        // Parquet doesn't return statistics for INT96
+        // not push down Parquet Binary because min/max could be truncated
+        // (https://issues.apache.org/jira/browse/PARQUET-1685), Parquet Binary
+        // could be Spark StringType, BinaryType or DecimalType.
+        // not push down for ORC with same reason.
+        case _: BooleanType | _: TinyIntType | _: SmallIntType | _: IntType | 
_: BigIntType |
+            _: FloatType | _: DoubleType | _: DateType =>
+          minmaxColumns.add(columnName)
+          hasMinMax = true
+          true
+        case _ =>
+          false
+      }
+    }
+
+    aggregation.groupByExpressions.map(V2ColumnUtils.extractV2Column).foreach {
+      colName =>
+        // don't push down if the group by columns are not the same as the 
partition columns (orders
+        // doesn't matter because reorder can be done at data source layer)
+        if (colName.isEmpty || !isPartitionCol(colName.get)) return false
+    }
+
+    aggregation.aggregateExpressions.foreach {
+      case max: Max =>
+        if (!processMinOrMax(max)) return false
+      case min: Min =>
+        if (!processMinOrMax(min)) return false
+      case _: CountStar =>
+        hasCount = true
+      case _ =>
+        return false
+    }
+
+    if (hasMinMax) {
+      dataSplits.forall {
+        dataSplit =>
+          dataSplit.dataFiles().asScala.forall {
+            dataFile =>
+              // It means there are all column statistics when valueStatsCols 
== null
+              dataFile.valueStatsCols() == null ||
+              minmaxColumns.forall(dataFile.valueStatsCols().contains)
+          }
+      }
+    } else if (hasCount) {
+      dataSplits.forall(_.mergedRowCountAvailable())
+    } else {
+      true
+    }
+  }
+
+  def hasMinMaxAggregation(aggregation: Aggregation): Boolean = {
+    var hasMinMax = false;
+    aggregation.aggregateExpressions().foreach {
+      case _: Min | _: Max =>
+        hasMinMax = true
+      case _ =>
+    }
+    hasMinMax
+  }
+
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala
index 8988e7218d..bb88aa669e 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala
@@ -18,34 +18,59 @@
 
 package org.apache.paimon.spark.aggregate
 
+import org.apache.paimon.CoreOptions
 import org.apache.paimon.data.BinaryRow
+import org.apache.paimon.schema.SchemaManager
 import org.apache.paimon.spark.SparkTypeUtils
 import org.apache.paimon.spark.data.SparkInternalRow
-import org.apache.paimon.table.{DataTable, Table}
+import org.apache.paimon.stats.SimpleStatsEvolutions
+import org.apache.paimon.table.FileStoreTable
 import org.apache.paimon.table.source.DataSplit
 import org.apache.paimon.utils.{InternalRowUtils, ProjectedRow}
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.JoinedRow
-import org.apache.spark.sql.connector.expressions.{Expression, NamedReference}
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, 
Aggregation, CountStar}
-import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
+import org.apache.spark.sql.connector.expressions.NamedReference
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, 
CountStar, Max, Min}
+import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
 
 import scala.collection.mutable
 
-class LocalAggregator(table: Table) {
+class LocalAggregator(table: FileStoreTable) {
+  private val rowType = table.rowType()
   private val partitionType = SparkTypeUtils.toPartitionType(table)
   private val groupByEvaluatorMap = new mutable.HashMap[InternalRow, 
Seq[AggFuncEvaluator[_]]]()
   private var requiredGroupByType: Seq[DataType] = _
   private var requiredGroupByIndexMapping: Seq[Int] = _
   private var aggFuncEvaluatorGetter: () => Seq[AggFuncEvaluator[_]] = _
   private var isInitialized = false
+  private lazy val simpleStatsEvolutions = {
+    val schemaManager = new SchemaManager(
+      table.fileIO(),
+      table.location(),
+      CoreOptions.branch(table.schema().options()))
+    new SimpleStatsEvolutions(sid => schemaManager.schema(sid).fields(), 
table.schema().id())
+  }
 
-  private def initialize(aggregation: Aggregation): Unit = {
+  def initialize(aggregation: Aggregation): Unit = {
     aggFuncEvaluatorGetter = () =>
       aggregation.aggregateExpressions().map {
         case _: CountStar => new CountStarEvaluator()
-        case _ => throw new UnsupportedOperationException()
+        case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined =>
+          val fieldName = V2ColumnUtils.extractV2Column(min.column).get
+          MinEvaluator(
+            rowType.getFieldIndex(fieldName),
+            rowType.getField(fieldName),
+            simpleStatsEvolutions)
+        case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined =>
+          val fieldName = V2ColumnUtils.extractV2Column(max.column).get
+          MaxEvaluator(
+            rowType.getFieldIndex(fieldName),
+            rowType.getField(fieldName),
+            simpleStatsEvolutions)
+        case _ =>
+          throw new UnsupportedOperationException()
       }
 
     requiredGroupByType = aggregation.groupByExpressions().map {
@@ -61,39 +86,6 @@ class LocalAggregator(table: Table) {
     isInitialized = true
   }
 
-  private def supportAggregateFunction(func: AggregateFunc): Boolean = {
-    func match {
-      case _: CountStar => true
-      case _ => false
-    }
-  }
-
-  private def supportGroupByExpressions(exprs: Array[Expression]): Boolean = {
-    // Support empty group by keys or group by partition column
-    exprs.forall {
-      case r: NamedReference =>
-        r.fieldNames.length == 1 && 
table.partitionKeys().contains(r.fieldNames().head)
-      case _ => false
-    }
-  }
-
-  def pushAggregation(aggregation: Aggregation): Boolean = {
-    if (!table.isInstanceOf[DataTable]) {
-      return false
-    }
-
-    if (
-      !supportGroupByExpressions(aggregation.groupByExpressions()) ||
-      aggregation.aggregateExpressions().isEmpty ||
-      aggregation.aggregateExpressions().exists(!supportAggregateFunction(_))
-    ) {
-      return false
-    }
-
-    initialize(aggregation)
-    true
-  }
-
   private def requiredGroupByRow(partitionRow: BinaryRow): InternalRow = {
     val projectedRow =
       
ProjectedRow.from(requiredGroupByIndexMapping.toArray).replaceRow(partitionRow)
@@ -139,24 +131,3 @@ class LocalAggregator(table: Table) {
     StructType.apply(groupByFields ++ aggResultFields)
   }
 }
-
-trait AggFuncEvaluator[T] {
-  def update(dataSplit: DataSplit): Unit
-  def result(): T
-  def resultType: DataType
-  def prettyName: String
-}
-
-class CountStarEvaluator extends AggFuncEvaluator[Long] {
-  private var _result: Long = 0L
-
-  override def update(dataSplit: DataSplit): Unit = {
-    _result += dataSplit.mergedRowCount()
-  }
-
-  override def result(): Long = _result
-
-  override def resultType: DataType = LongType
-
-  override def prettyName: String = "count_star"
-}
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
index 78c02644a7..26c19ecc27 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.execution.LocalTableScanExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
 
+import java.sql.Date
+
 class PushDownAggregatesTest extends PaimonSparkTestBase with 
AdaptiveSparkPlanHelper {
 
   private def runAndCheckAggregate(
@@ -49,70 +51,162 @@ class PushDownAggregatesTest extends PaimonSparkTestBase 
with AdaptiveSparkPlanH
     }
   }
 
-  test("Push down aggregate - append table") {
+  test("Push down aggregate - append table without partitions") {
     withTable("T") {
-      spark.sql("CREATE TABLE T (c1 INT, c2 STRING) PARTITIONED BY(day 
STRING)")
+      spark.sql("CREATE TABLE T (c1 INT, c2 STRING, c3 DOUBLE, c4 DATE)")
 
       runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(0) :: Nil, 0)
+      runAndCheckAggregate(
+        "SELECT COUNT(*), MIN(c1), MIN(c3), MIN(c4) FROM T",
+        Row(0, null, null, null) :: Nil,
+        0)
+      runAndCheckAggregate(
+        "SELECT COUNT(*), MAX(c1), MAX(c3), MAX(c4) FROM T",
+        Row(0, null, null, null) :: Nil,
+        0)
+      // count(c1) and min/max for string are not supported.
+      runAndCheckAggregate("SELECT COUNT(c1) FROM T", Row(0) :: Nil, 2)
+      runAndCheckAggregate("SELECT MIN(c2) FROM T", Row(null) :: Nil, 2)
+      runAndCheckAggregate("SELECT MAX(c2) FROM T", Row(null) :: Nil, 2)
+
       // This query does not contain aggregate due to AQE optimize it to empty 
relation.
       runAndCheckAggregate("SELECT COUNT(*) FROM T GROUP BY c1", Nil, 0)
-      runAndCheckAggregate("SELECT COUNT(c1) FROM T", Row(0) :: Nil, 2)
       runAndCheckAggregate("SELECT COUNT(*), COUNT(c1) FROM T", Row(0, 0) :: 
Nil, 2)
-      runAndCheckAggregate("SELECT COUNT(*), COUNT(*) + 1 FROM T", Row(0, 1) 
:: Nil, 0)
-      runAndCheckAggregate("SELECT COUNT(*) as c FROM T WHERE day='a'", Row(0) 
:: Nil, 0)
-      runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE c1=1", Row(0) :: Nil, 
2)
-      runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE day='a' and c1=1", 
Row(0) :: Nil, 2)
+      runAndCheckAggregate(
+        "SELECT COUNT(*) + 1, MIN(c1) * 10, MAX(c3) + 1.0 FROM T",
+        Row(1, null, null) :: Nil,
+        0)
+      runAndCheckAggregate(
+        "SELECT COUNT(*) as cnt, MIN(c4) as min_c4 FROM T",
+        Row(0, null) :: Nil,
+        0)
+      // The cases with common data filters are not supported.
+      runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE c1 = 1", Row(0) :: 
Nil, 2)
 
       spark.sql(
-        "INSERT INTO T VALUES(1, 'x', 'a'), (2, 'x', 'a'), (3, 'x', 'b'), (3, 
'x', 'c'), (null, 'x', 'a')")
+        s"""
+           |INSERT INTO T VALUES (1, 'xyz', 11.1, TO_DATE('2025-01-01', 
'yyyy-MM-dd')),
+           |(2, null, null, TO_DATE('2025-01-01', 'yyyy-MM-dd')), (3, 'abc', 
33.3, null),
+           |(3, 'abc', null, TO_DATE('2025-03-01', 'yyyy-MM-dd')), (null, 
'abc', 44.4, TO_DATE('2025-03-01', 'yyyy-MM-dd'))
+           |""".stripMargin)
 
+      val date1 = Date.valueOf("2025-01-01")
+      val date2 = Date.valueOf("2025-03-01")
       runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(5) :: Nil, 0)
       runAndCheckAggregate(
         "SELECT COUNT(*) FROM T GROUP BY c1",
         Row(1) :: Row(1) :: Row(1) :: Row(2) :: Nil,
         2)
       runAndCheckAggregate("SELECT COUNT(c1) FROM T", Row(4) :: Nil, 2)
+
+      runAndCheckAggregate("SELECT COUNT(*), MIN(c1), MAX(c1) FROM T", Row(5, 
1, 3) :: Nil, 0)
+      runAndCheckAggregate(
+        "SELECT COUNT(*), MIN(c2), MAX(c2) FROM T",
+        Row(5, "abc", "xyz") :: Nil,
+        2)
+      runAndCheckAggregate("SELECT COUNT(*), MIN(c3), MAX(c3) FROM T", Row(5, 
11.1, 44.4) :: Nil, 0)
+      runAndCheckAggregate(
+        "SELECT COUNT(*), MIN(c4), MAX(c4) FROM T",
+        Row(5, date1, date2) :: Nil,
+        0)
       runAndCheckAggregate("SELECT COUNT(*), COUNT(c1) FROM T", Row(5, 4) :: 
Nil, 2)
       runAndCheckAggregate("SELECT COUNT(*), COUNT(*) + 1 FROM T", Row(5, 6) 
:: Nil, 0)
-      runAndCheckAggregate("SELECT COUNT(*) as c FROM T WHERE day='a'", Row(3) 
:: Nil, 0)
-      runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE c1=1", Row(1) :: Nil, 
2)
-      runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE day='a' and c1=1", 
Row(1) :: Nil, 2)
+      runAndCheckAggregate(
+        "SELECT COUNT(*) + 1, MIN(c1) * 10, MAX(c3) + 1.0 FROM T",
+        Row(6, 10, 45.4) :: Nil,
+        0)
+      runAndCheckAggregate(
+        "SELECT MIN(c3) as min, MAX(c4) as max FROM T",
+        Row(11.1, date2) :: Nil,
+        0)
+      runAndCheckAggregate("SELECT COUNT(*), MIN(c3) FROM T WHERE c1 = 3", 
Row(2, 33.3) :: Nil, 2)
     }
   }
 
-  test("Push down aggregate - group by partition column") {
+  test("Push down aggregate - append table with partitions") {
     withTable("T") {
-      spark.sql("CREATE TABLE T (c1 INT) PARTITIONED BY(day STRING, hour INT)")
+      spark.sql("CREATE TABLE T (c1 INT, c2 LONG) PARTITIONED BY(day STRING, 
hour INT)")
 
       runAndCheckAggregate("SELECT COUNT(*) FROM T GROUP BY day", Nil, 0)
-      runAndCheckAggregate("SELECT day, COUNT(*) as c FROM T GROUP BY day, 
hour", Nil, 0)
-      runAndCheckAggregate("SELECT day, COUNT(*), hour FROM T GROUP BY day, 
hour", Nil, 0)
       runAndCheckAggregate(
-        "SELECT day, COUNT(*), hour FROM T WHERE day='x' GROUP BY day, hour",
+        "SELECT day, hour, COUNT(*), MIN(c1), MIN(c1) FROM T GROUP BY day, 
hour",
+        Nil,
+        0)
+      runAndCheckAggregate(
+        "SELECT day, hour, COUNT(*), MIN(c2), MIN(c2) FROM T GROUP BY day, 
hour",
+        Nil,
+        0)
+      runAndCheckAggregate(
+        "SELECT day, COUNT(*), hour FROM T WHERE day= '2025-01-01' GROUP BY 
day, hour",
         Nil,
         0)
       // This query does not contain aggregate due to AQE optimize it to empty 
relation.
       runAndCheckAggregate("SELECT day, COUNT(*) FROM T GROUP BY c1, day", 
Nil, 0)
 
       spark.sql(
-        "INSERT INTO T VALUES(1, 'x', 1), (2, 'x', 1), (3, 'x', 2), (3, 'x', 
3), (null, 'y', null)")
+        """
+          |INSERT INTO T VALUES(1, 100L, '2025-01-01', 1), (2, null, 
'2025-01-01', 1),
+          |(3, 300L, '2025-03-01', 3), (3, 330L, '2025-03-01', 3), (null, 
400L, '2025-03-01', null)
+          |""".stripMargin)
 
-      runAndCheckAggregate("SELECT COUNT(*) FROM T GROUP BY day", Row(1) :: 
Row(4) :: Nil, 0)
+      runAndCheckAggregate("SELECT COUNT(*) FROM T GROUP BY day", Row(2) :: 
Row(3) :: Nil, 0)
       runAndCheckAggregate(
-        "SELECT day, COUNT(*) as c FROM T GROUP BY day, hour",
-        Row("x", 1) :: Row("x", 1) :: Row("x", 2) :: Row("y", 1) :: Nil,
+        "SELECT day, hour, COUNT(*) as c FROM T GROUP BY day, hour",
+        Row("2025-01-01", 1, 2) :: Row("2025-03-01", 3, 2) :: 
Row("2025-03-01", null, 1) :: Nil,
         0)
       runAndCheckAggregate(
-        "SELECT day, COUNT(*), hour FROM T GROUP BY day, hour",
-        Row("x", 1, 2) :: Row("y", 1, null) :: Row("x", 2, 1) :: Row("x", 1, 
3) :: Nil,
-        0)
+        "SELECT day, COUNT(*), hour, MIN(c1), MAX(c1) FROM T GROUP BY day, 
hour",
+        Row("2025-01-01", 2, 1, 1, 2) :: Row("2025-03-01", 2, 3, 3, 3) :: Row(
+          "2025-03-01",
+          1,
+          null,
+          null,
+          null) :: Nil,
+        0
+      )
       runAndCheckAggregate(
-        "SELECT day, COUNT(*), hour FROM T WHERE day='x' GROUP BY day, hour",
-        Row("x", 1, 2) :: Row("x", 1, 3) :: Row("x", 2, 1) :: Nil,
-        0)
+        "SELECT hour, COUNT(*), MIN(c2) as min, MAX(c2) as max FROM T WHERE 
day='2025-03-01' GROUP BY day, hour",
+        Row(3, 2, 300L, 330L) :: Row(null, 1, 400L, 400L) :: Nil,
+        0
+      )
+      runAndCheckAggregate(
+        "SELECT c1, day, COUNT(*) FROM T GROUP BY c1, day ORDER BY c1, day",
+        Row(null, "2025-03-01", 1) :: Row(1, "2025-01-01", 1) :: Row(2, 
"2025-01-01", 1) :: Row(
+          3,
+          "2025-03-01",
+          2) :: Nil,
+        2
+      )
+    }
+  }
+
+  test("Push down aggregate - append table with dense statistics") {
+    withTable("T") {
+      spark.sql("""
+                  |CREATE TABLE T (c1 INT, c2 STRING, c3 DOUBLE, c4 DATE)
+                  |TBLPROPERTIES('metadata.stats-mode' = 'none')
+                  |""".stripMargin)
+      spark.sql(
+        s"""
+           |INSERT INTO T VALUES (1, 'xyz', 11.1, TO_DATE('2025-01-01', 
'yyyy-MM-dd')),
+           |(2, null, null, TO_DATE('2025-01-01', 'yyyy-MM-dd')), (3, 'abc', 
33.3, null),
+           |(3, 'abc', null, TO_DATE('2025-03-01', 'yyyy-MM-dd')), (null, 
'abc', 44.4, TO_DATE('2025-03-01', 'yyyy-MM-dd'))
+           |""".stripMargin)
+
+      val date1 = Date.valueOf("2025-01-01")
+      val date2 = Date.valueOf("2025-03-01")
+      runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(5) :: Nil, 0)
+
+      // for metadata.stats-mode = none, no available statistics.
+      runAndCheckAggregate("SELECT COUNT(*), MIN(c1), MAX(c1) FROM T", Row(5, 
1, 3) :: Nil, 2)
+      runAndCheckAggregate(
+        "SELECT COUNT(*), MIN(c2), MAX(c2) FROM T",
+        Row(5, "abc", "xyz") :: Nil,
+        2)
+      runAndCheckAggregate("SELECT COUNT(*), MIN(c3), MAX(c3) FROM T", Row(5, 
11.1, 44.4) :: Nil, 2)
       runAndCheckAggregate(
-        "SELECT day, COUNT(*) FROM T GROUP BY c1, day",
-        Row("x", 1) :: Row("x", 1) :: Row("x", 2) :: Row("y", 1) :: Nil,
+        "SELECT COUNT(*), MIN(c4), MAX(c4) FROM T",
+        Row(5, date1, date2) :: Nil,
         2)
     }
   }


Reply via email to