This is an automated email from the ASF dual-hosted git repository.

suibianwanwan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 2fb437b937 [CALCITE-7064] Test introduced in [CALCITE-7009] breaks the 
build for main
2fb437b937 is described below

commit 2fb437b93774d3d69df66547715b374d9281fdbb
Author: Konstantin Orlov <[email protected]>
AuthorDate: Tue Jun 24 15:35:16 2025 +0300

    [CALCITE-7064] Test introduced in [CALCITE-7009] breaks the build for main
---
 .../java/org/apache/calcite/plan/RelOptUtil.java   |  93 ++++++++---
 .../apache/calcite/sql2rel/RelFieldTrimmer.java    | 185 +++------------------
 .../calcite/sql2rel/RexRewritingRelShuttle.java    |   2 +-
 .../org/apache/calcite/plan/RelOptUtilTest.java    | 130 +++++++++++++++
 .../calcite/sql2rel/RelFieldTrimmerTest.java       |   2 +-
 .../org/apache/calcite/test/RelOptRulesTest.xml    |  16 +-
 core/src/test/resources/sql/scalar.iq              |   2 -
 7 files changed, 233 insertions(+), 197 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java 
b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
index 88f083ea2a..2742cc45d0 100644
--- a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
+++ b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
@@ -95,6 +95,7 @@
 import org.apache.calcite.sql.type.MultisetSqlType;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.type.SqlTypeUtil;
+import org.apache.calcite.sql2rel.RexRewritingRelShuttle;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
 import org.apache.calcite.util.ImmutableBitSet;
@@ -2692,6 +2693,33 @@ public static RexNode andJoinFilters(
     return left;
   }
 
