This is an automated email from the ASF dual-hosted git repository.

rubenql pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c4f3bb  [CALCITE-3673] ListTransientTable should not leave tables in 
the schema [CALCITE-4054] RepeatUnion containing a Correlate with a 
transientScan on its RHS causes NPE
9c4f3bb is described below

commit 9c4f3bb540dd67a0ffefc09f4ebd98d2be65bb14
Author: rubenada <rube...@gmail.com>
AuthorDate: Thu Jan 13 14:12:29 2022 +0000

    [CALCITE-3673] ListTransientTable should not leave tables in the schema
    [CALCITE-4054] RepeatUnion containing a Correlate with a transientScan on 
its RHS causes NPE
---
 .../adapter/enumerable/EnumerableRepeatUnion.java  | 42 +++++++++-
 .../enumerable/EnumerableRepeatUnionRule.java      |  3 +-
 .../org/apache/calcite/jdbc/CalciteSchema.java     |  4 +
 .../apache/calcite/prepare/RelOptTableImpl.java    |  4 +
 .../org/apache/calcite/rel/core/RelFactories.java  |  6 +-
 .../org/apache/calcite/rel/core/RepeatUnion.java   | 23 +++++-
 .../calcite/rel/logical/LogicalRepeatUnion.java    | 19 +++--
 .../java/org/apache/calcite/schema/SchemaPlus.java |  7 ++
 .../calcite/schema/impl/ListTransientTable.java    | 11 ++-
 .../java/org/apache/calcite/tools/RelBuilder.java  |  2 +-
 .../org/apache/calcite/util/BuiltInMethod.java     |  5 +-
 .../test/enumerable/EnumerableRepeatUnionTest.java | 91 ++++++++++++++++++++++
 .../apache/calcite/linq4j/EnumerableDefaults.java  |  7 +-
 13 files changed, 198 insertions(+), 26 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnion.java
 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnion.java
index 9fc2764..c6e3fda 100644
--- 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnion.java
+++ 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnion.java
@@ -17,17 +17,23 @@
 package org.apache.calcite.adapter.enumerable;
 
 import org.apache.calcite.linq4j.function.Experimental;
+import org.apache.calcite.linq4j.function.Function0;
 import org.apache.calcite.linq4j.tree.BlockBuilder;
 import org.apache.calcite.linq4j.tree.Expression;
 import org.apache.calcite.linq4j.tree.Expressions;
 import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptTable;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.RepeatUnion;
+import org.apache.calcite.schema.TransientTable;
 import org.apache.calcite.util.BuiltInMethod;
 import org.apache.calcite.util.Util;
 
+import org.checkerframework.checker.nullness.qual.Nullable;
+
 import java.util.List;
