This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 47fdf61a02 [CALCITE-7394] Nested sub-query with multiple levels of 
correlation returns incorrect results
47fdf61a02 is described below

commit 47fdf61a0243e4efcfa4a318550844e21d51a996
Author: iwanttobepowerful <[email protected]>
AuthorDate: Mon Jan 26 19:37:15 2026 +0800

    [CALCITE-7394] Nested sub-query with multiple levels of correlation returns 
incorrect results
---
 .../apache/calcite/sql2rel/RelDecorrelator.java    |  35 +++--
 .../calcite/sql2rel/RelDecorrelatorTest.java       | 132 ++++++++++++++++
 core/src/test/resources/sql/sub-query.iq           | 175 +++++++++++++++++++++
 3 files changed, 330 insertions(+), 12 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java 
b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
index f0b7ad4098..47c0040fc3 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
@@ -954,16 +954,12 @@ private RelNode rewriteScalarAggregate(Aggregate oldRel,
       RelNode newRel,
       Map<Integer, Integer> outputMap,
       NavigableMap<CorDef, Integer> corDefOutputs) {
-    final CorelMap localCorelMap = new CorelMapBuilder().build(oldRel);
-    final List<CorRef> corVarList = new 
ArrayList<>(localCorelMap.mapRefRelToCorRef.values());
-    Collections.sort(corVarList);
-
+    final List<CorRef> corVarList = collectExternalCorVars(oldRel);
     final NavigableMap<CorDef, Integer> valueGenCorDefOutputs = new 
TreeMap<>();
     final RelNode valueGen =
         requireNonNull(createValueGenerator(corVarList, 0, 
valueGenCorDefOutputs));
     final int valueGenFieldCount = valueGen.getRowType().getFieldCount();
 
-    // Build join conditions
     final Map<Integer, RexNode> newProjectMap = new HashMap<>();
     for (Map.Entry<CorDef, Integer> corDefOutput : corDefOutputs.entrySet()) {
       final CorDef corDef = corDefOutput.getKey();
@@ -974,6 +970,7 @@ private RelNode rewriteScalarAggregate(Aggregate oldRel,
       newProjectMap.put(valueGenFieldCount + rightPos, leftRef);
     }
 
+    // Build join conditions
     final List<RexNode> conditions =
         buildCorDefJoinConditions(valueGenCorDefOutputs, corDefOutputs,
             valueGen, newRel, relBuilder);
@@ -1260,10 +1257,7 @@ private static void shiftMapping(Map<Integer, Integer> 
mapping, int startIndex,
       return decorrelateRel((RelNode) rel, false, parentPropagatesNullValues);
     }
 
-    final CorelMap localCorelMap = new CorelMapBuilder().build(rel);
-    final List<CorRef> corVarList = new 
ArrayList<>(localCorelMap.mapRefRelToCorRef.values());
-    Collections.sort(corVarList);
-
+    final List<CorRef> corVarList = collectExternalCorVars(rel);
     final NavigableMap<CorDef, Integer> valueGenCorDefOutputs = new 
TreeMap<>();
     final RelNode valueGen =
         requireNonNull(createValueGenerator(corVarList, 0, 
valueGenCorDefOutputs));
@@ -1958,9 +1952,7 @@ private static boolean isWidening(RelDataType type, 
RelDataType type1) {
     }
 
     // 1. Collect all CorRefs involved
-    final CorelMap localCorelMap = new CorelMapBuilder().build(rel);
-    final List<CorRef> corVarList = new 
ArrayList<>(localCorelMap.mapRefRelToCorRef.values());
-    Collections.sort(corVarList);
+    final List<CorRef> corVarList = collectExternalCorVars(rel);
 
     // 2. Ensure CorVars are present in inputs (adding ValueGenerators if 
needed)
     Frame newLeftFrame = leftFrame;
@@ -3849,6 +3841,25 @@ private static boolean isFieldNotNullRecursive(RelNode 
rel, int index) {
     }
   }
 
