This is an automated email from the ASF dual-hosted git repository.

xiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 45e41e8db6 [CALCITE-7348] Remove redundant extraction correlation 
variables when Trim Project Fields
45e41e8db6 is described below

commit 45e41e8db6f7da147b9981aba856243c6c4736c4
Author: Xiong Duan <[email protected]>
AuthorDate: Tue Dec 30 16:33:25 2025 +0800

    [CALCITE-7348] Remove redundant extraction correlation variables when Trim 
Project Fields
---
 .../apache/calcite/sql2rel/RelFieldTrimmer.java    | 26 ++----
 .../calcite/sql2rel/RelFieldTrimmerTest.java       | 31 +++++---
 core/src/test/resources/sql/scalar.iq              | 93 ++++++++++++++++++++++
 3 files changed, 116 insertions(+), 34 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java 
b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
index 0c9c761b33..864ee9d6a0 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
@@ -548,27 +548,11 @@ public TrimResult trimFields(
       }
     }
 
-    // Collect all the SubQueries in the projection list.
-    List<RexSubQuery> subQueries = RexUtil.SubQueryCollector.collect(project);
-    // Get all the correlationIds present in the SubQueries
-    Set<CorrelationId> correlationIds = 
RelOptUtil.getVariablesUsed(subQueries);
-    ImmutableBitSet requiredColumns = ImmutableBitSet.of();
-    if (!correlationIds.isEmpty()) {
-      assert correlationIds.size() == 1;
-      // Correlation columns are also needed by SubQueries, so add them to 
inputFieldsUsed.
-      requiredColumns = 
RelOptUtil.correlationColumns(correlationIds.iterator().next(), project);
-    }
-
     ImmutableBitSet finderFields = inputFinder.build();
 
-    ImmutableBitSet inputFieldsUsed = ImmutableBitSet.builder()
-        .addAll(requiredColumns)
-        .addAll(finderFields)
-        .build();
-
     // Create input with trimmed columns.
     TrimResult trimResult =
-        trimChild(project, input, inputFieldsUsed, inputExtraFields);
+        trimChild(project, input, finderFields, inputExtraFields);
     RelNode newInput = trimResult.left;
     final Mapping inputMapping = trimResult.right;
 