+  /**
+   * Updates instances of correlated variables with provided {@link 
CorrelationId} in a given
+   * subquery. That is, updates referenced row type with a new one, as well as 
remaps field indexes
+   * with according to provided {@link Mapping inputMapping}.
+   *
+   * @param rexBuilder A builder for constructing new RexNodes.
+   * @param node A subquery expression to update.
+   * @param correlationId The ID of the correlation variable to update.
+   * @param newRowType The new row type for the correlate reference.
+   * @param inputMapping Mapping to transform field indices.
+   * @return An updated subquery expression.
+   */
+  public static RexSubQuery remapCorrelatesInSuqQuery(
+      RexBuilder rexBuilder,
+      RexSubQuery node,
+      CorrelationId correlationId,
+      RelDataType newRowType,
+      Mapping inputMapping) {
+    RelNode subQuery = node.rel;
+    RexCorrelVariableMapShuttle rexVisitor =
+        new RexCorrelVariableMapShuttle(correlationId, newRowType, 
inputMapping, rexBuilder);
+    RelNode newSubQuery =
+        subQuery.accept(new RexRewritingRelShuttle(rexVisitor));
+
+    return node.clone(newSubQuery);
+  }
+
   /** Decomposes the WHERE clause of a view into predicates that constraint
    * a column to a particular value.
    *
@@ -4552,34 +4580,51 @@ private void acceptFields(final List<RelDataTypeField> 
fields) {
     }
   }
 
-  /** Extension of {@link RelOptUtil.InputFinder} with optional subquery 
lookup. */
-  public static class SubQueryAwareInputFinder extends RelOptUtil.InputFinder {
-    boolean visitSubQuery;
+  /**
+   * Updates correlate references in {@link RexNode} expressions.
+   */
+  public static class RexCorrelVariableMapShuttle extends RexShuttle {
+    private final CorrelationId correlationId;
+    private final Mapping mapping;
+    private final RelDataType newCorrelRowType;
+    private final RexBuilder rexBuilder;
 
-    public SubQueryAwareInputFinder(@Nullable Set<RelDataTypeField> 
extraFields,
-        boolean visitSubQuery) {
-      super(extraFields, ImmutableBitSet.builder());
-      this.visitSubQuery = visitSubQuery;
-    }
 
-    @Override public Void visitSubQuery(RexSubQuery subQuery) {
-      if (visitSubQuery && subQuery.getKind() == SqlKind.SCALAR_QUERY) {
-        subQuery.rel.accept(new RelHomogeneousShuttle() {
-          @Override public RelNode visit(LogicalProject project) {
-            project.getProjects().forEach(r -> 
r.accept(SubQueryAwareInputFinder.this));
-            return super.visit(project);
-          }
-
-          @Override public RelNode visit(LogicalFilter filter) {
-            filter.getCondition().accept(SubQueryAwareInputFinder.this);
-            return super.visit(filter);
-          }
-        });
+    /**
+     * Constructs a RexCorrelVariableMapShuttle.
+     *
+     * @param correlationId The ID of the correlation variable to update.
+     * @param newCorrelRowType The new row type for the correlate reference.
+     * @param mapping Mapping to transform field indices.
+     * @param rexBuilder A builder for constructing new RexNodes.
+     */
+    public RexCorrelVariableMapShuttle(final CorrelationId correlationId,
+        RelDataType newCorrelRowType, Mapping mapping, RexBuilder rexBuilder) {
+      this.correlationId = correlationId;
+      this.newCorrelRowType = newCorrelRowType;
+      this.mapping = mapping;
+      this.rexBuilder = rexBuilder;
+    }
 
-        return null;
-      } else {
-        return super.visitSubQuery(subQuery);
+    @Override public RexNode visitFieldAccess(final RexFieldAccess 
fieldAccess) {
+      if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
+        RexCorrelVariable referenceExpr =
+            (RexCorrelVariable) fieldAccess.getReferenceExpr();
+        if (referenceExpr.id.equals(correlationId)) {
+          int oldIndex = fieldAccess.getField().getIndex();
+          int newIndex = mapping.getTarget(oldIndex);
+          RexNode newCorrel =
+              rexBuilder.makeCorrel(newCorrelRowType, referenceExpr.id);
+          return rexBuilder.makeFieldAccess(newCorrel, newIndex);
+        }
       }
+      return super.visitFieldAccess(fieldAccess);
+    }
+
+    @Override public RexNode visitSubQuery(RexSubQuery subQuery) {
+      subQuery = (RexSubQuery) super.visitSubQuery(subQuery);
+      return remapCorrelatesInSuqQuery(
+          rexBuilder, subQuery, correlationId, newCorrelRowType, mapping);
     }
   }
 
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java 
b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
index 2dee1af505..be08ee7ec2 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
@@ -19,11 +19,11 @@
 import org.apache.calcite.linq4j.Ord;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.RelOptUtil.RexCorrelVariableMapShuttle;
 import org.apache.calcite.rel.RelCollation;
 import org.apache.calcite.rel.RelCollations;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.calcite.rel.RelFieldCollation;
-import org.apache.calcite.rel.RelHomogeneousShuttle;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.AggregateCall;
@@ -58,11 +58,9 @@
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexPermuteInputsShuttle;
 import org.apache.calcite.rex.RexProgram;
-import org.apache.calcite.rex.RexShuttle;
 import org.apache.calcite.rex.RexSubQuery;
 import org.apache.calcite.rex.RexUtil;
 import org.apache.calcite.rex.RexVisitor;
-import org.apache.calcite.rex.RexVisitorImpl;
 import org.apache.calcite.sql.SqlExplainFormat;
 import org.apache.calcite.sql.SqlExplainLevel;
 import org.apache.calcite.sql.SqlKind;
@@ -80,8 +78,6 @@
 import org.apache.calcite.util.mapping.Mappings;
 
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Iterables;
 
 import org.checkerframework.checker.nullness.qual.Nullable;
 
@@ -477,77 +473,6 @@ public TrimResult trimFields(
     return result(newCalc, mapping, calc);
   }
 
-  /**
-   * Shuttle that finds all {@link TableScan}s inside a given {@link RelNode}.
-   */
-  private static class TableScanCollector extends RelHomogeneousShuttle {
-    private ImmutableSet.Builder<List<String>> builder = 
ImmutableSet.builder();
-
-    /** Qualified names. */
-    Set<List<String>> tables() {
-      return builder.build();
-    }
-
-    @Override public RelNode visit(TableScan scan) {
-      builder.add(scan.getTable().getQualifiedName());
-      return super.visit(scan);
-    }
-  }
-
-  /**
-   * Shuttle that finds all {@link TableScan}`s inside a given {@link RexNode}.
-   */
-  private static class InputTablesVisitor extends RexVisitorImpl<Void> {
-    private ImmutableSet.Builder<List<String>> builder = 
ImmutableSet.builder();
-
-    protected InputTablesVisitor() {
-      super(false);
-    }
-
-    /** Qualified names. */
-    Set<List<String>> tables() {
-      return builder.build();
-    }
-
-    @Override public Void visitSubQuery(RexSubQuery subQuery) {
-      if (subQuery.getKind() == SqlKind.SCALAR_QUERY) {
-        subQuery.rel.accept(new RelHomogeneousShuttle() {
-          @Override public RelNode visit(TableScan scan) {
-            builder.add(scan.getTable().getQualifiedName());
-            return super.visit(scan);
-          }
-        });
-      }
-      return null;
-    }
-  }
-
-  private boolean inputContainsSubQueryTables(Project project, RelNode input) {
-    InputTablesVisitor inputSubQueryTablesCollector = new InputTablesVisitor();
-
-    RexUtil.apply(inputSubQueryTablesCollector, project.getProjects(), null);
-
-    Set<List<String>> subQueryTables = inputSubQueryTablesCollector.tables();
-
-    assert subQueryTables.isEmpty() || subQueryTables.size() == 1
-        : "unexpected different tables in subquery: " + subQueryTables;
-
-    TableScanCollector inputTablesCollector = new TableScanCollector();
-    input.accept(inputTablesCollector);
-
-    Set<List<String>> inputTables = inputTablesCollector.tables();
-    // Check for input and subquery tables intersection.
-    if (!subQueryTables.isEmpty()) {
-      for (List<String> t : inputTables) {
-        if (t.equals(Iterables.getOnlyElement(subQueryTables))) {
-          return true;
-        }
-      }
-    }
-
-    return false;
-  }
-
   /**
    * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for
    * {@link org.apache.calcite.rel.logical.LogicalProject}.
@@ -563,24 +488,18 @@ public TrimResult trimFields(
     // Which fields are required from the input?
     final Set<RelDataTypeField> inputExtraFields =
         new LinkedHashSet<>(extraFields);
-
-    // Collect all the SubQueries in the projection list.
-    List<RexSubQuery> subQueries = RexUtil.SubQueryCollector.collect(project);
-    // Get all the correlationIds present in the SubQueries
-    Set<CorrelationId> correlationIds = 
RelOptUtil.getVariablesUsed(subQueries);
-    // Subquery lookup is required.
-    boolean subQueryLookUp =
-        !correlationIds.isEmpty() && inputContainsSubQueryTables(project, 
input);
-
     RelOptUtil.InputFinder inputFinder =
-        new RelOptUtil.SubQueryAwareInputFinder(inputExtraFields, 
subQueryLookUp);
-
+        new RelOptUtil.InputFinder(inputExtraFields);
     for (Ord<RexNode> ord : Ord.zip(project.getProjects())) {
       if (fieldsUsed.get(ord.i)) {
         ord.e.accept(inputFinder);
       }
     }
 
+    // Collect all the SubQueries in the projection list.
+    List<RexSubQuery> subQueries = RexUtil.SubQueryCollector.collect(project);
+    // Get all the correlationIds present in the SubQueries
+    Set<CorrelationId> correlationIds = 
RelOptUtil.getVariablesUsed(subQueries);
     ImmutableBitSet requiredColumns = ImmutableBitSet.of();
     if (!correlationIds.isEmpty()) {
       assert correlationIds.size() == 1;
@@ -616,9 +535,22 @@ public TrimResult trimFields(
 
     // Build new project expressions, and populate the mapping.
     final List<RexNode> newProjects = new ArrayList<>();
-    final RexVisitor<RexNode> shuttle =
-        new RexPermuteInputsShuttle(
-            inputMapping, newInput);
+    final RexVisitor<RexNode> shuttle;
+
+    if (!correlationIds.isEmpty()) {
+      assert correlationIds.size() == 1;
+      shuttle = new RexPermuteInputsShuttle(inputMapping, newInput) {
+        @Override public RexNode visitSubQuery(RexSubQuery subQuery) {
+          subQuery = (RexSubQuery) super.visitSubQuery(subQuery);
+
+          return 
RelOptUtil.remapCorrelatesInSuqQuery(relBuilder.getRexBuilder(),
+            subQuery, correlationIds.iterator().next(), newInput.getRowType(), 
inputMapping);
+        }
+      };
+    } else {
+      shuttle = new RexPermuteInputsShuttle(inputMapping, newInput);
+    }
+
     final Mapping mapping =
         Mappings.create(
             MappingType.INVERSE_SURJECTION,
@@ -628,14 +560,6 @@ public TrimResult trimFields(
       if (fieldsUsed.get(ord.i)) {
         mapping.set(ord.i, newProjects.size());
         RexNode newProjectExpr = ord.e.accept(shuttle);
-        // Subquery need to be remapped
-        if (newProjectExpr instanceof RexSubQuery
-            && newProjectExpr.getKind() == SqlKind.SCALAR_QUERY
-            && !correlationIds.isEmpty()) {
-          newProjectExpr =
-              changeCorrelateReferences((RexSubQuery) newProjectExpr,
-                  Iterables.getOnlyElement(correlationIds), 
newInput.getRowType(), inputMapping);
-        }
         newProjects.add(newProjectExpr);
       }
     }
@@ -645,29 +569,11 @@ public TrimResult trimFields(
             mapping);
 
     relBuilder.push(newInput);
-    relBuilder.project(newProjects, newRowType.getFieldNames());
+    relBuilder.project(newProjects, newRowType.getFieldNames(), false, 
correlationIds);
     final RelNode newProject = relBuilder.build();
     return result(newProject, mapping, project);
   }
 
-  private RexNode changeCorrelateReferences(
-      RexSubQuery node,
-      CorrelationId corrId,
-      RelDataType rowType,
-      Mapping inputMapping) {
-    assert node.getKind() == SqlKind.SCALAR_QUERY : "Expected a SCALAR_QUERY, 
found "
-        + node.getKind();
-    RelNode subQuery = node.rel;
-    RexBuilder rexBuilder = relBuilder.getRexBuilder();
-
-    RexCorrelVariableMapShuttle rexVisitor =
-        new RexCorrelVariableMapShuttle(corrId, rowType, inputMapping, 
rexBuilder);
-    RelNode newSubQuery =
-        subQuery.accept(new RexRewritingRelShuttle(rexVisitor));
-
-    return RexSubQuery.scalar(newSubQuery);
-  }
-
   /** Creates a project with a dummy column, to protect the parts of the system
    * that cannot handle a relational expression with no columns.
    *
@@ -1512,51 +1418,6 @@ public TrimResult trimFields(LogicalCorrelate correlate,
     return result(newCorrelate, mapping);
   }
 
-  /**
-   * Updates correlate references in {@link RexNode} expressions.
-   */
-  static class RexCorrelVariableMapShuttle extends RexShuttle {
-    private final CorrelationId correlationId;
-    private final Mapping mapping;
-    private final RelDataType newCorrelRowType;
-    private final RexBuilder rexBuilder;
-
-
-    /**
-     * Constructs a RexCorrelVariableMapShuttle.
-     *
-     * @param correlationId The ID of the correlation variable to update.
-     * @param newCorrelRowType The new row type for the correlate reference.
-     * @param mapping Mapping to transform field indices.
-     * @param rexBuilder A builder for constructing new RexNodes.
-     */
-    RexCorrelVariableMapShuttle(final CorrelationId correlationId,
-        RelDataType newCorrelRowType, Mapping mapping, RexBuilder rexBuilder) {
-      this.correlationId = correlationId;
-      this.newCorrelRowType = newCorrelRowType;
-      this.mapping = mapping;
-      this.rexBuilder = rexBuilder;
-    }
-
-    @Override public RexNode visitFieldAccess(final RexFieldAccess 
fieldAccess) {
-      if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
-        RexCorrelVariable referenceExpr =
-            (RexCorrelVariable) fieldAccess.getReferenceExpr();
-        if (referenceExpr.id.equals(correlationId)) {
-          int oldIndex = fieldAccess.getField().getIndex();
-          int newIndex = mapping.getTarget(oldIndex);
-          if (newIndex == oldIndex) {
-            return super.visitFieldAccess(fieldAccess);
-          }
-          RexNode newCorrel =
-              rexBuilder.makeCorrel(newCorrelRowType, referenceExpr.id);
-          return rexBuilder.makeFieldAccess(newCorrel, newIndex);
-        }
-      }
-      return super.visitFieldAccess(fieldAccess);
-    }
-  }
-
   protected Mapping createMapping(ImmutableBitSet fieldsUsed, int fieldCount) {
     final Mapping mapping =
         Mappings.create(
diff --git 
a/core/src/main/java/org/apache/calcite/sql2rel/RexRewritingRelShuttle.java 
b/core/src/main/java/org/apache/calcite/sql2rel/RexRewritingRelShuttle.java
index f2c64cdd27..d486d7b256 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RexRewritingRelShuttle.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RexRewritingRelShuttle.java
@@ -26,7 +26,7 @@
 public class RexRewritingRelShuttle extends RelHomogeneousShuttle {
   private final RexShuttle rexVisitor;
 
-  RexRewritingRelShuttle(RexShuttle rexVisitor) {
+  public RexRewritingRelShuttle(RexShuttle rexVisitor) {
     this.rexVisitor = rexVisitor;
   }
 
diff --git a/core/src/test/java/org/apache/calcite/plan/RelOptUtilTest.java 
b/core/src/test/java/org/apache/calcite/plan/RelOptUtilTest.java
index 63086de4c7..3bdfbd2625 100644
--- a/core/src/test/java/org/apache/calcite/plan/RelOptUtilTest.java
+++ b/core/src/test/java/org/apache/calcite/plan/RelOptUtilTest.java
@@ -25,6 +25,7 @@
 import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.convert.ConverterRule;
+import org.apache.calcite.rel.core.CorrelationId;
 import org.apache.calcite.rel.core.Join;
 import org.apache.calcite.rel.core.JoinRelType;
 import org.apache.calcite.rel.core.Project;
@@ -35,19 +36,27 @@
 import org.apache.calcite.rel.type.RelDataTypeSystem;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCorrelVariable;
+import org.apache.calcite.rex.RexFieldAccess;
 import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexShuttle;
+import org.apache.calcite.rex.RexSubQuery;
 import org.apache.calcite.schema.SchemaPlus;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.parser.SqlParser;
 import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
 import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql2rel.RexRewritingRelShuttle;
 import org.apache.calcite.test.CalciteAssert;
 import org.apache.calcite.tools.Frameworks;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.util.Pair;
 import org.apache.calcite.util.TestUtil;
 import org.apache.calcite.util.Util;
+import org.apache.calcite.util.mapping.Mapping;
+import org.apache.calcite.util.mapping.MappingType;
+import org.apache.calcite.util.mapping.Mappings;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
@@ -60,6 +69,7 @@
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.List;
 
 import static org.apache.calcite.test.Matchers.isListOf;
@@ -68,6 +78,7 @@
 import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.hasToString;
 import static org.junit.jupiter.api.Assertions.fail;
 
@@ -992,6 +1003,125 @@ private void splitJoinConditionHelper(RexNode joinCond, 
List<Integer> expLeftKey
     assertThat(castNode2.explain(), is(expectNode2.explain()));
   }
 
+  @Test void testRemapCorrelHandlesNestedSubQueries() {
+    // Equivalent SQL:
+    // select
+    //    (
+    //        select count(*)
+    //          from emp as middle_emp
+    //         where exists (
+    //                   select true
+    //                     from emp as innermost_emp
+    //                    where outermost_emp.deptno = innermost_emp.deptno
+    //                      and middle_emp.sal < innermost_emp.sal
+    //               )
+    //    ) as c
+    // from emp as outermost_emp
+
+    int deptNoIdx = empRow.getFieldNames().indexOf("DEPTNO");
+    int salIdx = empRow.getFieldNames().indexOf("SAL");
+
+    RelOptCluster cluster = relBuilder.getCluster();
+    RexBuilder rexBuilder = relBuilder.getRexBuilder();
+
+    CorrelationId outermostCorrelationId = cluster.createCorrel();
+    RexNode outermostCorrelate =
+        rexBuilder.makeFieldAccess(
+            rexBuilder.makeCorrel(empRow, outermostCorrelationId), deptNoIdx);
+    CorrelationId middleCorrelationId = cluster.createCorrel();
+    RexNode middleCorrelate =
+        rexBuilder.makeFieldAccess(rexBuilder.makeCorrel(empRow, 
middleCorrelationId), salIdx);
+
+    RelNode innermostQuery = relBuilder
+        .push(empScan)
+        .filter(
+            rexBuilder.makeCall(
+                SqlStdOperatorTable.AND,
+                rexBuilder.makeCall(
+                    SqlStdOperatorTable.EQUALS,
+                    outermostCorrelate,
+                    rexBuilder.makeInputRef(empRow, deptNoIdx)),
+                rexBuilder.makeCall(
+                    SqlStdOperatorTable.LESS_THAN,
+                    middleCorrelate,
+                    rexBuilder.makeInputRef(empRow, salIdx)
+                )
+            )
+        )
+        .project(rexBuilder.makeLiteral(true))
+        .build();
+
+    RelNode middleQuery = relBuilder
+        .push(empScan)
+        .filter(relBuilder.exists(ignored -> innermostQuery))
+        .aggregate(
+            relBuilder.groupKey(),
+            relBuilder.countStar("COUNT_ALL")
+        )
+        .build();
+
+    RelNode outermostQuery = relBuilder
+        .push(empScan)
+        .project(relBuilder.scalarQuery(ignored -> middleQuery))
+        .build();
+
+    // Wrap the outermost query in RexSubQuery since 
RelOptUtil.remapCorrelatesInSuqQuery
+    // accepts RexSubQuery as input.
+    RexSubQuery subQuery = relBuilder.exists(ignored -> outermostQuery);
+
+    RelDataType newType = cluster.getTypeFactory().builder()
+            .add(empRow.getFieldList().get(deptNoIdx))
+            .build();
+
+    Mapping mapping =
+            Mappings.create(MappingType.INVERSE_SURJECTION,
+            empRow.getFieldCount(),
+            1);
+    mapping.set(deptNoIdx, 0);
+
+    RexSubQuery newSubQuery =
+        RelOptUtil.remapCorrelatesInSuqQuery(
+            rexBuilder, subQuery, outermostCorrelationId, newType, mapping);
+
+    List<RexFieldAccess> variablesUsed = new ArrayList<>();
+
+    newSubQuery.accept(new RexShuttle() {
+      @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
+        if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
+          variablesUsed.add(fieldAccess);
+        }
+
+        return super.visitFieldAccess(fieldAccess);
+      }
+
+      @Override public RexNode visitSubQuery(RexSubQuery subQuery) {
+        subQuery.rel.accept(new RexRewritingRelShuttle(this));
+
+        return super.visitSubQuery(subQuery);
+      }
+    });
+
+    assertThat(variablesUsed, hasSize(2));
+
+    variablesUsed.sort(
+        Comparator.comparingInt(v ->
+            ((RexCorrelVariable) v.getReferenceExpr()).id.getId()));
+
+    RexFieldAccess firstFieldAccess = variablesUsed.get(0);
+    assertThat(firstFieldAccess.getField().getIndex(), is(0));
+
+    RexCorrelVariable firstVar = (RexCorrelVariable) 
firstFieldAccess.getReferenceExpr();
+    assertThat(firstVar.id, is(outermostCorrelationId));
+    assertThat(firstVar.getType(), is(newType));
+
+    RexFieldAccess secondFieldAccess = variablesUsed.get(1);
+    assertThat(secondFieldAccess.getField().getIndex(), is(salIdx));
+
+    RexCorrelVariable secondVar = (RexCorrelVariable) 
secondFieldAccess.getReferenceExpr();
+    assertThat(secondVar.id, is(middleCorrelationId));
+    assertThat(secondVar.getType(), is(empRow));
+  }
+
   /** Dummy sub-class of ConverterRule, to check whether generated descriptions
    * are OK. */
   private static class MyConverterRule extends ConverterRule {
diff --git 
a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java 
b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
index 0318552bc3..77e6d89091 100644
--- a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
+++ b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
@@ -708,7 +708,7 @@ public static Frameworks.ConfigBuilder config() {
     final RelFieldTrimmer fieldTrimmer = new RelFieldTrimmer(null, builder);
     final RelNode trimmed = fieldTrimmer.trim(root);
     final String expected = ""
-        + "LogicalProject(EMPNO=[$0], $f1=[$SCALAR_QUERY({\n"
+        + "LogicalProject(variablesSet=[[$cor0]], EMPNO=[$0], 
$f1=[$SCALAR_QUERY({\n"
         + "LogicalAggregate(group=[{}], c=[COUNT()])\n"
         + "  LogicalFilter(condition=[<($3, $cor0.MGR)])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n"
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 057c44c5df..5ff9ee3f65 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -4958,26 +4958,28 @@ from sales.empnullables as e]]>
     </Resource>
     <Resource name="planBefore">
       <![CDATA[
-LogicalProject(EMPNO=[$0], EXPR$1=[OR(IN($7, {
+LogicalProject(variablesSet=[[$cor0]], EMPNO=[$0], EXPR$1=[OR(IN($2, {
 LogicalProject(DEPTNO=[$0])
   LogicalFilter(condition=[AND(=($cor0.ENAME, CAST($1):VARCHAR(20)), >($0, 
10))])
     LogicalTableScan(table=[[CATALOG, SALES, DEPTNULLABLES]])
-}), IN($7, {
+}), IN($2, {
 LogicalProject(DEPTNO=[$0])
   LogicalFilter(condition=[AND(=($cor0.ENAME, CAST($1):VARCHAR(20)), <($0, 
20))])
     LogicalTableScan(table=[[CATALOG, SALES, DEPTNULLABLES]])
 }))])
-  LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
+  LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7])
+    LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
 ]]>
     </Resource>
     <Resource name="planAfter">
       <![CDATA[
-LogicalProject(EMPNO=[$0], EXPR$1=[OR(CASE(=($9, 0), false, IS NULL($7), 
null:BOOLEAN, IS NOT NULL($12), true, <($10, $9), null:BOOLEAN, false), 
CASE(=($13, 0), false, IS NULL($7), null:BOOLEAN, IS NOT NULL($16), true, 
<($14, $13), null:BOOLEAN, false))])
-  LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{1, 
7}])
+LogicalProject(EMPNO=[$0], EXPR$1=[OR(CASE(=($3, 0), false, IS NULL($2), 
null:BOOLEAN, IS NOT NULL($6), true, <($4, $3), null:BOOLEAN, false), 
CASE(=($7, 0), false, IS NULL($2), null:BOOLEAN, IS NOT NULL($10), true, <($8, 
$7), null:BOOLEAN, false))])
+  LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{1, 
2}])
     LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{1}])
-      LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{1, 7}])
+      LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{1, 2}])
         LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{1}])
-          LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
+          LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7])
+            LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
           LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT($0)])
             LogicalProject(DEPTNO=[$0])
               LogicalFilter(condition=[AND(=($cor0.ENAME, 
CAST($1):VARCHAR(20)), >($0, 10))])
diff --git a/core/src/test/resources/sql/scalar.iq 
b/core/src/test/resources/sql/scalar.iq
index 43aaa84d72..05a5277137 100644
--- a/core/src/test/resources/sql/scalar.iq
+++ b/core/src/test/resources/sql/scalar.iq
@@ -302,7 +302,6 @@ from (values (1), (3)) t1(id);
 !ok
 
 # Several scalar sub-queries reference different tables in FROM list
-!if (false) {
 select
     (select ename from emp where empno = empnos.empno) as emp_name,
     (select dname from dept where deptno = deptnos.deptno) as dept_name
@@ -318,6 +317,5 @@ select
 (4 rows)
 
 !ok
-!}
 
 # End scalar.iq

Reply via email to