This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git

commit 931b4e775b403bc455edf6362f8ca039ada05106
Author: JingsongLi <jingsongl...@gmail.com>
AuthorDate: Wed Aug 27 15:18:36 2025 +0800

    [core] Refactor topN and minmax pushdown and should not work for pk table
---
 .../paimon/operation/AbstractFileStoreScan.java    |   6 +
 .../org/apache/paimon/operation/FileStoreScan.java |   2 +
 .../org/apache/paimon/table/source/DataSplit.java  |  18 --
 .../paimon/table/source/DataTableBatchScan.java    |  96 +++++++----
 .../source/PushDownUtils.java}                     |  31 +++-
 .../table/source/TopNDataSplitEvaluator.java       | 186 +++++++--------------
 .../table/source/snapshot/SnapshotReader.java      |   2 +
 .../table/source/snapshot/SnapshotReaderImpl.java  |   6 +
 .../apache/paimon/table/system/AuditLogTable.java  |   6 +
 .../apache/paimon/table/source/TableScanTest.java  |   4 +-
 .../paimon/spark/ColumnPruningAndPushDown.scala    |   8 +-
 .../apache/paimon/spark/PaimonScanBuilder.scala    |  32 ++--
 .../spark/aggregate/AggregatePushDownUtils.scala   | 131 ++++++++-------
 .../paimon/spark/sql/PushDownAggregatesTest.scala  |   2 +
 14 files changed, 261 insertions(+), 269 deletions(-)

diff --git 
a/paimon-core/src/main/java/org/apache/paimon/operation/AbstractFileStoreScan.java
 
b/paimon-core/src/main/java/org/apache/paimon/operation/AbstractFileStoreScan.java
index f8cd260f52..3ce8d46e40 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/operation/AbstractFileStoreScan.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/operation/AbstractFileStoreScan.java
@@ -224,6 +224,12 @@ public abstract class AbstractFileStoreScan implements 
FileStoreScan {
         return this;
     }
 
