This is an automated email from the ASF dual-hosted git repository. lancelly pushed a commit to branch support_exists_and_correlate in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit dc4cd7bbb21700d28fb1a2da448b9263276f1d04 Author: lancelly <[email protected]> AuthorDate: Fri Jan 10 22:24:52 2025 +0800 add correlated agg related rules --- .../iterative/rule/AggregationDecorrelation.java | 83 ++++++ ...orrelatedDistinctAggregationWithProjection.java | 186 ++++++++++++ ...elatedDistinctAggregationWithoutProjection.java | 167 +++++++++++ ...mCorrelatedGlobalAggregationWithProjection.java | 318 +++++++++++++++++++++ ...rrelatedGlobalAggregationWithoutProjection.java | 304 ++++++++++++++++++++ ...CorrelatedGroupedAggregationWithProjection.java | 252 ++++++++++++++++ ...relatedGroupedAggregationWithoutProjection.java | 233 +++++++++++++++ .../rule/TransformCorrelatedJoinToJoin.java | 99 +++++++ .../plan/relational/planner/node/Patterns.java | 19 +- .../optimizations/LogicalOptimizeFactory.java | 30 +- 10 files changed, 1678 insertions(+), 13 deletions(-) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/AggregationDecorrelation.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/AggregationDecorrelation.java new file mode 100644 index 00000000000..e61bf6f7be2 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/AggregationDecorrelation.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; + +class AggregationDecorrelation { + private AggregationDecorrelation() {} + + public static boolean isDistinctOperator(PlanNode node) { + return node instanceof AggregationNode + && ((AggregationNode) node).getAggregations().isEmpty() + && ((AggregationNode) node).getGroupingSetCount() == 1 + && ((AggregationNode) node).hasNonEmptyGroupingSet(); + } + + public static Map<Symbol, AggregationNode.Aggregation> rewriteWithMasks( + Map<Symbol, AggregationNode.Aggregation> aggregations, Map<Symbol, Symbol> masks) { + ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> rewritten = ImmutableMap.builder(); + for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregations.entrySet()) { + Symbol symbol = entry.getKey(); + AggregationNode.Aggregation aggregation = entry.getValue(); + rewritten.put( + symbol, + new AggregationNode.Aggregation( + aggregation.getResolvedFunction(), + aggregation.getArguments(), + aggregation.isDistinct(), + aggregation.getFilter(), + aggregation.getOrderingScheme(), + Optional.of(masks.get(symbol)))); + } + + return rewritten.buildOrThrow(); + } + + /** + * Creates distinct aggregation node based on existing distinct aggregation node. + * + * @see #isDistinctOperator(PlanNode) + */ + public static AggregationNode restoreDistinctAggregation( + AggregationNode distinct, PlanNode source, List<Symbol> groupingKeys) { + checkArgument(isDistinctOperator(distinct)); + return new AggregationNode( + distinct.getPlanNodeId(), + source, + ImmutableMap.of(), + AggregationNode.singleGroupingSet(groupingKeys), + ImmutableList.of(), + distinct.getStep(), + Optional.empty(), + Optional.empty()); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithProjection.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithProjection.java new file mode 100644 index 00000000000..9e9423fea4e --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithProjection.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments; +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AssignUniqueId; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanNodeDecorrelator; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import com.google.common.collect.ImmutableList; +import org.apache.tsfile.read.common.type.LongType; + +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.singleGroupingSet; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.LEFT; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.filter; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.subquery; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.type; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.aggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.correlatedJoin; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.project; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern.nonEmpty; + +/** + * This rule decorrelates a correlated subquery of LEFT correlated join with distinct operator + * (grouped aggregation with no aggregation assignments) + * + * <p>Transforms: + * + * <pre> + * - CorrelatedJoin LEFT (correlation: [c], filter: true, output: a, x) + * - Input (a, c) + * - Project (x <- b + 100) + * - Aggregation "distinct operator" group by [b] + * - Source (b) with correlated filter (b > c) + * </pre> + * + * Into: + * + * <pre> + * - Project (a <- a, x <- b + 100) + * - Aggregation "distinct operator" group by [a, c, unique, b] + * - LEFT join (filter: b > c) + * - UniqueId (unique) + * - Input (a, c) + * - Source (b) decorrelated + * </pre> + */ +public class TransformCorrelatedDistinctAggregationWithProjection + implements Rule<CorrelatedJoinNode> { + private static final Capture<ProjectNode> PROJECTION = newCapture(); + private static final Capture<AggregationNode> AGGREGATION = newCapture(); + + private static final Pattern<CorrelatedJoinNode> PATTERN = + correlatedJoin() + .with(type().equalTo(LEFT)) + .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) + .with(filter().equalTo(TRUE_LITERAL)) + .with( + subquery() + .matching( + project() + .capturedAs(PROJECTION) + .with( + source() + .matching( + aggregation() + .matching(AggregationDecorrelation::isDistinctOperator) + .capturedAs(AGGREGATION))))); + + private final PlannerContext plannerContext; + + public TransformCorrelatedDistinctAggregationWithProjection(PlannerContext plannerContext) { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + } + + @Override + public Pattern<CorrelatedJoinNode> getPattern() { + return PATTERN; + } + + @Override + public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) { + // decorrelate nested plan + PlanNodeDecorrelator decorrelator = + new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup()); + Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = + decorrelator.decorrelateFilters( + captures.get(AGGREGATION).getChild(), correlatedJoinNode.getCorrelation()); + if (!decorrelatedSource.isPresent()) { + return Result.empty(); + } + + PlanNode source = decorrelatedSource.get().getNode(); + + // assign unique id on correlated join's input. It will be used to distinguish between original + // input rows after join + PlanNode inputWithUniqueId = + new AssignUniqueId( + context.getIdAllocator().genPlanNodeId(), + correlatedJoinNode.getInput(), + context.getSymbolAllocator().newSymbol("unique", LongType.getInstance())); + + JoinNode join = + new JoinNode( + context.getIdAllocator().genPlanNodeId(), + JoinNode.JoinType.LEFT, + inputWithUniqueId, + source, + ImmutableList.of(), + inputWithUniqueId.getOutputSymbols(), + source.getOutputSymbols(), + decorrelatedSource.get().getCorrelatedPredicates(), + Optional.empty()); + + // restore aggregation + AggregationNode aggregation = captures.get(AGGREGATION); + aggregation = + new AggregationNode( + aggregation.getPlanNodeId(), + join, + aggregation.getAggregations(), + singleGroupingSet( + ImmutableList.<Symbol>builder() + .addAll(join.getLeftOutputSymbols()) + .addAll(aggregation.getGroupingKeys()) + .build()), + ImmutableList.of(), + aggregation.getStep(), + Optional.empty(), + Optional.empty()); + + // restrict outputs and apply projection + Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols()); + List<Symbol> expectedAggregationOutputs = + aggregation.getOutputSymbols().stream() + .filter(outputSymbols::contains) + .collect(toImmutableList()); + + Assignments assignments = + Assignments.builder() + .putIdentities(expectedAggregationOutputs) + .putAll(captures.get(PROJECTION).getAssignments()) + .build(); + + return Result.ofPlanNode( + new ProjectNode(context.getIdAllocator().genPlanNodeId(), aggregation, assignments)); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.java new file mode 100644 index 00000000000..9625736a8f3 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AssignUniqueId; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns; +import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanNodeDecorrelator; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.apache.tsfile.read.common.type.BooleanType; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Util.restrictOutputs; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.singleGroupingSet; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.LEFT; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.filter; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.subquery; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.type; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.aggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.correlatedJoin; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern.nonEmpty; + +/** + * This rule decorrelates a correlated subquery of LEFT correlated join with distinct operator + * (grouped aggregation with no aggregation assignments) It is similar to + * TransformCorrelatedDistinctAggregationWithProjection rule, but does not support projection over + * aggregation in the subquery + * + * <p>Transforms: + * + * <pre> + * - CorrelatedJoin LEFT (correlation: [c], filter: true, output: a, b) + * - Input (a, c) + * - Aggregation "distinct operator" group by [b] + * - Source (b) with correlated filter (b > c) + * </pre> + * + * Into: + * + * <pre> + * - Project (a <- a, b <- b) + * - Aggregation "distinct operator" group by [a, c, unique, b] + * - LEFT join (filter: b > c) + * - UniqueId (unique) + * - Input (a, c) + * - Source (b) decorrelated + * </pre> + */ +public class TransformCorrelatedDistinctAggregationWithoutProjection + implements Rule<CorrelatedJoinNode> { + private static final Capture<AggregationNode> AGGREGATION = newCapture(); + + private static final Pattern<CorrelatedJoinNode> PATTERN = + correlatedJoin() + .with(type().equalTo(LEFT)) + .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) + .with(filter().equalTo(TRUE_LITERAL)) + .with( + subquery() + .matching( + aggregation() + .matching(AggregationDecorrelation::isDistinctOperator) + .capturedAs(AGGREGATION))); + + private final PlannerContext plannerContext; + + public TransformCorrelatedDistinctAggregationWithoutProjection(PlannerContext plannerContext) { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + } + + @Override + public Pattern<CorrelatedJoinNode> getPattern() { + return PATTERN; + } + + @Override + public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) { + // decorrelate nested plan + PlanNodeDecorrelator decorrelator = + new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup()); + Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = + decorrelator.decorrelateFilters( + captures.get(AGGREGATION).getChild(), correlatedJoinNode.getCorrelation()); + if (!decorrelatedSource.isPresent()) { + return Result.empty(); + } + + PlanNode source = decorrelatedSource.get().getNode(); + + // assign unique id on correlated join's input. It will be used to distinguish between original + // input rows after join + PlanNode inputWithUniqueId = + new AssignUniqueId( + context.getIdAllocator().genPlanNodeId(), + correlatedJoinNode.getInput(), + context.getSymbolAllocator().newSymbol("unique", BooleanType.getInstance())); + + JoinNode join = + new JoinNode( + context.getIdAllocator().genPlanNodeId(), + JoinNode.JoinType.LEFT, + inputWithUniqueId, + source, + ImmutableList.of(), + inputWithUniqueId.getOutputSymbols(), + source.getOutputSymbols(), + decorrelatedSource.get().getCorrelatedPredicates(), + Optional.empty()); + + // restore aggregation + AggregationNode aggregation = captures.get(AGGREGATION); + aggregation = + AggregationNode.builderFrom(aggregation) + .setSource(join) + .setGroupingSets( + singleGroupingSet( + ImmutableList.<Symbol>builder() + .addAll(join.getLeftOutputSymbols()) + .addAll(aggregation.getGroupingKeys()) + .build())) + .setPreGroupedSymbols(ImmutableList.of()) + .setHashSymbol(Optional.empty()) + .setGroupIdSymbol(Optional.empty()) + .build(); + + // restrict outputs + Optional<PlanNode> project = + restrictOutputs( + context.getIdAllocator(), + aggregation, + ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())); + + return Result.ofPlanNode(project.orElse(aggregation)); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java new file mode 100644 index 00000000000..021de9fe9bd --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments; +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AssignUniqueId; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanNodeDecorrelator; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.tsfile.read.common.type.BooleanType; +import org.apache.tsfile.read.common.type.LongType; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils.and; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation.rewriteWithMasks; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.singleGroupingSet; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.INNER; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.LEFT; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.Aggregation.groupingColumns; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.filter; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.subquery; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.aggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.correlatedJoin; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.project; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern.empty; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern.nonEmpty; + +/** + * This rule decorrelates a correlated subquery of LEFT or INNER correlated join with: - single + * global aggregation, or - global aggregation over distinct operator (grouped aggregation with no + * aggregation assignments), in case when the distinct operator cannot be de-correlated by + * PlanNodeDecorrelator + * + * <p>In the case of single aggregation, it transforms: + * + * <pre> + * - CorrelatedJoin LEFT or INNER (correlation: [c], filter: true, output: a, x, y) + * - Input (a, c) + * - Project (x <- f(count), y <- f'(agg)) + * - Aggregation global + * count <- count(*) + * agg <- agg(b) + * - Source (b) with correlated filter (b > c) + * </pre> + * + * Into: + * + * <pre> + * - Project (a <- a, x <- f(count), y <- f'(agg)) + * - Aggregation (group by [a, c, unique]) + * count <- count(*) mask(non_null) + * agg <- agg(b) mask(non_null) + * - LEFT join (filter: b > c) + * - UniqueId (unique) + * - Input (a, c) + * - Project (non_null <- TRUE) + * - Source (b) decorrelated + * </pre> + * + * <p>In the case of global aggregation over distinct operator, it transforms: + * + * <pre> + * - CorrelatedJoin LEFT or INNER (correlation: [c], filter: true, output: a, x, y) + * - Input (a, c) + * - Project (x <- f(count), y <- f'(agg)) + * - Aggregation global + * count <- count(*) + * agg <- agg(b) + * - Aggregation "distinct operator" group by [b] + * - Source (b) with correlated filter (b > c) + * </pre> + * + * Into: + * + * <pre> + * - Project (a <- a, x <- f(count), y <- f'(agg)) + * - Aggregation (group by [a, c, unique]) + * count <- count(*) mask(non_null) + * agg <- agg(b) mask(non_null) + * - Aggregation "distinct operator" group by [a, c, unique, non_null, b] + * - LEFT join (filter: b > c) + * - UniqueId (unique) + * - Input (a, c) + * - Project (non_null <- TRUE) + * - Source (b) decorrelated + * </pre> + */ +public class TransformCorrelatedGlobalAggregationWithProjection + implements Rule<CorrelatedJoinNode> { + private static final Capture<ProjectNode> PROJECTION = newCapture(); + private static final Capture<AggregationNode> AGGREGATION = newCapture(); + private static final Capture<PlanNode> SOURCE = newCapture(); + + private static final Pattern<CorrelatedJoinNode> PATTERN = + correlatedJoin() + .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) + .with(filter().equalTo(TRUE_LITERAL)) + .with( + subquery() + .matching( + project() + .capturedAs(PROJECTION) + .with( + source() + .matching( + aggregation() + .with(empty(groupingColumns())) + .with(source().capturedAs(SOURCE)) + .capturedAs(AGGREGATION))))); + + private final PlannerContext plannerContext; + + public TransformCorrelatedGlobalAggregationWithProjection(PlannerContext plannerContext) { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + } + + @Override + public Pattern<CorrelatedJoinNode> getPattern() { + return PATTERN; + } + + @Override + public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) { + checkArgument( + correlatedJoinNode.getJoinType() == INNER || correlatedJoinNode.getJoinType() == LEFT, + "unexpected correlated join type: %s", + correlatedJoinNode.getJoinType()); + + // if there is another aggregation below the AggregationNode, handle both + PlanNode source = captures.get(SOURCE); + + // if we fail to decorrelate the nested plan, and it contains a distinct operator, we can + // extract and special-handle the distinct operator + AggregationNode distinct = null; + + // decorrelate nested plan + PlanNodeDecorrelator decorrelator = + new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup()); + Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = + decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); + if (!decorrelatedSource.isPresent()) { + // we failed to decorrelate the nested plan, so check if we can extract a distinct operator + // from the nested plan + if (isDistinctOperator(source)) { + distinct = (AggregationNode) source; + source = distinct.getChild(); + decorrelatedSource = + decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); + } + if (!decorrelatedSource.isPresent()) { + return Result.empty(); + } + } + + source = decorrelatedSource.get().getNode(); + + // append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive + // aggregations after LEFT join + Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BooleanType.getInstance()); + source = + new ProjectNode( + context.getIdAllocator().genPlanNodeId(), + source, + Assignments.builder() + .putIdentities(source.getOutputSymbols()) + .put(nonNull, TRUE_LITERAL) + .build()); + + // assign unique id on correlated join's input. It will be used to distinguish between original + // input rows after join + PlanNode inputWithUniqueId = + new AssignUniqueId( + context.getIdAllocator().genPlanNodeId(), + correlatedJoinNode.getInput(), + context.getSymbolAllocator().newSymbol("unique", LongType.getInstance())); + + JoinNode join = + new JoinNode( + context.getIdAllocator().genPlanNodeId(), + JoinNode.JoinType.LEFT, + inputWithUniqueId, + source, + ImmutableList.of(), + inputWithUniqueId.getOutputSymbols(), + source.getOutputSymbols(), + decorrelatedSource.get().getCorrelatedPredicates(), + Optional.empty()); + + PlanNode root = join; + + // restore distinct aggregation + if (distinct != null) { + root = + restoreDistinctAggregation( + distinct, + join, + ImmutableList.<Symbol>builder() + .addAll(join.getLeftOutputSymbols()) + .add(nonNull) + .addAll(distinct.getGroupingKeys()) + .build()); + } + + // prepare mask symbols for aggregations + // Every original aggregation agg() will be rewritten to agg() mask(non_null). If the + // aggregation + // already has a mask, it will be replaced with conjunction of the existing mask and non_null. + // This is necessary to restore the original aggregation result in case when: + // - the nested lateral subquery returned empty result for some input row, + // - aggregation is null-sensitive, which means that its result over a single null row is + // different + // than result for empty input (with global grouping) + // It applies to the following aggregate functions: count(*), checksum(), array_agg(). + AggregationNode globalAggregation = captures.get(AGGREGATION); + ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder(); + Assignments.Builder assignmentsBuilder = Assignments.builder(); + for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : + globalAggregation.getAggregations().entrySet()) { + AggregationNode.Aggregation aggregation = entry.getValue(); + if (aggregation.getMask().isPresent()) { + Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BooleanType.getInstance()); + Expression expression = + and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference()); + assignmentsBuilder.put(newMask, expression); + masks.put(entry.getKey(), newMask); + } else { + masks.put(entry.getKey(), nonNull); + } + } + Assignments maskAssignments = assignmentsBuilder.build(); + if (!maskAssignments.isEmpty()) { + root = + new ProjectNode( + context.getIdAllocator().genPlanNodeId(), + root, + Assignments.builder() + .putIdentities(root.getOutputSymbols()) + .putAll(maskAssignments) + .build()); + } + + // restore global aggregation + globalAggregation = + new AggregationNode( + globalAggregation.getPlanNodeId(), + root, + rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), + singleGroupingSet( + ImmutableList.<Symbol>builder() + .addAll(join.getLeftOutputSymbols()) + .addAll(globalAggregation.getGroupingKeys()) + .build()), + ImmutableList.of(), + globalAggregation.getStep(), + Optional.empty(), + Optional.empty()); + + // restrict outputs and apply projection + Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols()); + List<Symbol> expectedAggregationOutputs = + globalAggregation.getOutputSymbols().stream() + .filter(outputSymbols::contains) + .collect(toImmutableList()); + + Assignments assignments = + Assignments.builder() + .putIdentities(expectedAggregationOutputs) + .putAll(captures.get(PROJECTION).getAssignments()) + .build(); + + return Result.ofPlanNode( + new ProjectNode(context.getIdAllocator().genPlanNodeId(), globalAggregation, assignments)); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java new file mode 100644 index 00000000000..ab7979f9eda --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments; +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AssignUniqueId; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanNodeDecorrelator; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.apache.tsfile.read.common.type.BooleanType; +import org.apache.tsfile.read.common.type.LongType; + +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils.and; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation.rewriteWithMasks; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Util.restrictOutputs; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.singleGroupingSet; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.INNER; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.LEFT; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.Aggregation.groupingColumns; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.filter; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.subquery; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.aggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.correlatedJoin; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern.empty; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern.nonEmpty; + +/** + * This rule decorrelates a correlated subquery with: - single global aggregation, or - global + * aggregation over distinct operator (grouped aggregation with no aggregation assignments), in case + * when the distinct operator cannot be de-correlated by PlanNodeDecorrelator It is similar to + * TransformCorrelatedGlobalAggregationWithProjection rule, but does not support projection over + * aggregation in the subquery + * + * <p>In the case of single aggregation, it transforms: + * + * <pre> + * - CorrelatedJoin LEFT or INNER (correlation: [c], filter: true, output: a, count, agg) + * - Input (a, c) + * - Aggregation global + * count <- count(*) + * agg <- agg(b) + * - Source (b) with correlated filter (b > c) + * </pre> + * + * Into: + * + * <pre> + * - Project (a <- a, count <- count, agg <- agg) + * - Aggregation (group by [a, c, unique]) + * count <- count(*) mask(non_null) + * agg <- agg(b) mask(non_null) + * - LEFT join (filter: b > c) + * - UniqueId (unique) + * - Input (a, c) + * - Project (non_null <- TRUE) + * - Source (b) decorrelated + * </pre> + * + * <p>In the case of global aggregation over distinct operator, it transforms: + * + * <pre> + * - CorrelatedJoin LEFT or INNER (correlation: [c], filter: true, output: a, count, agg) + * - Input (a, c) + * - Aggregation global + * count <- count(*) + * agg <- agg(b) + * - Aggregation "distinct operator" group by [b] + * - Source (b) with correlated filter (b > c) + * </pre> + * + * Into: + * + * <pre> + * - Project (a <- a, count <- count, agg <- agg) + * - Aggregation (group by [a, c, unique]) + * count <- count(*) mask(non_null) + * agg <- agg(b) mask(non_null) + * - Aggregation "distinct operator" group by [a, c, unique, non_null, b] + * - LEFT join (filter: b > c) + * - UniqueId (unique) + * - Input (a, c) + * - Project (non_null <- TRUE) + * - Source (b) decorrelated + * </pre> + */ +public class TransformCorrelatedGlobalAggregationWithoutProjection + implements Rule<CorrelatedJoinNode> { + private static final Capture<AggregationNode> AGGREGATION = newCapture(); + private static final Capture<PlanNode> SOURCE = newCapture(); + + private static final Pattern<CorrelatedJoinNode> PATTERN = + correlatedJoin() + .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) + .with( + filter() + .equalTo( + TRUE_LITERAL)) // todo non-trivial join filter: adding filter/project on top + // of aggregation + .with( + subquery() + .matching( + aggregation() + .with(empty(groupingColumns())) + .with(source().capturedAs(SOURCE)) + .capturedAs(AGGREGATION))); + + private final PlannerContext plannerContext; + + public TransformCorrelatedGlobalAggregationWithoutProjection(PlannerContext plannerContext) { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + } + + @Override + public Pattern<CorrelatedJoinNode> getPattern() { + return PATTERN; + } + + @Override + public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) { + checkArgument( + correlatedJoinNode.getJoinType() == INNER || correlatedJoinNode.getJoinType() == LEFT, + "unexpected correlated join type: %s", + correlatedJoinNode.getJoinType()); + + PlanNode source = captures.get(SOURCE); + + // if we fail to decorrelate the nested plan, and it contains a distinct operator, we can + // extract and special-handle the distinct operator + AggregationNode distinct = null; + + // decorrelate nested plan + PlanNodeDecorrelator decorrelator = + new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup()); + Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = + decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); + if (!decorrelatedSource.isPresent()) { + // we failed to decorrelate the nested plan, so check if we can extract a distinct operator + // from the nested plan + if (isDistinctOperator(source)) { + distinct = (AggregationNode) source; + source = distinct.getChild(); + decorrelatedSource = + decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); + } + if (!decorrelatedSource.isPresent()) { + return Result.empty(); + } + } + + source = decorrelatedSource.get().getNode(); + + // append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive + // aggregations after LEFT join + Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BooleanType.getInstance()); + source = + new ProjectNode( + context.getIdAllocator().genPlanNodeId(), + source, + Assignments.builder() + .putIdentities(source.getOutputSymbols()) + .put(nonNull, TRUE_LITERAL) + .build()); + + // assign unique id on correlated join's input. It will be used to distinguish between original + // input rows after join + PlanNode inputWithUniqueId = + new AssignUniqueId( + context.getIdAllocator().genPlanNodeId(), + correlatedJoinNode.getInput(), + context.getSymbolAllocator().newSymbol("unique", LongType.getInstance())); + + JoinNode join = + new JoinNode( + context.getIdAllocator().genPlanNodeId(), + JoinNode.JoinType.LEFT, + inputWithUniqueId, + source, + ImmutableList.of(), + inputWithUniqueId.getOutputSymbols(), + source.getOutputSymbols(), + decorrelatedSource.get().getCorrelatedPredicates(), + Optional.empty()); + + PlanNode root = join; + + // restore distinct aggregation + if (distinct != null) { + root = + restoreDistinctAggregation( + distinct, + join, + ImmutableList.<Symbol>builder() + .addAll(join.getLeftOutputSymbols()) + .add(nonNull) + .addAll(distinct.getGroupingKeys()) + .build()); + } + + // prepare mask symbols for aggregations + // Every original aggregation agg() will be rewritten to agg() mask(non_null). If the + // aggregation + // already has a mask, it will be replaced with conjunction of the existing mask and non_null. + // This is necessary to restore the original aggregation result in case when: + // - the nested lateral subquery returned empty result for some input row, + // - aggregation is null-sensitive, which means that its result over a single null row is + // different + // than result for empty input (with global grouping) + // It applies to the following aggregate functions: count(*), checksum(), array_agg(). + AggregationNode globalAggregation = captures.get(AGGREGATION); + ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder(); + Assignments.Builder assignmentsBuilder = Assignments.builder(); + for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : + globalAggregation.getAggregations().entrySet()) { + AggregationNode.Aggregation aggregation = entry.getValue(); + if (aggregation.getMask().isPresent()) { + Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BooleanType.getInstance()); + Expression expression = + and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference()); + assignmentsBuilder.put(newMask, expression); + masks.put(entry.getKey(), newMask); + } else { + masks.put(entry.getKey(), nonNull); + } + } + Assignments maskAssignments = assignmentsBuilder.build(); + if (!maskAssignments.isEmpty()) { + root = + new ProjectNode( + context.getIdAllocator().genPlanNodeId(), + root, + Assignments.builder() + .putIdentities(root.getOutputSymbols()) + .putAll(maskAssignments) + .build()); + } + + // restore global aggregation + globalAggregation = + new AggregationNode( + globalAggregation.getPlanNodeId(), + root, + rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), + singleGroupingSet( + ImmutableList.<Symbol>builder() + .addAll(join.getLeftOutputSymbols()) + .addAll(globalAggregation.getGroupingKeys()) + .build()), + ImmutableList.of(), + globalAggregation.getStep(), + Optional.empty(), + Optional.empty()); + + // restrict outputs + Optional<PlanNode> project = + restrictOutputs( + context.getIdAllocator(), + globalAggregation, + ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())); + + return Result.ofPlanNode(project.orElse(globalAggregation)); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java new file mode 100644 index 00000000000..70a1a4d7bc2 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments; +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AssignUniqueId; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanNodeDecorrelator; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import com.google.common.collect.ImmutableList; +import org.apache.tsfile.read.common.type.BooleanType; + +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.singleGroupingSet; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.INNER; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.Aggregation.groupingColumns; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.filter; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.subquery; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.type; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.aggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.correlatedJoin; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.project; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern.nonEmpty; + +/** + * This rule decorrelates a correlated subquery of INNER correlated join with: - single grouped + * aggregation, or - grouped aggregation over distinct operator (grouped aggregation with no + * aggregation assignments), in case when the distinct operator cannot be de-correlated by + * PlanNodeDecorrelator + * + * <p>In the case of single aggregation, it transforms: + * + * <pre> + * - CorrelatedJoin INNER (correlation: [c], filter: true, output: a, count, agg) + * - Input (a, c) + * - Project (x <- f(count), y <- f'(agg)) + * - Aggregation (group by b) + * count <- count(*) + * agg <- agg(d) + * - Source (b, d) with correlated filter (b > c) + * </pre> + * + * Into: + * + * <pre> + * - Project (a <- a, x <- f(count), y <- f'(agg)) + * - Aggregation (group by [a, c, unique, b]) + * count <- count(*) + * agg <- agg(d) + * - INNER join (filter: b > c) + * - UniqueId (unique) + * - Input (a, c) + * - Source (b, d) decorrelated + * </pre> + * + * <p>In the case of grouped aggregation over distinct operator, it transforms: + * + * <pre> + * - CorrelatedJoin INNER (correlation: [c], filter: true, output: a, count, agg) + * - Input (a, c) + * - Project (x <- f(count), y <- f'(agg)) + * - Aggregation (group by b) + * count <- count(*) + * agg <- agg(b) + * - Aggregation "distinct operator" group by [b] + * - Source (b) with correlated filter (b > c) + * </pre> + * + * Into: + * + * <pre> + * - Project (a <- a, x <- f(count), y <- f'(agg)) + * - Aggregation (group by [a, c, unique, b]) + * count <- count(*) + * agg <- agg(b) + * - Aggregation "distinct operator" group by [a, c, unique, b] + * - INNER join (filter: b > c) + * - UniqueId (unique) + * - Input (a, c) + * - Source (b) decorrelated + * </pre> + */ +public class TransformCorrelatedGroupedAggregationWithProjection + implements Rule<CorrelatedJoinNode> { + private static final Capture<ProjectNode> PROJECTION = newCapture(); + private static final Capture<AggregationNode> AGGREGATION = newCapture(); + private static final Capture<PlanNode> SOURCE = newCapture(); + + private static final Pattern<CorrelatedJoinNode> PATTERN = + correlatedJoin() + .with(type().equalTo(INNER)) + .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) + .with(filter().equalTo(TRUE_LITERAL)) + .with( + subquery() + .matching( + project() + .capturedAs(PROJECTION) + .with( + source() + .matching( + aggregation() + .with(nonEmpty(groupingColumns())) + .matching( + aggregation -> aggregation.getGroupingSetCount() == 1) + .with(source().capturedAs(SOURCE)) + .capturedAs(AGGREGATION))))); + + private final PlannerContext plannerContext; + + public TransformCorrelatedGroupedAggregationWithProjection(PlannerContext plannerContext) { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + } + + @Override + public Pattern<CorrelatedJoinNode> getPattern() { + return PATTERN; + } + + @Override + public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) { + PlanNode source = captures.get(SOURCE); + + // if we fail to decorrelate the nested plan, and it contains a distinct operator, we can + // extract and special-handle the distinct operator + AggregationNode distinct = null; + + // decorrelate nested plan + PlanNodeDecorrelator decorrelator = + new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup()); + Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = + decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); + if (!decorrelatedSource.isPresent()) { + // we failed to decorrelate the nested plan, so check if we can extract a distinct operator + // from the nested plan + if (isDistinctOperator(source)) { + distinct = (AggregationNode) source; + source = distinct.getChild(); + decorrelatedSource = + decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); + } + if (!decorrelatedSource.isPresent()) { + return Result.empty(); + } + } + + source = decorrelatedSource.get().getNode(); + + // assign unique id on correlated join's input. It will be used to distinguish between original + // input rows after join + PlanNode inputWithUniqueId = + new AssignUniqueId( + context.getIdAllocator().genPlanNodeId(), + correlatedJoinNode.getInput(), + context.getSymbolAllocator().newSymbol("unique", BooleanType.getInstance())); + + JoinNode join = + new JoinNode( + context.getIdAllocator().genPlanNodeId(), + JoinNode.JoinType.INNER, + inputWithUniqueId, + source, + ImmutableList.of(), + inputWithUniqueId.getOutputSymbols(), + source.getOutputSymbols(), + decorrelatedSource.get().getCorrelatedPredicates(), + Optional.empty()); + + // restore distinct aggregation + if (distinct != null) { + distinct = + restoreDistinctAggregation( + distinct, + join, + ImmutableList.<Symbol>builder() + .addAll(join.getLeftOutputSymbols()) + .addAll(distinct.getGroupingKeys()) + .build()); + } + + // restore grouped aggregation + AggregationNode groupedAggregation = captures.get(AGGREGATION); + groupedAggregation = + AggregationNode.builderFrom(groupedAggregation) + .setSource(distinct != null ? distinct : join) + .setGroupingSets( + singleGroupingSet( + ImmutableList.<Symbol>builder() + .addAll(join.getLeftOutputSymbols()) + .addAll(groupedAggregation.getGroupingKeys()) + .build())) + .setPreGroupedSymbols(ImmutableList.of()) + .setHashSymbol(Optional.empty()) + .setGroupIdSymbol(Optional.empty()) + .build(); + + // restrict outputs and apply projection + Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols()); + List<Symbol> expectedAggregationOutputs = + groupedAggregation.getOutputSymbols().stream() + .filter(outputSymbols::contains) + .collect(toImmutableList()); + + Assignments assignments = + Assignments.builder() + .putIdentities(expectedAggregationOutputs) + .putAll(captures.get(PROJECTION).getAssignments()) + .build(); + + return Result.ofPlanNode( + new ProjectNode(context.getIdAllocator().genPlanNodeId(), groupedAggregation, assignments)); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java new file mode 100644 index 00000000000..2674632a619 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AssignUniqueId; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns; +import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanNodeDecorrelator; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.apache.tsfile.read.common.type.BooleanType; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Util.restrictOutputs; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.singleGroupingSet; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.INNER; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.Aggregation.groupingColumns; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.filter; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.subquery; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.type; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.aggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.correlatedJoin; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern.nonEmpty; + +/** + * This rule decorrelates a correlated subquery of INNER correlated join with: - single grouped + * aggregation, or - grouped aggregation over distinct operator (grouped aggregation with no + * aggregation assignments), in case when the distinct operator cannot be de-correlated by + * PlanNodeDecorrelator It is similar to TransformCorrelatedGroupedAggregationWithProjection rule, + * but does not support projection over aggregation in the subquery + * + * <p>In the case of single aggregation, it transforms: + * + * <pre> + * - CorrelatedJoin INNER (correlation: [c], filter: true, output: a, count, agg) + * - Input (a, c) + * - Aggregation (group by b) + * count <- count(*) + * agg <- agg(d) + * - Source (b, d) with correlated filter (b > c) + * </pre> + * + * Into: + * + * <pre> + * - Project (a <- a, count <- count, agg <- agg) + * - Aggregation (group by [a, c, unique, b]) + * count <- count(*) + * agg <- agg(d) + * - INNER join (filter: b > c) + * - UniqueId (unique) + * - Input (a, c) + * - Source (b, d) decorrelated + * </pre> + * + * <p>In the case of grouped aggregation over distinct operator, it transforms: + * + * <pre> + * - CorrelatedJoin INNER (correlation: [c], filter: true, output: a, count, agg) + * - Input (a, c) + * - Aggregation (group by b) + * count <- count(*) + * agg <- agg(b) + * - Aggregation "distinct operator" group by [b] + * - Source (b) with correlated filter (b > c) + * </pre> + * + * Into: + * + * <pre> + * - Project (a <- a, count <- count, agg <- agg) + * - Aggregation (group by [a, c, unique, b]) + * count <- count(*) + * agg <- agg(b) + * - Aggregation "distinct operator" group by [a, c, unique, b] + * - INNER join (filter: b > c) + * - UniqueId (unique) + * - Input (a, c) + * - Source (b) decorrelated + * </pre> + */ +public class TransformCorrelatedGroupedAggregationWithoutProjection + implements Rule<CorrelatedJoinNode> { + private static final Capture<AggregationNode> AGGREGATION = newCapture(); + private static final Capture<PlanNode> SOURCE = newCapture(); + + private static final Pattern<CorrelatedJoinNode> PATTERN = + correlatedJoin() + .with(type().equalTo(INNER)) + .with(nonEmpty(Patterns.CorrelatedJoin.correlation())) + .with(filter().equalTo(TRUE_LITERAL)) + .with( + subquery() + .matching( + aggregation() + .with(nonEmpty(groupingColumns())) + .matching(aggregation -> aggregation.getGroupingSetCount() == 1) + .with(source().capturedAs(SOURCE)) + .capturedAs(AGGREGATION))); + + private final PlannerContext plannerContext; + + public TransformCorrelatedGroupedAggregationWithoutProjection(PlannerContext plannerContext) { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + } + + @Override + public Pattern<CorrelatedJoinNode> getPattern() { + return PATTERN; + } + + @Override + public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) { + PlanNode source = captures.get(SOURCE); + + // if we fail to decorrelate the nested plan, and it contains a distinct operator, we can + // extract and special-handle the distinct operator + AggregationNode distinct = null; + + // decorrelate nested plan + PlanNodeDecorrelator decorrelator = + new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup()); + Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = + decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); + if (!decorrelatedSource.isPresent()) { + // we failed to decorrelate the nested plan, so check if we can extract a distinct operator + // from the nested plan + if (isDistinctOperator(source)) { + distinct = (AggregationNode) source; + source = distinct.getChild(); + decorrelatedSource = + decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); + } + if (!decorrelatedSource.isPresent()) { + return Result.empty(); + } + } + + source = decorrelatedSource.get().getNode(); + + // assign unique id on correlated join's input. It will be used to distinguish between original + // input rows after join + PlanNode inputWithUniqueId = + new AssignUniqueId( + context.getIdAllocator().genPlanNodeId(), + correlatedJoinNode.getInput(), + context.getSymbolAllocator().newSymbol("unique", BooleanType.getInstance())); + + JoinNode join = + new JoinNode( + context.getIdAllocator().genPlanNodeId(), + JoinNode.JoinType.INNER, + inputWithUniqueId, + source, + ImmutableList.of(), + inputWithUniqueId.getOutputSymbols(), + source.getOutputSymbols(), + decorrelatedSource.get().getCorrelatedPredicates(), + Optional.empty()); + + // restore distinct aggregation + if (distinct != null) { + distinct = + restoreDistinctAggregation( + distinct, + join, + ImmutableList.<Symbol>builder() + .addAll(join.getLeftOutputSymbols()) + .addAll(distinct.getGroupingKeys()) + .build()); + } + + // restore grouped aggregation + AggregationNode groupedAggregation = captures.get(AGGREGATION); + groupedAggregation = + AggregationNode.builderFrom(groupedAggregation) + .setSource(distinct != null ? distinct : join) + .setAggregations(groupedAggregation.getAggregations()) + .setGroupingSets( + singleGroupingSet( + ImmutableList.<Symbol>builder() + .addAll(join.getLeftOutputSymbols()) + .addAll(groupedAggregation.getGroupingKeys()) + .build())) + .setPreGroupedSymbols(ImmutableList.of()) + .setHashSymbol(Optional.empty()) + .setGroupIdSymbol(Optional.empty()) + .build(); + + // restrict outputs + Optional<PlanNode> project = + restrictOutputs( + context.getIdAllocator(), + groupedAggregation, + ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())); + + return Result.ofPlanNode(project.orElse(groupedAggregation)); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedJoinToJoin.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedJoinToJoin.java new file mode 100644 index 00000000000..cfd22fe4f12 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedJoinToJoin.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanNodeDecorrelator; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import com.google.common.collect.ImmutableList; + +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils.combineConjuncts; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.INNER; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.LEFT; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.correlation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.correlatedJoin; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern.nonEmpty; + +/** + * Tries to decorrelate subquery and rewrite it using normal join. Decorrelated predicates are part + * of join condition. + */ +public class TransformCorrelatedJoinToJoin implements Rule<CorrelatedJoinNode> { + private static final Pattern<CorrelatedJoinNode> PATTERN = + correlatedJoin().with(nonEmpty(correlation())); + + private final PlannerContext plannerContext; + + public TransformCorrelatedJoinToJoin(PlannerContext plannerContext) { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + } + + @Override + public Pattern<CorrelatedJoinNode> getPattern() { + return PATTERN; + } + + @Override + public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) { + checkArgument( + correlatedJoinNode.getJoinType() == INNER || correlatedJoinNode.getJoinType() == LEFT, + "correlation in %s JOIN", + correlatedJoinNode.getJoinType().name()); + PlanNode subquery = correlatedJoinNode.getSubquery(); + + PlanNodeDecorrelator planNodeDecorrelator = + new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup()); + Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedNodeOptional = + planNodeDecorrelator.decorrelateFilters(subquery, correlatedJoinNode.getCorrelation()); + if (!decorrelatedNodeOptional.isPresent()) { + return Result.empty(); + } + PlanNodeDecorrelator.DecorrelatedNode decorrelatedSubquery = decorrelatedNodeOptional.get(); + + Expression filter = + combineConjuncts( + decorrelatedSubquery.getCorrelatedPredicates().orElse(TRUE_LITERAL), + correlatedJoinNode.getFilter()); + + return Result.ofPlanNode( + new JoinNode( + correlatedJoinNode.getPlanNodeId(), + correlatedJoinNode.getJoinType(), + correlatedJoinNode.getInput(), + decorrelatedSubquery.getNode(), + ImmutableList.of(), + correlatedJoinNode.getInput().getOutputSymbols(), + correlatedJoinNode.getSubquery().getOutputSymbols(), + filter.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(filter), + Optional.empty())); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/Patterns.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/Patterns.java index f676fcd0ca5..a64eea771f6 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/Patterns.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/Patterns.java @@ -272,18 +272,15 @@ public final class Patterns { node.getChildren().stream().map(lookup::resolve).collect(toImmutableList())); } - /*public static final class Aggregation - { - public static Property<AggregationNode, Lookup, List<Symbol>> groupingColumns() - { - return property("groupingKeys", AggregationNode::getGroupingKeys); - } + public static final class Aggregation { + public static Property<AggregationNode, Lookup, List<Symbol>> groupingColumns() { + return property("groupingKeys", AggregationNode::getGroupingKeys); + } - public static Property<AggregationNode, Lookup, AggregationNode.Step> step() - { - return property("step", AggregationNode::getStep); - } - }*/ + public static Property<AggregationNode, Lookup, AggregationNode.Step> step() { + return property("step", AggregationNode::getStep); + } + } public static final class Apply { public static Property<ApplyNode, Lookup, List<Symbol>> correlation() { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java index de2b245bfd2..ca27877079d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java @@ -58,8 +58,16 @@ import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Re import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveRedundantEnforceSingleRowNode; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveRedundantIdentityProjections; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveTrivialFilters; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveUnreferencedScalarApplyNodes; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveUnreferencedScalarSubqueries; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.SimplifyExpressions; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformCorrelatedDistinctAggregationWithProjection; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformCorrelatedDistinctAggregationWithoutProjection; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformCorrelatedGlobalAggregationWithProjection; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformCorrelatedGlobalAggregationWithoutProjection; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformCorrelatedGroupedAggregationWithProjection; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformCorrelatedGroupedAggregationWithoutProjection; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformCorrelatedJoinToJoin; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToSemiJoin; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformUncorrelatedSubqueryToJoin; @@ -215,9 +223,27 @@ public class LogicalOptimizeFactory { plannerContext, ruleStats, ImmutableSet.of( - new RemoveRedundantEnforceSingleRowNode(), new RemoveUnreferencedScalarSubqueries(), + new RemoveRedundantEnforceSingleRowNode(), + new RemoveUnreferencedScalarSubqueries(), new TransformUncorrelatedSubqueryToJoin(), - new TransformUncorrelatedInPredicateSubqueryToSemiJoin())), + new TransformUncorrelatedInPredicateSubqueryToSemiJoin(), + new TransformCorrelatedJoinToJoin(plannerContext), + new TransformCorrelatedGlobalAggregationWithProjection(plannerContext), + new TransformCorrelatedGlobalAggregationWithoutProjection(plannerContext), + new TransformCorrelatedDistinctAggregationWithProjection(plannerContext), + new TransformCorrelatedDistinctAggregationWithoutProjection(plannerContext), + new TransformCorrelatedGroupedAggregationWithProjection(plannerContext), + new TransformCorrelatedGroupedAggregationWithoutProjection(plannerContext))), + new IterativeOptimizer( + plannerContext, + ruleStats, + ImmutableSet.of( + new RemoveUnreferencedScalarApplyNodes(), + // new TransformCorrelatedInPredicateToJoin(metadata), // + // must be run after columnPruningOptimizer + // new TransformCorrelatedScalarSubquery(metadata), // + // must be run after TransformCorrelatedAggregation rules + new TransformCorrelatedJoinToJoin(plannerContext))), new IterativeOptimizer( plannerContext, ruleStats,
