This is an automated email from the ASF dual-hosted git repository.
xiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 45e41e8db6 [CALCITE-7348] Remove redundant extraction correlation
variables when Trim Project Fields
45e41e8db6 is described below
commit 45e41e8db6f7da147b9981aba856243c6c4736c4
Author: Xiong Duan <[email protected]>
AuthorDate: Tue Dec 30 16:33:25 2025 +0800
[CALCITE-7348] Remove redundant extraction correlation variables when Trim
Project Fields
---
.../apache/calcite/sql2rel/RelFieldTrimmer.java | 26 ++----
.../calcite/sql2rel/RelFieldTrimmerTest.java | 31 +++++---
core/src/test/resources/sql/scalar.iq | 93 ++++++++++++++++++++++
3 files changed, 116 insertions(+), 34 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
index 0c9c761b33..864ee9d6a0 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
@@ -548,27 +548,11 @@ public TrimResult trimFields(
}
}
- // Collect all the SubQueries in the projection list.
- List<RexSubQuery> subQueries = RexUtil.SubQueryCollector.collect(project);
- // Get all the correlationIds present in the SubQueries
- Set<CorrelationId> correlationIds =
RelOptUtil.getVariablesUsed(subQueries);
- ImmutableBitSet requiredColumns = ImmutableBitSet.of();
- if (!correlationIds.isEmpty()) {
- assert correlationIds.size() == 1;
- // Correlation columns are also needed by SubQueries, so add them to
inputFieldsUsed.
- requiredColumns =
RelOptUtil.correlationColumns(correlationIds.iterator().next(), project);
- }
-
ImmutableBitSet finderFields = inputFinder.build();
- ImmutableBitSet inputFieldsUsed = ImmutableBitSet.builder()
- .addAll(requiredColumns)
- .addAll(finderFields)
- .build();
-
// Create input with trimmed columns.
TrimResult trimResult =
- trimChild(project, input, inputFieldsUsed, inputExtraFields);
+ trimChild(project, input, finderFields, inputExtraFields);
RelNode newInput = trimResult.left;
final Mapping inputMapping = trimResult.right;
@@ -589,14 +573,14 @@ public TrimResult trimFields(
final List<RexNode> newProjects = new ArrayList<>();
final RexVisitor<RexNode> shuttle;
- if (!correlationIds.isEmpty()) {
- assert correlationIds.size() == 1;
+ if (!project.getVariablesSet().isEmpty()) {
shuttle = new RexPermuteInputsShuttle(inputMapping, newInput) {
@Override public RexNode visitSubQuery(RexSubQuery subQuery) {
subQuery = (RexSubQuery) super.visitSubQuery(subQuery);
return
RelOptUtil.remapCorrelatesInSuqQuery(relBuilder.getRexBuilder(),
- subQuery, correlationIds.iterator().next(), newInput.getRowType(),
inputMapping);
+ subQuery, project.getVariablesSet().iterator().next(),
+ newInput.getRowType(), inputMapping);
}
};
} else {
@@ -621,7 +605,7 @@ public TrimResult trimFields(
mapping);
relBuilder.push(newInput);
- relBuilder.project(newProjects, newRowType.getFieldNames(), false,
correlationIds);
+ relBuilder.project(newProjects, newRowType.getFieldNames(), false,
project.getVariablesSet());
final RelNode newProject = relBuilder.build();
return result(newProject, mapping, project);
}
diff --git
a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
index 77e6d89091..c23185c8c6 100644
--- a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
+++ b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java
@@ -33,6 +33,7 @@
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rex.RexCorrelVariable;
+import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParser;
@@ -42,6 +43,7 @@
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Holder;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.checkerframework.checker.nullness.qual.Nullable;
@@ -680,24 +682,27 @@ public static Frameworks.ConfigBuilder config() {
@Test void testTrimCorrelatedSubquery() {
final RelBuilder builder = RelBuilder.create(config().build());
final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
- RelNode root = builder.scan("EMP")
+ builder.scan("EMP")
.variable(v::set)
.filter(
builder.call(SqlStdOperatorTable.GREATER_THAN, builder.field(5),
- builder.literal(10)))
- .project(
- builder.field(0),
- builder.scalarQuery(
- b2 -> builder.scan("EMP").filter(
- builder.call(SqlStdOperatorTable.LESS_THAN,
- builder.field(3), builder.field(v.get(), "MGR")))
- .project(builder.field(0))
- .aggregate(builder.groupKey(), builder.countStar("c"))
- .build()))
- .build();
+ builder.literal(10)));
+ final ImmutableList.Builder<RexNode> projectsNode =
ImmutableList.builder();
+ projectsNode.add(builder.field(0));
+ projectsNode.add(
+ builder.scalarQuery(
+ b2 -> builder.scan("EMP").filter(
+ builder.call(SqlStdOperatorTable.LESS_THAN,
+ builder.field(3), builder.field(v.get(), "MGR")))
+ .project(builder.field(0))
+ .aggregate(builder.groupKey(), builder.countStar("c"))
+ .build()));
+ RelNode root =
+ builder.project(projectsNode.build(),
+ ImmutableList.of(), false,
ImmutableList.of(v.get().id)).build();
String origTree = ""
- + "LogicalProject(EMPNO=[$0], $f1=[$SCALAR_QUERY({\n"
+ + "LogicalProject(variablesSet=[[$cor0]], EMPNO=[$0],
$f1=[$SCALAR_QUERY({\n"
+ "LogicalAggregate(group=[{}], c=[COUNT()])\n"
+ " LogicalFilter(condition=[<($3, $cor0.MGR)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n})])\n"
diff --git a/core/src/test/resources/sql/scalar.iq
b/core/src/test/resources/sql/scalar.iq
index 82a1eb9223..4b5e186cc9 100644
--- a/core/src/test/resources/sql/scalar.iq
+++ b/core/src/test/resources/sql/scalar.iq
@@ -310,4 +310,97 @@ select
!ok
+# [CALCITE-7348] Remove redundant extraction correlation variables when Trim
Project Fields
+
+!set trimfields true
+
+SELECT empno, (SELECT COUNT(*) AS c
+FROM "scott".emp
+WHERE mgr < "t".mgr) AS "$f1"
+FROM "scott".emp as "t"
+WHERE sal > 10;
++-------+-----+
+| EMPNO | $f1 |
++-------+-----+
+| 7369 | 12 |
+| 7499 | 2 |
+| 7521 | 2 |
+| 7566 | 9 |
+| 7654 | 2 |
+| 7698 | 9 |
+| 7782 | 9 |
+| 7788 | 0 |
+| 7839 | 0 |
+| 7844 | 2 |
+| 7876 | 8 |
+| 7900 | 2 |
+| 7902 | 0 |
+| 7934 | 7 |
++-------+-----+
+(14 rows)
+
+!ok
+EnumerableCalc(expr#0..3=[{inputs}], expr#4=[IS NULL($t3)], expr#5=[0:BIGINT],
expr#6=[CASE($t4, $t5, $t3)], EMPNO=[$t0], $f1=[$t6])
+ EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($1, $2)], joinType=[left])
+ EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 2)],
expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], EMPNO=[$t0], MGR=[$t3],
$condition=[$t10])
+ EnumerableTableScan(table=[[scott, EMP]])
+ EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t2)],
expr#4=[0], expr#5=[CASE($t3, $t2, $t4)], MGR0=[$t0], C=[$t5])
+ EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($0, $1)],
joinType=[left])
+ EnumerableAggregate(group=[{3}])
+ EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12,
2)], expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], proj#0..7=[{exprs}],
$condition=[$t10])
+ EnumerableTableScan(table=[[scott, EMP]])
+ EnumerableAggregate(group=[{2}], C=[COUNT()])
+ EnumerableNestedLoopJoin(condition=[<($1, $2)], joinType=[inner])
+ EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], MGR=[$t3])
+ EnumerableTableScan(table=[[scott, EMP]])
+ EnumerableAggregate(group=[{3}])
+ EnumerableCalc(expr#0..7=[{inputs}],
expr#8=[CAST($t5):DECIMAL(12, 2)], expr#9=[10.00:DECIMAL(12, 2)],
expr#10=[>($t8, $t9)], proj#0..7=[{exprs}], $condition=[$t10])
+ EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+!set trimfields false
+
+SELECT empno, (SELECT COUNT(*) AS c
+FROM "scott".emp
+WHERE mgr < "t".mgr) AS "$f1"
+FROM "scott".emp as "t"
+WHERE sal > 10;
++-------+-----+
+| EMPNO | $f1 |
++-------+-----+
+| 7369 | 12 |
+| 7499 | 2 |
+| 7521 | 2 |
+| 7566 | 9 |
+| 7654 | 2 |
+| 7698 | 9 |
+| 7782 | 9 |
+| 7788 | 0 |
+| 7839 | 0 |
+| 7844 | 2 |
+| 7876 | 8 |
+| 7900 | 2 |
+| 7902 | 0 |
+| 7934 | 7 |
++-------+-----+
+(14 rows)
+
+!ok
+EnumerableCalc(expr#0..9=[{inputs}], expr#10=[IS NULL($t9)],
expr#11=[0:BIGINT], expr#12=[CASE($t10, $t11, $t9)], EMPNO=[$t0], $f1=[$t12])
+ EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($3, $8)], joinType=[left])
+ EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 2)],
expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], proj#0..7=[{exprs}],
$condition=[$t10])
+ EnumerableTableScan(table=[[scott, EMP]])
+ EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t2)],
expr#4=[0], expr#5=[CASE($t3, $t2, $t4)], MGR0=[$t0], C=[$t5])
+ EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($0, $1)],
joinType=[left])
+ EnumerableAggregate(group=[{3}])
+ EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12,
2)], expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], proj#0..7=[{exprs}],
$condition=[$t10])
+ EnumerableTableScan(table=[[scott, EMP]])
+ EnumerableAggregate(group=[{8}], C=[COUNT()])
+ EnumerableNestedLoopJoin(condition=[<($3, $8)], joinType=[inner])
+ EnumerableTableScan(table=[[scott, EMP]])
+ EnumerableAggregate(group=[{3}])
+ EnumerableCalc(expr#0..7=[{inputs}],
expr#8=[CAST($t5):DECIMAL(12, 2)], expr#9=[10.00:DECIMAL(12, 2)],
expr#10=[>($t8, $t9)], proj#0..7=[{exprs}], $condition=[$t10])
+ EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
# End scalar.iq