+    @Override
+    public FileStoreScan keepStats() {
+        this.dropStats = false;
+        return this;
+    }
+
     @Nullable
     @Override
     public Integer parallelism() {
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreScan.java 
b/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreScan.java
index 739e8b916f..5fa3127237 100644
--- a/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreScan.java
+++ b/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreScan.java
@@ -85,6 +85,8 @@ public interface FileStoreScan {
 
     FileStoreScan dropStats();
 
+    FileStoreScan keepStats();
+
     @Nullable
     Integer parallelism();
 
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java 
b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
index 276cf5fe50..8bdace010d 100644
--- a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
+++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
@@ -46,16 +46,13 @@ import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.util.ArrayList;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
 import java.util.OptionalLong;
-import java.util.Set;
 import java.util.stream.Collectors;
 
 import static org.apache.paimon.io.DataFilePathFactory.INDEX_PATH_SUFFIX;
-import static org.apache.paimon.utils.ListUtils.isNullOrEmpty;
 import static org.apache.paimon.utils.Preconditions.checkArgument;
 import static org.apache.paimon.utils.Preconditions.checkState;
 
@@ -158,21 +155,6 @@ public class DataSplit implements Split {
         return partialMergedRowCount();
     }
 
-    public boolean statsAvailable(Set<String> columns) {
-        if (isNullOrEmpty(columns)) {
-            return false;
-        }
-
-        return dataFiles.stream()
-                .map(DataFileMeta::valueStatsCols)
-                .allMatch(
-                        valueStatsCols ->
-                                // It means there are all column statistics 
when valueStatsCols ==
-                                // null
-                                valueStatsCols == null
-                                        || new 
HashSet<>(valueStatsCols).containsAll(columns));
-    }
-
     public Object minValue(int fieldIndex, DataField dataField, 
SimpleStatsEvolutions evolutions) {
         Object minValue = null;
         for (DataFileMeta dataFile : dataFiles) {
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java
 
b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java
index ec8eb8ee0b..98ac38725b 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java
@@ -21,6 +21,7 @@ package org.apache.paimon.table.source;
 import org.apache.paimon.CoreOptions;
 import org.apache.paimon.manifest.PartitionEntry;
 import org.apache.paimon.predicate.Predicate;
+import org.apache.paimon.predicate.SortValue;
 import org.apache.paimon.predicate.TopN;
 import org.apache.paimon.schema.SchemaManager;
 import org.apache.paimon.schema.TableSchema;
@@ -28,9 +29,13 @@ import org.apache.paimon.table.BucketMode;
 import org.apache.paimon.table.source.snapshot.SnapshotReader;
 import org.apache.paimon.table.source.snapshot.StartingScanner;
 import org.apache.paimon.table.source.snapshot.StartingScanner.ScannedResult;
+import org.apache.paimon.types.DataType;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Optional;
+
+import static org.apache.paimon.table.source.PushDownUtils.minmaxAvailable;
 
 /** {@link TableScan} implementation for batch planning. */
 public class DataTableBatchScan extends AbstractDataTableScan {
@@ -93,10 +98,15 @@ public class DataTableBatchScan extends 
AbstractDataTableScan {
 
         if (hasNext) {
             hasNext = false;
-            StartingScanner.Result result = 
startingScanner.scan(snapshotReader);
-            result = applyPushDownLimit(result);
-            result = applyPushDownTopN(result);
-            return DataFilePlan.fromResult(result);
+            Optional<StartingScanner.Result> pushed = applyPushDownLimit();
+            if (pushed.isPresent()) {
+                return DataFilePlan.fromResult(pushed.get());
+            }
+            pushed = applyPushDownTopN();
+            if (pushed.isPresent()) {
+                return DataFilePlan.fromResult(pushed.get());
+            }
+            return 
DataFilePlan.fromResult(startingScanner.scan(snapshotReader));
         } else {
             throw new EndOfScanException();
         }
@@ -110,51 +120,77 @@ public class DataTableBatchScan extends 
AbstractDataTableScan {
         return startingScanner.scanPartitions(snapshotReader);
     }
 
-    private StartingScanner.Result applyPushDownLimit(StartingScanner.Result 
result) {
-        if (pushDownLimit != null && result instanceof ScannedResult) {
-            long scannedRowCount = 0;
-            SnapshotReader.Plan plan = ((ScannedResult) result).plan();
-            List<DataSplit> splits = plan.dataSplits();
-            if (splits.isEmpty()) {
-                return result;
-            }
+    private Optional<StartingScanner.Result> applyPushDownLimit() {
+        if (pushDownLimit == null) {
+            return Optional.empty();
+        }
+
+        StartingScanner.Result result = startingScanner.scan(snapshotReader);
+        if (!(result instanceof ScannedResult)) {
+            return Optional.of(result);
+        }
+
+        long scannedRowCount = 0;
+        SnapshotReader.Plan plan = ((ScannedResult) result).plan();
+        List<DataSplit> splits = plan.dataSplits();
+        if (splits.isEmpty()) {
+            return Optional.of(result);
+        }
 
-            List<Split> limitedSplits = new ArrayList<>();
-            for (DataSplit dataSplit : splits) {
-                if (dataSplit.rawConvertible()) {
-                    long partialMergedRowCount = 
dataSplit.partialMergedRowCount();
-                    limitedSplits.add(dataSplit);
-                    scannedRowCount += partialMergedRowCount;
-                    if (scannedRowCount >= pushDownLimit) {
-                        SnapshotReader.Plan newPlan =
-                                new PlanImpl(plan.watermark(), 
plan.snapshotId(), limitedSplits);
-                        return new ScannedResult(newPlan);
-                    }
+        List<Split> limitedSplits = new ArrayList<>();
+        for (DataSplit dataSplit : splits) {
+            if (dataSplit.rawConvertible()) {
+                long partialMergedRowCount = dataSplit.partialMergedRowCount();
+                limitedSplits.add(dataSplit);
+                scannedRowCount += partialMergedRowCount;
+                if (scannedRowCount >= pushDownLimit) {
+                    SnapshotReader.Plan newPlan =
+                            new PlanImpl(plan.watermark(), plan.snapshotId(), 
limitedSplits);
+                    return Optional.of(new ScannedResult(newPlan));
                 }
             }
         }
-        return result;
+        return Optional.of(result);
     }
 
-    private StartingScanner.Result applyPushDownTopN(StartingScanner.Result 
result) {
+    private Optional<StartingScanner.Result> applyPushDownTopN() {
         if (topN == null
                 || pushDownLimit != null
-                || !(result instanceof ScannedResult)
                 || !schema.primaryKeys().isEmpty()
                 || options().deletionVectorsEnabled()) {
-            return result;
+            return Optional.empty();
+        }
+
+        List<SortValue> orders = topN.orders();
+        if (orders.size() != 1) {
+            return Optional.empty();
+        }
+
+        if (topN.limit() > 100) {
+            return Optional.empty();
+        }
+
+        SortValue order = orders.get(0);
+        DataType type = order.field().type();
+        if (!minmaxAvailable(type)) {
+            return Optional.empty();
+        }
+
+        StartingScanner.Result result = 
startingScanner.scan(snapshotReader.keepStats());
+        if (!(result instanceof ScannedResult)) {
+            return Optional.of(result);
         }
 
         SnapshotReader.Plan plan = ((ScannedResult) result).plan();
         List<DataSplit> splits = plan.dataSplits();
         if (splits.isEmpty()) {
-            return result;
+            return Optional.of(result);
         }
 
         TopNDataSplitEvaluator evaluator = new TopNDataSplitEvaluator(schema, 
schemaManager);
-        List<Split> topNSplits = new ArrayList<>(evaluator.evaluate(topN, 
splits));
+        List<Split> topNSplits = new ArrayList<>(evaluator.evaluate(order, 
topN.limit(), splits));
         SnapshotReader.Plan newPlan = new PlanImpl(plan.watermark(), 
plan.snapshotId(), topNSplits);
-        return new ScannedResult(newPlan);
+        return Optional.of(new ScannedResult(newPlan));
     }
 
     @Override
diff --git a/paimon-core/src/main/java/org/apache/paimon/stats/StatsUtils.java 
b/paimon-core/src/main/java/org/apache/paimon/table/source/PushDownUtils.java
similarity index 68%
rename from paimon-core/src/main/java/org/apache/paimon/stats/StatsUtils.java
rename to 
paimon-core/src/main/java/org/apache/paimon/table/source/PushDownUtils.java
index ec52cda0d9..833d6422fb 100644
--- a/paimon-core/src/main/java/org/apache/paimon/stats/StatsUtils.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/table/source/PushDownUtils.java
@@ -16,8 +16,9 @@
  * limitations under the License.
  */
 
-package org.apache.paimon.stats;
+package org.apache.paimon.table.source;
 
+import org.apache.paimon.io.DataFileMeta;
 import org.apache.paimon.types.BigIntType;
 import org.apache.paimon.types.BooleanType;
 import org.apache.paimon.types.DataType;
@@ -28,8 +29,13 @@ import org.apache.paimon.types.IntType;
 import org.apache.paimon.types.SmallIntType;
 import org.apache.paimon.types.TinyIntType;
 
-/** Utils for Stats. */
-public class StatsUtils {
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.apache.paimon.utils.ListUtils.isNullOrEmpty;
+
+/** Utils for pushing downs. */
+public class PushDownUtils {
 
     public static boolean minmaxAvailable(DataType type) {
         // not push down complex type
@@ -48,4 +54,23 @@ public class StatsUtils {
                 || type instanceof DoubleType
                 || type instanceof DateType;
     }
+
+    public static boolean minmaxAvailable(DataSplit split, Set<String> 
columns) {
+        if (isNullOrEmpty(columns)) {
+            return false;
+        }
+
+        if (!split.rawConvertible()) {
+            return false;
+        }
+
+        return split.dataFiles().stream()
+                .map(DataFileMeta::valueStatsCols)
+                .allMatch(
+                        valueStatsCols ->
+                                // It means there are all column statistics 
when valueStatsCols ==
+                                // null
+                                valueStatsCols == null
+                                        || new 
HashSet<>(valueStatsCols).containsAll(columns));
+    }
 }
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/source/TopNDataSplitEvaluator.java
 
b/paimon-core/src/main/java/org/apache/paimon/table/source/TopNDataSplitEvaluator.java
index 87c6b23cfd..68deb5830b 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/table/source/TopNDataSplitEvaluator.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/table/source/TopNDataSplitEvaluator.java
@@ -19,15 +19,12 @@
 package org.apache.paimon.table.source;
 
 import org.apache.paimon.predicate.CompareUtils;
-import org.apache.paimon.predicate.FieldRef;
 import org.apache.paimon.predicate.SortValue;
-import org.apache.paimon.predicate.TopN;
 import org.apache.paimon.schema.SchemaManager;
 import org.apache.paimon.schema.TableSchema;
 import org.apache.paimon.stats.SimpleStatsEvolutions;
 import org.apache.paimon.types.DataField;
 import org.apache.paimon.types.DataType;
-import org.apache.paimon.utils.Pair;
 
 import java.util.ArrayList;
 import java.util.Collections;
@@ -35,14 +32,11 @@ import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import java.util.stream.Collectors;
 
 import static org.apache.paimon.predicate.SortValue.NullOrdering.NULLS_FIRST;
-import static org.apache.paimon.predicate.SortValue.NullOrdering.NULLS_LAST;
 import static org.apache.paimon.predicate.SortValue.SortDirection.ASCENDING;
-import static org.apache.paimon.predicate.SortValue.SortDirection.DESCENDING;
-import static org.apache.paimon.stats.StatsUtils.minmaxAvailable;
+import static org.apache.paimon.table.source.PushDownUtils.minmaxAvailable;
 
 /** Evaluate DataSplit TopN result. */
 public class TopNDataSplitEvaluator {
@@ -57,47 +51,25 @@ public class TopNDataSplitEvaluator {
         this.schemaManager = schemaManager;
     }
 
-    public List<DataSplit> evaluate(TopN topN, List<DataSplit> splits) {
-        // todo: we can support all the sort columns.
-        List<SortValue> orders = topN.orders();
-        if (orders.size() != 1) {
+    public List<DataSplit> evaluate(SortValue order, int limit, 
List<DataSplit> splits) {
+        if (limit > splits.size()) {
             return splits;
         }
-
-        int limit = topN.limit();
-        if (limit >= splits.size()) {
-            return splits;
-        }
-
-        SortValue order = orders.get(0);
-        DataType type = order.field().type();
-        if (!minmaxAvailable(type)) {
-            return splits;
-        }
-
         return getTopNSplits(order, limit, splits);
     }
 
     private List<DataSplit> getTopNSplits(SortValue order, int limit, 
List<DataSplit> splits) {
-        FieldRef ref = order.field();
-        SortValue.SortDirection direction = order.direction();
-        SortValue.NullOrdering nullOrdering = order.nullOrdering();
-
-        int index = ref.index();
+        int index = order.field().index();
         DataField field = schema.fields().get(index);
         SimpleStatsEvolutions evolutions =
                 new SimpleStatsEvolutions((id) -> 
scanTableSchema(id).fields(), schema.id());
 
         // extract the stats
         List<DataSplit> results = new ArrayList<>();
-        List<Pair<Stats, DataSplit>> pairs = new ArrayList<>();
+        List<RichSplit> richSplits = new ArrayList<>();
         for (DataSplit split : splits) {
-            if (!split.rawConvertible()) {
-                return splits;
-            }
-
-            Set<String> cols = Collections.singleton(field.name());
-            if (!split.statsAvailable(cols)) {
+            if (!minmaxAvailable(split, Collections.singleton(field.name()))) {
+                // unknown split, read it
                 results.add(split);
                 continue;
             }
@@ -105,109 +77,66 @@ public class TopNDataSplitEvaluator {
             Object min = split.minValue(index, field, evolutions);
             Object max = split.maxValue(index, field, evolutions);
             Long nullCount = split.nullCount(index, evolutions);
-            Stats stats = new Stats(min, max, nullCount);
-            pairs.add(Pair.of(stats, split));
+            richSplits.add(new RichSplit(split, min, max, nullCount));
         }
 
         // pick the TopN splits
-        if (NULLS_FIRST.equals(nullOrdering)) {
-            results.addAll(pickNullFirstSplits(pairs, ref, direction, limit));
-        } else if (NULLS_LAST.equals(nullOrdering)) {
-            results.addAll(pickNullLastSplits(pairs, ref, direction, limit));
-        } else {
-            return splits;
-        }
-
+        boolean nullFirst = NULLS_FIRST.equals(order.nullOrdering());
+        boolean ascending = ASCENDING.equals(order.direction());
+        results.addAll(pickTopNSplits(richSplits, field.type(), ascending, 
nullFirst, limit));
         return results;
     }
 
-    private List<DataSplit> pickNullFirstSplits(
-            List<Pair<Stats, DataSplit>> pairs,
-            FieldRef field,
-            SortValue.SortDirection direction,
+    private List<DataSplit> pickTopNSplits(
+            List<RichSplit> splits,
+            DataType fieldType,
+            boolean ascending,
+            boolean nullFirst,
             int limit) {
-        Comparator<Pair<Stats, DataSplit>> comparator;
-        if (ASCENDING.equals(direction)) {
-            comparator =
-                    (x, y) -> {
-                        Stats left = x.getKey();
-                        Stats right = y.getKey();
-                        int result = nullsFirst(left.nullCount, 
right.nullCount);
-                        if (result == 0) {
-                            result = asc(field, left.min, right.min);
-                        }
-                        return result;
-                    };
-        } else if (DESCENDING.equals(direction)) {
+        Comparator<RichSplit> comparator;
+        if (ascending) {
             comparator =
                     (x, y) -> {
-                        Stats left = x.getKey();
-                        Stats right = y.getKey();
-                        int result = nullsFirst(left.nullCount, 
right.nullCount);
-                        if (result == 0) {
-                            result = desc(field, left.max, right.max);
+                        int result;
+                        if (nullFirst) {
+                            result = nullsFirstCompare(x.nullCount, 
y.nullCount);
+                            if (result == 0) {
+                                result = ascCompare(fieldType, x.min, y.min);
+                            }
+                        } else {
+                            result = ascCompare(fieldType, x.min, y.min);
+                            if (result == 0) {
+                                result = nullsLastCompare(x.nullCount, 
y.nullCount);
+                            }
                         }
                         return result;
                     };
         } else {
-            return 
pairs.stream().map(Pair::getValue).collect(Collectors.toList());
-        }
-        pairs.sort(comparator);
-
-        long scanned = 0;
-        List<DataSplit> splits = new ArrayList<>();
-        for (Pair<Stats, DataSplit> pair : pairs) {
-            Stats stats = pair.getKey();
-            DataSplit split = pair.getValue();
-            splits.add(split);
-            scanned += Math.max(stats.nullCount, 1);
-            if (scanned >= limit) {
-                break;
-            }
-        }
-        return splits;
-    }
-
-    private List<DataSplit> pickNullLastSplits(
-            List<Pair<Stats, DataSplit>> pairs,
-            FieldRef field,
-            SortValue.SortDirection direction,
-            int limit) {
-        Comparator<Pair<Stats, DataSplit>> comparator;
-        if (ASCENDING.equals(direction)) {
-            comparator =
-                    (x, y) -> {
-                        Stats left = x.getKey();
-                        Stats right = y.getKey();
-                        int result = asc(field, left.min, right.min);
-                        if (result == 0) {
-                            result = nullsLast(left.nullCount, 
right.nullCount);
-                        }
-                        return result;
-                    };
-        } else if (DESCENDING.equals(direction)) {
             comparator =
                     (x, y) -> {
-                        Stats left = x.getKey();
-                        Stats right = y.getKey();
-                        int result = desc(field, left.max, right.max);
-                        if (result == 0) {
-                            result = nullsLast(left.nullCount, 
right.nullCount);
+                        int result;
+                        if (nullFirst) {
+                            result = nullsFirstCompare(x.nullCount, 
y.nullCount);
+                            if (result == 0) {
+                                result = descCompare(fieldType, x.max, y.max);
+                            }
+                        } else {
+                            result = descCompare(fieldType, x.max, y.max);
+                            if (result == 0) {
+                                result = nullsLastCompare(x.nullCount, 
y.nullCount);
+                            }
                         }
                         return result;
                     };
-        } else {
-            return 
pairs.stream().map(Pair::getValue).collect(Collectors.toList());
         }
-
-        return pairs.stream()
+        return splits.stream()
                 .sorted(comparator)
-                .map(Pair::getValue)
+                .map(RichSplit::split)
                 .limit(limit)
                 .collect(Collectors.toList());
     }
 
-    private int nullsFirst(Long left, Long right) {
+    private int nullsFirstCompare(Long left, Long right) {
         if (left == null) {
             return -1;
         } else if (right == null) {
@@ -217,7 +146,7 @@ public class TopNDataSplitEvaluator {
         }
     }
 
-    private int nullsLast(Long left, Long right) {
+    private int nullsLastCompare(Long left, Long right) {
         if (left == null) {
             return -1;
         } else if (right == null) {
@@ -227,23 +156,23 @@ public class TopNDataSplitEvaluator {
         }
     }
 
-    private int asc(FieldRef field, Object left, Object right) {
+    private int ascCompare(DataType type, Object left, Object right) {
         if (left == null) {
             return -1;
         } else if (right == null) {
             return 1;
         } else {
-            return CompareUtils.compareLiteral(field.type(), left, right);
+            return CompareUtils.compareLiteral(type, left, right);
         }
     }
 
-    private int desc(FieldRef field, Object left, Object right) {
+    private int descCompare(DataType type, Object left, Object right) {
         if (left == null) {
             return -1;
         } else if (right == null) {
             return 1;
         } else {
-            return -CompareUtils.compareLiteral(field.type(), left, right);
+            return -CompareUtils.compareLiteral(type, left, right);
         }
     }
 
@@ -252,16 +181,23 @@ public class TopNDataSplitEvaluator {
                 id, key -> key == schema.id() ? schema : 
schemaManager.schema(id));
     }
 
-    /** The DataSplit's stats. */
-    private static class Stats {
-        Object min;
-        Object max;
-        Long nullCount;
+    /** DataSplit with stats. */
+    private static class RichSplit {
+
+        private final DataSplit split;
+        private final Object min;
+        private final Object max;
+        private final Long nullCount;
 
-        public Stats(Object min, Object max, Long nullCount) {
+        private RichSplit(DataSplit split, Object min, Object max, Long 
nullCount) {
+            this.split = split;
             this.min = min;
             this.max = max;
             this.nullCount = nullCount;
         }
+
+        private DataSplit split() {
+            return split;
+        }
     }
 }
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReader.java
 
b/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReader.java
index 79a1c94c32..82112f5fd4 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReader.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReader.java
@@ -101,6 +101,8 @@ public interface SnapshotReader {
 
     SnapshotReader dropStats();
 
+    SnapshotReader keepStats();
+
     SnapshotReader withShard(int indexOfThisSubtask, int 
numberOfParallelSubtasks);
 
     SnapshotReader withMetricRegistry(MetricRegistry registry);
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java
 
b/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java
index a51c1516cb..d979b91b2a 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java
@@ -311,6 +311,12 @@ public class SnapshotReaderImpl implements SnapshotReader {
         return this;
     }
 
+    @Override
+    public SnapshotReader keepStats() {
+        scan.keepStats();
+        return this;
+    }
+
     @Override
     public SnapshotReader withShard(int indexOfThisSubtask, int 
numberOfParallelSubtasks) {
         if (splitGenerator.alwaysRawConvertible()) {
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/system/AuditLogTable.java 
b/paimon-core/src/main/java/org/apache/paimon/table/system/AuditLogTable.java
index b9afacec2b..ab3b13a5ab 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/table/system/AuditLogTable.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/table/system/AuditLogTable.java
@@ -400,6 +400,12 @@ public class AuditLogTable implements DataTable, 
ReadonlyTable {
             return this;
         }
 
+        @Override
+        public SnapshotReader keepStats() {
+            wrapped.keepStats();
+            return this;
+        }
+
         @Override
         public SnapshotReader withShard(int indexOfThisSubtask, int 
numberOfParallelSubtasks) {
             wrapped.withShard(indexOfThisSubtask, numberOfParallelSubtasks);
diff --git 
a/paimon-core/src/test/java/org/apache/paimon/table/source/TableScanTest.java 
b/paimon-core/src/test/java/org/apache/paimon/table/source/TableScanTest.java
index 0fd381f918..c0295a3a5a 100644
--- 
a/paimon-core/src/test/java/org/apache/paimon/table/source/TableScanTest.java
+++ 
b/paimon-core/src/test/java/org/apache/paimon/table/source/TableScanTest.java
@@ -186,7 +186,7 @@ public class TableScanTest extends ScannerTestBase {
         TableScan.Plan plan4 =
                 table.newScan().withTopN(new TopN(ref, ASCENDING, NULLS_FIRST, 
5)).plan();
         List<Split> splits4 = plan4.splits();
-        assertThat(splits4.size()).isEqualTo(5);
+        assertThat(splits4.size()).isEqualTo(7);
         assertThat(((DataSplit) splits4.get(0)).minValue(field.id(), field, 
evolutions))
                 .isEqualTo(null);
         assertThat(((DataSplit) splits4.get(1)).minValue(field.id(), field, 
evolutions))
@@ -249,7 +249,7 @@ public class TableScanTest extends ScannerTestBase {
         TableScan.Plan plan8 =
                 table.newScan().withTopN(new TopN(ref, DESCENDING, 
NULLS_FIRST, 5)).plan();
         List<Split> splits8 = plan8.splits();
-        assertThat(splits8.size()).isEqualTo(5);
+        assertThat(splits8.size()).isEqualTo(7);
         assertThat(((DataSplit) splits8.get(0)).maxValue(field.id(), field, 
evolutions))
                 .isEqualTo(null);
         assertThat(((DataSplit) splits8.get(1)).maxValue(field.id(), field, 
evolutions))
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala
index 1e3bb846cf..ce1fe2e25f 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala
@@ -73,13 +73,7 @@ trait ColumnPruningAndPushDown extends Scan with Logging {
     }
     pushDownLimit.foreach(_readBuilder.withLimit)
     pushDownTopN.foreach(_readBuilder.withTopN)
-
-    // when TopN is not empty, we need the stats to pick the TopN DataSplits
-    if (pushDownTopN.nonEmpty) {
-      _readBuilder
-    } else {
-      _readBuilder.dropStats()
-    }
+    _readBuilder.dropStats()
   }
 
   final def metadataColumns: Seq[PaimonMetadataColumn] = {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
index aa50f76043..f021feab95 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
@@ -20,9 +20,8 @@ package org.apache.paimon.spark
 
 import org.apache.paimon.predicate._
 import org.apache.paimon.predicate.SortValue.{NullOrdering, SortDirection}
-import org.apache.paimon.spark.aggregate.{AggregatePushDownUtils, 
LocalAggregator}
-import org.apache.paimon.table.{AppendOnlyFileStoreTable, FileStoreTable, 
InnerTable}
-import org.apache.paimon.table.source.DataSplit
+import 
org.apache.paimon.spark.aggregate.AggregatePushDownUtils.tryPushdownAggregation
+import org.apache.paimon.table.{FileStoreTable, InnerTable}
 
 import org.apache.spark.sql.PaimonUtils
 import org.apache.spark.sql.connector.expressions
@@ -168,25 +167,14 @@ class PaimonScanBuilder(table: InnerTable)
       val pushedPartitionPredicate = 
PredicateBuilder.and(pushedPaimonPredicates.toList.asJava)
       readBuilder.withFilter(pushedPartitionPredicate)
     }
-    val dataSplits = if 
(AggregatePushDownUtils.hasMinMaxAggregation(aggregation)) {
-      
readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
-    } else {
-      
readBuilder.dropStats().newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
-    }
-    if (AggregatePushDownUtils.canPushdownAggregation(table, aggregation, 
dataSplits.toSeq)) {
-      val aggregator = new LocalAggregator(table.asInstanceOf[FileStoreTable])
-      aggregator.initialize(aggregation)
-      dataSplits.foreach(aggregator.update)
-      localScan = Some(
-        PaimonLocalScan(
-          aggregator.result(),
-          aggregator.resultSchema(),
-          table,
-          pushedPaimonPredicates)
-      )
-      true
-    } else {
-      false
+
+    tryPushdownAggregation(table.asInstanceOf[FileStoreTable], aggregation, 
readBuilder) match {
+      case Some(agg) =>
+        localScan = Some(
+          PaimonLocalScan(agg.result(), agg.resultSchema(), table, 
pushedPaimonPredicates)
+        )
+        true
+      case None => false
     }
   }
 
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala
index fd1ee59b5a..e07db9515e 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala
@@ -18,91 +18,98 @@
 
 package org.apache.paimon.spark.aggregate
 
-import org.apache.paimon.stats.StatsUtils.minmaxAvailable
-import org.apache.paimon.table.Table
-import org.apache.paimon.table.source.DataSplit
+import org.apache.paimon.table.FileStoreTable
+import org.apache.paimon.table.source.{DataSplit, ReadBuilder}
+import org.apache.paimon.table.source.PushDownUtils.minmaxAvailable
 import org.apache.paimon.types._
 
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, 
Aggregation, CountStar, Max, Min}
-import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
+import org.apache.spark.sql.connector.expressions.Expression
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, 
CountStar, Max, Min}
+import 
org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils.extractV2Column
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 
 object AggregatePushDownUtils {
 
-  def canPushdownAggregation(
-      table: Table,
+  def tryPushdownAggregation(
+      table: FileStoreTable,
       aggregation: Aggregation,
-      dataSplits: Seq[DataSplit]): Boolean = {
+      readBuilder: ReadBuilder): Option[LocalAggregator] = {
+    val options = table.coreOptions()
+    val rowType = table.rowType
+    val partitionKeys = table.partitionKeys()
 
-    var hasMinMax = false
-    val minmaxColumns = mutable.HashSet.empty[String]
-    var hasCount = false
-
-    def getDataFieldForCol(colName: String): DataField = {
-      table.rowType.getField(colName)
+    aggregation.groupByExpressions.map(extractV2Column).foreach {
+      colName =>
+        // don't push down if the group by columns are not the same as the 
partition columns (orders
+        // doesn't matter because reorder can be done at data source layer)
+        if (colName.isEmpty || !partitionKeys.contains(colName.get)) return 
None
     }
 
-    def isPartitionCol(colName: String) = {
-      table.partitionKeys.contains(colName)
+    val splits = extractMinMaxColumns(rowType, aggregation) match {
+      case Some(columns) =>
+        if (columns.isEmpty) {
+          generateSplits(readBuilder.dropStats())
+        } else {
+          if (options.deletionVectorsEnabled() || 
!table.primaryKeys().isEmpty) {
+            return None
+          }
+          val splits = generateSplits(readBuilder)
+          if (!splits.forall(minmaxAvailable(_, columns.asJava))) {
+            return None
+          }
+          splits
+        }
+      case None => return None
     }
 
-    def processMinOrMax(agg: AggregateFunc): Boolean = {
-      val columnName = agg match {
-        case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined =>
-          V2ColumnUtils.extractV2Column(max.column).get
-        case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined =>
-          V2ColumnUtils.extractV2Column(min.column).get
-        case _ => return false
-      }
-
-      val dataField = getDataFieldForCol(columnName)
-
-      if (minmaxAvailable(dataField.`type`())) {
-        minmaxColumns.add(columnName)
-        hasMinMax = true
-        true
-      } else {
-        false
-      }
+    if (!splits.forall(_.mergedRowCountAvailable())) {
+      return None
     }
 
-    aggregation.groupByExpressions.map(V2ColumnUtils.extractV2Column).foreach {
-      colName =>
-        // don't push down if the group by columns are not the same as the 
partition columns (orders
-        // doesn't matter because reorder can be done at data source layer)
-        if (colName.isEmpty || !isPartitionCol(colName.get)) return false
-    }
+    val aggregator = new LocalAggregator(table)
+    aggregator.initialize(aggregation)
+    splits.foreach(aggregator.update)
+    Option(aggregator)
+  }
+
+  private def generateSplits(readBuilder: ReadBuilder): mutable.Seq[DataSplit] 
= {
+    
readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
+  }
 
+  private def extractMinMaxColumns(
+      rowType: RowType,
+      aggregation: Aggregation): Option[Set[String]] = {
+    val columns = mutable.HashSet.empty[String]
     aggregation.aggregateExpressions.foreach {
-      case max: Max =>
-        if (!processMinOrMax(max)) return false
-      case min: Min =>
-        if (!processMinOrMax(min)) return false
+      case e if e.isInstanceOf[Min] || e.isInstanceOf[Max] =>
+        extractMinMaxColumn(rowType, e) match {
+          case Some(colName) => columns.add(colName)
+          case None => return None
+        }
       case _: CountStar =>
-        hasCount = true
-      case _ =>
-        return false
+      case _ => return None
     }
+    Option(columns.toSet)
+  }
 
-    if (hasMinMax) {
-      dataSplits.forall(_.statsAvailable(minmaxColumns.toSet.asJava))
-    } else if (hasCount) {
-      dataSplits.forall(_.mergedRowCountAvailable())
-    } else {
-      true
+  private def extractMinMaxColumn(rowType: RowType, minOrMax: Expression): 
Option[String] = {
+    val column = minOrMax match {
+      case min: Min => min.column()
+      case max: Max => max.column()
+    }
+    val extractColumn = extractV2Column(column)
+    if (extractColumn.isEmpty) {
+      return None
     }
-  }
 
-  def hasMinMaxAggregation(aggregation: Aggregation): Boolean = {
-    var hasMinMax = false;
-    aggregation.aggregateExpressions().foreach {
-      case _: Min | _: Max =>
-        hasMinMax = true
-      case _ =>
+    val columnName = extractColumn.get
+    val dataType = rowType.getField(columnName).`type`()
+    if (minmaxAvailable(dataType)) {
+      Option(columnName)
+    } else {
+      None
     }
-    hasMinMax
   }
-
 }
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
index 26c19ecc27..a60d88aef9 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
@@ -230,6 +230,8 @@ class PushDownAggregatesTest extends PaimonSparkTestBase 
with AdaptiveSparkPlanH
           spark.sql("INSERT INTO T VALUES(1, 'x_1')")
           if (deletionVectorsEnabled) {
             runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 0)
+            // should not push down min max for primary key table
+            runAndCheckAggregate("SELECT MIN(c1) FROM T", Row(1) :: Nil, 2)
           } else {
             runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 2)
           }


Reply via email to