This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 20cc78e5f2 [CALCITE-6715] Enhance RelFieldTrimmer to trim 
LogicalCorrelate nodes
20cc78e5f2 is described below

commit 20cc78e5f2424aed6a364afe4fb76e30cb96fef1
Author: Zoltan Haindrich <[email protected]>
AuthorDate: Wed Dec 4 15:07:12 2024 +0000

    [CALCITE-6715] Enhance RelFieldTrimmer to trim LogicalCorrelate nodes
---
 .../apache/calcite/sql2rel/RelFieldTrimmer.java    | 107 +++++++++++++++++++++
 .../calcite/sql2rel/RexRewritingRelShuttle.java    |  37 +++++++
 .../org/apache/calcite/util/mapping/Mappings.java  |  33 ++++++-
 .../calcite/sql2rel/RelFieldTrimmerTest.java       | 100 +++++++++++++++++++
 core/src/test/resources/sql/sub-query.iq           |   2 +-
 5 files changed, 277 insertions(+), 2 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java 
b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
index 46a2b085a1..69e9ae90bb 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
@@ -41,6 +41,7 @@ import org.apache.calcite.rel.core.Snapshot;
 import org.apache.calcite.rel.core.Sort;
 import org.apache.calcite.rel.core.SortExchange;
 import org.apache.calcite.rel.core.TableScan;
+import org.apache.calcite.rel.logical.LogicalCorrelate;
 import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
 import org.apache.calcite.rel.logical.LogicalTableModify;
 import org.apache.calcite.rel.logical.LogicalValues;
@@ -56,6 +57,7 @@ import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexPermuteInputsShuttle;
 import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexShuttle;
 import org.apache.calcite.rex.RexSubQuery;
 import org.apache.calcite.rex.RexUtil;
 import org.apache.calcite.rex.RexVisitor;
@@ -1336,6 +1338,111 @@ public class RelFieldTrimmer implements 
ReflectiveVisitor {
     return result(newSnapshot, inputMapping, snapshot);
   }
 