@@ -589,14 +573,14 @@ public TrimResult trimFields(
     final List<RexNode> newProjects = new ArrayList<>();
     final RexVisitor<RexNode> shuttle;
 
-    if (!correlationIds.isEmpty()) {
-      assert correlationIds.size() == 1;
+    if (!project.getVariablesSet().isEmpty()) {
       shuttle = new RexPermuteInputsShuttle(inputMapping, newInput) {
         @Override public RexNode visitSubQuery(RexSubQuery subQuery) {
           subQuery = (RexSubQuery) super.visitSubQuery(subQuery);
 
           return 
RelOptUtil.remapCorrelatesInSuqQuery(relBuilder.getRexBuilder(),
-            subQuery, correlationIds.iterator().next(), newInput.getRowType(), 
inputMapping);
+            subQuery, project.getVariablesSet().iterator().next(),
+              newInput.getRowType(), inputMapping);
         }
       };
     } else {
@@ -621,7 +605,7 @@ public TrimResult trimFields(
             mapping);
 
     relBuilder.push(newInput);
-    relBuilder.project(newProjects, newRowType.getFieldNames(), false, 
correlationIds);
+    relBuilder.project(newProjects, newRowType.getFieldNames(), false, 
project.getVariablesSet());
     final RelNode newProject = relBuilder.build();
     return result(newProject, mapping, project);
   }
diff --git 
a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java 
b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
index 77e6d89091..c23185c8c6 100644
--- a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
+++ b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
@@ -33,6 +33,7 @@
 import org.apache.calcite.rel.hint.RelHint;
 import org.apache.calcite.rel.rules.CoreRules;
 import org.apache.calcite.rex.RexCorrelVariable;
+import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.schema.SchemaPlus;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.parser.SqlParser;
@@ -42,6 +43,7 @@
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.util.Holder;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
 
 import org.checkerframework.checker.nullness.qual.Nullable;
@@ -680,24 +682,27 @@ public static Frameworks.ConfigBuilder config() {
   @Test void testTrimCorrelatedSubquery() {
     final RelBuilder builder = RelBuilder.create(config().build());
     final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
-    RelNode root = builder.scan("EMP")
+    builder.scan("EMP")
         .variable(v::set)
         .filter(
             builder.call(SqlStdOperatorTable.GREATER_THAN, builder.field(5),
-            builder.literal(10)))
-        .project(
-            builder.field(0),
-            builder.scalarQuery(
-                b2 -> builder.scan("EMP").filter(
-                    builder.call(SqlStdOperatorTable.LESS_THAN,
-                        builder.field(3), builder.field(v.get(), "MGR")))
-                    .project(builder.field(0))
-                    .aggregate(builder.groupKey(), builder.countStar("c"))
-                    .build()))
-        .build();
+            builder.literal(10)));
+    final ImmutableList.Builder<RexNode> projectsNode = 
ImmutableList.builder();
+    projectsNode.add(builder.field(0));
+    projectsNode.add(
+        builder.scalarQuery(
+            b2 -> builder.scan("EMP").filter(
+                builder.call(SqlStdOperatorTable.LESS_THAN,
+                    builder.field(3), builder.field(v.get(), "MGR")))
+                .project(builder.field(0))
+                .aggregate(builder.groupKey(), builder.countStar("c"))
+                .build()));
+    RelNode root =
+        builder.project(projectsNode.build(),
+                ImmutableList.of(), false, 
ImmutableList.of(v.get().id)).build();
 
     String origTree = ""
-        + "LogicalProject(EMPNO=[$0], $f1=[$SCALAR_QUERY({\n"
+        + "LogicalProject(variablesSet=[[$cor0]], EMPNO=[$0], 
$f1=[$SCALAR_QUERY({\n"
         + "LogicalAggregate(group=[{}], c=[COUNT()])\n"
         + "  LogicalFilter(condition=[<($3, $cor0.MGR)])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n})])\n"
diff --git a/core/src/test/resources/sql/scalar.iq 
b/core/src/test/resources/sql/scalar.iq
index 82a1eb9223..4b5e186cc9 100644
--- a/core/src/test/resources/sql/scalar.iq
+++ b/core/src/test/resources/sql/scalar.iq
@@ -310,4 +310,97 @@ select
 
 !ok
 
+# [CALCITE-7348] Remove redundant extraction correlation variables when Trim 
Project Fields
+
+!set trimfields true
+
+SELECT empno, (SELECT COUNT(*) AS c
+FROM "scott".emp
+WHERE mgr < "t".mgr) AS "$f1"
+FROM "scott".emp as "t"
+WHERE sal > 10;
++-------+-----+
+| EMPNO | $f1 |
++-------+-----+
+|  7369 |  12 |
+|  7499 |   2 |
+|  7521 |   2 |
+|  7566 |   9 |
+|  7654 |   2 |
+|  7698 |   9 |
+|  7782 |   9 |
+|  7788 |   0 |
+|  7839 |   0 |
+|  7844 |   2 |
+|  7876 |   8 |
+|  7900 |   2 |
+|  7902 |   0 |
+|  7934 |   7 |
++-------+-----+
+(14 rows)
+
+!ok
+EnumerableCalc(expr#0..3=[{inputs}], expr#4=[IS NULL($t3)], expr#5=[0:BIGINT], 
expr#6=[CASE($t4, $t5, $t3)], EMPNO=[$t0], $f1=[$t6])
+  EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($1, $2)], joinType=[left])
+    EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 2)], 
expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], EMPNO=[$t0], MGR=[$t3], 
$condition=[$t10])
+      EnumerableTableScan(table=[[scott, EMP]])
+    EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t2)], 
expr#4=[0], expr#5=[CASE($t3, $t2, $t4)], MGR0=[$t0], C=[$t5])
+      EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($0, $1)], 
joinType=[left])
+        EnumerableAggregate(group=[{3}])
+          EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 
2)], expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], proj#0..7=[{exprs}], 
$condition=[$t10])
+            EnumerableTableScan(table=[[scott, EMP]])
+        EnumerableAggregate(group=[{2}], C=[COUNT()])
+          EnumerableNestedLoopJoin(condition=[<($1, $2)], joinType=[inner])
+            EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], MGR=[$t3])
+              EnumerableTableScan(table=[[scott, EMP]])
+            EnumerableAggregate(group=[{3}])
+              EnumerableCalc(expr#0..7=[{inputs}], 
expr#8=[CAST($t5):DECIMAL(12, 2)], expr#9=[10.00:DECIMAL(12, 2)], 
expr#10=[>($t8, $t9)], proj#0..7=[{exprs}], $condition=[$t10])
+                EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+!set trimfields false
+
+SELECT empno, (SELECT COUNT(*) AS c
+FROM "scott".emp
+WHERE mgr < "t".mgr) AS "$f1"
+FROM "scott".emp as "t"
+WHERE sal > 10;
++-------+-----+
+| EMPNO | $f1 |
++-------+-----+
+|  7369 |  12 |
+|  7499 |   2 |
+|  7521 |   2 |
+|  7566 |   9 |
+|  7654 |   2 |
+|  7698 |   9 |
+|  7782 |   9 |
+|  7788 |   0 |
+|  7839 |   0 |
+|  7844 |   2 |
+|  7876 |   8 |
+|  7900 |   2 |
+|  7902 |   0 |
+|  7934 |   7 |
++-------+-----+
+(14 rows)
+
+!ok
+EnumerableCalc(expr#0..9=[{inputs}], expr#10=[IS NULL($t9)], 
expr#11=[0:BIGINT], expr#12=[CASE($t10, $t11, $t9)], EMPNO=[$t0], $f1=[$t12])
+  EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($3, $8)], joinType=[left])
+    EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 2)], 
expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], proj#0..7=[{exprs}], 
$condition=[$t10])
+      EnumerableTableScan(table=[[scott, EMP]])
+    EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t2)], 
expr#4=[0], expr#5=[CASE($t3, $t2, $t4)], MGR0=[$t0], C=[$t5])
+      EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($0, $1)], 
joinType=[left])
+        EnumerableAggregate(group=[{3}])
+          EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 
2)], expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], proj#0..7=[{exprs}], 
$condition=[$t10])
+            EnumerableTableScan(table=[[scott, EMP]])
+        EnumerableAggregate(group=[{8}], C=[COUNT()])
+          EnumerableNestedLoopJoin(condition=[<($3, $8)], joinType=[inner])
+            EnumerableTableScan(table=[[scott, EMP]])
+            EnumerableAggregate(group=[{3}])
+              EnumerableCalc(expr#0..7=[{inputs}], 
expr#8=[CAST($t5):DECIMAL(12, 2)], expr#9=[10.00:DECIMAL(12, 2)], 
expr#10=[>($t8, $t9)], proj#0..7=[{exprs}], $condition=[$t10])
+                EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
 # End scalar.iq

Reply via email to