This is an automated email from the ASF dual-hosted git repository. zhihao pushed a commit to branch perf/szh/push_limit_to_table_scan in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit a6431a26eeb54882ebed3d6b9f505a0045e8cad9 Author: Sh-Zh-7 <[email protected]> AuthorDate: Mon Mar 2 23:30:46 2026 +0800 Add push predicate through project optimization. --- .../distribute/TableDistributedPlanGenerator.java | 60 +++++ .../iterative/rule/PruneTopKRankingColumns.java | 35 +++ .../iterative/rule/PushFilterIntoRowNumber.java | 137 +++++++++++ .../PushPredicateThroughProjectIntoRowNumber.java | 196 ++++++++++++++++ .../PushPredicateThroughProjectIntoWindow.java | 210 +++++++++++++++++ .../optimizations/LogicalOptimizeFactory.java | 10 + .../planner/WindowFunctionOptimizationTest.java | 256 ++++++++++++++++++++- 7 files changed, 903 insertions(+), 1 deletion(-) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java index 52dad04df42..26d301605b2 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java @@ -54,6 +54,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.planner.OrderingScheme; import org.apache.iotdb.db.queryengine.plan.relational.planner.SortOrder; import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator; +import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolsExtractor; import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode; import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationTableScanNode; import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationTreeDeviceViewScanNode; @@ -306,7 +307,11 @@ public class TableDistributedPlanGenerator @Override public List<PlanNode> visitProject(ProjectNode node, PlanContext context) { + Set<Symbol> savedParentRefs = context.getParentReferencedSymbols(); + context.setParentReferencedSymbols( + SymbolsExtractor.extractUnique(node.getAssignments().getExpressions())); List<PlanNode> childrenNodes = node.getChild().accept(this, context); + context.setParentReferencedSymbols(savedParentRefs); OrderingScheme childOrdering = nodeOrderingMap.get(childrenNodes.get(0).getPlanNodeId()); boolean containAllSortItem = false; if (childOrdering != null) { @@ -595,7 +600,15 @@ public class TableDistributedPlanGenerator @Override public List<PlanNode> visitFilter(FilterNode node, PlanContext context) { + Set<Symbol> savedParentRefs = context.getParentReferencedSymbols(); + if (savedParentRefs != null) { + ImmutableSet.Builder<Symbol> merged = ImmutableSet.builder(); + merged.addAll(savedParentRefs); + merged.addAll(SymbolsExtractor.extractUnique(node.getPredicate())); + context.setParentReferencedSymbols(merged.build()); + } List<PlanNode> childrenNodes = node.getChild().accept(this, context); + context.setParentReferencedSymbols(savedParentRefs); OrderingScheme childOrdering = nodeOrderingMap.get(childrenNodes.get(0).getPlanNodeId()); if (childOrdering != null) { nodeOrderingMap.put(node.getPlanNodeId(), childOrdering); @@ -1890,6 +1903,40 @@ public class TableDistributedPlanGenerator node.setChild(((SortNode) node.getChild()).getChild()); } List<PlanNode> childrenNodes = node.getChild().accept(this, context); + + Set<Symbol> parentRefs = context.getParentReferencedSymbols(); + if (parentRefs != null && !parentRefs.contains(node.getRowNumberSymbol())) { + // If maxRowCountPerPartition is set, push it as a per-device limit to each + // DeviceTableScanNode so that only the required number of rows are scanned. + node + .getMaxRowCountPerPartition() + .ifPresent( + limit -> { + for (PlanNode child : childrenNodes) { + if (child instanceof DeviceTableScanNode + && !(child instanceof AggregationTableScanNode)) { + DeviceTableScanNode scanNode = (DeviceTableScanNode) child; + scanNode.setPushLimitToEachDevice(true); + if (scanNode.getPushDownLimit() <= 0) { + scanNode.setPushDownLimit(limit); + } else { + scanNode.setPushDownLimit(Math.min(limit, scanNode.getPushDownLimit())); + } + } + } + }); + // Eliminate RowNumberNode entirely - return children directly. + if (childrenNodes.size() == 1 || canSplitPushDown) { + return childrenNodes; + } else { + CollectNode collectNode = + new CollectNode( + queryId.genPlanNodeId(), childrenNodes.get(0).getOutputSymbols()); + childrenNodes.forEach(collectNode::addChild); + return Collections.singletonList(collectNode); + } + } + if (childrenNodes.size() == 1) { node.setChild(childrenNodes.get(0)); return Collections.singletonList(node); @@ -1931,6 +1978,10 @@ public class TableDistributedPlanGenerator if (canSplitPushDown && orderingScheme.isPresent()) { if (tryPushTopKRankingLimitToScan(node, childrenNodes, orderingScheme.get())) { node.setDataPreSortedAndLimited(true); + Set<Symbol> parentRefs = context.getParentReferencedSymbols(); + if (parentRefs != null && !parentRefs.contains(node.getRankingSymbol())) { + return childrenNodes; + } } } @@ -1992,6 +2043,7 @@ public class TableDistributedPlanGenerator OrderingScheme expectedOrderingScheme; TRegionReplicaSet mostUsedRegion; boolean deviceCrossRegion; + Set<Symbol> parentReferencedSymbols; public PlanContext() { this.nodeDistributionMap = new HashMap<>(); @@ -2018,5 +2070,13 @@ public class TableDistributedPlanGenerator public boolean isPushDownGrouping() { return pushDownGrouping; } + + public Set<Symbol> getParentReferencedSymbols() { + return parentReferencedSymbols; + } + + public void setParentReferencedSymbols(Set<Symbol> parentReferencedSymbols) { + this.parentReferencedSymbols = parentReferencedSymbols; + } } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PruneTopKRankingColumns.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PruneTopKRankingColumns.java new file mode 100644 index 00000000000..5bd1ed47239 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PruneTopKRankingColumns.java @@ -0,0 +1,35 @@ +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import com.google.common.collect.Streams; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKRankingNode; + +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Util.restrictChildOutputs; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.topNRanking; + +public class PruneTopKRankingColumns + extends ProjectOffPushDownRule<TopKRankingNode> +{ + public PruneTopKRankingColumns() + { + super(topNRanking()); + } + + @Override + protected Optional<PlanNode> pushDownProjectOff(Context context, TopKRankingNode topNRankingNode, Set<Symbol> referencedOutputs) + { + Set<Symbol> requiredInputs = Streams.concat( + referencedOutputs.stream() + .filter(symbol -> !symbol.equals(topNRankingNode.getRankingSymbol())), + topNRankingNode.getSpecification().getPartitionBy().stream(), + topNRankingNode.getSpecification().getOrderingScheme().get().getOrderBy().stream()) + .collect(toImmutableSet()); + + return restrictChildOutputs(context.getIdAllocator(), topNRankingNode, requiredInputs); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushFilterIntoRowNumber.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushFilterIntoRowNumber.java new file mode 100644 index 00000000000..9f1514b199a --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushFilterIntoRowNumber.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FilterNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.RowNumberNode; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import java.util.Optional; +import java.util.OptionalInt; + +import static java.lang.Math.toIntExact; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.filter; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.rowNumber; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture; + +/** + * Pushes a row-number upper-bound filter (e.g. {@code rn <= N}) into {@link RowNumberNode} by + * setting {@code maxRowCountPerPartition}. The filter is eliminated because the row-number node + * guarantees that no partition will emit more than {@code N} rows. + * + * <p>Before: + * + * <pre> + * FilterNode(rn <= N) + * └── RowNumberNode(rowNumberSymbol=rn, maxRowCountPerPartition=empty) + * </pre> + * + * After: + * + * <pre> + * RowNumberNode(rowNumberSymbol=rn, maxRowCountPerPartition=N) + * </pre> + */ +public class PushFilterIntoRowNumber implements Rule<FilterNode> { + private static final Capture<RowNumberNode> CHILD = newCapture(); + + private static final Pattern<FilterNode> PATTERN = + filter() + .with( + source() + .matching( + rowNumber() + .matching(rowNumber -> !rowNumber.getMaxRowCountPerPartition().isPresent()) + .capturedAs(CHILD))); + + @Override + public Pattern<FilterNode> getPattern() { + return PATTERN; + } + + @Override + public Result apply(FilterNode node, Captures captures, Context context) { + RowNumberNode rowNumberNode = captures.get(CHILD); + Symbol rowNumberSymbol = rowNumberNode.getRowNumberSymbol(); + + OptionalInt upperBound = extractUpperBound(node.getPredicate(), rowNumberSymbol); + if (!upperBound.isPresent()) { + return Result.empty(); + } + + if (upperBound.getAsInt() <= 0) { + return Result.empty(); + } + + return Result.ofPlanNode( + new RowNumberNode( + rowNumberNode.getPlanNodeId(), + rowNumberNode.getChild(), + rowNumberNode.getPartitionBy(), + rowNumberNode.isOrderSensitive(), + rowNumberSymbol, + Optional.of(upperBound.getAsInt()))); + } + + private OptionalInt extractUpperBound(Expression predicate, Symbol rowNumberSymbol) { + if (!(predicate instanceof ComparisonExpression)) { + return OptionalInt.empty(); + } + + ComparisonExpression comparison = (ComparisonExpression) predicate; + Expression left = comparison.getLeft(); + Expression right = comparison.getRight(); + + if (!(left instanceof SymbolReference) || !(right instanceof Literal)) { + return OptionalInt.empty(); + } + + SymbolReference symbolRef = (SymbolReference) left; + if (!symbolRef.getName().equals(rowNumberSymbol.getName())) { + return OptionalInt.empty(); + } + + Literal literal = (Literal) right; + Object value = literal.getTsValue(); + if (!(value instanceof Number)) { + return OptionalInt.empty(); + } + + long constantValue = ((Number) value).longValue(); + + switch (comparison.getOperator()) { + case LESS_THAN: + return OptionalInt.of(toIntExact(constantValue - 1)); + case LESS_THAN_OR_EQUAL: + return OptionalInt.of(toIntExact(constantValue)); + default: + return OptionalInt.empty(); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java new file mode 100644 index 00000000000..55e7da20675 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FilterNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.RowNumberNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ValuesNode; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import com.google.common.collect.ImmutableList; + +import java.util.Optional; +import java.util.OptionalInt; + +import static java.lang.Math.toIntExact; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.filter; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.project; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.rowNumber; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture; + +/** + * Pushes a row-number upper-bound filter through an identity projection into {@link + * RowNumberNode} by setting {@code maxRowCountPerPartition}. + * + * <p>Before: + * + * <pre> + * FilterNode(rn <= N) + * └── ProjectNode (identity) + * └── RowNumberNode(rowNumberSymbol=rn, maxRowCountPerPartition=empty) + * </pre> + * + * After (for LESS_THAN / LESS_THAN_OR_EQUAL — filter fully absorbed): + * + * <pre> + * ProjectNode (identity) + * └── RowNumberNode(rowNumberSymbol=rn, maxRowCountPerPartition=N) + * </pre> + * + * After (for EQUAL — filter kept): + * + * <pre> + * FilterNode(rn = N) + * └── ProjectNode (identity) + * └── RowNumberNode(rowNumberSymbol=rn, maxRowCountPerPartition=N) + * </pre> + */ +public class PushPredicateThroughProjectIntoRowNumber implements Rule<FilterNode> { + private static final Capture<ProjectNode> PROJECT = newCapture(); + private static final Capture<RowNumberNode> ROW_NUMBER = newCapture(); + + private static final Pattern<FilterNode> PATTERN = + filter() + .with( + source() + .matching( + project() + .matching(ProjectNode::isIdentity) + .capturedAs(PROJECT) + .with( + source() + .matching( + rowNumber() + .matching( + rn -> !rn.getMaxRowCountPerPartition().isPresent()) + .capturedAs(ROW_NUMBER))))); + + @Override + public Pattern<FilterNode> getPattern() { + return PATTERN; + } + + @Override + public Result apply(FilterNode filter, Captures captures, Context context) { + ProjectNode project = captures.get(PROJECT); + RowNumberNode rowNumberNode = captures.get(ROW_NUMBER); + + Symbol rowNumberSymbol = rowNumberNode.getRowNumberSymbol(); + if (!project.getAssignments().getSymbols().contains(rowNumberSymbol)) { + return Result.empty(); + } + + OptionalInt upperBound = extractUpperBound(filter.getPredicate(), rowNumberSymbol); + if (!upperBound.isPresent()) { + return Result.empty(); + } + if (upperBound.getAsInt() <= 0) { + return Result.ofPlanNode( + new ValuesNode(filter.getPlanNodeId(), filter.getOutputSymbols(), ImmutableList.of())); + } + + project = + (ProjectNode) + project.replaceChildren( + ImmutableList.of( + new RowNumberNode( + rowNumberNode.getPlanNodeId(), + rowNumberNode.getChild(), + rowNumberNode.getPartitionBy(), + rowNumberNode.isOrderSensitive(), + rowNumberSymbol, + Optional.of(upperBound.getAsInt())))); + + if (needToKeepFilter(filter.getPredicate())) { + return Result.ofPlanNode( + new FilterNode(filter.getPlanNodeId(), project, filter.getPredicate())); + } + return Result.ofPlanNode(project); + } + + private OptionalInt extractUpperBound(Expression predicate, Symbol rowNumberSymbol) { + if (!(predicate instanceof ComparisonExpression)) { + return OptionalInt.empty(); + } + + ComparisonExpression comparison = (ComparisonExpression) predicate; + Expression left = comparison.getLeft(); + Expression right = comparison.getRight(); + + if (!(left instanceof SymbolReference) || !(right instanceof Literal)) { + return OptionalInt.empty(); + } + + SymbolReference symbolRef = (SymbolReference) left; + if (!symbolRef.getName().equals(rowNumberSymbol.getName())) { + return OptionalInt.empty(); + } + + Literal literal = (Literal) right; + Object value = literal.getTsValue(); + if (!(value instanceof Number)) { + return OptionalInt.empty(); + } + + long constantValue = ((Number) value).longValue(); + + switch (comparison.getOperator()) { + case LESS_THAN: + return OptionalInt.of(toIntExact(constantValue - 1)); + case LESS_THAN_OR_EQUAL: + case EQUAL: + return OptionalInt.of(toIntExact(constantValue)); + default: + return OptionalInt.empty(); + } + } + + /** + * For {@code LESS_THAN} and {@code LESS_THAN_OR_EQUAL}, the RowNumberNode with + * maxRowCountPerPartition produces exactly the rows that satisfy the predicate (row numbers + * 1..N), so the filter can be removed. For {@code EQUAL} (e.g. {@code rn = 5}), the + * RowNumberNode produces rows 1..5 but only rows where {@code rn = 5} are wanted, so the filter + * must be kept. + */ + private static boolean needToKeepFilter(Expression predicate) { + if (!(predicate instanceof ComparisonExpression)) { + return true; + } + + ComparisonExpression comparison = (ComparisonExpression) predicate; + switch (comparison.getOperator()) { + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + return false; + default: + return true; + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java new file mode 100644 index 00000000000..aa401de2a46 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FilterNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKRankingNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ValuesNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.WindowNode; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import com.google.common.collect.ImmutableList; + +import java.util.OptionalInt; + +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Util.toTopNRankingType; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.filter; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.project; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.window; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture; + +/** + * Converts a filter on a ranking function (e.g. {@code rn <= N}) into a {@link TopKRankingNode} + * when there is an identity projection between the filter and window node. + * + * <p>Before: + * + * <pre> + * FilterNode(rn <= N) + * └── ProjectNode (identity) + * └── WindowNode(row_number/rank) + * </pre> + * + * After (for LESS_THAN / LESS_THAN_OR_EQUAL — filter fully absorbed): + * + * <pre> + * ProjectNode (identity) + * └── TopKRankingNode(maxRanking=N) + * </pre> + * + * After (for EQUAL — filter kept): + * + * <pre> + * FilterNode(rn = N) + * └── ProjectNode (identity) + * └── TopKRankingNode(maxRanking=N) + * </pre> + */ +public class PushPredicateThroughProjectIntoWindow implements Rule<FilterNode> { + private static final Capture<ProjectNode> PROJECT = newCapture(); + private static final Capture<WindowNode> WINDOW = newCapture(); + + private final PlannerContext plannerContext; + private final Pattern<FilterNode> pattern; + + public PushPredicateThroughProjectIntoWindow(PlannerContext plannerContext) { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.pattern = + filter() + .with( + source() + .matching( + project() + .matching(ProjectNode::isIdentity) + .capturedAs(PROJECT) + .with( + source() + .matching( + window() + .matching( + window -> + toTopNRankingType(window).isPresent()) + .capturedAs(WINDOW))))); + } + + @Override + public Pattern<FilterNode> getPattern() { + return pattern; + } + + @Override + public Result apply(FilterNode filter, Captures captures, Context context) { + ProjectNode project = captures.get(PROJECT); + WindowNode window = captures.get(WINDOW); + + Symbol rankingSymbol = getOnlyElement(window.getWindowFunctions().keySet()); + if (!project.getAssignments().getSymbols().contains(rankingSymbol)) { + return Result.empty(); + } + + OptionalInt upperBound = + extractUpperBoundFromComparison(filter.getPredicate(), rankingSymbol); + if (!upperBound.isPresent()) { + return Result.empty(); + } + if (upperBound.getAsInt() <= 0) { + return Result.ofPlanNode( + new ValuesNode(filter.getPlanNodeId(), filter.getOutputSymbols(), ImmutableList.of())); + } + + TopKRankingNode.RankingType rankingType = toTopNRankingType(window).get(); + project = + (ProjectNode) + project.replaceChildren( + ImmutableList.of( + new TopKRankingNode( + window.getPlanNodeId(), + window.getChild(), + window.getSpecification(), + rankingType, + rankingSymbol, + upperBound.getAsInt(), + false))); + + if (needToKeepFilter(filter.getPredicate())) { + return Result.ofPlanNode( + new FilterNode(filter.getPlanNodeId(), project, filter.getPredicate())); + } + return Result.ofPlanNode(project); + } + + private OptionalInt extractUpperBoundFromComparison( + Expression predicate, Symbol rankingSymbol) { + if (!(predicate instanceof ComparisonExpression)) { + return OptionalInt.empty(); + } + + ComparisonExpression comparison = (ComparisonExpression) predicate; + Expression left = comparison.getLeft(); + Expression right = comparison.getRight(); + + if (!(left instanceof SymbolReference) || !(right instanceof Literal)) { + return OptionalInt.empty(); + } + + SymbolReference symbolRef = (SymbolReference) left; + if (!symbolRef.getName().equals(rankingSymbol.getName())) { + return OptionalInt.empty(); + } + + Literal literal = (Literal) right; + Object value = literal.getTsValue(); + if (!(value instanceof Number)) { + return OptionalInt.empty(); + } + + long constantValue = ((Number) value).longValue(); + + switch (comparison.getOperator()) { + case LESS_THAN: + return OptionalInt.of(toIntExact(constantValue - 1)); + case LESS_THAN_OR_EQUAL: + case EQUAL: + return OptionalInt.of(toIntExact(constantValue)); + default: + return OptionalInt.empty(); + } + } + + /** + * For {@code LESS_THAN} and {@code LESS_THAN_OR_EQUAL}, the TopKRankingNode produces exactly + * the rows that satisfy the predicate (ranking values 1..N), so the filter can be removed. For + * {@code EQUAL} (e.g. {@code rn = 5}), TopKRankingNode produces rows 1..5 but only rows where + * {@code rn = 5} are wanted, so the filter must be kept. + */ + private static boolean needToKeepFilter(Expression predicate) { + if (!(predicate instanceof ComparisonExpression)) { + return true; + } + + ComparisonExpression comparison = (ComparisonExpression) predicate; + switch (comparison.getOperator()) { + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + return false; + default: + return true; + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java index 864b9a987ab..b3f398089a1 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java @@ -73,11 +73,15 @@ import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Pr import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneTableFunctionProcessorSourceColumns; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneTableScanColumns; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneTopKColumns; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneTopKRankingColumns; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneUnionColumns; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneUnionSourceColumns; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneWindowColumns; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushDownFilterIntoWindow; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushDownLimitIntoWindow; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushFilterIntoRowNumber; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushPredicateThroughProjectIntoRowNumber; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushPredicateThroughProjectIntoWindow; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushLimitThroughOffset; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushLimitThroughProject; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushLimitThroughUnion; @@ -153,6 +157,7 @@ public class LogicalOptimizeFactory { new PruneTableFunctionProcessorSourceColumns(), new PruneTableScanColumns(plannerContext.getMetadata()), new PruneTopKColumns(), + new PruneTopKRankingColumns(), new PruneWindowColumns(), new PruneJoinColumns(), new PruneJoinChildrenColumns(), @@ -375,9 +380,14 @@ public class LogicalOptimizeFactory { ImmutableSet.<Rule<?>>builder() .add(new PushDownLimitIntoWindow()) .add(new PushDownFilterIntoWindow(plannerContext)) + .add(new PushPredicateThroughProjectIntoWindow(plannerContext)) .add(new ReplaceWindowWithRowNumber(metadata)) + .add(new PushFilterIntoRowNumber()) + .add(new PushPredicateThroughProjectIntoRowNumber()) .addAll(GatherAndMergeWindows.rules()) .build()), + inlineProjectionLimitFiltersOptimizer, + columnPruningOptimizer, new TransformAggregationToStreamable(), new PushAggregationIntoTableScan(), new TransformSortToStreamSort(), diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/WindowFunctionOptimizationTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/WindowFunctionOptimizationTest.java index a2243557d2f..aed16f77c6e 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/WindowFunctionOptimizationTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/WindowFunctionOptimizationTest.java @@ -23,19 +23,24 @@ import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; import org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern; import org.apache.iotdb.db.queryengine.plan.relational.planner.node.DeviceTableScanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.RowNumberNode; import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKRankingNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.junit.Test; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.mergeSort; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanAssert.assertPlan; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.collect; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.exchange; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.filter; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.group; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.limit; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.output; @@ -160,7 +165,9 @@ public class WindowFunctionOptimizationTest { planTester.getFragmentPlan(0), output((collect(exchange(), exchange(), exchange())))); assertPlan(planTester.getFragmentPlan(1), topKRanking(tableScan)); assertPlan(planTester.getFragmentPlan(2), topKRanking(tableScan)); - assertPlan(planTester.getFragmentPlan(3), topKRanking(sort(tableScan))); + assertPlan(planTester.getFragmentPlan(3), topKRanking(mergeSort(exchange(), exchange()))); + assertPlan(planTester.getFragmentPlan(4), sort(tableScan)); + assertPlan(planTester.getFragmentPlan(5), sort(tableScan)); } @Test @@ -343,4 +350,251 @@ public class WindowFunctionOptimizationTest { assertEquals("pushDownLimit should be 2", 2, dts.getPushDownLimit()); } } + + @Test + public void testTopKRankingEliminatedWhenRankSymbolNotOutput() { + PlanTester planTester = new PlanTester(); + + String sql = + "SELECT tag1, s1 FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3 ORDER BY time) as rn FROM table1) WHERE rn <= 2"; + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + PlanMatchPattern tableScan = tableScan("testdb.table1"); + + // Logical plan: OutputNode -> ProjectNode -> TopKRankingNode -> GroupNode -> TableScanNode + assertPlan(logicalQueryPlan, output(project(topKRanking(group(tableScan))))); + + // Distributed plan: TopKRankingNode eliminated since rn is not in the output. + // Limit is pushed to DeviceTableScanNode. + // Fragment 0: OutputNode -> CollectNode -> ExchangeNodes + assertPlan( + planTester.getFragmentPlan(0), output(collect(exchange(), exchange(), exchange()))); + + // Worker fragments: ProjectNode -> DeviceTableScanNode (no TopKRankingNode) + for (int i = 1; i <= 2; i++) { + PlanNode fragmentRoot = planTester.getFragmentPlan(i); + assertFalse( + "Fragment " + i + " root should NOT be TopKRankingNode", + fragmentRoot instanceof TopKRankingNode); + assertPlan(planTester.getFragmentPlan(i), project(tableScan)); + + assertTrue( + "Fragment " + i + " root should be ProjectNode", + fragmentRoot instanceof ProjectNode); + PlanNode scanChild = fragmentRoot.getChildren().get(0); + assertTrue( + "Child should be DeviceTableScanNode", scanChild instanceof DeviceTableScanNode); + DeviceTableScanNode dts = (DeviceTableScanNode) scanChild; + assertTrue("pushLimitToEachDevice should be true", dts.isPushLimitToEachDevice()); + assertEquals("pushDownLimit should be 2", 2, dts.getPushDownLimit()); + } + } + + @Test + public void testTopKRankingKeptWhenRankSymbolIsOutput() { + PlanTester planTester = new PlanTester(); + + // Same query but SELECT * includes rn - TopKRankingNode should NOT be eliminated + String sql = + "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3 ORDER BY time) as rn FROM table1) WHERE rn <= 2"; + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + PlanMatchPattern tableScan = tableScan("testdb.table1"); + + assertPlan(logicalQueryPlan, output(topKRanking(group(tableScan)))); + + // Worker fragments should still have TopKRankingNode + for (int i = 1; i <= 2; i++) { + PlanNode fragmentRoot = planTester.getFragmentPlan(i); + assertTrue( + "Fragment " + i + " root should be TopKRankingNode", + fragmentRoot instanceof TopKRankingNode); + } + } + + @Test + public void testRowNumberEliminatedWhenRowNumberNotOutput() { + PlanTester planTester = new PlanTester(); + + // RowNumber with all IDs as partition - row number not in output + String sql = + "SELECT tag1, s1 FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3) as rn FROM table1)"; + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + PlanMatchPattern tableScan = tableScan("testdb.table1"); + + // Logical plan: row_number is pruned at the window level by PruneWindowColumns + // since rn is not referenced anywhere. The plan should not contain RowNumberNode. + assertPlan(logicalQueryPlan, output(project(group(tableScan)))); + } + + @Test + public void testRowNumberPushDownWhenRowNumberIsOutput() { + PlanTester planTester = new PlanTester(); + + // RowNumber with all IDs as partition and rn IS referenced in output + String sql = + "SELECT tag1, s1, rn FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3) as rn FROM table1)"; + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + PlanMatchPattern tableScan = tableScan("testdb.table1"); + + // Logical plan: OutputNode -> ProjectNode -> RowNumberNode -> GroupNode -> TableScanNode + // (project is inlined since it selects a subset including rn) + assertPlan(logicalQueryPlan, output(project(rowNumber(group(tableScan))))); + + // RowNumberNode is pushed down to each partition (not eliminated, since rn IS in the output) + assertPlan(planTester.getFragmentPlan(0), output(collect(exchange(), exchange(), exchange()))); + assertPlan(planTester.getFragmentPlan(1), project(rowNumber(tableScan))); + assertPlan(planTester.getFragmentPlan(2), project(rowNumber(tableScan))); + } + + @Test + public void testRowNumberWithMaxCountEliminatedWhenRowNumberNotOutput() { + PlanTester planTester = new PlanTester(); + + // rn <= 2 pushes the limit into RowNumberNode (maxRowCountPerPartition=2), and since rn is + // not in the outer SELECT, the RowNumberNode is eliminated in the distributed plan. + String sql = + "SELECT tag1, s1 FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3) as rn FROM table1) WHERE rn <= 2"; + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + PlanMatchPattern tableScan = tableScan("testdb.table1"); + + // Logical plan: PushFilterIntoRowNumber absorbs rn<=2, leaving RowNumberNode with maxRowCount=2 + // No filter remains above RowNumberNode. + assertPlan(logicalQueryPlan, output(project(rowNumber(group(tableScan))))); + + // Distributed plan: RowNumberNode eliminated since rn is not in the output. + // Limit (maxRowCountPerPartition=2) is pushed to each DeviceTableScanNode. + assertPlan( + planTester.getFragmentPlan(0), output(collect(exchange(), exchange(), exchange()))); + + // Worker fragments: ProjectNode -> DeviceTableScanNode (no RowNumberNode) + for (int i = 1; i <= 2; i++) { + PlanNode fragmentRoot = planTester.getFragmentPlan(i); + assertFalse( + "Fragment " + i + " root should NOT be RowNumberNode", + fragmentRoot instanceof RowNumberNode); + assertPlan(planTester.getFragmentPlan(i), project(tableScan)); + + assertTrue( + "Fragment " + i + " root should be ProjectNode", + fragmentRoot instanceof ProjectNode); + PlanNode scanChild = fragmentRoot.getChildren().get(0); + assertTrue( + "Child should be DeviceTableScanNode", scanChild instanceof DeviceTableScanNode); + DeviceTableScanNode dts = (DeviceTableScanNode) scanChild; + assertTrue("pushLimitToEachDevice should be true", dts.isPushLimitToEachDevice()); + assertEquals("pushDownLimit should be 2", 2, dts.getPushDownLimit()); + } + } + + @Test + public void testRowNumberWithMaxCountKeptWhenRowNumberIsOutput() { + PlanTester planTester = new PlanTester(); + + // Same query but SELECT * includes rn - RowNumberNode should NOT be eliminated + String sql = + "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3) as rn FROM table1) WHERE rn <= 2"; + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + PlanMatchPattern tableScan = tableScan("testdb.table1"); + + // Logical plan: RowNumberNode with maxRowCount=2 (filter absorbed), no outer project removing rn + assertPlan(logicalQueryPlan, output(rowNumber(group(tableScan)))); + + // Worker fragments should still have RowNumberNode since rn IS in the output + for (int i = 1; i <= 2; i++) { + PlanNode fragmentRoot = planTester.getFragmentPlan(i); + assertTrue( + "Fragment " + i + " root should be RowNumberNode", + fragmentRoot instanceof RowNumberNode); + } + } + + @Test + public void testTopKRankingWithEqualPredicate() { + PlanTester planTester = new PlanTester(); + + String sql = + "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1 ORDER BY s1) as rn FROM table1) WHERE rn = 2"; + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + PlanMatchPattern tableScan = tableScan("testdb.table1"); + + // TopKRanking created with maxRanking=2, but filter(rn = 2) is kept because + // ranking values 1..2 do not all satisfy rn = 2 + /* + * └──OutputNode + * └──FilterNode(rn = 2) + * └──TopKRankingNode + * └──SortNode + * └──TableScanNode + */ + assertPlan(logicalQueryPlan, output(filter(topKRanking(sort(tableScan))))); + } + + @Test + public void testTopKRankingWithEqualPredicateAllPartitions() { + PlanTester planTester = new PlanTester(); + + String sql = + "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3 ORDER BY s1) as rn FROM table1) WHERE rn = 2"; + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + PlanMatchPattern tableScan = tableScan("testdb.table1"); + + // TopKRanking created with maxRanking=2, filter(rn = 2) is kept + /* + * └──OutputNode + * └──FilterNode(rn = 2) + * └──TopKRankingNode + * └──GroupNode + * └──TableScanNode + */ + assertPlan(logicalQueryPlan, output(filter(topKRanking(group(tableScan))))); + + // Distributed plan: TopKRanking and filter pushed down + assertPlan( + planTester.getFragmentPlan(0), + output(collect(exchange(), exchange(), exchange()))); + assertPlan(planTester.getFragmentPlan(1), filter(topKRanking(tableScan))); + assertPlan(planTester.getFragmentPlan(2), filter(topKRanking(tableScan))); + } + + @Test + public void testTopKRankingWithLessThanPredicate() { + PlanTester planTester = new PlanTester(); + + // rn < 3 is equivalent to rn <= 2, so the filter should be fully absorbed + String sql = + "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1 ORDER BY s1) as rn FROM table1) WHERE rn < 3"; + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + PlanMatchPattern tableScan = tableScan("testdb.table1"); + + // Filter absorbed into TopKRankingNode (maxRanking=2) + /* + * └──OutputNode + * └──TopKRankingNode + * └──SortNode + * └──TableScanNode + */ + assertPlan(logicalQueryPlan, output(topKRanking(sort(tableScan)))); + } + + @Test + public void testTopKRankingWithEqualPredicateColumnPruned() { + PlanTester planTester = new PlanTester(); + + // rn = 2 with rn not in output: filter kept, rn pruned by project + String sql = + "SELECT tag1, s1 FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3 ORDER BY s1) as rn FROM table1) WHERE rn = 2"; + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + PlanMatchPattern tableScan = tableScan("testdb.table1"); + + // Filter(rn = 2) kept, project prunes rn from output + /* + * └──OutputNode + * └──ProjectNode + * └──FilterNode(rn = 2) + * └──ProjectNode + * └──TopKRankingNode + * └──GroupNode + * └──TableScanNode + */ + assertPlan(logicalQueryPlan, output(project(filter(project(topKRanking(group(tableScan))))))); + } }