+import java.util.Objects;
 
 /**
  * Implementation of {@link RepeatUnion} in
@@ -43,14 +49,15 @@ public class EnumerableRepeatUnion extends RepeatUnion 
implements EnumerableRel
    * Creates an EnumerableRepeatUnion.
    */
   EnumerableRepeatUnion(RelOptCluster cluster, RelTraitSet traitSet,
-      RelNode seed, RelNode iterative, boolean all, int iterationLimit) {
-    super(cluster, traitSet, seed, iterative, all, iterationLimit);
+      RelNode seed, RelNode iterative, boolean all, int iterationLimit,
+      @Nullable RelOptTable transientTable) {
+    super(cluster, traitSet, seed, iterative, all, iterationLimit, 
transientTable);
   }
 
   @Override public EnumerableRepeatUnion copy(RelTraitSet traitSet, 
List<RelNode> inputs) {
     assert inputs.size() == 2;
     return new EnumerableRepeatUnion(getCluster(), traitSet,
-        inputs.get(0), inputs.get(1), all, iterationLimit);
+        inputs.get(0), inputs.get(1), all, iterationLimit, transientTable);
   }
 
   @Override public Result implement(EnumerableRelImplementor implementor, 
Prefer pref) {
@@ -61,6 +68,32 @@ public class EnumerableRepeatUnion extends RepeatUnion 
implements EnumerableRel
     RelNode seed = getSeedRel();
     RelNode iteration = getIterativeRel();
 
+    Expression cleanUpFunctionExp = Expressions.constant(null);
+    if (transientTable != null) {
+      // root.getRootSchema().add(tableName, table);
+      Expression tableExp = implementor.stash(
+          Objects.requireNonNull(transientTable.unwrap(TransientTable.class)),
+          TransientTable.class);
+      String tableName =
+          
transientTable.getQualifiedName().get(transientTable.getQualifiedName().size() 
- 1);
+      Expression tableNameExp = Expressions.constant(tableName, String.class);
+      builder.append(
+          Expressions.call(
+              Expressions.call(
+                  implementor.getRootExpression(),
+                  BuiltInMethod.DATA_CONTEXT_GET_ROOT_SCHEMA.method),
+              BuiltInMethod.SCHEMA_PLUS_ADD_TABLE.method,
+              tableNameExp,
+              tableExp));
+      // root.getRootSchema().removeTable(tableName);
+      cleanUpFunctionExp = Expressions.lambda(Function0.class,
+          Expressions.call(
+              Expressions.call(
+                  implementor.getRootExpression(),
+                  BuiltInMethod.DATA_CONTEXT_GET_ROOT_SCHEMA.method),
+              BuiltInMethod.SCHEMA_PLUS_REMOVE_TABLE.method, tableNameExp));
+    }
+
     Result seedResult = implementor.visitChild(this, 0, (EnumerableRel) seed, 
pref);
     Result iterationResult = implementor.visitChild(this, 1, (EnumerableRel) 
iteration, pref);
 
@@ -78,7 +111,8 @@ public class EnumerableRepeatUnion extends RepeatUnion 
implements EnumerableRel
         iterativeExp,
         Expressions.constant(iterationLimit, int.class),
         Expressions.constant(all, boolean.class),
-        Util.first(physType.comparer(), 
Expressions.call(BuiltInMethod.IDENTITY_COMPARER.method)));
+        Util.first(physType.comparer(), 
Expressions.call(BuiltInMethod.IDENTITY_COMPARER.method)),
+        cleanUpFunctionExp);
     builder.add(unionExp);
 
     return implementor.result(physType, builder.toBlock());
diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnionRule.java
 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnionRule.java
index cb3cdf3..308b2aa 100644
--- 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnionRule.java
+++ 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnionRule.java
@@ -54,6 +54,7 @@ public class EnumerableRepeatUnionRule extends ConverterRule {
         convert(seedRel, seedRel.getTraitSet().replace(out)),
         convert(iterativeRel, iterativeRel.getTraitSet().replace(out)),
         union.all,
-        union.iterationLimit);
+        union.iterationLimit,
+        union.getTransientTable());
   }
 }
diff --git a/core/src/main/java/org/apache/calcite/jdbc/CalciteSchema.java 
b/core/src/main/java/org/apache/calcite/jdbc/CalciteSchema.java
index e0afc8e..05eee7f 100644
--- a/core/src/main/java/org/apache/calcite/jdbc/CalciteSchema.java
+++ b/core/src/main/java/org/apache/calcite/jdbc/CalciteSchema.java
@@ -730,6 +730,10 @@ public abstract class CalciteSchema {
       CalciteSchema.this.add(name, table);
     }
 
+    @Override public boolean removeTable(String name) {
+      return CalciteSchema.this.removeTable(name);
+    }
+
     @Override public void add(String name, Function function) {
       CalciteSchema.this.add(name, function);
     }
diff --git a/core/src/main/java/org/apache/calcite/prepare/RelOptTableImpl.java 
b/core/src/main/java/org/apache/calcite/prepare/RelOptTableImpl.java
index fac9293..0878292 100644
--- a/core/src/main/java/org/apache/calcite/prepare/RelOptTableImpl.java
+++ b/core/src/main/java/org/apache/calcite/prepare/RelOptTableImpl.java
@@ -473,6 +473,10 @@ public class RelOptTableImpl extends 
Prepare.AbstractPreparingTable {
       throw new UnsupportedOperationException();
     }
 
