This is an automated email from the ASF dual-hosted git repository.
mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 20cc78e5f2 [CALCITE-6715] Enhance RelFieldTrimmer to trim
LogicalCorrelate nodes
20cc78e5f2 is described below
commit 20cc78e5f2424aed6a364afe4fb76e30cb96fef1
Author: Zoltan Haindrich <[email protected]>
AuthorDate: Wed Dec 4 15:07:12 2024 +0000
[CALCITE-6715] Enhance RelFieldTrimmer to trim LogicalCorrelate nodes
---
.../apache/calcite/sql2rel/RelFieldTrimmer.java | 107 +++++++++++++++++++++
.../calcite/sql2rel/RexRewritingRelShuttle.java | 37 +++++++
.../org/apache/calcite/util/mapping/Mappings.java | 33 ++++++-
.../calcite/sql2rel/RelFieldTrimmerTest.java | 100 +++++++++++++++++++
core/src/test/resources/sql/sub-query.iq | 2 +-
5 files changed, 277 insertions(+), 2 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
index 46a2b085a1..69e9ae90bb 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
@@ -41,6 +41,7 @@ import org.apache.calcite.rel.core.Snapshot;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.SortExchange;
import org.apache.calcite.rel.core.TableScan;
+import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
import org.apache.calcite.rel.logical.LogicalTableModify;
import org.apache.calcite.rel.logical.LogicalValues;
@@ -56,6 +57,7 @@ import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexPermuteInputsShuttle;
import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
@@ -1336,6 +1338,111 @@ public class RelFieldTrimmer implements
ReflectiveVisitor {
return result(newSnapshot, inputMapping, snapshot);
}
+ /**
+ * Trims {@link LogicalCorrelate} nodes.
+ */
+ public TrimResult trimFields(LogicalCorrelate correlate,
+ ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
+ if (!extraFields.isEmpty()) {
+ // bail out with generic trim
+ return trimFields((RelNode) correlate, fieldsUsed, extraFields);
+ }
+
+ fieldsUsed = fieldsUsed.union(correlate.getRequiredColumns());
+
+ List<RelNode> newInputs = new ArrayList<>();
+ List<Mapping> inputMappings = new ArrayList<>();
+ int changeCount = 0;
+ int offset = 0;
+ for (RelNode input : correlate.getInputs()) {
+ final RelDataType inputRowType = input.getRowType();
+ final int inputFieldCount = inputRowType.getFieldCount();
+
+ ImmutableBitSet currentInputFieldsUsed = fieldsUsed
+ .intersect(ImmutableBitSet.range(offset, offset + inputFieldCount))
+ .shift(-offset);
+
+ TrimResult trimResult =
+ dispatchTrimFields(input, currentInputFieldsUsed, extraFields);
+
+ newInputs.add(trimResult.left);
+ inputMappings.add(trimResult.right);
+
+ offset += inputFieldCount;
+
+ if (trimResult.left != input) {
+ changeCount++;
+ }
+ }
+
+ if (changeCount == 0) {
+ return result(correlate,
+ Mappings.createIdentity(correlate.getRowType().getFieldCount()));
+ }
+
+ Mapping mapping = Mappings.concatenateMappings(inputMappings);
+ RexBuilder rexBuilder = relBuilder.getRexBuilder();
+
+ RelNode newLeft = newInputs.get(0);
+ RexCorrelVariableMapShuttle rexVisitor =
+ new RexCorrelVariableMapShuttle(correlate.getCorrelationId(),
+ newLeft.getRowType(), mapping, rexBuilder);
+ RelNode newRight =
+ newInputs.get(1).accept(new RexRewritingRelShuttle(rexVisitor));
+ final LogicalCorrelate newCorrelate =
+ correlate
+ .copy(correlate.getTraitSet(),
+ newLeft,
+ newRight,
+ correlate.getCorrelationId(),
+ correlate.getRequiredColumns().permute(mapping),
+ correlate.getJoinType());
+
+ return result(newCorrelate, mapping);
+ }
+
+ /**
+ * Updates correlate references in {@link RexNode} expressions.
+ */
+ static class RexCorrelVariableMapShuttle extends RexShuttle {
+ private final CorrelationId correlationId;
+ private final Mapping mapping;
+ private final RelDataType newCorrelRowType;
+ private final RexBuilder rexBuilder;
+
+
+ /**
+ * Constructs a RexCorrelVariableMapShuttle.
+ *
+ * @param correlationId The ID of the correlation variable to update.
+ * @param newCorrelRowType The new row type for the correlate reference.
+ * @param mapping Mapping to transform field indices.
+ * @param rexBuilder A builder for constructing new RexNodes.
+ */
+ RexCorrelVariableMapShuttle(final CorrelationId correlationId,
+ RelDataType newCorrelRowType, Mapping mapping, RexBuilder rexBuilder) {
+ this.correlationId = correlationId;
+ this.newCorrelRowType = newCorrelRowType;
+ this.mapping = mapping;
+ this.rexBuilder = rexBuilder;
+ }
+
+ @Override public RexNode visitFieldAccess(final RexFieldAccess
fieldAccess) {
+ if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
+ RexCorrelVariable referenceExpr =
+ (RexCorrelVariable) fieldAccess.getReferenceExpr();
+ if (referenceExpr.id.equals(correlationId)) {
+ int oldIndex = fieldAccess.getField().getIndex();
+ RexNode newCorrel =
+ rexBuilder.makeCorrel(newCorrelRowType, referenceExpr.id);
+ int newIndex = mapping.getTarget(oldIndex);
+ return rexBuilder.makeFieldAccess(newCorrel, newIndex);
+ }
+ }
+ return super.visitFieldAccess(fieldAccess);
+ }
+ }
+
protected Mapping createMapping(ImmutableBitSet fieldsUsed, int fieldCount) {
final Mapping mapping =
Mappings.create(
diff --git
a/core/src/main/java/org/apache/calcite/sql2rel/RexRewritingRelShuttle.java
b/core/src/main/java/org/apache/calcite/sql2rel/RexRewritingRelShuttle.java
new file mode 100644
index 0000000000..f2c64cdd27
--- /dev/null
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RexRewritingRelShuttle.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.sql2rel;
+
+import org.apache.calcite.rel.RelHomogeneousShuttle;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rex.RexShuttle;
+
+/**
+ * Dispatches a {@link RexShuttle} for all {@link RelNode}-s.
+ */
+public class RexRewritingRelShuttle extends RelHomogeneousShuttle {
+ private final RexShuttle rexVisitor;
+
+ RexRewritingRelShuttle(RexShuttle rexVisitor) {
+ this.rexVisitor = rexVisitor;
+ }
+
+ @Override public RelNode visit(RelNode other) {
+ RelNode next = super.visit(other);
+ return next.accept(rexVisitor);
+ }
+}
diff --git a/core/src/main/java/org/apache/calcite/util/mapping/Mappings.java
b/core/src/main/java/org/apache/calcite/util/mapping/Mappings.java
index 8b000032ac..82a3c74039 100644
--- a/core/src/main/java/org/apache/calcite/util/mapping/Mappings.java
+++ b/core/src/main/java/org/apache/calcite/util/mapping/Mappings.java
@@ -1560,7 +1560,7 @@ public abstract class Mappings {
@Override public Mapping inverse() {
return new OverridingTargetMapping(
- (TargetMapping) parent.inverse(),
+ parent.inverse(),
target,
source);
}
@@ -1861,4 +1861,35 @@ public abstract class Mappings {
parent.set(target, source);
}
}
+
+ /**
+ * Concatenates multiple mappings.
+ *
+ * <pre>
+ * [ 1:0, 2:1] // sourceCount:100
+ * [ 1:0, 2:1] // sourceCount:100
+ * output:
+ * [ 1:0, 2:1, 101:2, 102:3 ] ; sourceCount:200
+ * </pre>
+ */
+ public static Mapping concatenateMappings(List<Mapping> inputMappings) {
+ int fieldCount = 0;
+ int newFieldCount = 0;
+ for (Mapping inputMapping : inputMappings) {
+ fieldCount += inputMapping.getSourceCount();
+ newFieldCount += inputMapping.getTargetCount();
+ }
+ Mapping mapping =
+ create(MappingType.INVERSE_SURJECTION, fieldCount, newFieldCount);
+ int offset = 0;
+ int newOffset = 0;
+ for (Mapping inputMapping : inputMappings) {
+ for (IntPair pair : inputMapping) {
+ mapping.set(pair.source + offset, pair.target + newOffset);
+ }
+ offset += inputMapping.getSourceCount();
+ newOffset += inputMapping.getTargetCount();
+ }
+ return mapping;
+ }
}
diff --git
a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
index 2e3b509f5a..ebe65d93d0 100644
--- a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
+++ b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
@@ -32,17 +32,22 @@ import org.apache.calcite.rel.hint.HintPredicates;
import org.apache.calcite.rel.hint.HintStrategyTable;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.rules.CoreRules;
+import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.schema.SchemaPlus;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.test.CalciteAssert;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.tools.Programs;
import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.Holder;
import com.google.common.collect.Lists;
+import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.jupiter.api.Test;
+import java.util.Collections;
import java.util.List;
import static org.apache.calcite.test.Matchers.hasTree;
@@ -544,4 +549,99 @@ class RelFieldTrimmerTest {
+ " LogicalTableScan(table=[[scott, EMP]])\n";
assertThat(trimmed, hasTree(expected));
}
+
+ /**
+ * Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-6715">[CALCITE-6715]
+ * Enhance RelFieldTrimmer to trim LogicalCorrelate nodes</a>.
+ */
+ @Test void testLogicalCorrelateFieldTrimmer() {
+ final RelBuilder builder = RelBuilder.create(config().build());
+ final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
+ RelNode root = builder.scan("EMP")
+ .projectPlus(builder.call(SqlStdOperatorTable.PLUS, builder.field(0),
builder.field(0)))
+ .variable(v::set)
+ .values(new String[] {"dummy"}, true)
+ .project(
+ builder.call(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+ builder.field(v.get(), "DEPTNO"), builder.field(v.get(),
"DEPTNO")))
+ .uncollect(Collections.emptyList(), false)
+ .correlate(JoinRelType.LEFT, v.get().id, builder.field(2, 0, "DEPTNO"))
+ .aggregate(builder.groupKey("ENAME"),
builder.max(builder.field("EMPNO")))
+ .build();
+
+ String origTree = ""
+ + "LogicalAggregate(group=[{1}], agg#0=[MAX($0)])\n"
+ + " LogicalCorrelate(correlation=[$cor0], joinType=[left],
requiredColumns=[{7}])\n"
+ + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3],
HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[+($0, $0)])\n"
+ + " LogicalTableScan(table=[[scott, EMP]])\n"
+ + " Uncollect\n"
+ + " LogicalProject($f0=[ARRAY($cor0.DEPTNO, $cor0.DEPTNO)])\n"
+ + " LogicalValues(tuples=[[{ true }]])\n";
+ assertThat(root, hasTree(origTree));
+
+ final RelFieldTrimmer fieldTrimmer = new RelFieldTrimmer(null, builder);
+ final RelNode trimmed = fieldTrimmer.trim(root);
+ final String expected = ""
+ + "LogicalAggregate(group=[{1}], agg#0=[MAX($0)])\n"
+ + " LogicalCorrelate(correlation=[$cor0], joinType=[left],
requiredColumns=[{2}])\n"
+ + " LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7])\n"
+ + " LogicalTableScan(table=[[scott, EMP]])\n"
+ + " Uncollect\n"
+ + " LogicalProject($f0=[ARRAY($cor0.DEPTNO, $cor0.DEPTNO)])\n"
+ + " LogicalValues(tuples=[[{ true }]])\n";
+
+ assertThat(trimmed, hasTree(expected));
+ }
+
+ /**
+ * Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-6715">[CALCITE-6715]
+ * Enhance RelFieldTrimmer to trim LogicalCorrelate nodes</a>.
+ */
+ @Test void testLogicalCorrelateFieldTrimmer2() {
+ final RelBuilder builder = RelBuilder.create(config().build());
+ final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
+ RelNode root = builder.scan("EMP")
+ .projectPlus(builder.call(SqlStdOperatorTable.PLUS, builder.field(0),
builder.field(0)))
+ .variable(v::set)
+ .scan("DEPT")
+ .projectPlus(
+ builder.call(SqlStdOperatorTable.PLUS,
+ builder.field(v.get(), "DEPTNO"), builder.field(v.get(),
"DEPTNO")))
+ .filter(
+ builder.equals(builder.field(0),
+ builder.call(
+ SqlStdOperatorTable.PLUS,
+ builder.literal(10),
+ builder.field(v.get(), "DEPTNO"))))
+ .correlate(JoinRelType.LEFT, v.get().id, builder.field(2, 0, "DEPTNO"))
+ .aggregate(builder.groupKey("ENAME"),
builder.max(builder.field("EMPNO")))
+ .build();
+
+ String origTree = ""
+ + "LogicalAggregate(group=[{1}], agg#0=[MAX($0)])\n"
+ + " LogicalCorrelate(correlation=[$cor0], joinType=[left],
requiredColumns=[{7}])\n"
+ + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3],
HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[+($0, $0)])\n"
+ + " LogicalTableScan(table=[[scott, EMP]])\n"
+ + " LogicalFilter(condition=[=($0, +(10, $cor0.DEPTNO))])\n"
+ + " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2],
$f3=[+($cor0.DEPTNO, $cor0.DEPTNO)])\n"
+ + " LogicalTableScan(table=[[scott, DEPT]])\n";
+ assertThat(root, hasTree(origTree));
+
+ final RelFieldTrimmer fieldTrimmer = new RelFieldTrimmer(null, builder);
+ final RelNode trimmed = fieldTrimmer.trim(root);
+ final String expected = ""
+ + "LogicalAggregate(group=[{1}], agg#0=[MAX($0)])\n"
+ + " LogicalCorrelate(correlation=[$cor0], joinType=[left],
requiredColumns=[{2}])\n"
+ + " LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7])\n"
+ + " LogicalTableScan(table=[[scott, EMP]])\n"
+ + " LogicalFilter(condition=[=($0, +(10, $cor0.DEPTNO))])\n"
+ + " LogicalProject(DEPTNO=[$0])\n"
+ + " LogicalTableScan(table=[[scott, DEPT]])\n";
+
+ assertThat(trimmed, hasTree(expected));
+ }
+
+
}
diff --git a/core/src/test/resources/sql/sub-query.iq
b/core/src/test/resources/sql/sub-query.iq
index dc06578779..45edb3e4c9 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -481,7 +481,7 @@ EnumerableCalc(expr#0..2=[{inputs}], proj#0..1=[{exprs}])
EnumerableCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{0}])
EnumerableValues(tuples=[[{ 1, 2 }]])
EnumerableAggregate(group=[{0}])
- EnumerableCalc(expr#0..7=[{inputs}], expr#8=[true],
expr#9=[CAST($t7):INTEGER], expr#10=[$cor0], expr#11=[$t10.A], expr#12=[=($t9,
$t11)], i=[$t8], $condition=[$t12])
+ EnumerableCalc(expr#0..7=[{inputs}], expr#8=[true],
expr#9=[CAST($t7):INTEGER], expr#10=[$cor0], expr#11=[$t10.EXPR$0],
expr#12=[=($t9, $t11)], i=[$t8], $condition=[$t12])
EnumerableTableScan(table=[[scott, EMP]])
!plan