+  /**
+   * Trims {@link LogicalCorrelate} nodes.
+   */
+  public TrimResult trimFields(LogicalCorrelate correlate,
+      ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
+    if (!extraFields.isEmpty()) {
+      // bail out with generic trim
+      return trimFields((RelNode) correlate, fieldsUsed, extraFields);
+    }
+
+    fieldsUsed = fieldsUsed.union(correlate.getRequiredColumns());
+
+    List<RelNode> newInputs = new ArrayList<>();
+    List<Mapping> inputMappings = new ArrayList<>();
+    int changeCount = 0;
+    int offset = 0;
+    for (RelNode input : correlate.getInputs()) {
+      final RelDataType inputRowType = input.getRowType();
+      final int inputFieldCount = inputRowType.getFieldCount();
+
+      ImmutableBitSet currentInputFieldsUsed = fieldsUsed
+          .intersect(ImmutableBitSet.range(offset, offset + inputFieldCount))
+          .shift(-offset);
+
+      TrimResult trimResult =
+          dispatchTrimFields(input, currentInputFieldsUsed, extraFields);
+
+      newInputs.add(trimResult.left);
+      inputMappings.add(trimResult.right);
+
+      offset += inputFieldCount;
+
+      if (trimResult.left != input) {
+        changeCount++;
+      }
+    }
+
+    if (changeCount == 0) {
+      return result(correlate,
+          Mappings.createIdentity(correlate.getRowType().getFieldCount()));
+    }
+
+    Mapping mapping = Mappings.concatenateMappings(inputMappings);
+    RexBuilder rexBuilder = relBuilder.getRexBuilder();
+
+    RelNode newLeft = newInputs.get(0);
+    RexCorrelVariableMapShuttle rexVisitor =
+        new RexCorrelVariableMapShuttle(correlate.getCorrelationId(),
+            newLeft.getRowType(), mapping, rexBuilder);
+    RelNode newRight =
+        newInputs.get(1).accept(new RexRewritingRelShuttle(rexVisitor));
+    final LogicalCorrelate newCorrelate =
+        correlate
+            .copy(correlate.getTraitSet(),
+                newLeft,
+                newRight,
+                correlate.getCorrelationId(),
+                correlate.getRequiredColumns().permute(mapping),
+                correlate.getJoinType());
+
+    return result(newCorrelate, mapping);
+  }
+
+  /**
+   * Updates correlate references in {@link RexNode} expressions.
+   */
+  static class RexCorrelVariableMapShuttle extends RexShuttle {
+    private final CorrelationId correlationId;
+    private final Mapping mapping;
+    private final RelDataType newCorrelRowType;
+    private final RexBuilder rexBuilder;
+
+
+    /**
+     * Constructs a RexCorrelVariableMapShuttle.
+     *
+     * @param correlationId The ID of the correlation variable to update.
+     * @param newCorrelRowType The new row type for the correlate reference.
+     * @param mapping Mapping to transform field indices.
+     * @param rexBuilder A builder for constructing new RexNodes.
+     */
+    RexCorrelVariableMapShuttle(final CorrelationId correlationId,
+        RelDataType newCorrelRowType, Mapping mapping, RexBuilder rexBuilder) {
+      this.correlationId = correlationId;
+      this.newCorrelRowType = newCorrelRowType;
+      this.mapping = mapping;
+      this.rexBuilder = rexBuilder;
+    }
+
+    @Override public RexNode visitFieldAccess(final RexFieldAccess 
fieldAccess) {
+      if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
+        RexCorrelVariable referenceExpr =
+            (RexCorrelVariable) fieldAccess.getReferenceExpr();
+        if (referenceExpr.id.equals(correlationId)) {
+          int oldIndex = fieldAccess.getField().getIndex();
+          RexNode newCorrel =
+              rexBuilder.makeCorrel(newCorrelRowType, referenceExpr.id);
+          int newIndex = mapping.getTarget(oldIndex);
+          return rexBuilder.makeFieldAccess(newCorrel, newIndex);
+        }
+      }
+      return super.visitFieldAccess(fieldAccess);
+    }
+  }
+
   protected Mapping createMapping(ImmutableBitSet fieldsUsed, int fieldCount) {
     final Mapping mapping =
         Mappings.create(
diff --git 
a/core/src/main/java/org/apache/calcite/sql2rel/RexRewritingRelShuttle.java 
b/core/src/main/java/org/apache/calcite/sql2rel/RexRewritingRelShuttle.java
new file mode 100644
index 0000000000..f2c64cdd27
--- /dev/null
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RexRewritingRelShuttle.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.sql2rel;
+
+import org.apache.calcite.rel.RelHomogeneousShuttle;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rex.RexShuttle;
+
+/**
+ * Dispatches a {@link RexShuttle} for all {@link RelNode}-s.
+ */
+public class RexRewritingRelShuttle extends RelHomogeneousShuttle {
+  private final RexShuttle rexVisitor;
+
+  RexRewritingRelShuttle(RexShuttle rexVisitor) {
+    this.rexVisitor = rexVisitor;
+  }
+
+  @Override public RelNode visit(RelNode other) {
+    RelNode next = super.visit(other);
+    return next.accept(rexVisitor);
+  }
+}
diff --git a/core/src/main/java/org/apache/calcite/util/mapping/Mappings.java 
b/core/src/main/java/org/apache/calcite/util/mapping/Mappings.java
index 8b000032ac..82a3c74039 100644
--- a/core/src/main/java/org/apache/calcite/util/mapping/Mappings.java
+++ b/core/src/main/java/org/apache/calcite/util/mapping/Mappings.java
@@ -1560,7 +1560,7 @@ public abstract class Mappings {
 
     @Override public Mapping inverse() {
       return new OverridingTargetMapping(
-          (TargetMapping) parent.inverse(),
+          parent.inverse(),
           target,
           source);
     }
@@ -1861,4 +1861,35 @@ public abstract class Mappings {
       parent.set(target, source);
     }
   }
+
+  /**
+   * Concatenates multiple mappings.
+   *
+   * <pre>
+   * [ 1:0, 2:1] // sourceCount:100
+   * [ 1:0, 2:1] // sourceCount:100
+   * output:
+   * [ 1:0, 2:1, 101:2, 102:3 ] ; sourceCount:200
+   * </pre>
+   */
+  public static Mapping concatenateMappings(List<Mapping> inputMappings) {
+    int fieldCount = 0;
+    int newFieldCount = 0;
+    for (Mapping inputMapping : inputMappings) {
+      fieldCount += inputMapping.getSourceCount();
+      newFieldCount += inputMapping.getTargetCount();
+    }
+    Mapping mapping =
+        create(MappingType.INVERSE_SURJECTION, fieldCount, newFieldCount);
+    int offset = 0;
+    int newOffset = 0;
+    for (Mapping inputMapping : inputMappings) {
+      for (IntPair pair : inputMapping) {
+        mapping.set(pair.source + offset, pair.target + newOffset);
+      }
+      offset += inputMapping.getSourceCount();
+      newOffset += inputMapping.getTargetCount();
+    }
+    return mapping;
+  }
 }
diff --git 
a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java 
b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
index 2e3b509f5a..ebe65d93d0 100644
--- a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
+++ b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
@@ -32,17 +32,22 @@ import org.apache.calcite.rel.hint.HintPredicates;
 import org.apache.calcite.rel.hint.HintStrategyTable;
 import org.apache.calcite.rel.hint.RelHint;
 import org.apache.calcite.rel.rules.CoreRules;
+import org.apache.calcite.rex.RexCorrelVariable;
 import org.apache.calcite.schema.SchemaPlus;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.parser.SqlParser;
 import org.apache.calcite.test.CalciteAssert;
 import org.apache.calcite.tools.Frameworks;
 import org.apache.calcite.tools.Programs;
 import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.Holder;
 
 import com.google.common.collect.Lists;
 
+import org.checkerframework.checker.nullness.qual.Nullable;
 import org.junit.jupiter.api.Test;
 
+import java.util.Collections;
 import java.util.List;
 
 import static org.apache.calcite.test.Matchers.hasTree;
@@ -544,4 +549,99 @@ class RelFieldTrimmerTest {
         + "        LogicalTableScan(table=[[scott, EMP]])\n";
     assertThat(trimmed, hasTree(expected));
   }
+
+  /**
+   * Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6715";>[CALCITE-6715]
+   * Enhance RelFieldTrimmer to trim LogicalCorrelate nodes</a>.
+   */
+  @Test void testLogicalCorrelateFieldTrimmer() {
+    final RelBuilder builder = RelBuilder.create(config().build());
+    final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
+    RelNode root = builder.scan("EMP")
+        .projectPlus(builder.call(SqlStdOperatorTable.PLUS, builder.field(0), 
builder.field(0)))
+        .variable(v::set)
+        .values(new String[] {"dummy"}, true)
+        .project(
+            builder.call(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+            builder.field(v.get(), "DEPTNO"), builder.field(v.get(), 
"DEPTNO")))
+        .uncollect(Collections.emptyList(), false)
+        .correlate(JoinRelType.LEFT, v.get().id, builder.field(2, 0, "DEPTNO"))
+        .aggregate(builder.groupKey("ENAME"), 
builder.max(builder.field("EMPNO")))
+        .build();
+
+    String origTree = ""
+        + "LogicalAggregate(group=[{1}], agg#0=[MAX($0)])\n"
+        + "  LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{7}])\n"
+        + "    LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], 
HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[+($0, $0)])\n"
+        + "      LogicalTableScan(table=[[scott, EMP]])\n"
+        + "    Uncollect\n"
+        + "      LogicalProject($f0=[ARRAY($cor0.DEPTNO, $cor0.DEPTNO)])\n"
+        + "        LogicalValues(tuples=[[{ true }]])\n";
+    assertThat(root, hasTree(origTree));
+
+    final RelFieldTrimmer fieldTrimmer = new RelFieldTrimmer(null, builder);
+    final RelNode trimmed = fieldTrimmer.trim(root);
+    final String expected = ""
+        + "LogicalAggregate(group=[{1}], agg#0=[MAX($0)])\n"
+        + "  LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{2}])\n"
+        + "    LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7])\n"
+        + "      LogicalTableScan(table=[[scott, EMP]])\n"
+        + "    Uncollect\n"
+        + "      LogicalProject($f0=[ARRAY($cor0.DEPTNO, $cor0.DEPTNO)])\n"
+        + "        LogicalValues(tuples=[[{ true }]])\n";
+
+    assertThat(trimmed, hasTree(expected));
+  }
+
+  /**
+   * Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6715";>[CALCITE-6715]
+   * Enhance RelFieldTrimmer to trim LogicalCorrelate nodes</a>.
+   */
+  @Test void testLogicalCorrelateFieldTrimmer2() {
+    final RelBuilder builder = RelBuilder.create(config().build());
+    final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
+    RelNode root = builder.scan("EMP")
+        .projectPlus(builder.call(SqlStdOperatorTable.PLUS, builder.field(0), 
builder.field(0)))
+        .variable(v::set)
+        .scan("DEPT")
+        .projectPlus(
+            builder.call(SqlStdOperatorTable.PLUS,
+            builder.field(v.get(), "DEPTNO"), builder.field(v.get(), 
"DEPTNO")))
+        .filter(
+            builder.equals(builder.field(0),
+                builder.call(
+                    SqlStdOperatorTable.PLUS,
+                    builder.literal(10),
+                    builder.field(v.get(), "DEPTNO"))))
+        .correlate(JoinRelType.LEFT, v.get().id, builder.field(2, 0, "DEPTNO"))
+        .aggregate(builder.groupKey("ENAME"), 
builder.max(builder.field("EMPNO")))
+        .build();
+
+    String origTree = ""
+        + "LogicalAggregate(group=[{1}], agg#0=[MAX($0)])\n"
+        + "  LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{7}])\n"
+        + "    LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], 
HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[+($0, $0)])\n"
+        + "      LogicalTableScan(table=[[scott, EMP]])\n"
+        + "    LogicalFilter(condition=[=($0, +(10, $cor0.DEPTNO))])\n"
+        + "      LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], 
$f3=[+($cor0.DEPTNO, $cor0.DEPTNO)])\n"
+        + "        LogicalTableScan(table=[[scott, DEPT]])\n";
+    assertThat(root, hasTree(origTree));
+
+    final RelFieldTrimmer fieldTrimmer = new RelFieldTrimmer(null, builder);
+    final RelNode trimmed = fieldTrimmer.trim(root);
+    final String expected = ""
+        + "LogicalAggregate(group=[{1}], agg#0=[MAX($0)])\n"
+        + "  LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{2}])\n"
+        + "    LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7])\n"
+        + "      LogicalTableScan(table=[[scott, EMP]])\n"
+        + "    LogicalFilter(condition=[=($0, +(10, $cor0.DEPTNO))])\n"
+        + "      LogicalProject(DEPTNO=[$0])\n"
+        + "        LogicalTableScan(table=[[scott, DEPT]])\n";
+
+    assertThat(trimmed, hasTree(expected));
+  }
+
+
 }
diff --git a/core/src/test/resources/sql/sub-query.iq 
b/core/src/test/resources/sql/sub-query.iq
index dc06578779..45edb3e4c9 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -481,7 +481,7 @@ EnumerableCalc(expr#0..2=[{inputs}], proj#0..1=[{exprs}])
   EnumerableCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{0}])
     EnumerableValues(tuples=[[{ 1, 2 }]])
     EnumerableAggregate(group=[{0}])
-      EnumerableCalc(expr#0..7=[{inputs}], expr#8=[true], 
expr#9=[CAST($t7):INTEGER], expr#10=[$cor0], expr#11=[$t10.A], expr#12=[=($t9, 
$t11)], i=[$t8], $condition=[$t12])
+      EnumerableCalc(expr#0..7=[{inputs}], expr#8=[true], 
expr#9=[CAST($t7):INTEGER], expr#10=[$cor0], expr#11=[$t10.EXPR$0], 
expr#12=[=($t9, $t11)], i=[$t8], $condition=[$t12])
         EnumerableTableScan(table=[[scott, EMP]])
 !plan
 

Reply via email to