+    @Override public boolean removeTable(String name) {
+      throw new UnsupportedOperationException();
+    }
+
     @Override public void add(String name,
         org.apache.calcite.schema.Function function) {
       throw new UnsupportedOperationException();
diff --git a/core/src/main/java/org/apache/calcite/rel/core/RelFactories.java 
b/core/src/main/java/org/apache/calcite/rel/core/RelFactories.java
index 0998c99..6d5e2c0 100644
--- a/core/src/main/java/org/apache/calcite/rel/core/RelFactories.java
+++ b/core/src/main/java/org/apache/calcite/rel/core/RelFactories.java
@@ -610,7 +610,7 @@ public class RelFactories {
   public interface RepeatUnionFactory {
     /** Creates a {@link RepeatUnion}. */
     RelNode createRepeatUnion(RelNode seed, RelNode iterative, boolean all,
-        int iterationLimit);
+        int iterationLimit, RelOptTable table);
   }
 
   /**
@@ -619,8 +619,8 @@ public class RelFactories {
    */
   private static class RepeatUnionFactoryImpl implements RepeatUnionFactory {
     @Override public RelNode createRepeatUnion(RelNode seed, RelNode iterative,
-        boolean all, int iterationLimit) {
-      return LogicalRepeatUnion.create(seed, iterative, all, iterationLimit);
+        boolean all, int iterationLimit, RelOptTable table) {
+      return LogicalRepeatUnion.create(seed, iterative, all, iterationLimit, 
table);
     }
   }
 
diff --git a/core/src/main/java/org/apache/calcite/rel/core/RepeatUnion.java 
b/core/src/main/java/org/apache/calcite/rel/core/RepeatUnion.java
index f1419d7..073861e 100644
--- a/core/src/main/java/org/apache/calcite/rel/core/RepeatUnion.java
+++ b/core/src/main/java/org/apache/calcite/rel/core/RepeatUnion.java
@@ -18,15 +18,20 @@ package org.apache.calcite.rel.core;
 
 import org.apache.calcite.linq4j.function.Experimental;
 import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptTable;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.BiRel;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.RelWriter;
 import org.apache.calcite.rel.metadata.RelMetadataQuery;
 import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.schema.TransientTable;
 import org.apache.calcite.util.Util;
 
+import org.checkerframework.checker.nullness.qual.Nullable;
+
 import java.util.List;
+import java.util.Objects;
 
 /**
  * Relational expression that computes a repeat union (recursive union in SQL
@@ -40,7 +45,7 @@ import java.util.List;
  *
  * <li>Evaluate the right input (i.e., iterative relational expression) over 
and
  *   over until it produces no more results (or until an optional maximum 
number
- *   of iterations is reached).  For UNION (but not UNION ALL), discard
+ *   of iterations is reached). For UNION (but not UNION ALL), discard
  *   duplicated results.
  * </ul>
  *
@@ -61,12 +66,22 @@ public abstract class RepeatUnion extends BiRel {
    */
   public final int iterationLimit;
 
+  /**
+   * Transient table where repeat union's intermediate results will be stored 
(optional).
+   */
+  protected final @Nullable RelOptTable transientTable;
+
   //~ Constructors -----------------------------------------------------------
   protected RepeatUnion(RelOptCluster cluster, RelTraitSet traitSet,
-      RelNode seed, RelNode iterative, boolean all, int iterationLimit) {
+      RelNode seed, RelNode iterative, boolean all, int iterationLimit,
+      @Nullable RelOptTable transientTable) {
     super(cluster, traitSet, seed, iterative);
     this.iterationLimit = iterationLimit;
     this.all = all;
+    this.transientTable = transientTable;
+    if (transientTable != null) {
+      Objects.requireNonNull(transientTable.unwrap(TransientTable.class));
+    }
   }
 
   @Override public double estimateRowCount(RelMetadataQuery mq) {
@@ -95,6 +110,10 @@ public abstract class RepeatUnion extends BiRel {
     return right;
   }
 
+  public @Nullable RelOptTable getTransientTable() {
+    return transientTable;
+  }
+
   @Override protected RelDataType deriveRowType() {
     final List<RelDataType> inputRowTypes =
         Util.transform(getInputs(), RelNode::getRowType);
diff --git 
a/core/src/main/java/org/apache/calcite/rel/logical/LogicalRepeatUnion.java 
b/core/src/main/java/org/apache/calcite/rel/logical/LogicalRepeatUnion.java
index 5084176..f3bface 100644
--- a/core/src/main/java/org/apache/calcite/rel/logical/LogicalRepeatUnion.java
+++ b/core/src/main/java/org/apache/calcite/rel/logical/LogicalRepeatUnion.java
@@ -19,10 +19,13 @@ package org.apache.calcite.rel.logical;
 import org.apache.calcite.linq4j.function.Experimental;
 import org.apache.calcite.plan.Convention;
 import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptTable;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.RepeatUnion;
 
+import org.checkerframework.checker.nullness.qual.Nullable;
+
 import java.util.List;
 
 /**
@@ -37,22 +40,24 @@ public class LogicalRepeatUnion extends RepeatUnion {
 
   //~ Constructors -----------------------------------------------------------
   private LogicalRepeatUnion(RelOptCluster cluster, RelTraitSet traitSet,
-      RelNode seed, RelNode iterative, boolean all, int iterationLimit) {
-    super(cluster, traitSet, seed, iterative, all, iterationLimit);
+      RelNode seed, RelNode iterative, boolean all, int iterationLimit,
+      @Nullable RelOptTable transientTable) {
+    super(cluster, traitSet, seed, iterative, all, iterationLimit, 
transientTable);
   }
 
   /** Creates a LogicalRepeatUnion. */
   public static LogicalRepeatUnion create(RelNode seed, RelNode iterative,
-      boolean all) {
-    return create(seed, iterative, all, -1);
+      boolean all, @Nullable RelOptTable transientTable) {
+    return create(seed, iterative, all, -1, transientTable);
   }
 
   /** Creates a LogicalRepeatUnion. */
   public static LogicalRepeatUnion create(RelNode seed, RelNode iterative,
-      boolean all, int iterationLimit) {
+      boolean all, int iterationLimit, @Nullable RelOptTable transientTable) {
     RelOptCluster cluster = seed.getCluster();
     RelTraitSet traitSet = cluster.traitSetOf(Convention.NONE);
-    return new LogicalRepeatUnion(cluster, traitSet, seed, iterative, all, 
iterationLimit);
+    return new LogicalRepeatUnion(cluster, traitSet, seed, iterative, all, 
iterationLimit,
+        transientTable);
   }
 
   //~ Methods ----------------------------------------------------------------
@@ -62,6 +67,6 @@ public class LogicalRepeatUnion extends RepeatUnion {
     assert traitSet.containsIfApplicable(Convention.NONE);
     assert inputs.size() == 2;
     return new LogicalRepeatUnion(getCluster(), traitSet,
-        inputs.get(0), inputs.get(1), all, iterationLimit);
+        inputs.get(0), inputs.get(1), all, iterationLimit, transientTable);
   }
 }
diff --git a/core/src/main/java/org/apache/calcite/schema/SchemaPlus.java 
b/core/src/main/java/org/apache/calcite/schema/SchemaPlus.java
index 7052603..5f41788 100644
--- a/core/src/main/java/org/apache/calcite/schema/SchemaPlus.java
+++ b/core/src/main/java/org/apache/calcite/schema/SchemaPlus.java
@@ -68,6 +68,13 @@ public interface SchemaPlus extends Schema {
   /** Adds a table to this schema. */
   void add(String name, Table table);
 
+  /** Removes a table from this schema, used e.g. to clean-up temporary 
tables. */
+  default boolean removeTable(String name) {
+    // Default implementation provided for backwards compatibility, to be 
removed before 2.0
+    return false;
+  }
+
+
   /** Adds a function to this schema. */
   void add(String name, Function function);
 
diff --git 
a/core/src/main/java/org/apache/calcite/schema/impl/ListTransientTable.java 
b/core/src/main/java/org/apache/calcite/schema/impl/ListTransientTable.java
index 0ee6ec4..653a66f 100644
--- a/core/src/main/java/org/apache/calcite/schema/impl/ListTransientTable.java
+++ b/core/src/main/java/org/apache/calcite/schema/impl/ListTransientTable.java
@@ -49,8 +49,6 @@ import java.util.Collection;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicBoolean;
 
-import static java.util.Objects.requireNonNull;
-
 /**
  * {@link TransientTable} backed by a Java list. It will be automatically 
added to the
  * current schema when {@link #scan(DataContext)} method gets called.
@@ -61,7 +59,9 @@ import static java.util.Objects.requireNonNull;
 public class ListTransientTable extends AbstractQueryableTable
     implements TransientTable, ModifiableTable, ScannableTable {
   private static final Type TYPE = Object[].class;
+  @SuppressWarnings("rawtypes")
   private final List rows = new ArrayList();
+  @SuppressWarnings({"unused", "FieldCanBeLocal"})
   private final String name;
   private final RelDataType protoRowType;
 
@@ -84,20 +84,19 @@ public class ListTransientTable extends 
AbstractQueryableTable
         updateColumnList, sourceExpressionList, flattened);
   }
 
+  @SuppressWarnings("rawtypes")
   @Override public Collection getModifiableCollection() {
     return rows;
   }
 
   @Override public Enumerable<@Nullable Object[]> scan(DataContext root) {
-    // add the table into the schema, so that it is accessible by any 
potential operator
-    requireNonNull(root.getRootSchema(), "root.getRootSchema()")
-        .add(name, this);
 
     final AtomicBoolean cancelFlag = 
DataContext.Variable.CANCEL_FLAG.get(root);
 
     return new AbstractEnumerable<@Nullable Object[]>() {
       @Override public Enumerator<@Nullable Object[]> enumerator() {
         return new Enumerator<@Nullable Object[]>() {
+          @SuppressWarnings({"rawtypes", "unchecked"})
           private final List list = new ArrayList(rows);
           private int i = -1;
 
@@ -129,7 +128,7 @@ public class ListTransientTable extends 
AbstractQueryableTable
   }
 
   @Override public Expression getExpression(SchemaPlus schema, String 
tableName,
-                                  Class clazz) {
+      @SuppressWarnings("rawtypes") Class clazz) {
     return Schemas.tableExpression(schema, elementType, tableName, clazz);
   }
 
diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java 
b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index 9c178c0..1bcacf3 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -2709,7 +2709,7 @@ public class RelBuilder {
     RelNode seed = tableSpool(Spool.Type.LAZY, Spool.Type.LAZY, 
finder.relOptTable).build();
     RelNode repeatUnion =
         struct.repeatUnionFactory.createRepeatUnion(seed, iterative, all,
-            iterationLimit);
+            iterationLimit, finder.relOptTable);
     return push(repeatUnion);
   }
 
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java 
b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index 510baf1..e1089a8 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -106,6 +106,7 @@ import org.apache.calcite.schema.ScannableTable;
 import org.apache.calcite.schema.Schema;
 import org.apache.calcite.schema.SchemaPlus;
 import org.apache.calcite.schema.Schemas;
+import org.apache.calcite.schema.Table;
 import org.apache.calcite.sql.SqlExplainLevel;
 import org.apache.calcite.sql.SqlJsonConstructorNullClause;
 import org.apache.calcite.sql.SqlJsonQueryEmptyOrErrorBehavior;
@@ -156,6 +157,8 @@ public enum BuiltInMethod {
   REMOVE_ALL(ExtendedEnumerable.class, "removeAll", Collection.class),
   SCHEMA_GET_SUB_SCHEMA(Schema.class, "getSubSchema", String.class),
   SCHEMA_GET_TABLE(Schema.class, "getTable", String.class),
+  SCHEMA_PLUS_ADD_TABLE(SchemaPlus.class, "add", String.class, Table.class),
+  SCHEMA_PLUS_REMOVE_TABLE(SchemaPlus.class, "removeTable", String.class),
   SCHEMA_PLUS_UNWRAP(SchemaPlus.class, "unwrap", Class.class),
   SCHEMAS_ENUMERABLE_SCANNABLE(Schemas.class, "enumerable",
       ScannableTable.class, DataContext.class),
@@ -240,7 +243,7 @@ public enum BuiltInMethod {
   UNION(ExtendedEnumerable.class, "union", Enumerable.class),
   CONCAT(ExtendedEnumerable.class, "concat", Enumerable.class),
   REPEAT_UNION(EnumerableDefaults.class, "repeatUnion", Enumerable.class,
-      Enumerable.class, int.class, boolean.class, EqualityComparer.class),
+      Enumerable.class, int.class, boolean.class, EqualityComparer.class, 
Function0.class),
   MERGE_UNION(EnumerableDefaults.class, "mergeUnion", List.class, 
Function1.class,
       Comparator.class, boolean.class, EqualityComparer.class),
   LAZY_COLLECTION_SPOOL(EnumerableDefaults.class, "lazyCollectionSpool", 
Collection.class,
diff --git 
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java
 
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java
index 733ba65..f4a3878 100644
--- 
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java
+++ 
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java
@@ -17,12 +17,24 @@
 package org.apache.calcite.test.enumerable;
 
 import org.apache.calcite.adapter.enumerable.EnumerableRepeatUnion;
+import org.apache.calcite.adapter.enumerable.EnumerableRules;
+import org.apache.calcite.adapter.java.ReflectiveSchema;
+import org.apache.calcite.config.CalciteConnectionProperty;
+import org.apache.calcite.config.Lex;
+import org.apache.calcite.plan.RelOptPlanner;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.rules.JoinCommuteRule;
+import org.apache.calcite.rel.rules.JoinToCorrelateRule;
+import org.apache.calcite.runtime.Hook;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.test.CalciteAssert;
+import org.apache.calcite.test.schemata.hr.HierarchySchema;
 
 import org.junit.jupiter.api.Test;
 
 import java.util.Arrays;
+import java.util.function.Consumer;
 
 /**
  * Unit tests for {@link EnumerableRepeatUnion}.
@@ -227,4 +239,83 @@ class EnumerableRepeatUnionTest {
         .returnsOrdered("i=1", "i=2", "i=null", "i=3", "i=2", "i=3", "i=3");
   }
 
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-4054";>[CALCITE-4054]
+   * RepeatUnion containing a Correlate with a transientScan on its RHS causes 
NPE</a>. */
+  @Test void testRepeatUnionWithCorrelateWithTransientScanOnItsRight() {
+    CalciteAssert.that()
+        .with(CalciteConnectionProperty.LEX, Lex.JAVA)
+        .with(CalciteConnectionProperty.FORCE_DECORRELATE, false)
+        .withSchema("s", new ReflectiveSchema(new HierarchySchema()))
+        .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+          planner.addRule(JoinToCorrelateRule.Config.DEFAULT.toRule());
+          planner.removeRule(JoinCommuteRule.Config.DEFAULT.toRule());
+          planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE);
+          planner.removeRule(EnumerableRules.ENUMERABLE_JOIN_RULE);
+        })
+        .withRel(builder -> {
+          builder
+              //   WITH RECURSIVE delta(empid, name) as (
+              //     SELECT empid, name FROM emps WHERE empid = 2
+              //     UNION ALL
+              //     SELECT e.empid, e.name FROM delta d
+              //                            JOIN hierarchies h ON d.empid = 
h.managerid
+              //                            JOIN emps e        ON 
h.subordinateid = e.empid
+              //   )
+              //   SELECT empid, name FROM delta
+              .scan("s", "emps")
+              .filter(
+                  builder.equals(
+                      builder.field("empid"),
+                      builder.literal(2)))
+              .project(
+                  builder.field("emps", "empid"),
+                  builder.field("emps", "name"))
+
+              .transientScan("#DELTA#");
+          RelNode transientScan = builder.build(); // pop the transientScan to 
use it later
+
+          builder
+              .scan("s", "hierarchies")
+              .push(transientScan) // use the transientScan as right input of 
the join
+              .join(
+                  JoinRelType.INNER,
+                  builder.equals(
+                      builder.field(2, "#DELTA#", "empid"),
+                      builder.field(2, "hierarchies", "managerid")))
+
+              .scan("s", "emps")
+              .join(
+                  JoinRelType.INNER,
+                  builder.equals(
+                      builder.field(2, "hierarchies", "subordinateid"),
+                      builder.field(2, "emps", "empid")))
+              .project(
+                  builder.field("emps", "empid"),
+                  builder.field("emps", "name"))
+              .repeatUnion("#DELTA#", true);
+          return builder.build();
+        })
+        .explainHookMatches(""
+            + "EnumerableRepeatUnion(all=[true])\n"
+            + "  EnumerableTableSpool(readType=[LAZY], writeType=[LAZY], 
table=[[#DELTA#]])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], expr#5=[2], 
expr#6=[=($t0, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n"
+            + "  EnumerableTableSpool(readType=[LAZY], writeType=[LAZY], 
table=[[#DELTA#]])\n"
+            + "    EnumerableCalc(expr#0..8=[{inputs}], empid=[$t4], 
name=[$t6])\n"
+            + "      EnumerableCorrelate(correlation=[$cor1], 
joinType=[inner], requiredColumns=[{1}])\n"
+            // It is important to have EnumerableCorrelate + #DELTA# table 
scan on its right
+            // to reproduce the issue CALCITE-4054
+            + "        EnumerableCorrelate(correlation=[$cor0], 
joinType=[inner], requiredColumns=[{0}])\n"
+            + "          EnumerableTableScan(table=[[s, hierarchies]])\n"
+            + "          EnumerableCalc(expr#0..1=[{inputs}], expr#2=[$cor0], 
expr#3=[$t2.managerid], expr#4=[=($t0, $t3)], proj#0..1=[{exprs}], 
$condition=[$t4])\n"
+            + "            EnumerableInterpreter\n"
+            + "              BindableTableScan(table=[[#DELTA#]])\n"
+            + "        EnumerableCalc(expr#0..4=[{inputs}], expr#5=[$cor1], 
expr#6=[$t5.subordinateid], expr#7=[=($t6, $t0)], proj#0..4=[{exprs}], 
$condition=[$t7])\n"
+            + "          EnumerableTableScan(table=[[s, emps]])\n")
+        .returnsUnordered(""
+            + "empid=2; name=Emp2\n"
+            + "empid=3; name=Emp3\n"
+            + "empid=5; name=Emp5");
+  }
 }
