This is an automated email from the ASF dual-hosted git repository. lzljs3620320 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/paimon.git
commit 931b4e775b403bc455edf6362f8ca039ada05106 Author: JingsongLi <jingsongl...@gmail.com> AuthorDate: Wed Aug 27 15:18:36 2025 +0800 [core] Refactor topN and minmax pushdown and should not work for pk table --- .../paimon/operation/AbstractFileStoreScan.java | 6 + .../org/apache/paimon/operation/FileStoreScan.java | 2 + .../org/apache/paimon/table/source/DataSplit.java | 18 -- .../paimon/table/source/DataTableBatchScan.java | 96 +++++++---- .../source/PushDownUtils.java} | 31 +++- .../table/source/TopNDataSplitEvaluator.java | 186 +++++++-------------- .../table/source/snapshot/SnapshotReader.java | 2 + .../table/source/snapshot/SnapshotReaderImpl.java | 6 + .../apache/paimon/table/system/AuditLogTable.java | 6 + .../apache/paimon/table/source/TableScanTest.java | 4 +- .../paimon/spark/ColumnPruningAndPushDown.scala | 8 +- .../apache/paimon/spark/PaimonScanBuilder.scala | 32 ++-- .../spark/aggregate/AggregatePushDownUtils.scala | 131 ++++++++------- .../paimon/spark/sql/PushDownAggregatesTest.scala | 2 + 14 files changed, 261 insertions(+), 269 deletions(-) diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/AbstractFileStoreScan.java b/paimon-core/src/main/java/org/apache/paimon/operation/AbstractFileStoreScan.java index f8cd260f52..3ce8d46e40 100644 --- a/paimon-core/src/main/java/org/apache/paimon/operation/AbstractFileStoreScan.java +++ b/paimon-core/src/main/java/org/apache/paimon/operation/AbstractFileStoreScan.java @@ -224,6 +224,12 @@ public abstract class AbstractFileStoreScan implements FileStoreScan { return this; } + @Override + public FileStoreScan keepStats() { + this.dropStats = false; + return this; + } + @Nullable @Override public Integer parallelism() { diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreScan.java b/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreScan.java index 739e8b916f..5fa3127237 100644 --- a/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreScan.java +++ b/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreScan.java @@ -85,6 +85,8 @@ public interface FileStoreScan { FileStoreScan dropStats(); + FileStoreScan keepStats(); + @Nullable Integer parallelism(); diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java index 276cf5fe50..8bdace010d 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java @@ -46,16 +46,13 @@ import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.OptionalLong; -import java.util.Set; import java.util.stream.Collectors; import static org.apache.paimon.io.DataFilePathFactory.INDEX_PATH_SUFFIX; -import static org.apache.paimon.utils.ListUtils.isNullOrEmpty; import static org.apache.paimon.utils.Preconditions.checkArgument; import static org.apache.paimon.utils.Preconditions.checkState; @@ -158,21 +155,6 @@ public class DataSplit implements Split { return partialMergedRowCount(); } - public boolean statsAvailable(Set<String> columns) { - if (isNullOrEmpty(columns)) { - return false; - } - - return dataFiles.stream() - .map(DataFileMeta::valueStatsCols) - .allMatch( - valueStatsCols -> - // It means there are all column statistics when valueStatsCols == - // null - valueStatsCols == null - || new HashSet<>(valueStatsCols).containsAll(columns)); - } - public Object minValue(int fieldIndex, DataField dataField, SimpleStatsEvolutions evolutions) { Object minValue = null; for (DataFileMeta dataFile : dataFiles) { diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java index ec8eb8ee0b..98ac38725b 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java @@ -21,6 +21,7 @@ package org.apache.paimon.table.source; import org.apache.paimon.CoreOptions; import org.apache.paimon.manifest.PartitionEntry; import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.predicate.SortValue; import org.apache.paimon.predicate.TopN; import org.apache.paimon.schema.SchemaManager; import org.apache.paimon.schema.TableSchema; @@ -28,9 +29,13 @@ import org.apache.paimon.table.BucketMode; import org.apache.paimon.table.source.snapshot.SnapshotReader; import org.apache.paimon.table.source.snapshot.StartingScanner; import org.apache.paimon.table.source.snapshot.StartingScanner.ScannedResult; +import org.apache.paimon.types.DataType; import java.util.ArrayList; import java.util.List; +import java.util.Optional; + +import static org.apache.paimon.table.source.PushDownUtils.minmaxAvailable; /** {@link TableScan} implementation for batch planning. */ public class DataTableBatchScan extends AbstractDataTableScan { @@ -93,10 +98,15 @@ public class DataTableBatchScan extends AbstractDataTableScan { if (hasNext) { hasNext = false; - StartingScanner.Result result = startingScanner.scan(snapshotReader); - result = applyPushDownLimit(result); - result = applyPushDownTopN(result); - return DataFilePlan.fromResult(result); + Optional<StartingScanner.Result> pushed = applyPushDownLimit(); + if (pushed.isPresent()) { + return DataFilePlan.fromResult(pushed.get()); + } + pushed = applyPushDownTopN(); + if (pushed.isPresent()) { + return DataFilePlan.fromResult(pushed.get()); + } + return DataFilePlan.fromResult(startingScanner.scan(snapshotReader)); } else { throw new EndOfScanException(); } @@ -110,51 +120,77 @@ public class DataTableBatchScan extends AbstractDataTableScan { return startingScanner.scanPartitions(snapshotReader); } - private StartingScanner.Result applyPushDownLimit(StartingScanner.Result result) { - if (pushDownLimit != null && result instanceof ScannedResult) { - long scannedRowCount = 0; - SnapshotReader.Plan plan = ((ScannedResult) result).plan(); - List<DataSplit> splits = plan.dataSplits(); - if (splits.isEmpty()) { - return result; - } + private Optional<StartingScanner.Result> applyPushDownLimit() { + if (pushDownLimit == null) { + return Optional.empty(); + } + + StartingScanner.Result result = startingScanner.scan(snapshotReader); + if (!(result instanceof ScannedResult)) { + return Optional.of(result); + } + + long scannedRowCount = 0; + SnapshotReader.Plan plan = ((ScannedResult) result).plan(); + List<DataSplit> splits = plan.dataSplits(); + if (splits.isEmpty()) { + return Optional.of(result); + } - List<Split> limitedSplits = new ArrayList<>(); - for (DataSplit dataSplit : splits) { - if (dataSplit.rawConvertible()) { - long partialMergedRowCount = dataSplit.partialMergedRowCount(); - limitedSplits.add(dataSplit); - scannedRowCount += partialMergedRowCount; - if (scannedRowCount >= pushDownLimit) { - SnapshotReader.Plan newPlan = - new PlanImpl(plan.watermark(), plan.snapshotId(), limitedSplits); - return new ScannedResult(newPlan); - } + List<Split> limitedSplits = new ArrayList<>(); + for (DataSplit dataSplit : splits) { + if (dataSplit.rawConvertible()) { + long partialMergedRowCount = dataSplit.partialMergedRowCount(); + limitedSplits.add(dataSplit); + scannedRowCount += partialMergedRowCount; + if (scannedRowCount >= pushDownLimit) { + SnapshotReader.Plan newPlan = + new PlanImpl(plan.watermark(), plan.snapshotId(), limitedSplits); + return Optional.of(new ScannedResult(newPlan)); } } } - return result; + return Optional.of(result); } - private StartingScanner.Result applyPushDownTopN(StartingScanner.Result result) { + private Optional<StartingScanner.Result> applyPushDownTopN() { if (topN == null || pushDownLimit != null - || !(result instanceof ScannedResult) || !schema.primaryKeys().isEmpty() || options().deletionVectorsEnabled()) { - return result; + return Optional.empty(); + } + + List<SortValue> orders = topN.orders(); + if (orders.size() != 1) { + return Optional.empty(); + } + + if (topN.limit() > 100) { + return Optional.empty(); + } + + SortValue order = orders.get(0); + DataType type = order.field().type(); + if (!minmaxAvailable(type)) { + return Optional.empty(); + } + + StartingScanner.Result result = startingScanner.scan(snapshotReader.keepStats()); + if (!(result instanceof ScannedResult)) { + return Optional.of(result); } SnapshotReader.Plan plan = ((ScannedResult) result).plan(); List<DataSplit> splits = plan.dataSplits(); if (splits.isEmpty()) { - return result; + return Optional.of(result); } TopNDataSplitEvaluator evaluator = new TopNDataSplitEvaluator(schema, schemaManager); - List<Split> topNSplits = new ArrayList<>(evaluator.evaluate(topN, splits)); + List<Split> topNSplits = new ArrayList<>(evaluator.evaluate(order, topN.limit(), splits)); SnapshotReader.Plan newPlan = new PlanImpl(plan.watermark(), plan.snapshotId(), topNSplits); - return new ScannedResult(newPlan); + return Optional.of(new ScannedResult(newPlan)); } @Override diff --git a/paimon-core/src/main/java/org/apache/paimon/stats/StatsUtils.java b/paimon-core/src/main/java/org/apache/paimon/table/source/PushDownUtils.java similarity index 68% rename from paimon-core/src/main/java/org/apache/paimon/stats/StatsUtils.java rename to paimon-core/src/main/java/org/apache/paimon/table/source/PushDownUtils.java index ec52cda0d9..833d6422fb 100644 --- a/paimon-core/src/main/java/org/apache/paimon/stats/StatsUtils.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/PushDownUtils.java @@ -16,8 +16,9 @@ * limitations under the License. */ -package org.apache.paimon.stats; +package org.apache.paimon.table.source; +import org.apache.paimon.io.DataFileMeta; import org.apache.paimon.types.BigIntType; import org.apache.paimon.types.BooleanType; import org.apache.paimon.types.DataType; @@ -28,8 +29,13 @@ import org.apache.paimon.types.IntType; import org.apache.paimon.types.SmallIntType; import org.apache.paimon.types.TinyIntType; -/** Utils for Stats. */ -public class StatsUtils { +import java.util.HashSet; +import java.util.Set; + +import static org.apache.paimon.utils.ListUtils.isNullOrEmpty; + +/** Utils for pushing downs. */ +public class PushDownUtils { public static boolean minmaxAvailable(DataType type) { // not push down complex type @@ -48,4 +54,23 @@ public class StatsUtils { || type instanceof DoubleType || type instanceof DateType; } + + public static boolean minmaxAvailable(DataSplit split, Set<String> columns) { + if (isNullOrEmpty(columns)) { + return false; + } + + if (!split.rawConvertible()) { + return false; + } + + return split.dataFiles().stream() + .map(DataFileMeta::valueStatsCols) + .allMatch( + valueStatsCols -> + // It means there are all column statistics when valueStatsCols == + // null + valueStatsCols == null + || new HashSet<>(valueStatsCols).containsAll(columns)); + } } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/TopNDataSplitEvaluator.java b/paimon-core/src/main/java/org/apache/paimon/table/source/TopNDataSplitEvaluator.java index 87c6b23cfd..68deb5830b 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/TopNDataSplitEvaluator.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/TopNDataSplitEvaluator.java @@ -19,15 +19,12 @@ package org.apache.paimon.table.source; import org.apache.paimon.predicate.CompareUtils; -import org.apache.paimon.predicate.FieldRef; import org.apache.paimon.predicate.SortValue; -import org.apache.paimon.predicate.TopN; import org.apache.paimon.schema.SchemaManager; import org.apache.paimon.schema.TableSchema; import org.apache.paimon.stats.SimpleStatsEvolutions; import org.apache.paimon.types.DataField; import org.apache.paimon.types.DataType; -import org.apache.paimon.utils.Pair; import java.util.ArrayList; import java.util.Collections; @@ -35,14 +32,11 @@ import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; import static org.apache.paimon.predicate.SortValue.NullOrdering.NULLS_FIRST; -import static org.apache.paimon.predicate.SortValue.NullOrdering.NULLS_LAST; import static org.apache.paimon.predicate.SortValue.SortDirection.ASCENDING; -import static org.apache.paimon.predicate.SortValue.SortDirection.DESCENDING; -import static org.apache.paimon.stats.StatsUtils.minmaxAvailable; +import static org.apache.paimon.table.source.PushDownUtils.minmaxAvailable; /** Evaluate DataSplit TopN result. */ public class TopNDataSplitEvaluator { @@ -57,47 +51,25 @@ public class TopNDataSplitEvaluator { this.schemaManager = schemaManager; } - public List<DataSplit> evaluate(TopN topN, List<DataSplit> splits) { - // todo: we can support all the sort columns. - List<SortValue> orders = topN.orders(); - if (orders.size() != 1) { + public List<DataSplit> evaluate(SortValue order, int limit, List<DataSplit> splits) { + if (limit > splits.size()) { return splits; } - - int limit = topN.limit(); - if (limit >= splits.size()) { - return splits; - } - - SortValue order = orders.get(0); - DataType type = order.field().type(); - if (!minmaxAvailable(type)) { - return splits; - } - return getTopNSplits(order, limit, splits); } private List<DataSplit> getTopNSplits(SortValue order, int limit, List<DataSplit> splits) { - FieldRef ref = order.field(); - SortValue.SortDirection direction = order.direction(); - SortValue.NullOrdering nullOrdering = order.nullOrdering(); - - int index = ref.index(); + int index = order.field().index(); DataField field = schema.fields().get(index); SimpleStatsEvolutions evolutions = new SimpleStatsEvolutions((id) -> scanTableSchema(id).fields(), schema.id()); // extract the stats List<DataSplit> results = new ArrayList<>(); - List<Pair<Stats, DataSplit>> pairs = new ArrayList<>(); + List<RichSplit> richSplits = new ArrayList<>(); for (DataSplit split : splits) { - if (!split.rawConvertible()) { - return splits; - } - - Set<String> cols = Collections.singleton(field.name()); - if (!split.statsAvailable(cols)) { + if (!minmaxAvailable(split, Collections.singleton(field.name()))) { + // unknown split, read it results.add(split); continue; } @@ -105,109 +77,66 @@ public class TopNDataSplitEvaluator { Object min = split.minValue(index, field, evolutions); Object max = split.maxValue(index, field, evolutions); Long nullCount = split.nullCount(index, evolutions); - Stats stats = new Stats(min, max, nullCount); - pairs.add(Pair.of(stats, split)); + richSplits.add(new RichSplit(split, min, max, nullCount)); } // pick the TopN splits - if (NULLS_FIRST.equals(nullOrdering)) { - results.addAll(pickNullFirstSplits(pairs, ref, direction, limit)); - } else if (NULLS_LAST.equals(nullOrdering)) { - results.addAll(pickNullLastSplits(pairs, ref, direction, limit)); - } else { - return splits; - } - + boolean nullFirst = NULLS_FIRST.equals(order.nullOrdering()); + boolean ascending = ASCENDING.equals(order.direction()); + results.addAll(pickTopNSplits(richSplits, field.type(), ascending, nullFirst, limit)); return results; } - private List<DataSplit> pickNullFirstSplits( - List<Pair<Stats, DataSplit>> pairs, - FieldRef field, - SortValue.SortDirection direction, + private List<DataSplit> pickTopNSplits( + List<RichSplit> splits, + DataType fieldType, + boolean ascending, + boolean nullFirst, int limit) { - Comparator<Pair<Stats, DataSplit>> comparator; - if (ASCENDING.equals(direction)) { - comparator = - (x, y) -> { - Stats left = x.getKey(); - Stats right = y.getKey(); - int result = nullsFirst(left.nullCount, right.nullCount); - if (result == 0) { - result = asc(field, left.min, right.min); - } - return result; - }; - } else if (DESCENDING.equals(direction)) { + Comparator<RichSplit> comparator; + if (ascending) { comparator = (x, y) -> { - Stats left = x.getKey(); - Stats right = y.getKey(); - int result = nullsFirst(left.nullCount, right.nullCount); - if (result == 0) { - result = desc(field, left.max, right.max); + int result; + if (nullFirst) { + result = nullsFirstCompare(x.nullCount, y.nullCount); + if (result == 0) { + result = ascCompare(fieldType, x.min, y.min); + } + } else { + result = ascCompare(fieldType, x.min, y.min); + if (result == 0) { + result = nullsLastCompare(x.nullCount, y.nullCount); + } } return result; }; } else { - return pairs.stream().map(Pair::getValue).collect(Collectors.toList()); - } - pairs.sort(comparator); - - long scanned = 0; - List<DataSplit> splits = new ArrayList<>(); - for (Pair<Stats, DataSplit> pair : pairs) { - Stats stats = pair.getKey(); - DataSplit split = pair.getValue(); - splits.add(split); - scanned += Math.max(stats.nullCount, 1); - if (scanned >= limit) { - break; - } - } - return splits; - } - - private List<DataSplit> pickNullLastSplits( - List<Pair<Stats, DataSplit>> pairs, - FieldRef field, - SortValue.SortDirection direction, - int limit) { - Comparator<Pair<Stats, DataSplit>> comparator; - if (ASCENDING.equals(direction)) { - comparator = - (x, y) -> { - Stats left = x.getKey(); - Stats right = y.getKey(); - int result = asc(field, left.min, right.min); - if (result == 0) { - result = nullsLast(left.nullCount, right.nullCount); - } - return result; - }; - } else if (DESCENDING.equals(direction)) { comparator = (x, y) -> { - Stats left = x.getKey(); - Stats right = y.getKey(); - int result = desc(field, left.max, right.max); - if (result == 0) { - result = nullsLast(left.nullCount, right.nullCount); + int result; + if (nullFirst) { + result = nullsFirstCompare(x.nullCount, y.nullCount); + if (result == 0) { + result = descCompare(fieldType, x.max, y.max); + } + } else { + result = descCompare(fieldType, x.max, y.max); + if (result == 0) { + result = nullsLastCompare(x.nullCount, y.nullCount); + } } return result; }; - } else { - return pairs.stream().map(Pair::getValue).collect(Collectors.toList()); } - - return pairs.stream() + return splits.stream() .sorted(comparator) - .map(Pair::getValue) + .map(RichSplit::split) .limit(limit) .collect(Collectors.toList()); } - private int nullsFirst(Long left, Long right) { + private int nullsFirstCompare(Long left, Long right) { if (left == null) { return -1; } else if (right == null) { @@ -217,7 +146,7 @@ public class TopNDataSplitEvaluator { } } - private int nullsLast(Long left, Long right) { + private int nullsLastCompare(Long left, Long right) { if (left == null) { return -1; } else if (right == null) { @@ -227,23 +156,23 @@ public class TopNDataSplitEvaluator { } } - private int asc(FieldRef field, Object left, Object right) { + private int ascCompare(DataType type, Object left, Object right) { if (left == null) { return -1; } else if (right == null) { return 1; } else { - return CompareUtils.compareLiteral(field.type(), left, right); + return CompareUtils.compareLiteral(type, left, right); } } - private int desc(FieldRef field, Object left, Object right) { + private int descCompare(DataType type, Object left, Object right) { if (left == null) { return -1; } else if (right == null) { return 1; } else { - return -CompareUtils.compareLiteral(field.type(), left, right); + return -CompareUtils.compareLiteral(type, left, right); } } @@ -252,16 +181,23 @@ public class TopNDataSplitEvaluator { id, key -> key == schema.id() ? schema : schemaManager.schema(id)); } - /** The DataSplit's stats. */ - private static class Stats { - Object min; - Object max; - Long nullCount; + /** DataSplit with stats. */ + private static class RichSplit { + + private final DataSplit split; + private final Object min; + private final Object max; + private final Long nullCount; - public Stats(Object min, Object max, Long nullCount) { + private RichSplit(DataSplit split, Object min, Object max, Long nullCount) { + this.split = split; this.min = min; this.max = max; this.nullCount = nullCount; } + + private DataSplit split() { + return split; + } } } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReader.java b/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReader.java index 79a1c94c32..82112f5fd4 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReader.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReader.java @@ -101,6 +101,8 @@ public interface SnapshotReader { SnapshotReader dropStats(); + SnapshotReader keepStats(); + SnapshotReader withShard(int indexOfThisSubtask, int numberOfParallelSubtasks); SnapshotReader withMetricRegistry(MetricRegistry registry); diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java b/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java index a51c1516cb..d979b91b2a 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java @@ -311,6 +311,12 @@ public class SnapshotReaderImpl implements SnapshotReader { return this; } + @Override + public SnapshotReader keepStats() { + scan.keepStats(); + return this; + } + @Override public SnapshotReader withShard(int indexOfThisSubtask, int numberOfParallelSubtasks) { if (splitGenerator.alwaysRawConvertible()) { diff --git a/paimon-core/src/main/java/org/apache/paimon/table/system/AuditLogTable.java b/paimon-core/src/main/java/org/apache/paimon/table/system/AuditLogTable.java index b9afacec2b..ab3b13a5ab 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/system/AuditLogTable.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/system/AuditLogTable.java @@ -400,6 +400,12 @@ public class AuditLogTable implements DataTable, ReadonlyTable { return this; } + @Override + public SnapshotReader keepStats() { + wrapped.keepStats(); + return this; + } + @Override public SnapshotReader withShard(int indexOfThisSubtask, int numberOfParallelSubtasks) { wrapped.withShard(indexOfThisSubtask, numberOfParallelSubtasks); diff --git a/paimon-core/src/test/java/org/apache/paimon/table/source/TableScanTest.java b/paimon-core/src/test/java/org/apache/paimon/table/source/TableScanTest.java index 0fd381f918..c0295a3a5a 100644 --- a/paimon-core/src/test/java/org/apache/paimon/table/source/TableScanTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/table/source/TableScanTest.java @@ -186,7 +186,7 @@ public class TableScanTest extends ScannerTestBase { TableScan.Plan plan4 = table.newScan().withTopN(new TopN(ref, ASCENDING, NULLS_FIRST, 5)).plan(); List<Split> splits4 = plan4.splits(); - assertThat(splits4.size()).isEqualTo(5); + assertThat(splits4.size()).isEqualTo(7); assertThat(((DataSplit) splits4.get(0)).minValue(field.id(), field, evolutions)) .isEqualTo(null); assertThat(((DataSplit) splits4.get(1)).minValue(field.id(), field, evolutions)) @@ -249,7 +249,7 @@ public class TableScanTest extends ScannerTestBase { TableScan.Plan plan8 = table.newScan().withTopN(new TopN(ref, DESCENDING, NULLS_FIRST, 5)).plan(); List<Split> splits8 = plan8.splits(); - assertThat(splits8.size()).isEqualTo(5); + assertThat(splits8.size()).isEqualTo(7); assertThat(((DataSplit) splits8.get(0)).maxValue(field.id(), field, evolutions)) .isEqualTo(null); assertThat(((DataSplit) splits8.get(1)).maxValue(field.id(), field, evolutions)) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala index 1e3bb846cf..ce1fe2e25f 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala @@ -73,13 +73,7 @@ trait ColumnPruningAndPushDown extends Scan with Logging { } pushDownLimit.foreach(_readBuilder.withLimit) pushDownTopN.foreach(_readBuilder.withTopN) - - // when TopN is not empty, we need the stats to pick the TopN DataSplits - if (pushDownTopN.nonEmpty) { - _readBuilder - } else { - _readBuilder.dropStats() - } + _readBuilder.dropStats() } final def metadataColumns: Seq[PaimonMetadataColumn] = { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala index aa50f76043..f021feab95 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala @@ -20,9 +20,8 @@ package org.apache.paimon.spark import org.apache.paimon.predicate._ import org.apache.paimon.predicate.SortValue.{NullOrdering, SortDirection} -import org.apache.paimon.spark.aggregate.{AggregatePushDownUtils, LocalAggregator} -import org.apache.paimon.table.{AppendOnlyFileStoreTable, FileStoreTable, InnerTable} -import org.apache.paimon.table.source.DataSplit +import org.apache.paimon.spark.aggregate.AggregatePushDownUtils.tryPushdownAggregation +import org.apache.paimon.table.{FileStoreTable, InnerTable} import org.apache.spark.sql.PaimonUtils import org.apache.spark.sql.connector.expressions @@ -168,25 +167,14 @@ class PaimonScanBuilder(table: InnerTable) val pushedPartitionPredicate = PredicateBuilder.and(pushedPaimonPredicates.toList.asJava) readBuilder.withFilter(pushedPartitionPredicate) } - val dataSplits = if (AggregatePushDownUtils.hasMinMaxAggregation(aggregation)) { - readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit]) - } else { - readBuilder.dropStats().newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit]) - } - if (AggregatePushDownUtils.canPushdownAggregation(table, aggregation, dataSplits.toSeq)) { - val aggregator = new LocalAggregator(table.asInstanceOf[FileStoreTable]) - aggregator.initialize(aggregation) - dataSplits.foreach(aggregator.update) - localScan = Some( - PaimonLocalScan( - aggregator.result(), - aggregator.resultSchema(), - table, - pushedPaimonPredicates) - ) - true - } else { - false + + tryPushdownAggregation(table.asInstanceOf[FileStoreTable], aggregation, readBuilder) match { + case Some(agg) => + localScan = Some( + PaimonLocalScan(agg.result(), agg.resultSchema(), table, pushedPaimonPredicates) + ) + true + case None => false } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala index fd1ee59b5a..e07db9515e 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala @@ -18,91 +18,98 @@ package org.apache.paimon.spark.aggregate -import org.apache.paimon.stats.StatsUtils.minmaxAvailable -import org.apache.paimon.table.Table -import org.apache.paimon.table.source.DataSplit +import org.apache.paimon.table.FileStoreTable +import org.apache.paimon.table.source.{DataSplit, ReadBuilder} +import org.apache.paimon.table.source.PushDownUtils.minmaxAvailable import org.apache.paimon.types._ -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, CountStar, Max, Min} -import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils +import org.apache.spark.sql.connector.expressions.Expression +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, CountStar, Max, Min} +import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils.extractV2Column import scala.collection.JavaConverters._ import scala.collection.mutable object AggregatePushDownUtils { - def canPushdownAggregation( - table: Table, + def tryPushdownAggregation( + table: FileStoreTable, aggregation: Aggregation, - dataSplits: Seq[DataSplit]): Boolean = { + readBuilder: ReadBuilder): Option[LocalAggregator] = { + val options = table.coreOptions() + val rowType = table.rowType + val partitionKeys = table.partitionKeys() - var hasMinMax = false - val minmaxColumns = mutable.HashSet.empty[String] - var hasCount = false - - def getDataFieldForCol(colName: String): DataField = { - table.rowType.getField(colName) + aggregation.groupByExpressions.map(extractV2Column).foreach { + colName => + // don't push down if the group by columns are not the same as the partition columns (orders + // doesn't matter because reorder can be done at data source layer) + if (colName.isEmpty || !partitionKeys.contains(colName.get)) return None } - def isPartitionCol(colName: String) = { - table.partitionKeys.contains(colName) + val splits = extractMinMaxColumns(rowType, aggregation) match { + case Some(columns) => + if (columns.isEmpty) { + generateSplits(readBuilder.dropStats()) + } else { + if (options.deletionVectorsEnabled() || !table.primaryKeys().isEmpty) { + return None + } + val splits = generateSplits(readBuilder) + if (!splits.forall(minmaxAvailable(_, columns.asJava))) { + return None + } + splits + } + case None => return None } - def processMinOrMax(agg: AggregateFunc): Boolean = { - val columnName = agg match { - case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined => - V2ColumnUtils.extractV2Column(max.column).get - case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined => - V2ColumnUtils.extractV2Column(min.column).get - case _ => return false - } - - val dataField = getDataFieldForCol(columnName) - - if (minmaxAvailable(dataField.`type`())) { - minmaxColumns.add(columnName) - hasMinMax = true - true - } else { - false - } + if (!splits.forall(_.mergedRowCountAvailable())) { + return None } - aggregation.groupByExpressions.map(V2ColumnUtils.extractV2Column).foreach { - colName => - // don't push down if the group by columns are not the same as the partition columns (orders - // doesn't matter because reorder can be done at data source layer) - if (colName.isEmpty || !isPartitionCol(colName.get)) return false - } + val aggregator = new LocalAggregator(table) + aggregator.initialize(aggregation) + splits.foreach(aggregator.update) + Option(aggregator) + } + + private def generateSplits(readBuilder: ReadBuilder): mutable.Seq[DataSplit] = { + readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit]) + } + private def extractMinMaxColumns( + rowType: RowType, + aggregation: Aggregation): Option[Set[String]] = { + val columns = mutable.HashSet.empty[String] aggregation.aggregateExpressions.foreach { - case max: Max => - if (!processMinOrMax(max)) return false - case min: Min => - if (!processMinOrMax(min)) return false + case e if e.isInstanceOf[Min] || e.isInstanceOf[Max] => + extractMinMaxColumn(rowType, e) match { + case Some(colName) => columns.add(colName) + case None => return None + } case _: CountStar => - hasCount = true - case _ => - return false + case _ => return None } + Option(columns.toSet) + } - if (hasMinMax) { - dataSplits.forall(_.statsAvailable(minmaxColumns.toSet.asJava)) - } else if (hasCount) { - dataSplits.forall(_.mergedRowCountAvailable()) - } else { - true + private def extractMinMaxColumn(rowType: RowType, minOrMax: Expression): Option[String] = { + val column = minOrMax match { + case min: Min => min.column() + case max: Max => max.column() + } + val extractColumn = extractV2Column(column) + if (extractColumn.isEmpty) { + return None } - } - def hasMinMaxAggregation(aggregation: Aggregation): Boolean = { - var hasMinMax = false; - aggregation.aggregateExpressions().foreach { - case _: Min | _: Max => - hasMinMax = true - case _ => + val columnName = extractColumn.get + val dataType = rowType.getField(columnName).`type`() + if (minmaxAvailable(dataType)) { + Option(columnName) + } else { + None } - hasMinMax } - } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala index 26c19ecc27..a60d88aef9 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala @@ -230,6 +230,8 @@ class PushDownAggregatesTest extends PaimonSparkTestBase with AdaptiveSparkPlanH spark.sql("INSERT INTO T VALUES(1, 'x_1')") if (deletionVectorsEnabled) { runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 0) + // should not push down min max for primary key table + runAndCheckAggregate("SELECT MIN(c1) FROM T", Row(1) :: Nil, 2) } else { runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 2) }