+  /**
+   * Collects all correlated variables used in the given relational expression
+   * that are not defined within the expression itself.
+   *
+   * @param rel The relational expression to inspect
+   * @return A sorted list of external correlated variables
+   */
+  private static List<CorRef> collectExternalCorVars(RelNode rel) {
+    final CorelMap localCorelMap = new CorelMapBuilder().build(rel);
+    final List<CorRef> corVarList = new ArrayList<>();
+    for (CorRef corVar : localCorelMap.mapRefRelToCorRef.values()) {
+      if (!localCorelMap.mapCorToCorRel.containsKey(corVar.corr)) {
+        corVarList.add(corVar);
+      }
+    }
+    Collections.sort(corVarList);
+    return corVarList;
+  }
+
   /**
    * Ensures that the correlated variables in {@code allCorDefs} are present
    * in the output of the frame.
diff --git 
a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java 
b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
index 6ac711e406..7ddaeaf420 100644
--- a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
+++ b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
@@ -356,6 +356,138 @@ public static Frameworks.ConfigBuilder config() {
     assertThat(after, hasTree(planAfter));
   }
 
+  /** Test case for <a 
href="https://issues.apache.org/jira/browse/CALCITE-7394";>[CALCITE-7394]
+   * Nested sub-query with multiple levels of correlation returns incorrect 
results</a>. */
+  @Test void testNestedSubQueryWithMultiLevelCorrelation() {
+    final FrameworkConfig frameworkConfig = config().build();
+    final RelBuilder builder = RelBuilder.create(frameworkConfig);
+    final RelOptCluster cluster = builder.getCluster();
+    final Planner planner = Frameworks.getPlanner(frameworkConfig);
+    final String sql = ""
+        + "select d.dname,\n"
+        + "  (select count(*)\n"
+        + "   from emp e\n"
+        + "   where e.deptno = d.deptno\n"
+        + "   and exists (\n"
+        + "     select 1\n"
+        + "     from (values (1000), (2000), (3000)) as v(sal)\n"
+        + "     where e.sal > v.sal\n"
+        + "     and d.deptno * 100 < v.sal\n"
+        + "   )\n"
+        + "  ) as c\n"
+        + "from dept d\n"
+        + "order by d.dname";
+    final RelNode originalRel;
+    try {
+      final SqlNode parse = planner.parse(sql);
+      final SqlNode validate = planner.validate(parse);
+      originalRel = planner.rel(validate).rel;
+    } catch (Exception e) {
+      throw TestUtil.rethrow(e);
+    }
+
+    final HepProgram hepProgram = HepProgram.builder()
+        .addRuleCollection(
+            ImmutableList.of(
+                // SubQuery program rules
+                CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
+                CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
+                CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
+        .build();
+    final Program program =
+        Programs.of(hepProgram, true,
+            requireNonNull(cluster.getMetadataProvider()));
+    final RelNode before =
+        program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
+            Collections.emptyList(), Collections.emptyList());
+    final String planBefore = ""
+        + "LogicalSort(sort0=[$0], dir0=[ASC])\n"
+        + "  LogicalProject(DNAME=[$1], C=[$3])\n"
+        + "    LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{0}])\n"
+        + "      LogicalTableScan(table=[[scott, DEPT]])\n"
+        + "      LogicalAggregate(group=[{}], EXPR$0=[COUNT()])\n"
+        + "        LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], 
HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n"
+        + "          LogicalFilter(condition=[=($7, $cor0.DEPTNO)])\n"
+        + "            LogicalCorrelate(correlation=[$cor1], joinType=[inner], 
requiredColumns=[{5}])\n"
+        + "              LogicalTableScan(table=[[scott, EMP]])\n"
+        + "              LogicalAggregate(group=[{0}])\n"
+        + "                LogicalProject(i=[true])\n"
+        + "                  
LogicalFilter(condition=[AND(>(CAST($cor1.SAL):DECIMAL(12, 2), 
CAST($0):DECIMAL(12, 2) NOT NULL), <(*($cor0.DEPTNO, 100), $0))])\n"
+        + "                    LogicalValues(tuples=[[{ 1000 }, { 2000 }, { 
3000 }]])\n";
+    assertThat(before, hasTree(planBefore));
+
+    // Decorrelate without any rules, just "purely" decorrelation algorithm on 
RelDecorrelator
+    final RelNode after =
+        RelDecorrelator.decorrelateQuery(before, builder, 
RuleSets.ofList(Collections.emptyList()),
+            RuleSets.ofList(Collections.emptyList()));
+    // before fix:
+    //
+    // LogicalSort(sort0=[$0], dir0=[ASC])
+    //  LogicalProject(DNAME=[$1], C=[$7])
+    //    LogicalJoin(condition=[AND(=($0, $5), =($4, $6))], joinType=[left])
+    //      LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], DEPTNO0=[$0], 
$f4=[*($0, 100)])
+    //        LogicalTableScan(table=[[scott, DEPT]])
+    //      LogicalProject(DEPTNO8=[$0], $f4=[$1], EXPR$0=[CASE(IS NOT 
NULL($5), $5, 0)])
+    //        LogicalJoin(condition=[AND(IS NOT DISTINCT FROM($0, $3),
+    //                                   IS NOT DISTINCT FROM($1, $4))], 
joinType=[left])
+    //          LogicalJoin(condition=[true], joinType=[inner])       // <---- 
error part
+    //            LogicalProject(DEPTNO=[$0], $f4=[*($0, 100)])
+    //              LogicalTableScan(table=[[scott, DEPT]])
+    //            LogicalAggregate(group=[{0}])                       // <---- 
error part
+    //              LogicalProject(SAL0=[CAST($5):DECIMAL(12, 2)])    // <---- 
error part
+    //                LogicalTableScan(table=[[scott, EMP]])          // <---- 
error part
+    //          LogicalAggregate(group=[{0, 1}], EXPR$0=[COUNT()])
+    //            LogicalProject(DEPTNO8=[$7], $f4=[$9])
+    //              LogicalFilter(condition=[IS NOT NULL($7)])
+    //                LogicalProject(..., DEPTNO=[$7], i=[$11], $f4=[$9])
+    //                  LogicalJoin(condition=[=($8, $10)], joinType=[inner])
+    //                    LogicalProject(..., SAL0=[CAST($5):DECIMAL(12, 2)])
+    //                      LogicalTableScan(table=[[scott, EMP]])
+    //                    LogicalProject($f4=[$0], SAL0=[$1], $f2=[true])
+    //                      LogicalAggregate(group=[{0, 1}])
+    //                        LogicalProject($f4=[$1], SAL0=[$2])
+    //                          LogicalJoin(condition=[AND(>($2, 
CAST($0):DECIMAL(12, 2) NOT NULL),
+    //                                                    <($1, $0))], 
joinType=[inner])
+    //                            LogicalValues(tuples=[[{ 1000 }, { 2000 }, { 
3000 }]])
+    //                            LogicalJoin(condition=[true], 
joinType=[inner])
+    //                              LogicalAggregate(group=[{0}])
+    //                                LogicalProject($f4=[*($0, 100)])
+    //                                  LogicalTableScan(table=[[scott, DEPT]])
+    //                              LogicalAggregate(group=[{0}])
+    //                                
LogicalProject(SAL0=[CAST($5):DECIMAL(12, 2)])
+    //                                  LogicalTableScan(table=[[scott, EMP]])
+    final String planAfter = ""
+        + "LogicalSort(sort0=[$0], dir0=[ASC])\n"
+        + "  LogicalProject(DNAME=[$1], C=[$7])\n"
+        + "    LogicalJoin(condition=[AND(=($0, $5), =($4, $6))], 
joinType=[left])\n"
+        + "      LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], 
DEPTNO0=[$0], $f4=[*($0, 100)])\n"
+        + "        LogicalTableScan(table=[[scott, DEPT]])\n"
+        + "      LogicalProject(DEPTNO8=[$0], $f4=[$1], EXPR$0=[CASE(IS NOT 
NULL($4), $4, 0)])\n"
+        + "        LogicalJoin(condition=[AND(IS NOT DISTINCT FROM($0, $2), IS 
NOT DISTINCT FROM($1, $3))], joinType=[left])\n"
+        + "          LogicalProject(DEPTNO=[$0], $f4=[*($0, 100)])\n"
+        + "            LogicalTableScan(table=[[scott, DEPT]])\n"
+        + "          LogicalAggregate(group=[{0, 1}], EXPR$0=[COUNT()])\n"
+        + "            LogicalProject(DEPTNO8=[$7], $f4=[$9])\n"
+        + "              LogicalFilter(condition=[IS NOT NULL($7)])\n"
+        + "                LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], 
MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], i=[$11], $f4=[$9])\n"
+        + "                  LogicalJoin(condition=[=($8, $10)], 
joinType=[inner])\n"
+        + "                    LogicalProject(EMPNO=[$0], ENAME=[$1], 
JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], 
SAL0=[CAST($5):DECIMAL(12, 2)])\n"
+        + "                      LogicalTableScan(table=[[scott, EMP]])\n"
+        + "                    LogicalProject($f4=[$0], SAL0=[$1], 
$f2=[true])\n"
+        + "                      LogicalAggregate(group=[{0, 1}])\n"
+        + "                        LogicalProject($f4=[$1], SAL0=[$2])\n"
+        + "                          LogicalJoin(condition=[AND(>($2, 
CAST($0):DECIMAL(12, 2) NOT NULL), <($1, $0))], joinType=[inner])\n"
+        + "                            LogicalValues(tuples=[[{ 1000 }, { 2000 
}, { 3000 }]])\n"
+        + "                            LogicalJoin(condition=[true], 
joinType=[inner])\n"
+        + "                              LogicalAggregate(group=[{0}])\n"
+        + "                                LogicalProject($f4=[*($0, 100)])\n"
+        + "                                  LogicalTableScan(table=[[scott, 
DEPT]])\n"
+        + "                              LogicalAggregate(group=[{0}])\n"
+        + "                                
LogicalProject(SAL0=[CAST($5):DECIMAL(12, 2)])\n"
+        + "                                  LogicalTableScan(table=[[scott, 
EMP]])\n";
+    assertThat(after, hasTree(planAfter));
+  }
+
   /** Test case for <a 
href="https://issues.apache.org/jira/browse/CALCITE-7297";>[CALCITE-7297]
    * The result is incorrect when the GROUP BY key in a subquery is a 
RexFieldAccess</a>. */
   @Test void testSkipsRedundantValueGenerator() {
diff --git a/core/src/test/resources/sql/sub-query.iq 
b/core/src/test/resources/sql/sub-query.iq
index 0d649f8558..a9440f6f43 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -5617,6 +5617,181 @@ ORDER BY deptno;
 
 !ok
 
+# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns 
incorrect results
+select d.dname,
+  (select count(*)
+   from emp e
+   where e.deptno = d.deptno
+   and e.sal > (
+     select min(s.losal)
+     from (VALUES (1, 700, 1200), (2, 1201, 1400), (3, 1401, 2000), (4, 2001, 
3000), (5, 3001, 9999)) AS s(grade, losal, hisal)
+     where e.sal BETWEEN s.losal AND s.hisal
+     and s.hisal > d.deptno * 10
+   )
+  ) as high_paid_count
+from dept d
+order by d.dname;
++------------+-----------------+
+| DNAME      | HIGH_PAID_COUNT |
++------------+-----------------+
+| ACCOUNTING |               3 |
+| OPERATIONS |               0 |
+| RESEARCH   |               5 |
+| SALES      |               6 |
++------------+-----------------+
+(4 rows)
+
+!ok
+
+# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns 
incorrect results
+select e.ename
+from emp e
+where e.sal > (
+  select avg(e2.sal)
+  from emp e2
+  where e2.deptno = e.deptno
+  and exists (
+     select 1
+     from (values (7369, 20)) as b(empno, deptno)
+     where b.empno = e2.empno
+     and b.deptno = e.deptno
+  )
+)
+and e.sal < 2000
+order by e.ename;
++-------+
+| ENAME |
++-------+
+| ADAMS |
++-------+
+(1 row)
+
+!ok
+
+# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns 
incorrect results
+select d.deptno
+from dept d
+where exists (
+  select 1
+  from emp e
+  where e.deptno = d.deptno
+  and exists (
+    select 1
+    from (VALUES (1, 700, 1200), (2, 1201, 1400), (3, 1401, 2000), (4, 2001, 
3000), (5, 3001, 9999)) AS s(grade, losal, hisal)
+    where s.grade = 1
+    and s.hisal >= e.sal
+    and s.losal <= d.deptno * 20
+  )
+)
+order by d.deptno;
++--------+
+| DEPTNO |
++--------+
++--------+
+(0 rows)
+
+!ok
+
+# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns 
incorrect results
+select e.ename
+from emp e
+where e.deptno in (
+  select d.deptno
+  from dept d
+  where d.deptno = e.deptno and d.deptno = 10
+  union
+  select d.deptno
+  from dept d
+  where d.deptno = e.deptno
+  and exists (
+    select 1
+    from emp e2
+    where e2.deptno = d.deptno
+    and e2.empno = e.empno
+    and e2.sal > 2000
+  )
+)
+order by e.ename;
++--------+
+| ENAME  |
++--------+
+| BLAKE  |
+| CLARK  |
+| FORD   |
+| JONES  |
+| KING   |
+| MILLER |
+| SCOTT  |
++--------+
+(7 rows)
+
+!ok
+
+# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns 
incorrect results
+select e.ename
+from emp e
+where exists (
+  select 1
+  from dept d
+  join emp e2 on d.deptno = e2.deptno
+  where d.deptno = e.deptno
+  and exists (
+    select 1
+    from (values (10), (20), (30)) as v(deptno)
+    where v.deptno = e2.deptno
+    and v.deptno = e.deptno
+  )
+  and e2.empno = e.empno
+)
+order by e.ename;
++--------+
+| ENAME  |
++--------+
+| ADAMS  |
+| ALLEN  |
+| BLAKE  |
+| CLARK  |
+| FORD   |
+| JAMES  |
+| JONES  |
+| KING   |
+| MARTIN |
+| MILLER |
+| SCOTT  |
+| SMITH  |
+| TURNER |
+| WARD   |
++--------+
+(14 rows)
+
+!ok
+
+# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns 
incorrect results
+select d.dname,
+  (select count(*)
+   from emp e
+   where e.deptno = d.deptno
+   and exists (
+     select 1
+     from (values (1000), (2000), (3000)) as v(sal)
+     where e.sal > v.sal
+     and d.deptno * 100 < v.sal
+   )
+  ) as c
+from dept d
+order by d.dname;
++------------+---+
+| DNAME      | C |
++------------+---+
+| ACCOUNTING | 2 |
+| OPERATIONS | 0 |
+| RESEARCH   | 0 |
+| SALES      | 0 |
++------------+---+
+(4 rows)
+
+!ok
+
 # [CALCITE-7303] Subqueries cannot be decorrelated if filter condition have 
multi CorrelationId
 SELECT deptno
 FROM emp e

Reply via email to