This is an automated email from the ASF dual-hosted git repository.
zhenchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 96e318b6fa [CALCITE-7403] Missing ENUMERABLE Convention for
LogicalConditionalCorrelate
96e318b6fa is described below
commit 96e318b6fab2722677f6264fb0bb10f2ffb32163
Author: Zhen Chen <[email protected]>
AuthorDate: Thu Jan 29 14:27:30 2026 +0800
[CALCITE-7403] Missing ENUMERABLE Convention for LogicalConditionalCorrelate
---
.../enumerable/EnumerableConditionalCorrelate.java | 214 +++++++++++++++++++++
.../EnumerableConditionalCorrelateRule.java | 57 ++++++
.../adapter/enumerable/EnumerableRules.java | 8 +
.../calcite/rel/core/ConditionalCorrelate.java | 2 +-
.../java/org/apache/calcite/tools/Programs.java | 1 +
.../org/apache/calcite/util/BuiltInMethod.java | 4 +
.../test/enumerable/EnumerableCorrelateTest.java | 172 +++++++++++++++++
core/src/test/resources/sql/new-decorr.iq | 24 +++
.../apache/calcite/linq4j/DefaultEnumerable.java | 7 +
.../apache/calcite/linq4j/EnumerableDefaults.java | 59 ++++--
.../apache/calcite/linq4j/ExtendedEnumerable.java | 13 ++
11 files changed, 548 insertions(+), 13 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableConditionalCorrelate.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableConditionalCorrelate.java
new file mode 100644
index 0000000000..02b51a6249
--- /dev/null
+++
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableConditionalCorrelate.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.adapter.enumerable;
+
+import org.apache.calcite.linq4j.tree.BlockBuilder;
+import org.apache.calcite.linq4j.tree.Expression;
+import org.apache.calcite.linq4j.tree.Expressions;
+import org.apache.calcite.linq4j.tree.ParameterExpression;
+import org.apache.calcite.linq4j.tree.Primitive;
+import org.apache.calcite.plan.DeriveMode;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelTraitSet;
+import org.apache.calcite.rel.RelCollationTraitDef;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.ConditionalCorrelate;
+import org.apache.calcite.rel.core.CorrelationId;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.metadata.RelMdCollation;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.util.BuiltInMethod;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Pair;
+
+import com.google.common.collect.ImmutableList;
+
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+import java.lang.reflect.Modifier;
+import java.lang.reflect.Type;
+import java.util.List;
+
+/** Implementation of {@link org.apache.calcite.rel.core.ConditionalCorrelate}
in
+ * {@link org.apache.calcite.adapter.enumerable.EnumerableConvention
enumerable calling convention}. */
+public class EnumerableConditionalCorrelate extends ConditionalCorrelate
+ implements EnumerableRel {
+
+ protected EnumerableConditionalCorrelate(
+ RelOptCluster cluster,
+ RelTraitSet traits,
+ RelNode left,
+ RelNode right,
+ CorrelationId correlationId,
+ ImmutableBitSet requiredColumns,
+ JoinRelType joinType,
+ RexNode condition) {
+ super(cluster, traits, ImmutableList.of(), left, right, correlationId,
+ requiredColumns, joinType, condition);
+ }
+
+ /** Creates an EnumerableConditionalCorrelate. */
+ public static EnumerableConditionalCorrelate create(
+ RelNode left,
+ RelNode right,
+ CorrelationId correlationId,
+ ImmutableBitSet requiredColumns,
+ JoinRelType joinType,
+ RexNode condition) {
+ final RelOptCluster cluster = left.getCluster();
+ final RelMetadataQuery mq = cluster.getMetadataQuery();
+ final RelTraitSet traitSet =
+ cluster.traitSetOf(EnumerableConvention.INSTANCE)
+ .replaceIfs(RelCollationTraitDef.INSTANCE,
+ () -> RelMdCollation.enumerableCorrelate(mq, left, right,
joinType));
+ return new EnumerableConditionalCorrelate(
+ cluster,
+ traitSet,
+ left,
+ right,
+ correlationId,
+ requiredColumns,
+ joinType,
+ condition);
+ }
+
+ @Override public EnumerableConditionalCorrelate copy(
+ RelTraitSet traitSet,
+ RelNode left,
+ RelNode right,
+ CorrelationId correlationId,
+ ImmutableBitSet requiredColumns,
+ JoinRelType joinType,
+ RexNode condition) {
+ return new EnumerableConditionalCorrelate(
+ getCluster(),
+ traitSet,
+ left,
+ right,
+ correlationId,
+ requiredColumns,
+ joinType,
+ condition);
+ }
+
+ @Override public EnumerableConditionalCorrelate copy(
+ RelTraitSet traitSet,
+ RelNode left,
+ RelNode right,
+ CorrelationId correlationId,
+ ImmutableBitSet requiredColumns,
+ JoinRelType joinType) {
+ // This method does not provide the condition as an argument, so it should
never be called
+ throw new RuntimeException("This method should not be called");
+ }
+
+ @Override public @Nullable Pair<RelTraitSet, List<RelTraitSet>>
passThroughTraits(
+ final RelTraitSet required) {
+ // EnumerableConditionalCorrelate traits passdown shall only pass through
+ // collation to left input. This is because for
EnumerableConditionalCorrelate
+ // always uses left input as the outer loop, thus only left input can
preserve ordering.
+ return EnumerableTraitsUtils.passThroughTraitsForJoin(
+ required, joinType, left.getRowType().getFieldCount(), getTraitSet());
+ }
+
+ @Override public @Nullable Pair<RelTraitSet, List<RelTraitSet>> deriveTraits(
+ final RelTraitSet childTraits, final int childId) {
+ // should only derive traits (limited to collation for now) from left
input.
+ return EnumerableTraitsUtils.deriveTraitsForJoin(
+ childTraits, childId, joinType, traitSet, right.getTraitSet());
+ }
+
+ @Override public DeriveMode getDeriveMode() {
+ return DeriveMode.LEFT_FIRST;
+ }
+
+ @Override public Result implement(EnumerableRelImplementor implementor,
+ Prefer pref) {
+ final BlockBuilder builder = new BlockBuilder();
+ final Result leftResult =
+ implementor.visitChild(this, 0, (EnumerableRel) left, pref);
+ Expression leftExpression =
+ builder.append(
+ "left", leftResult.block);
+
+ final BlockBuilder corrBlock = new BlockBuilder();
+ Type corrVarType = leftResult.physType.getJavaRowType();
+ ParameterExpression corrRef; // correlate to be used in inner loop
+ ParameterExpression corrArg; // argument to correlate lambda (must be
boxed)
+ if (!Primitive.is(corrVarType)) {
+ corrArg =
+ Expressions.parameter(Modifier.FINAL,
+ corrVarType, getCorrelVariable());
+ corrRef = corrArg;
+ } else {
+ corrArg =
+ Expressions.parameter(Modifier.FINAL,
+ Primitive.box(corrVarType), "$box" + getCorrelVariable());
+ corrRef =
+ (ParameterExpression) corrBlock.append(getCorrelVariable(),
+ Expressions.unbox(corrArg));
+ }
+
+ implementor.registerCorrelVariable(getCorrelVariable(), corrRef,
+ corrBlock, leftResult.physType);
+
+ final Result rightResult =
+ implementor.visitChild(this, 1, (EnumerableRel) right, pref);
+
+ implementor.clearCorrelVariable(getCorrelVariable());
+
+ // Generate the condition predicate
+ final Expression predicate =
+ EnumUtils.generatePredicate(
+ implementor,
+ getCluster().getRexBuilder(),
+ left,
+ right,
+ leftResult.physType,
+ rightResult.physType,
+ getCondition(),
+ true);
+
+ corrBlock.add(rightResult.block);
+
+ final PhysType physType =
+ PhysTypeImpl.of(
+ implementor.getTypeFactory(),
+ getRowType(),
+ pref.prefer(JavaRowFormat.CUSTOM));
+
+ if (joinType == JoinRelType.LEFT_MARK) {
+ // For LEFT_MARK join, use CORRELATE_LEFT_MARK_JOIN with predicate
+ Expression selector =
+ EnumUtils.markJoinSelector(physType, leftResult.physType);
+
+ builder.append(
+ Expressions.call(leftExpression,
BuiltInMethod.CORRELATE_LEFT_MARK_JOIN.method,
+ Expressions.lambda(corrBlock.toBlock(), corrArg),
+ predicate,
+ selector));
+ } else {
+ // TODO: Support other join types. Currently, ConditionalCorrelate is
only created
+ // when rewriting correlated IN/SOME/EXISTS subqueries, and its type is
always LEFT_MARK.
+ throw new UnsupportedOperationException(
+ "EnumerableConditionalCorrelate does not support join type: " +
joinType);
+ }
+
+ return implementor.result(physType, builder.toBlock());
+ }
+}
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableConditionalCorrelateRule.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableConditionalCorrelateRule.java
new file mode 100644
index 0000000000..4d1402c342
--- /dev/null
+++
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableConditionalCorrelateRule.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.adapter.enumerable;
+
+import org.apache.calcite.plan.Convention;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.convert.ConverterRule;
+import org.apache.calcite.rel.logical.LogicalConditionalCorrelate;
+
+import org.immutables.value.Value;
+
+/**
+ * Rule which converts a {@link LogicalConditionalCorrelate} into its
enumerable implementation,
+ * implementing conditional correlates via nested loops over enumerable inputs.
+ *
+ * @see EnumerableRules#ENUMERABLE_CONDITIONAL_CORRELATE_RULE
+ */
[email protected]
+public class EnumerableConditionalCorrelateRule extends ConverterRule {
+ /** Default configuration. */
+ public static final Config DEFAULT_CONFIG = Config.INSTANCE
+ .withConversion(LogicalConditionalCorrelate.class, r -> true,
Convention.NONE,
+ EnumerableConvention.INSTANCE, "EnumerableConditionalCorrelateRule")
+ .withRuleFactory(EnumerableConditionalCorrelateRule::new);
+
+ /** Creates an EnumerableConditionalCorrelateRule. */
+ protected EnumerableConditionalCorrelateRule(Config config) {
+ super(config);
+ }
+
+ @Override public RelNode convert(RelNode rel) {
+ final LogicalConditionalCorrelate c = (LogicalConditionalCorrelate) rel;
+ return EnumerableConditionalCorrelate.create(
+ convert(c.getLeft(), c.getLeft().getTraitSet()
+ .replace(EnumerableConvention.INSTANCE)),
+ convert(c.getRight(), c.getRight().getTraitSet()
+ .replace(EnumerableConvention.INSTANCE)),
+ c.getCorrelationId(),
+ c.getRequiredColumns(),
+ c.getJoinType(),
+ c.getCondition());
+ }
+}
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java
index f33997450c..a470d3b7ce 100644
---
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java
+++
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java
@@ -65,6 +65,13 @@ private EnumerableRules() {
EnumerableCorrelateRule.DEFAULT_CONFIG
.toRule(EnumerableCorrelateRule.class);
+ /** Rule that converts a
+ * {@link org.apache.calcite.rel.logical.LogicalConditionalCorrelate} to
+ * {@link EnumerableConvention enumerable calling convention}. */
+ public static final RelOptRule ENUMERABLE_CONDITIONAL_CORRELATE_RULE =
+ EnumerableConditionalCorrelateRule.DEFAULT_CONFIG
+ .toRule(EnumerableConditionalCorrelateRule.class);
+
/** Rule that converts a
* {@link org.apache.calcite.rel.logical.LogicalJoin} into an
* {@link
org.apache.calcite.adapter.enumerable.EnumerableBatchNestedLoopJoin}. */
@@ -219,6 +226,7 @@ private EnumerableRules() {
EnumerableRules.ENUMERABLE_ASOFJOIN_RULE,
EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE,
EnumerableRules.ENUMERABLE_CORRELATE_RULE,
+ EnumerableRules.ENUMERABLE_CONDITIONAL_CORRELATE_RULE,
EnumerableRules.ENUMERABLE_PROJECT_RULE,
EnumerableRules.ENUMERABLE_FILTER_RULE,
EnumerableRules.ENUMERABLE_CALC_RULE,
diff --git
a/core/src/main/java/org/apache/calcite/rel/core/ConditionalCorrelate.java
b/core/src/main/java/org/apache/calcite/rel/core/ConditionalCorrelate.java
index af203429ef..f5e1d38377 100644
--- a/core/src/main/java/org/apache/calcite/rel/core/ConditionalCorrelate.java
+++ b/core/src/main/java/org/apache/calcite/rel/core/ConditionalCorrelate.java
@@ -42,7 +42,7 @@
*/
public abstract class ConditionalCorrelate extends Correlate {
- private final RexNode condition;
+ protected final RexNode condition;
protected ConditionalCorrelate(
RelOptCluster cluster,
diff --git a/core/src/main/java/org/apache/calcite/tools/Programs.java
b/core/src/main/java/org/apache/calcite/tools/Programs.java
index 4974db3c20..83f6834d28 100644
--- a/core/src/main/java/org/apache/calcite/tools/Programs.java
+++ b/core/src/main/java/org/apache/calcite/tools/Programs.java
@@ -82,6 +82,7 @@ public class Programs {
EnumerableRules.ENUMERABLE_JOIN_RULE,
EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE,
EnumerableRules.ENUMERABLE_CORRELATE_RULE,
+ EnumerableRules.ENUMERABLE_CONDITIONAL_CORRELATE_RULE,
EnumerableRules.ENUMERABLE_PROJECT_RULE,
EnumerableRules.ENUMERABLE_FILTER_RULE,
EnumerableRules.ENUMERABLE_AGGREGATE_RULE,
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index 6adf1aa394..465a89347e 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -263,6 +263,10 @@ public enum BuiltInMethod {
Enumerable.class, // inner enumerable
NullablePredicate2.class, // non-equi predicate that can return NULL
Function2.class), // result selector
+ CORRELATE_LEFT_MARK_JOIN(ExtendedEnumerable.class, "correlateLeftMarkJoin",
+ Function1.class, // function to generate inner enumerable from
correlate variable
+ NullablePredicate2.class, // non-equi predicate that can return NULL
+ Function2.class), // result selector
CORRELATE_JOIN(ExtendedEnumerable.class, "correlateJoin",
JoinType.class, Function1.class, Function2.class),
CORRELATE_BATCH_JOIN(EnumerableDefaults.class, "correlateBatchJoin",
diff --git
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableCorrelateTest.java
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableCorrelateTest.java
index 9ea76b84e8..72455d5abb 100644
---
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableCorrelateTest.java
+++
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableCorrelateTest.java
@@ -21,15 +21,23 @@
import org.apache.calcite.config.CalciteConnectionProperty;
import org.apache.calcite.config.Lex;
import org.apache.calcite.plan.RelOptPlanner;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.runtime.Hook;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.test.CalciteAssert;
import org.apache.calcite.test.ReflectiveSchemaWithoutRowCount;
import org.apache.calcite.test.schemata.hr.HrSchema;
+import org.apache.calcite.tools.Program;
+import org.apache.calcite.tools.Programs;
+import org.apache.calcite.util.Holder;
+
+import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
+import java.util.List;
import java.util.function.Consumer;
/**
@@ -296,6 +304,170 @@ class EnumerableCorrelateTest {
.returnsUnordered("empid=200; name=Eric");
}
+ private static Program getConditionalCorrelateProgram() {
+ Program subQuery =
+ Programs.hep(
+ ImmutableList.of(CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
+ CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE),
+ true,
+ DefaultRelMetadataProvider.INSTANCE);
+ Program toCalc =
+ Programs.hep(
+ ImmutableList.of(
+ CoreRules.PROJECT_TO_CALC,
+ CoreRules.FILTER_TO_CALC,
+ CoreRules.CALC_MERGE),
+ true,
+ DefaultRelMetadataProvider.INSTANCE);
+
+ final List<RelOptRule> enumerableRules =
+ ImmutableList.of(
+ EnumerableRules.ENUMERABLE_VALUES_RULE,
+ EnumerableRules.ENUMERABLE_CALC_RULE,
+ EnumerableRules.ENUMERABLE_UNCOLLECT_RULE,
+ EnumerableRules.ENUMERABLE_CONDITIONAL_CORRELATE_RULE);
+ Program enumerableImpl = Programs.ofRules(enumerableRules);
+ return Programs.sequence(subQuery, toCalc, enumerableImpl);
+ }
+
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-7403">[CALCITE-7403]
+ * Missing ENUMERABLE Convention for LogicalConditionalCorrelate</a>. */
+ @Test void testConditionalCorrelateForExists() {
+ // test for exists
+ tester(false, new HrSchema())
+ .query(
+ "WITH t1(id, val) AS (\n"
+ + " VALUES (1, 10), (2, 20), (NULL, 30)\n"
+ + "),\n"
+ + "t2(id, val) AS (\n"
+ + " VALUES (2, 15), (3, 25)\n"
+ + ")\n"
+ + "SELECT\n"
+ + " t1.id,\n"
+ + " EXISTS (\n"
+ + " SELECT 1\n"
+ + " FROM t2\n"
+ + " WHERE t2.id = t1.id\n"
+ + " AND t2.val > 10\n"
+ + " ) AS marker\n"
+ + "FROM t1")
+ .withHook(Hook.PROGRAM, (Consumer<Holder<Program>>) program -> {
+ program.set(getConditionalCorrelateProgram());
+ })
+ .explainHookMatches(""
+ + "EnumerableCalc(expr#0..2=[{inputs}], id=[$t0], marker=[$t2])\n"
+ + " EnumerableConditionalCorrelate(correlation=[$cor0],
joinType=[left_mark], requiredColumns=[{0}])\n"
+ + " EnumerableValues(tuples=[[{ 1, 10 }, { 2, 20 }, { null, 30
}]])\n"
+ + " EnumerableCalc(expr#0..1=[{inputs}], expr#2=[$cor0],
expr#3=[$t2.id], expr#4=[=($t0, $t3)], expr#5=[10], expr#6=[>($t1, $t5)],
expr#7=[AND($t4, $t6)], proj#0..1=[{exprs}], $condition=[$t7])\n"
+ + " EnumerableValues(tuples=[[{ 2, 15 }, { 3, 25 }]])\n")
+ .returnsUnordered(
+ "id=1; marker=false",
+ "id=2; marker=true",
+ "id=null; marker=false");
+
+ // test for not exists
+ tester(false, new HrSchema())
+ .query(
+ "WITH t1(id, val) AS (\n"
+ + " VALUES (1, 10), (2, 20), (NULL, 30)\n"
+ + "),\n"
+ + "t2(id, val) AS (\n"
+ + " VALUES (2, 15), (3, 25)\n"
+ + ")\n"
+ + "SELECT\n"
+ + " t1.id,\n"
+ + " NOT EXISTS (\n"
+ + " SELECT 1\n"
+ + " FROM t2\n"
+ + " WHERE t2.id = t1.id\n"
+ + " AND t2.val > 10\n"
+ + " ) AS marker\n"
+ + "FROM t1")
+ .withHook(Hook.PROGRAM, (Consumer<Holder<Program>>) program -> {
+ program.set(getConditionalCorrelateProgram());
+ })
+ .explainHookMatches(""
+ + "EnumerableCalc(expr#0..2=[{inputs}], expr#3=[NOT($t2)],
id=[$t0], marker=[$t3])\n"
+ + " EnumerableConditionalCorrelate(correlation=[$cor0],
joinType=[left_mark], requiredColumns=[{0}])\n"
+ + " EnumerableValues(tuples=[[{ 1, 10 }, { 2, 20 }, { null, 30
}]])\n"
+ + " EnumerableCalc(expr#0..1=[{inputs}], expr#2=[$cor0],
expr#3=[$t2.id], expr#4=[=($t0, $t3)], expr#5=[10], expr#6=[>($t1, $t5)],
expr#7=[AND($t4, $t6)], proj#0..1=[{exprs}], $condition=[$t7])\n"
+ + " EnumerableValues(tuples=[[{ 2, 15 }, { 3, 25 }]])\n")
+ .returnsUnordered(
+ "id=1; marker=true",
+ "id=2; marker=false",
+ "id=null; marker=true");
+ }
+
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-7403">[CALCITE-7403]
+ * Missing ENUMERABLE Convention for LogicalConditionalCorrelate</a>. */
+ @Test void testConditionalCorrelateForIn() {
+ // test in
+ tester(false, new HrSchema())
+ .query(
+ "WITH t1(id, val) AS (\n"
+ + " VALUES (1, 10), (2, 20), (NULL, 30)\n"
+ + "),\n"
+ + "t2(id, val) AS (\n"
+ + " VALUES (2, 15), (3, 25)\n"
+ + ")\n"
+ + "SELECT\n"
+ + " t1.id,\n"
+ + " t1.id IN (\n"
+ + " SELECT t2.id\n"
+ + " FROM t2\n"
+ + " WHERE t2.id = t1.id\n"
+ + " AND t2.val > 10\n"
+ + " ) AS marker\n"
+ + "FROM t1")
+ .withHook(Hook.PROGRAM, (Consumer<Holder<Program>>) program -> {
+ program.set(getConditionalCorrelateProgram());
+ })
+ .explainHookMatches(""
+ + "EnumerableCalc(expr#0..2=[{inputs}], id=[$t0], marker=[$t2])\n"
+ + " EnumerableConditionalCorrelate(correlation=[$cor0],
joinType=[left_mark], requiredColumns=[{0}], condition=[=($0, $2)])\n"
+ + " EnumerableValues(tuples=[[{ 1, 10 }, { 2, 20 }, { null, 30
}]])\n"
+ + " EnumerableCalc(expr#0..1=[{inputs}], expr#2=[$cor0],
expr#3=[$t2.id], expr#4=[=($t0, $t3)], expr#5=[10], expr#6=[>($t1, $t5)],
expr#7=[AND($t4, $t6)], id=[$t0], $condition=[$t7])\n"
+ + " EnumerableValues(tuples=[[{ 2, 15 }, { 3, 25 }]])\n")
+ .returnsUnordered(
+ "id=1; marker=false",
+ "id=2; marker=true",
+ "id=null; marker=false");
+
+ // test not in
+ tester(false, new HrSchema())
+ .query(
+ "WITH t1(id, val) AS (\n"
+ + " VALUES (1, 10), (2, 20), (NULL, 30)\n"
+ + "),\n"
+ + "t2(id, val) AS (\n"
+ + " VALUES (2, 15), (3, 25)\n"
+ + ")\n"
+ + "SELECT\n"
+ + " t1.id,\n"
+ + " t1.id NOT IN (\n"
+ + " SELECT t2.id\n"
+ + " FROM t2\n"
+ + " WHERE t2.id = t1.id\n"
+ + " AND t2.val > 10\n"
+ + " ) AS marker\n"
+ + "FROM t1")
+ .withHook(Hook.PROGRAM, (Consumer<Holder<Program>>) program -> {
+ program.set(getConditionalCorrelateProgram());
+ })
+ .explainHookMatches(""
+ + "EnumerableCalc(expr#0..2=[{inputs}], expr#3=[NOT($t2)],
id=[$t0], marker=[$t3])\n"
+ + " EnumerableConditionalCorrelate(correlation=[$cor0],
joinType=[left_mark], requiredColumns=[{0}], condition=[=($0, $2)])\n"
+ + " EnumerableValues(tuples=[[{ 1, 10 }, { 2, 20 }, { null, 30
}]])\n"
+ + " EnumerableCalc(expr#0..1=[{inputs}], expr#2=[$cor0],
expr#3=[$t2.id], expr#4=[=($t0, $t3)], expr#5=[10], expr#6=[>($t1, $t5)],
expr#7=[AND($t4, $t6)], id=[$t0], $condition=[$t7])\n"
+ + " EnumerableValues(tuples=[[{ 2, 15 }, { 3, 25 }]])\n")
+ .returnsUnordered(
+ "id=1; marker=true",
+ "id=2; marker=false",
+ "id=null; marker=true");
+ }
+
private CalciteAssert.AssertThat tester(boolean forceDecorrelate,
Object schema) {
return CalciteAssert.that()
diff --git a/core/src/test/resources/sql/new-decorr.iq
b/core/src/test/resources/sql/new-decorr.iq
index 9677fae713..8f05bbac2a 100644
--- a/core/src/test/resources/sql/new-decorr.iq
+++ b/core/src/test/resources/sql/new-decorr.iq
@@ -205,6 +205,30 @@ EnumerableCalc(expr#0..3=[{inputs}], DEPTNO=[$t0],
EXPR$0=[$t2])
!plan
!}
+# [CALCITE-7403] Missing ENUMERABLE Convention for LogicalConditionalCorrelate
+# This case comes from some.iq [CALCITE-6786]
+WITH tb as (select array(SELECT * FROM (VALUES (TRUE), (NULL)) as x(a)) as a)
+SELECT TRUE IN (SELECT b FROM UNNEST(a) AS x1(b)) AS test FROM tb;
++------+
+| TEST |
++------+
+| true |
++------+
+(1 row)
+
+!ok
+
+!if (use_new_decorr) {
+EnumerableCalc(expr#0..1=[{inputs}], TEST=[$t1])
+ EnumerableConditionalCorrelate(correlation=[$cor0], joinType=[left_mark],
requiredColumns=[{0}], condition=[$1])
+ EnumerableCollect(field=[x])
+ EnumerableValues(tuples=[[{ true }, { null }]])
+ EnumerableUncollect
+ EnumerableCalc(expr#0=[{inputs}], expr#1=[$cor0], expr#2=[$t1.A],
A=[$t2])
+ EnumerableValues(tuples=[[{ 0 }]])
+!plan
+!}
+
# [CALCITE-7396] PruneEmptyRules does not support LEFT_MARK JOIN
# This case comes from sub-query.iq
!use post
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
index 8a45548d31..d859519b41 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
@@ -455,6 +455,13 @@ protected OrderedQueryable<T> asOrderedQueryable() {
return EnumerableDefaults.leftMarkNestedLoopJoin(getThis(), inner,
predicate, resultSelector);
}
+ @Override public <TInner, TResult> Enumerable<TResult> correlateLeftMarkJoin(
+ Function1<T, Enumerable<TInner>> inner,
+ NullablePredicate2<T, TInner> predicate,
+ Function2<T, @Nullable Boolean, TResult> resultSelector) {
+ return EnumerableDefaults.correlateLeftMarkJoin(getThis(), inner,
predicate, resultSelector);
+ }
+
@Override public <TInner, TResult> Enumerable<TResult> correlateJoin(
JoinType joinType, Function1<T, Enumerable<TInner>> inner,
Function2<T, TInner, TResult> resultSelector) {
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
index 28f4a1185e..988450df36 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
@@ -1967,14 +1967,18 @@ static <TSource, TInner, TKey, TNsKey, TResult>
Enumerable<TResult> leftMarkHash
/**
* The implementation of left mark join based on nested loop.
+ * This is a unified implementation supporting both correlated and
non-correlated cases.
*
* @param outer Left input
- * @param inner Right input
+ * @param innerProvider Function that provides inner enumerable for each
outer row
+ * (for correlated join, pass outerRow ->
correlatedInner)
+ * (for non-correlated join, pass outerRow ->
staticInner)
* @param predicate Non-equi predicate that can return NULL
* @param resultSelector Function that concats the row of left input and
marker
*/
- public static <TSource, TInner, TResult> Enumerable<TResult>
leftMarkNestedLoopJoin(
- final Enumerable<TSource> outer, final Enumerable<TInner> inner,
+ private static <TSource, TInner, TResult> Enumerable<TResult>
leftMarkJoinInternal(
+ final Enumerable<TSource> outer,
+ final Function1<TSource, Enumerable<TInner>> innerProvider,
final NullablePredicate2<TSource, TInner> predicate,
final Function2<TSource, @Nullable Boolean, TResult> resultSelector) {
return new AbstractEnumerable<TResult>() {
@@ -1993,15 +1997,18 @@ public static <TSource, TInner, TResult>
Enumerable<TResult> leftMarkNestedLoopJ
}
marker = false;
final TSource outerRow = outers.current();
- try (Enumerator<TInner> inners = inner.enumerator()) {
- while (inners.moveNext()) {
- final TInner innerRow = inners.current();
- Boolean predicateMatched = predicate.apply(outerRow, innerRow);
- if (predicateMatched == null) {
- marker = null;
- } else if (predicateMatched) {
- marker = true;
- break;
+ Enumerable<TInner> innerEnumerable = innerProvider.apply(outerRow);
+ if (innerEnumerable != null) {
+ try (Enumerator<TInner> inners = innerEnumerable.enumerator()) {
+ while (inners.moveNext()) {
+ final TInner innerRow = inners.current();
+ Boolean predicateMatched = predicate.apply(outerRow,
innerRow);
+ if (predicateMatched == null) {
+ marker = null;
+ } else if (predicateMatched) {
+ marker = true;
+ break;
+ }
}
}
}
@@ -2020,6 +2027,34 @@ public static <TSource, TInner, TResult>
Enumerable<TResult> leftMarkNestedLoopJ
};
}
+ /**
+ * For each row of the {@code outer} enumerable returns correlated rows
+ * from the inner enumerable generated for each outer row, filtered by a
predicate
+ * (correlated LEFT_MARK join).
+ */
+ public static <TSource, TInner, TResult> Enumerable<TResult>
correlateLeftMarkJoin(
+ final Enumerable<TSource> outer,
+ final Function1<TSource, Enumerable<TInner>> inner,
+ final NullablePredicate2<TSource, TInner> predicate,
+ final Function2<TSource, @Nullable Boolean, TResult> resultSelector) {
+ return leftMarkJoinInternal(outer, inner, predicate, resultSelector);
+ }
+
+ /**
+ * The implementation of left mark join based on nested loop.
+ *
+ * @param outer Left input
+ * @param inner Right input
+ * @param predicate Non-equi predicate that can return NULL
+ * @param resultSelector Function that concats the row of left input and
marker
+ */
+ public static <TSource, TInner, TResult> Enumerable<TResult>
leftMarkNestedLoopJoin(
+ final Enumerable<TSource> outer, final Enumerable<TInner> inner,
+ final NullablePredicate2<TSource, TInner> predicate,
+ final Function2<TSource, @Nullable Boolean, TResult> resultSelector) {
+ return leftMarkJoinInternal(outer, ignored -> inner, predicate,
resultSelector);
+ }
+
/**
* For each row of the {@code outer} enumerable returns the correlated rows
* from the {@code inner} enumerable.
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
index 982ab0ca85..160f2afa0b 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
@@ -699,6 +699,19 @@ <TInner, TResult> Enumerable<TResult>
leftMarkNestedLoopJoin(Enumerable<TInner>
NullablePredicate2<TSource, TInner> predicate,
Function2<TSource, @Nullable Boolean, TResult> resultSelector);
+ /**
+ * For each row of the current enumerable returns correlated rows where each
row
+ * from the inner enumerable satisfies the predicate (correlated LEFT_MARK
join).
+ *
+ * @param inner function to generate inner enumerable from outer row
+ * @param predicate predicate that can return NULL
+ * @param resultSelector selector of the result, receives outer row and
boolean marker
+ */
+ <TInner, TResult> Enumerable<TResult> correlateLeftMarkJoin(
+ Function1<TSource, Enumerable<TInner>> inner,
+ NullablePredicate2<TSource, TInner> predicate,
+ Function2<TSource, @Nullable Boolean, TResult> resultSelector);
+
/**
* For each row of the current enumerable returns the correlated rows
* from the {@code inner} enumerable (nested loops join).