diff --git 
a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java 
b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
index 4f8271d..7febd11 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
@@ -4519,6 +4519,7 @@ public abstract class EnumerableDefaults {
    * @param all whether duplicates will be considered or not
    * @param comparer {@link EqualityComparer} to control duplicates,
    *                 only used if {@code all} is {@code false}
+   * @param cleanUpFunction optional clean-up actions (e.g. delete temporary 
table)
    * @param <TSource> record type
    */
   @SuppressWarnings("unchecked")
@@ -4527,7 +4528,8 @@ public abstract class EnumerableDefaults {
       Enumerable<TSource> iteration,
       int iterationLimit,
       boolean all,
-      EqualityComparer<TSource> comparer) {
+      EqualityComparer<TSource> comparer,
+      @Nullable Function0<Boolean> cleanUpFunction) {
     return new AbstractEnumerable<TSource>() {
       @Override public Enumerator<TSource> enumerator() {
         return new Enumerator<TSource>() {
@@ -4623,6 +4625,9 @@ public abstract class EnumerableDefaults {
           }
 
           @Override public void close() {
+            if (cleanUpFunction != null) {
+              cleanUpFunction.apply();
+            }
             seedEnumerator.close();
             if (iterativeEnumerator != null) {
               iterativeEnumerator.close();

Reply via email to