This is an automated email from the ASF dual-hosted git repository.
mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new c49792f9c7 [CALCITE-6032] Multilevel correlated query is failing in
RelDecorrelator code path
c49792f9c7 is described below
commit c49792f9c72159571f898c5fca1e26cba9870b07
Author: Hanumath Maduri <[email protected]>
AuthorDate: Fri Jan 19 11:14:46 2024 -0800
[CALCITE-6032] Multilevel correlated query is failing in RelDecorrelator
code path
---
.../java/org/apache/calcite/plan/RelOptUtil.java | 20 ++++++-
.../main/java/org/apache/calcite/rex/RexUtil.java | 23 ++++++++
.../apache/calcite/sql2rel/RelFieldTrimmer.java | 20 ++++++-
.../apache/calcite/sql2rel/SqlToRelConverter.java | 29 +++++++--
.../java/org/apache/calcite/test/JdbcTest.java | 69 ++++++++++++++++++++++
5 files changed, 154 insertions(+), 7 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
index e8c9a48809..5387acd8cf 100644
--- a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
+++ b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
@@ -271,7 +271,7 @@ public abstract class RelOptUtil {
}
/**
- * Returns a set of variables used by a relational expression or its
+ * Returns the set of variables used by a relational expression or its
* descendants.
*
* <p>The set may contain "duplicates" (variables with different ids that,
@@ -286,6 +286,24 @@ public abstract class RelOptUtil {
return visitor.vuv.variables;
}
+ /**
+ * Returns the set of variables used by the given list of sub-queries and
its descendants.
+ *
+ * @param subQueries The sub-queries containing correlation variables
+ * @return A list of correlation identifiers found within the sub-queries.
+ * The type of the [CorrelationId] parameter corresponds to
+ * {@link org.apache.calcite.rex.RexCorrelVariable#id}.
+ */
+ public static Set<CorrelationId> getVariablesUsed(List<RexSubQuery>
subQueries) {
+ // Internally this function calls getVariablesUsed on a RelNode to get all
the
+ // correlated variables in that RelNode
+ Set<CorrelationId> correlationIds = new HashSet<>();
+ for (RexSubQuery subQ : subQueries) {
+ correlationIds.addAll(getVariablesUsed(subQ.rel));
+ }
+ return correlationIds;
+ }
+
/** Finds which columns of a correlation variable are used within a
* relational expression. */
public static ImmutableBitSet correlationColumns(CorrelationId id,
diff --git a/core/src/main/java/org/apache/calcite/rex/RexUtil.java
b/core/src/main/java/org/apache/calcite/rex/RexUtil.java
index f9390a5b3c..5d7f473a20 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexUtil.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexUtil.java
@@ -2854,6 +2854,29 @@ public class RexUtil {
}
}
+ /** Visitor that collects all the top level SubQueries {@link RexSubQuery}
+ * in a projection list of a given {@link Project}.*/
+ public static class SubQueryCollector extends RexVisitorImpl<Void> {
+ private List<RexSubQuery> subQueries;
+ private SubQueryCollector() {
+ super(true);
+ this.subQueries = new ArrayList<>();
+ }
+
+ @Override public Void visitSubQuery(RexSubQuery subQuery) {
+ subQueries.add(subQuery);
+ return null;
+ }
+
+ public static List<RexSubQuery> collect(Project project) {
+ SubQueryCollector subQueryCollector = new SubQueryCollector();
+ for (RexNode node : project.getProjects()) {
+ node.accept(subQueryCollector);
+ }
+ return subQueryCollector.subQueries;
+ }
+ }
+
/** Visitor that throws {@link org.apache.calcite.util.Util.FoundOne} if
* applied to an expression that contains a {@link RexSubQuery}. */
public static class SubQueryFinder extends RexVisitorImpl<Void> {
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
index bfa69d0a4d..a8d99126ea 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
@@ -55,6 +55,7 @@ import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexPermuteInputsShuttle;
import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.SqlExplainFormat;
@@ -489,7 +490,24 @@ public class RelFieldTrimmer implements ReflectiveVisitor {
ord.e.accept(inputFinder);
}
}
- ImmutableBitSet inputFieldsUsed = inputFinder.build();
+
+ // Collect all the SubQueries in the projection list.
+ List<RexSubQuery> subQueries = RexUtil.SubQueryCollector.collect(project);
+ // Get all the correlationIds present in the SubQueries
+ Set<CorrelationId> correlationIds =
RelOptUtil.getVariablesUsed(subQueries);
+ ImmutableBitSet requiredColumns = ImmutableBitSet.of();
+ if (correlationIds.size() > 0) {
+ assert correlationIds.size() == 1;
+ // Correlation columns are also needed by SubQueries, so add them to
inputFieldsUsed.
+ requiredColumns =
RelOptUtil.correlationColumns(correlationIds.iterator().next(), project);
+ }
+
+ ImmutableBitSet finderFields = inputFinder.build();
+
+ ImmutableBitSet inputFieldsUsed = ImmutableBitSet.builder()
+ .addAll(requiredColumns)
+ .addAll(finderFields)
+ .build();
// Create input with trimmed columns.
TrimResult trimResult =
diff --git
a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
index 0998c05403..2b79213838 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
@@ -5491,18 +5491,25 @@ public class SqlToRelConverter {
builder.add(convertExpression(node));
}
final ImmutableList<RexNode> list = builder.build();
+ RelNode rel = root.rel;
+ // Fix the correlation namespaces and de-duplicate the correlation
variables.
+ CorrelationUse correlationUse = getCorrelationUse(this, root.rel);
+ if (correlationUse != null) {
+ rel = correlationUse.r;
+ }
+
switch (kind) {
case IN:
- return RexSubQuery.in(root.rel, list);
+ return RexSubQuery.in(rel, list);
case NOT_IN:
return rexBuilder.makeCall(SqlStdOperatorTable.NOT,
- RexSubQuery.in(root.rel, list));
+ RexSubQuery.in(rel, list));
case SOME:
- return RexSubQuery.some(root.rel, list,
+ return RexSubQuery.some(rel, list,
(SqlQuantifyOperator) call.getOperator());
case ALL:
return rexBuilder.makeCall(SqlStdOperatorTable.NOT,
- RexSubQuery.some(root.rel, list,
+ RexSubQuery.some(rel, list,
negate((SqlQuantifyOperator) call.getOperator())));
default:
throw new AssertionError(kind);
@@ -5515,6 +5522,12 @@ public class SqlToRelConverter {
query = Iterables.getOnlyElement(call.getOperandList());
root = convertQueryRecursive(query, false, null);
RelNode rel = root.rel;
+ // Fix the correlation namespaces and de-duplicate the correlation
variables.
+ CorrelationUse correlationUse = getCorrelationUse(this, root.rel);
+ if (correlationUse != null) {
+ rel = correlationUse.r;
+ }
+
while (rel instanceof Project
|| rel instanceof Sort
&& ((Sort) rel).fetch == null
@@ -5533,7 +5546,13 @@ public class SqlToRelConverter {
call = (SqlCall) expr;
query = Iterables.getOnlyElement(call.getOperandList());
root = convertQueryRecursive(query, false, null);
- return RexSubQuery.scalar(root.rel);
+ rel = root.rel;
+ // Fix the correlation namespaces and de-duplicate the correlation
variables.
+ correlationUse = getCorrelationUse(this, root.rel);
+ if (correlationUse != null) {
+ rel = correlationUse.r;
+ }
+ return RexSubQuery.scalar(rel);
case ARRAY_QUERY_CONSTRUCTOR:
call = (SqlCall) expr;
diff --git a/core/src/test/java/org/apache/calcite/test/JdbcTest.java
b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
index 5cf1177f7f..af4bf13ae8 100644
--- a/core/src/test/java/org/apache/calcite/test/JdbcTest.java
+++ b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
@@ -8270,6 +8270,75 @@ public class JdbcTest {
.returns("EXPR$0=[1, 1.1]\n");
}
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-6032">[CALCITE-6032]
+ * NullPointerException in Reldecorrelator for a Multi level correlated
subquery</a>. */
+ @Test void testMultiLevelDecorrelation() throws Exception {
+ String hsqldbMemUrl = "jdbc:hsqldb:mem:.";
+ Connection baseConnection = DriverManager.getConnection(hsqldbMemUrl);
+ Statement baseStmt = baseConnection.createStatement();
+ baseStmt.execute("create table invoice (inv_id integer, col1\n"
+ + "integer, inv_amt integer)");
+ baseStmt.execute("create table item(item_id integer, item_amt\n"
+ + "integer, item_col1 integer, item_col2 integer, item_col3\n"
+ + "integer,item_col4 integer )");
+ baseStmt.execute("INSERT INTO invoice VALUES (1, 1, 1)");
+ baseStmt.execute("INSERT INTO invoice VALUES (2, 2, 2)");
+ baseStmt.execute("INSERT INTO invoice VALUES (3, 3, 3)");
+ baseStmt.execute("INSERT INTO item values (1, 1, 1, 1, 1, 1)");
+ baseStmt.execute("INSERT INTO item values (2, 2, 2, 2, 2, 2)");
+ baseStmt.close();
+ baseConnection.commit();
+
+ Properties info = new Properties();
+ info.put("model",
+ "inline:"
+ + "{\n"
+ + " version: '1.0',\n"
+ + " defaultSchema: 'BASEJDBC',\n"
+ + " schemas: [\n"
+ + " {\n"
+ + " type: 'jdbc',\n"
+ + " name: 'BASEJDBC',\n"
+ + " jdbcDriver: '" + jdbcDriver.class.getName() + "',\n"
+ + " jdbcUrl: '" + hsqldbMemUrl + "',\n"
+ + " jdbcCatalog: null,\n"
+ + " jdbcSchema: null\n"
+ + " }\n"
+ + " ]\n"
+ + "}");
+
+ Connection calciteConnection =
+ DriverManager.getConnection("jdbc:calcite:", info);
+
+ String statement = "SELECT Sum(invoice.inv_amt * (\n"
+ + " SELECT max(mainrate.item_id + mainrate.item_amt)\n"
+ + " FROM item AS mainrate\n"
+ + " WHERE mainrate.item_col1 is not null\n"
+ + " AND mainrate.item_col2 is not null\n"
+ + " AND mainrate.item_col3 = invoice.col1\n"
+ + " AND mainrate.item_col4 = (\n"
+ + " SELECT max(cr.item_col4)\n"
+ + " FROM item AS cr\n"
+ + " WHERE cr.item_col3 = mainrate.item_col3\n"
+ + " AND cr.item_col1 =\n"
+ + "mainrate.item_col1\n"
+ + " AND cr.item_col2 =\n"
+ + "mainrate.item_col2 \n"
+ + " AND cr.item_col4 <=\n"
+ + "invoice.inv_id))) AS invamount,\n"
+ + "count(*) AS invcount\n"
+ + "FROM invoice\n"
+ + "WHERE invoice.inv_amt < 10 AND invoice.inv_amt > 0";
+ ResultSet rs =
calciteConnection.prepareStatement(statement).executeQuery();
+ assert rs.next();
+ assertEquals(rs.getInt(1), 10);
+ assertEquals(rs.getInt(2), 3);
+ assert !rs.next();
+ rs.close();
+ calciteConnection.close();
+ }
+
/** Test case for
* <a
href="https://issues.apache.org/jira/browse/CALCITE-5414">[CALCITE-5414]</a>
* Convert between standard Gregorian and proleptic Gregorian calendars for