This is an automated email from the ASF dual-hosted git repository.
silun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 77d561fe5e [CALCITE-7315] Support LEFT_MARK type for hash join in
enumerable convention
77d561fe5e is described below
commit 77d561fe5e18794e1c4ae271ec7fa7eb3fcbe65e
Author: Silun Dong <[email protected]>
AuthorDate: Tue Jan 6 09:13:38 2026 +0800
[CALCITE-7315] Support LEFT_MARK type for hash join in enumerable convention
---
.../calcite/adapter/enumerable/EnumUtils.java | 69 +++-
.../adapter/enumerable/EnumerableHashJoin.java | 100 ++++++
.../adapter/enumerable/EnumerableJoinRule.java | 8 +-
.../adapter/enumerable/RexToLixTranslator.java | 9 +-
.../calcite/rel/core/ConditionalCorrelate.java | 10 +
.../org/apache/calcite/rel/core/Correlate.java | 1 -
.../java/org/apache/calcite/rel/core/Join.java | 11 +
.../calcite/sql/validate/SqlValidatorUtil.java | 45 ++-
.../org/apache/calcite/util/BuiltInMethod.java | 12 +
.../test/enumerable/EnumerableHashJoinTest.java | 141 ++++++++
core/src/test/resources/sql/blank.iq | 16 +-
.../apache/calcite/linq4j/DefaultEnumerable.java | 19 ++
.../apache/calcite/linq4j/EnumerableDefaults.java | 372 ++++++++++++++++++++-
.../apache/calcite/linq4j/ExtendedEnumerable.java | 40 +++
.../java/org/apache/calcite/linq4j/JoinType.java | 27 +-
.../linq4j/function/NullablePredicate2.java | 27 ++
16 files changed, 857 insertions(+), 50 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumUtils.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumUtils.java
index 93db65e228..8787e682cc 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumUtils.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumUtils.java
@@ -26,6 +26,7 @@
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.linq4j.function.Function1;
import org.apache.calcite.linq4j.function.Function2;
+import org.apache.calcite.linq4j.function.NullablePredicate2;
import org.apache.calcite.linq4j.function.Predicate2;
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.BlockStatement;
@@ -875,6 +876,8 @@ static JoinType toLinq4jJoinType(JoinRelType joinRelType) {
return JoinType.ASOF;
case LEFT_ASOF:
return JoinType.LEFT_ASOF;
+ case LEFT_MARK:
+ return JoinType.LEFT_MARK;
default:
break;
}
@@ -882,7 +885,47 @@ static JoinType toLinq4jJoinType(JoinRelType joinRelType) {
"Unable to convert " + joinRelType + " to Linq4j JoinType");
}
- /** Returns a predicate expression based on a join condition. */
+ /**
+ * Return the result selector of a mark join. It is a Expression that will
generate a Function2 in
+ * runtime, the Function2 will concat the left/right side row and the marker.
+ *
+ * <p> For example:
+ *
+ * <blockquote><pre>
+ * new Function2<Object[], Boolean, Object[]>() {
+ * public Object[] apply(Object[] input, Boolean marker) {
+ * return new Object[] {
+ * input[0], input[1], ..., input[n], marker
+ * };
+ * }
+ * }</pre></blockquote>
+ *
+ * @param resultPhysType Physical type of result
+ * @param inputPhysType Physical type of lhs/rhs
+ * @return the result selector of a mark join
+ */
+ static Expression markJoinSelector(PhysType resultPhysType, PhysType
inputPhysType) {
+ final List<ParameterExpression> parameters = new ArrayList<>();
+ final ParameterExpression inputParameter =
+ Expressions.parameter(Primitive.box(inputPhysType.getJavaRowType()),
"input");
+ final ParameterExpression markerParameter
+ = Expressions.parameter(Boolean.class, "marker");
+ parameters.add(inputParameter);
+ parameters.add(markerParameter);
+
+ final List<Expression> expressions = new ArrayList<>();
+ final int inputFieldCount = inputPhysType.getRowType().getFieldCount();
+ for (int i = 0; i < inputFieldCount; i++) {
+ Expression expression = inputPhysType.fieldReference(inputParameter, i);
+ expressions.add(expression);
+ }
+ expressions.add(markerParameter);
+ return Expressions.lambda(
+ Function2.class,
+ resultPhysType.record(expressions),
+ parameters);
+ }
+
static Expression generatePredicate(
EnumerableRelImplementor implementor,
RexBuilder rexBuilder,
@@ -891,6 +934,24 @@ static Expression generatePredicate(
PhysType leftPhysType,
PhysType rightPhysType,
RexNode condition) {
+ return generatePredicate(implementor, rexBuilder, left, right,
+ leftPhysType, rightPhysType, condition, false);
+ }
+
+ /**
+ * Returns a predicate expression based on a join condition. If one of the
arguments of the
+ * expression is NULL, when <code>nullable</code> is TRUE, the expression
will return NULL value;
+ * when <code>nullable</code> is FALSE, it will return FALSE.
+ */
+ static Expression generatePredicate(
+ EnumerableRelImplementor implementor,
+ RexBuilder rexBuilder,
+ RelNode left,
+ RelNode right,
+ PhysType leftPhysType,
+ PhysType rightPhysType,
+ RexNode condition,
+ boolean nullable) {
final BlockBuilder builder = new BlockBuilder();
final ParameterExpression left_ =
Expressions.parameter(leftPhysType.getJavaRowType(), "left");
@@ -913,8 +974,10 @@ static Expression generatePredicate(
ImmutableMap.of(left_, leftPhysType,
right_, rightPhysType)),
implementor.allCorrelateVariables,
- implementor.getConformance())));
- return Expressions.lambda(Predicate2.class, builder.toBlock(), left_,
right_);
+ implementor.getConformance(),
+ nullable)));
+ Class clazz = nullable ? NullablePredicate2.class : Predicate2.class;
+ return Expressions.lambda(clazz, builder.toBlock(), left_, right_);
}
/**
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
index e37173137a..6fc9a9ba3d 100644
---
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
+++
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
@@ -45,6 +45,7 @@
import org.checkerframework.checker.nullness.qual.Nullable;
import java.lang.reflect.Method;
+import java.util.ArrayList;
import java.util.List;
import java.util.Set;
@@ -173,11 +174,110 @@ public static EnumerableHashJoin create(
case SEMI:
case ANTI:
return implementHashSemiJoin(implementor, pref);
+ case LEFT_MARK:
+ return implementHashMarkJoin(implementor, pref);
default:
return implementHashJoin(implementor, pref);
}
}
+ private Result implementHashMarkJoin(EnumerableRelImplementor implementor,
Prefer pref) {
+ assert joinType == JoinRelType.LEFT_MARK;
+ BlockBuilder builder = new BlockBuilder();
+ final Result leftResult =
+ implementor.visitChild(this, 0, (EnumerableRel) left, pref);
+ Expression leftExpression =
+ builder.append(
+ "left", leftResult.block);
+ final Result rightResult =
+ implementor.visitChild(this, 1, (EnumerableRel) right, pref);
+ Expression rightExpression =
+ builder.append(
+ "right", rightResult.block);
+ final PhysType physType =
+ PhysTypeImpl.of(
+ implementor.getTypeFactory(), getRowType(), pref.preferArray());
+
+ // convert equi and non-equi conditions to Expression
+ Expression nonEquiPredicate = Expressions.constant(null);
+ if (!joinInfo.nonEquiConditions.isEmpty()) {
+ RexNode nonEquiCondition =
+ RexUtil.composeConjunction(getCluster().getRexBuilder(),
+ joinInfo.nonEquiConditions, true);
+ if (nonEquiCondition != null) {
+ // need three-valued boolean logic
+ nonEquiPredicate =
+ EnumUtils.generatePredicate(implementor,
+ getCluster().getRexBuilder(), left, right, leftResult.physType,
+ rightResult.physType, nonEquiCondition, true);
+ }
+ }
+ RexNode equiCondition = joinInfo.getEquiCondition(left, right,
getCluster().getRexBuilder());
+ // need three-valued boolean logic
+ final Expression equiPredicate =
+ EnumUtils.generatePredicate(implementor,
+ getCluster().getRexBuilder(), left, right, leftResult.physType,
+ rightResult.physType, equiCondition, true);
+
+ // create key selector and null-safe key selector
+ final Expression leftKeySelector =
+ leftResult.physType.generateNullAwareAccessor(
+ joinInfo.leftKeys, joinInfo.nullExclusionFlags);
+ final Expression rightKeySelector =
+ rightResult.physType.generateNullAwareAccessor(
+ joinInfo.rightKeys, joinInfo.nullExclusionFlags);
+
+ int notNullSafeKeyCount = 0;
+ List<Integer> leftNullSafeKeys = new ArrayList<>();
+ List<Integer> rightNullSafeKeys = new ArrayList<>();
+ for (int i = 0; i < joinInfo.nullExclusionFlags.size(); i++) {
+ if (joinInfo.nullExclusionFlags.get(i)) {
+ notNullSafeKeyCount++;
+ } else {
+ leftNullSafeKeys.add(joinInfo.leftKeys.get(i));
+ rightNullSafeKeys.add(joinInfo.rightKeys.get(i));
+ }
+ }
+ final Expression leftNullSafeKeySelector =
+ leftNullSafeKeys.isEmpty()
+ ? Expressions.constant(null)
+ :
leftResult.physType.generateAccessor(ImmutableIntList.copyOf(leftNullSafeKeys));
+ final Expression rightNullSafeKeySelector =
+ rightNullSafeKeys.isEmpty()
+ ? Expressions.constant(null)
+ :
rightResult.physType.generateAccessor(ImmutableIntList.copyOf(rightNullSafeKeys));
+ final boolean atMostOneNotNullSafeKey = notNullSafeKeyCount <= 1;
+
+ // create key comparator and null-safe key comparator
+ final PhysType nullSafeKeyPhysType =
+ leftResult.physType.project(leftNullSafeKeys, JavaRowFormat.LIST);
+ final Expression nullSafeKeyComparator =
+ Util.first(nullSafeKeyPhysType.comparer(), Expressions.constant(null));
+ final PhysType keyPhysType =
+ leftResult.physType.project(joinInfo.leftKeys, JavaRowFormat.LIST);
+ final Expression keyComparator =
+ Util.first(keyPhysType.comparer(), Expressions.constant(null));
+
+ return implementor.result(physType,
+ builder.append(
+ Expressions.call(
+ leftExpression,
+ BuiltInMethod.LEFT_MARK_HASH_JOIN.method,
+ Expressions.list(
+ rightExpression,
+ leftKeySelector,
+ rightKeySelector,
+ leftNullSafeKeySelector,
+ rightNullSafeKeySelector,
+ Expressions.constant(atMostOneNotNullSafeKey),
+ EnumUtils.markJoinSelector(physType,
leftResult.physType),
+ keyComparator,
+ nullSafeKeyComparator,
+ nonEquiPredicate,
+ equiPredicate)))
+ .toBlock());
+ }
+
private Result implementHashSemiJoin(EnumerableRelImplementor implementor,
Prefer pref) {
assert joinType == JoinRelType.SEMI || joinType == JoinRelType.ANTI;
final Method method = joinType == JoinRelType.SEMI
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableJoinRule.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableJoinRule.java
index fe40d17dc4..08a301da5a 100644
---
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableJoinRule.java
+++
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableJoinRule.java
@@ -54,10 +54,6 @@ protected EnumerableJoinRule(Config config) {
@Override public @Nullable RelNode convert(RelNode rel) {
Join join = (Join) rel;
- if (!Bug.TODO_FIXED && join.getJoinType() == JoinRelType.LEFT_MARK) {
- // TODO implement LEFT MARK join
- return null;
- }
List<RelNode> newInputs = new ArrayList<>();
for (RelNode input : join.getInputs()) {
if (!(input.getConvention() instanceof EnumerableConvention)) {
@@ -100,6 +96,10 @@ protected EnumerableJoinRule(Config config) {
join.getVariablesSet(),
join.getJoinType());
}
+ if (!Bug.TODO_FIXED && join.getJoinType() == JoinRelType.LEFT_MARK) {
+ // TODO Support LEFT MARK type for nested loop join
+ return null;
+ }
return EnumerableNestedLoopJoin.create(
left,
right,
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
index 7e5f1948ab..e7b8657bf5 100644
---
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
+++
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
@@ -1211,6 +1211,13 @@ private Expression translateTableFunction(RexCall
rexCall, Expression inputEnume
public static Expression translateCondition(RexProgram program,
JavaTypeFactory typeFactory, BlockBuilder list, InputGetter inputGetter,
Function1<String, InputGetter> correlates, SqlConformance conformance) {
+ return translateCondition(program, typeFactory, list, inputGetter,
+ correlates, conformance, false);
+ }
+
+ public static Expression translateCondition(RexProgram program,
+ JavaTypeFactory typeFactory, BlockBuilder list, InputGetter inputGetter,
+ Function1<String, InputGetter> correlates, SqlConformance conformance,
boolean nullable) {
RexLocalRef condition = program.getCondition();
if (condition == null) {
return RexImpTable.TRUE_EXPR;
@@ -1222,7 +1229,7 @@ public static Expression translateCondition(RexProgram
program,
translator = translator.setCorrelates(correlates);
return translator.translate(
condition,
- RexImpTable.NullAs.FALSE);
+ nullable ? RexImpTable.NullAs.NULL : RexImpTable.NullAs.FALSE);
}
/** Returns whether an expression is nullable.
diff --git
a/core/src/main/java/org/apache/calcite/rel/core/ConditionalCorrelate.java
b/core/src/main/java/org/apache/calcite/rel/core/ConditionalCorrelate.java
index a6527f23fd..af203429ef 100644
--- a/core/src/main/java/org/apache/calcite/rel/core/ConditionalCorrelate.java
+++ b/core/src/main/java/org/apache/calcite/rel/core/ConditionalCorrelate.java
@@ -22,9 +22,13 @@
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.rules.CoreRules;
+import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.util.ImmutableBitSet;
+import com.google.common.collect.ImmutableList;
+
import java.util.List;
/**
@@ -70,6 +74,12 @@ public abstract ConditionalCorrelate copy(RelTraitSet
traitSet, RelNode left, Re
.itemIf("condition", condition, !condition.isAlwaysTrue());
}
+ @Override protected RelDataType deriveRowType() {
+ assert joinType == JoinRelType.LEFT_MARK;
+ return SqlValidatorUtil.createMarkJoinType(getCluster().getTypeFactory(),
left.getRowType(),
+ condition.getType(), ImmutableList.of());
+ }
+
@Override public RexNode getCondition() {
return condition;
}
diff --git a/core/src/main/java/org/apache/calcite/rel/core/Correlate.java
b/core/src/main/java/org/apache/calcite/rel/core/Correlate.java
index e9d2adbccd..c6a994e922 100644
--- a/core/src/main/java/org/apache/calcite/rel/core/Correlate.java
+++ b/core/src/main/java/org/apache/calcite/rel/core/Correlate.java
@@ -170,7 +170,6 @@ public JoinRelType getJoinType() {
switch (joinType) {
case LEFT:
case INNER:
- case LEFT_MARK:
return SqlValidatorUtil.deriveJoinRowType(left.getRowType(),
right.getRowType(), joinType,
getCluster().getTypeFactory(), null,
diff --git a/core/src/main/java/org/apache/calcite/rel/core/Join.java
b/core/src/main/java/org/apache/calcite/rel/core/Join.java
index 1b81717aed..999c4639f9 100644
--- a/core/src/main/java/org/apache/calcite/rel/core/Join.java
+++ b/core/src/main/java/org/apache/calcite/rel/core/Join.java
@@ -190,6 +190,13 @@ public JoinRelType getJoinType() {
+ " failures in condition " + condition);
}
}
+ if (joinType == JoinRelType.LEFT_MARK
+ && joinInfo.nullExclusionFlags.contains(true)
+ && !joinInfo.nonEquiConditions.isEmpty()) {
+ return litmus.fail("Left mark join is produced by rewriting
IN/SOME/EXISTS "
+ + "subqueries, it will never contain both not null-safe join keys
and non-equi "
+ + "predicates.");
+ }
return litmus.succeed();
}
@@ -260,6 +267,10 @@ protected int deepHashCode0() {
}
@Override protected RelDataType deriveRowType() {
+ if (joinType == JoinRelType.LEFT_MARK) {
+ return
SqlValidatorUtil.createMarkJoinType(getCluster().getTypeFactory(),
left.getRowType(),
+ condition.getType(), getSystemFieldList());
+ }
return SqlValidatorUtil.deriveJoinRowType(left.getRowType(),
right.getRowType(), joinType, getCluster().getTypeFactory(), null,
getSystemFieldList());
diff --git
a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
index a01367057a..f05c58a8c7 100644
--- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
+++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
@@ -57,7 +57,6 @@
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
-import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Litmus;
@@ -557,15 +556,6 @@ public static RelDataType deriveJoinRowType(
case ANTI:
rightType = null;
break;
- case LEFT_MARK:
- final String markColName =
- SqlValidatorUtil.uniquify("markCol",
Sets.newHashSet(leftType.getFieldNames()),
- SqlValidatorUtil.EXPR_SUGGESTER);
- rightType =
- typeFactory.createStructType(
- ImmutableList.of(typeFactory.createSqlType(SqlTypeName.BOOLEAN)),
- ImmutableList.of(markColName));
- break;
default:
break;
}
@@ -627,6 +617,41 @@ public static RelDataType createJoinType(
return typeFactory.createStructType(typeList, nameList);
}
+ /**
+ * Returns the type of the result collection produced by a mark join. Taking
LEFT_MARK join as an
+ * example, its output is all rows from the left side and creates a new
attribute to mark a tuple
+ * as having join partners from right side or not.
+ *
+ * @param typeFactory Type factory
+ * @param inputType Type of lhs/rhs of the mark join
+ * @param joinConditionType Type of the join condition
+ * @param systemFieldList List of system fields that will be prefixed to
output row type;
+ * typically empty but must not be null
+ * @return mark join type
+ */
+ public static RelDataType createMarkJoinType(
+ RelDataTypeFactory typeFactory,
+ RelDataType inputType,
+ RelDataType joinConditionType,
+ List<RelDataTypeField> systemFieldList) {
+ final String markerName =
+ SqlValidatorUtil.uniquify("markCol",
Sets.newHashSet(inputType.getFieldNames()),
+ SqlValidatorUtil.EXPR_SUGGESTER);
+ // conceptually the type of marker is a three-valued boolean, but it can
be simplified to a
+ // two-valued boolean in specific cases (e.g., rewriting from an EXISTS
subquery). Simple
+ // defining the marker type as nullable boolean might cause type mismatch
errors after rewriting
+ // some subqueries (such as EXISTS subquery).
+ // When deriving the type of LEFT_MARK join, we no longer know which
subquery it was
+ // rewritten from, but that information is implicit in the join condition.
For example, after
+ // rewriting and decorrelating an EXISTS (correlated) subquery, the
condition will only contain
+ // IS NOT DISTINCT FROM. Therefore, we derive the marker type from the
condition.
+ final RelDataType markerType =
+ typeFactory.createStructType(
+ ImmutableList.of(joinConditionType),
+ ImmutableList.of(markerName));
+ return createJoinType(typeFactory, inputType, markerType, null,
systemFieldList);
+ }
+
private static void addFields(List<RelDataTypeField> fieldList,
List<RelDataType> typeList, List<String> nameList,
Set<String> uniqueNames) {
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index 221ff5f713..04a8162450 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -47,6 +47,7 @@
import org.apache.calcite.linq4j.function.Function1;
import org.apache.calcite.linq4j.function.Function2;
import org.apache.calcite.linq4j.function.Functions;
+import org.apache.calcite.linq4j.function.NullablePredicate2;
import org.apache.calcite.linq4j.function.Predicate1;
import org.apache.calcite.linq4j.function.Predicate2;
import org.apache.calcite.linq4j.tree.FunctionExpression;
@@ -212,6 +213,17 @@ public enum BuiltInMethod {
Function1.class,
Function1.class, Function2.class, EqualityComparer.class,
boolean.class, boolean.class, Predicate2.class),
+ LEFT_MARK_HASH_JOIN(ExtendedEnumerable.class, "leftMarkHashJoin",
Enumerable.class,
+ Function1.class, // outer key null aware selector
+ Function1.class, // inner key null aware selector
+ Function1.class, // outer null-safe key selector
+ Function1.class, // inner null-safe key selector
+ boolean.class, // whether there is at most one not
null-safe key
+ Function2.class, // result selector
+ EqualityComparer.class, // join keys comparator
+ EqualityComparer.class, // null-safe join keys comparator
+ NullablePredicate2.class, // non-equi predicate that can return NULL
+ NullablePredicate2.class), // equi predicate that can return NULL
ASOF_JOIN(ExtendedEnumerable.class, "asofJoin", Enumerable.class,
Function1.class, // outer key selector
Function1.class, // inner key selector
diff --git
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
index c589f6eea5..fce70ff8f5 100644
---
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
+++
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
@@ -19,15 +19,30 @@
import org.apache.calcite.adapter.enumerable.EnumerableRules;
import org.apache.calcite.config.CalciteConnectionProperty;
import org.apache.calcite.config.Lex;
+import org.apache.calcite.plan.RelOptLattice;
+import org.apache.calcite.plan.RelOptMaterialization;
import org.apache.calcite.plan.RelOptPlanner;
+import org.apache.calcite.plan.RelTraitSet;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider;
+import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.runtime.Hook;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql2rel.TopDownGeneralDecorrelator;
import org.apache.calcite.test.CalciteAssert;
import org.apache.calcite.test.ReflectiveSchemaWithoutRowCount;
import org.apache.calcite.test.schemata.hr.HrSchema;
+import org.apache.calcite.tools.Program;
+import org.apache.calcite.tools.Programs;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.Holder;
+
+import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
+import java.util.List;
import java.util.function.Consumer;
/**
@@ -451,6 +466,132 @@ class EnumerableHashJoinTest {
"id1=2; sal1=null");
}
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-7315">[CALCITE-7315]
+ * Support LEFT_MARK type for hash join in enumerable convention</a>. */
+ @Test void testLeftMarkJoin() {
+ Program subQuery =
+ Programs.hep(
+ ImmutableList.of(CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE),
+ true,
+ DefaultRelMetadataProvider.INSTANCE);
+ Program toCalc =
+ Programs.hep(
+ ImmutableList.of(CoreRules.PROJECT_TO_CALC,
CoreRules.FILTER_TO_CALC,
+ CoreRules.CALC_MERGE),
+ true,
+ DefaultRelMetadataProvider.INSTANCE);
+ Program topDownDecorrelator = new Program() {
+ @Override public RelNode run(RelOptPlanner planner, RelNode rel,
+ RelTraitSet requiredOutputTraits, List<RelOptMaterialization>
materializations,
+ List<RelOptLattice> lattices) {
+ final RelBuilder relBuilder =
+ RelFactories.LOGICAL_BUILDER.create(rel.getCluster(), null);
+ return TopDownGeneralDecorrelator.decorrelateQuery(rel, relBuilder);
+ }
+ };
+ Program enumerableImpl =
Programs.ofRules(EnumerableRules.ENUMERABLE_RULES);
+
+ // case1: left mark join from uncorrelated IN subquery (0 null-safe key, 1
not null-safe key)
+ tester(false, new HrSchema())
+ .query(
+ "WITH t1(id) as (VALUES (1), (2), (NULL)), t2(id) as (VALUES (2),
(3)) "
+ + "select id, id in (select id from t2) as marker from t1")
+ .withHook(Hook.PROGRAM, (Consumer<Holder<Program>>) program -> {
+ program.set(Programs.sequence(subQuery, toCalc, enumerableImpl));
+ })
+ .explainHookMatches(
+ "EnumerableHashJoin(condition=[=($0, $1)], joinType=[left_mark])\n"
+ + " EnumerableValues(tuples=[[{ 1 }, { 2 }, { null }]])\n"
+ + " EnumerableCalc(expr#0=[{inputs}], id=[$t0])\n"
+ + " EnumerableValues(tuples=[[{ 2 }, { 3 }]])\n")
+ .returnsUnordered(
+ "id=1; marker=false",
+ "id=2; marker=true",
+ "id=null; marker=null");
+
+ // case2: left mark join from uncorrelated IN subquery (0 null-safe key, 1
not null-safe key)
+ tester(false, new HrSchema())
+ .query(
+ "WITH t1(id) as (VALUES (1), (2), (3)), t2(id) as (VALUES (2),
(NULL)) "
+ + "select id, id in (select id from t2) as marker from t1")
+ .withHook(Hook.PROGRAM, (Consumer<Holder<Program>>) program -> {
+ program.set(Programs.sequence(subQuery, toCalc, enumerableImpl));
+ })
+ .explainHookMatches(
+ "EnumerableHashJoin(condition=[=($0, $1)], joinType=[left_mark])\n"
+ + " EnumerableValues(tuples=[[{ 1 }, { 2 }, { 3 }]])\n"
+ + " EnumerableCalc(expr#0=[{inputs}], id=[$t0])\n"
+ + " EnumerableValues(tuples=[[{ 2 }, { null }]])\n")
+ .returnsUnordered(
+ "id=1; marker=null",
+ "id=2; marker=true",
+ "id=3; marker=null");
+
+ // case3: left mark join from uncorrelated IN subquery (0 null-safe key, 2
not null-safe key)
+ tester(false, new HrSchema())
+ .query(
+ "WITH t1(id, sal) as (VALUES (1, 10), (2, NULL), (3, NULL)), "
+ + "t2(id, sal) as (VALUES (1, 10), (2, NULL)) "
+ + "select id, sal, (id, sal) in (select id, sal from t2) as
marker from t1")
+ .withHook(Hook.PROGRAM, (Consumer<Holder<Program>>) program -> {
+ program.set(Programs.sequence(subQuery, toCalc, enumerableImpl));
+ })
+ .explainHookMatches(
+ "EnumerableHashJoin(condition=[AND(=($0, $2), =($1, $3))],
joinType=[left_mark])\n"
+ + " EnumerableValues(tuples=[[{ 1, 10 }, { 2, null }, { 3,
null }]])\n"
+ + " EnumerableCalc(expr#0..1=[{inputs}],
proj#0..1=[{exprs}])\n"
+ + " EnumerableValues(tuples=[[{ 1, 10 }, { 2, null }]])\n")
+ .returnsUnordered(
+ "id=1; sal=10; marker=true",
+ "id=2; sal=null; marker=null",
+ "id=3; sal=null; marker=false");
+
+ // case4: left mark join from correlated IN subquery (1 null-safe key, 1
not null-safe key)
+ tester(false, new HrSchema())
+ .query(
+ "WITH t1(id, sal) as (VALUES (1, 10), (2, 20), (3, NULL)), "
+ + "t2(id, sal) as (VALUES (1, 10), (2, NULL)) "
+ + "select id, sal in (select sal from t2 where t1.id = t2.id)
as marker from t1")
+ .withHook(Hook.PROGRAM, (Consumer<Holder<Program>>) program -> {
+ program.set(Programs.sequence(subQuery, topDownDecorrelator, toCalc,
enumerableImpl));
+ })
+ .explainHookMatches(
+ "EnumerableCalc(expr#0..2=[{inputs}], id=[$t0], marker=[$t2])\n"
+ + " EnumerableHashJoin(condition=[AND(=($1, $2), IS NOT
DISTINCT FROM($0, $3))], joinType=[left_mark])\n"
+ + " EnumerableValues(tuples=[[{ 1, 10 }, { 2, 20 }, { 3,
null }]])\n"
+ + " EnumerableCalc(expr#0..1=[{inputs}], EXPR$1=[$t1],
EXPR$0=[$t0])\n"
+ + " EnumerableValues(tuples=[[{ 1, 10 }, { 2, null
}]])\n")
+ .returnsUnordered(
+ "id=1; marker=true",
+ "id=2; marker=null",
+ "id=3; marker=false");
+
+ // case5: left mark join from correlated SOME subquery (1 null-safe key,
and non-equi predicate)
+ tester(false, new HrSchema())
+ .query(
+ "WITH t1(id, sal) as (VALUES (1, 10), (2, 20), (NULL, 30)), "
+ + "t2(id, sal) as (VALUES (1, 9), (2, NULL), (NULL, 31)) "
+ + "select id, sal < SOME(select sal from t2 where t1.id =
t2.id or t1.id is null) "
+ + "as marker from t1")
+ .withHook(Hook.PROGRAM, (Consumer<Holder<Program>>) program -> {
+ program.set(Programs.sequence(subQuery, topDownDecorrelator, toCalc,
enumerableImpl));
+ })
+ .explainHookMatches(
+ "EnumerableCalc(expr#0..2=[{inputs}], id=[$t0], marker=[$t2])\n"
+ + " EnumerableHashJoin(condition=[AND(IS NOT DISTINCT
FROM($0, $3), <($1, $2))], joinType=[left_mark])\n"
+ + " EnumerableValues(tuples=[[{ 1, 10 }, { 2, 20 }, { null,
30 }]])\n"
+ + " EnumerableCalc(expr#0..2=[{inputs}], EXPR$1=[$t1],
EXPR$00=[$t2])\n"
+ + " EnumerableNestedLoopJoin(condition=[OR(=($2, $0), IS
NULL($2))], joinType=[inner])\n"
+ + " EnumerableValues(tuples=[[{ 1, 9 }, { 2, null }, {
null, 31 }]])\n"
+ + " EnumerableCalc(expr#0..1=[{inputs}],
EXPR$0=[$t0])\n"
+ + " EnumerableValues(tuples=[[{ 1, 10 }, { 2, 20 }, {
null, 30 }]])\n")
+ .returnsUnordered(
+ "id=1; marker=false",
+ "id=2; marker=null",
+ "id=null; marker=true");
+ }
+
private CalciteAssert.AssertThat tester(boolean forceDecorrelate,
Object schema) {
return CalciteAssert.that()
diff --git a/core/src/test/resources/sql/blank.iq
b/core/src/test/resources/sql/blank.iq
index 4053cbb26b..206707287e 100644
--- a/core/src/test/resources/sql/blank.iq
+++ b/core/src/test/resources/sql/blank.iq
@@ -107,6 +107,7 @@ EnumerableCalc(expr#0..7=[{inputs}], expr#8=[0],
expr#9=[=($t3, $t8)], expr#10=[
EnumerableCalc(expr#0..1=[{inputs}], expr#2=[IS NOT NULL($t1)],
expr#3=[IS NOT NULL($t0)], expr#4=[AND($t2, $t3)], proj#0..1=[{exprs}],
$condition=[$t4])
EnumerableTableScan(table=[[BLANK, TABLE2]])
!plan
+!}
+---+---+
| I | J |
+---+---+
@@ -114,15 +115,7 @@ EnumerableCalc(expr#0..7=[{inputs}], expr#8=[0],
expr#9=[=($t3, $t8)], expr#10=[
(0 rows)
!ok
-!}
-
-# TODO: This error needs to be fixed
-!if (use_new_decorr) {
-Unable to convert LEFT_MARK to Linq4j JoinType
-!error
-!}
-!if (use_old_decorr) {
select * from table1 where j not in (select i from table2);
+---+---+
| I | J |
@@ -162,13 +155,6 @@ select * from table1 where j not in (select i from table2)
or j = 3;
(1 row)
!ok
-!}
-
-# TODO: This error needs to be fixed
-!if (use_new_decorr) {
-Unable to convert LEFT_MARK to Linq4j JoinType
-!error
-!}
# [CALCITE-4813] ANY_VALUE assumes that arguments should be comparable
select any_value(r) over(), s from(select array[f, s] r, s from (select 1 as
f, 2 as s) t) t;
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
index 4ea1b4e7e7..f90c9f112d 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
@@ -30,6 +30,7 @@
import org.apache.calcite.linq4j.function.NullableFloatFunction1;
import org.apache.calcite.linq4j.function.NullableIntegerFunction1;
import org.apache.calcite.linq4j.function.NullableLongFunction1;
+import org.apache.calcite.linq4j.function.NullablePredicate2;
import org.apache.calcite.linq4j.function.Predicate1;
import org.apache.calcite.linq4j.function.Predicate2;
@@ -429,6 +430,24 @@ protected OrderedQueryable<T> asOrderedQueryable() {
generateNullsOnRight, predicate);
}
+ @Override public <TInner, TKey, TNsKey, TResult> Enumerable<TResult>
leftMarkHashJoin(
+ Enumerable<TInner> inner,
+ Function1<T, TKey> outerKeyNullAwareSelector,
+ Function1<TInner, TKey> innerKeyNullAwareSelector,
+ @Nullable Function1<T, TNsKey> outerNullSafeKeySelector,
+ @Nullable Function1<TInner, TNsKey> innerNullSafeKeySelector,
+ boolean atMostOneNotNullSafeKey,
+ Function2<T, @Nullable Boolean, TResult> resultSelector,
+ @Nullable EqualityComparer<TKey> comparer,
+ @Nullable EqualityComparer<TNsKey> nullSafeComparer,
+ @Nullable NullablePredicate2<T, TInner> nonEquiPredicate,
+ NullablePredicate2<T, TInner> equiPredicate) {
+ return EnumerableDefaults.leftMarkHashJoin(getThis(), inner,
outerKeyNullAwareSelector,
+ innerKeyNullAwareSelector, outerNullSafeKeySelector,
innerNullSafeKeySelector,
+ atMostOneNotNullSafeKey, resultSelector, comparer, nullSafeComparer,
+ nonEquiPredicate, equiPredicate);
+ }
+
@Override public <TInner, TResult> Enumerable<TResult> correlateJoin(
JoinType joinType, Function1<T, Enumerable<TInner>> inner,
Function2<T, TInner, TResult> resultSelector) {
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
index 14309bdfd6..96828708f7 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
@@ -31,6 +31,7 @@
import org.apache.calcite.linq4j.function.NullableFloatFunction1;
import org.apache.calcite.linq4j.function.NullableIntegerFunction1;
import org.apache.calcite.linq4j.function.NullableLongFunction1;
+import org.apache.calcite.linq4j.function.NullablePredicate2;
import org.apache.calcite.linq4j.function.Predicate1;
import org.apache.calcite.linq4j.function.Predicate2;
@@ -1685,6 +1686,285 @@ private static <TSource, TInner, TKey, TResult>
Enumerable<TResult> hashJoinWith
};
}
+ /**
+ * Left mark join implementation based on hash. It will keep all rows from
the left side and
+ * creates a new attribute to mark the rows from left input as having join
partners from right
+ * side or not. Refer to <a
href="https://dl.gi.de/items/c5f7c49f-1572-490e-976a-cc4292519bdd">
+ * The Complete Story of Joins (in HyPer)</a>. The implementation in the
paper targets the case
+ * of a single join key that is not null‑safe, which is a representative
scenario. We need to
+ * consider the more general case of multiple join keys, including both the
null‑safe and not
+ * null‑safe. The key point is whether the marker should be FALSE or NULL
when the hash table
+ * match fails.
+ *
+ * <p> Left mark join is produced by rewriting IN/SOME/EXISTS subqueries, so
left mark join will
+ * never contain both not null-safe join keys and non-equi predicates.
+ *
+ * @param outer Left input
+ * @param inner Right input
+ * @param outerKeyNullAwareSelector Function that extracts keys from the
row of left input
+ * (return NULL when a not null-safe key
has a NULL value)
+ * @param innerKeyNullAwareSelector Function that extracts keys from the
row of right input
+ * (return NULL when a not null-safe key
has a NULL value)
+ * @param outerNullSafeKeySelector Function that extracts the null-safe
keys from the row of
+ * left input
+ * @param innerNullSafeKeySelector Function that extracts the null-safe
keys from the row of
+ * right input
+ * @param atMostOneNotNullSafeKey True when there is at most one not
null-safe key in join
+ * keys
+ * @param resultSelector Function that concats the row of left
input and marker
+ * @param comparer Function that compares the keys
+ * @param nullSafeComparer Function that compares the null-safe
keys
+ * @param nonEquiPredicate Non-equi predicate that can return NULL
+ * @param equiPredicate Equi predicate that can return NULL
+ */
+ public static <TSource, TInner, TKey, TNsKey, TResult> Enumerable<TResult>
leftMarkHashJoin(
+ final Enumerable<TSource> outer, final Enumerable<TInner> inner,
+ final Function1<TSource, TKey> outerKeyNullAwareSelector,
+ final Function1<TInner, TKey> innerKeyNullAwareSelector,
+ final @Nullable Function1<TSource, TNsKey> outerNullSafeKeySelector,
+ final @Nullable Function1<TInner, TNsKey> innerNullSafeKeySelector,
+ final boolean atMostOneNotNullSafeKey,
+ final Function2<TSource, @Nullable Boolean, TResult> resultSelector,
+ final @Nullable EqualityComparer<TKey> comparer,
+ final @Nullable EqualityComparer<TNsKey> nullSafeComparer,
+ final @Nullable NullablePredicate2<TSource, TInner> nonEquiPredicate,
+ final NullablePredicate2<TSource, TInner> equiPredicate) {
+ if (atMostOneNotNullSafeKey) {
+ return leftMarkHashJoinOptimized(outer, inner, outerKeyNullAwareSelector,
+ innerKeyNullAwareSelector, outerNullSafeKeySelector,
innerNullSafeKeySelector,
+ resultSelector, comparer, nullSafeComparer, nonEquiPredicate);
+ }
+ return leftMarkHashJoinGeneral(outer, inner, outerKeyNullAwareSelector,
+ innerKeyNullAwareSelector, outerNullSafeKeySelector,
innerNullSafeKeySelector,
+ resultSelector, comparer, nullSafeComparer, nonEquiPredicate,
equiPredicate);
+ }
+
+ /**
+ * For other join types (especially INNER join), the hash table can be used
to quickly determine
+ * which right-side rows should be joined with a given left-side row. But
left mark join is more
+ * complicated. The <code>marker</code> indicates whether a left row has a
join partner on the
+ * right side, it is a three-valued boolean, meaning we need to know the
actual result (TRUE,
+ * FALSE, or NULL) of the join condition.
+ *
+ * <p> Join key comes in two categories:
+ * <ul>
+ * <li>null-safe key (IS NOT DISTINCT FROM): it only produces
TRUE/FALSE.</li>
+ * <li>not null-safe key (EQUALS): it produces a three-valued boolean.</li>
+ * </ul>
+ *
+ * <p> If all join keys are null-safe, we can get the comparison result
(TRUE or FALSE) based on
+ * hash table matching.
+ *
+ * <p> If there are multiple not null-safe join keys, the hash table
matching alone is
+ * insufficient to determine the comparison result. However, if there is
only one not null-safe
+ * key, we can record whether this key has any NULL values when constructing
the hash table.
+ * During probing, if no match is found in the hash table and there are any
NULL values on this
+ * unique not null-safe key, we can know that the <code>marker</code> is
NULL.
+ */
+ static <TSource, TInner, TKey, TNsKey, TResult> Enumerable<TResult>
leftMarkHashJoinOptimized(
+ final Enumerable<TSource> outer, final Enumerable<TInner> inner,
+ final Function1<TSource, TKey> outerKeyNullAwareSelector,
+ final Function1<TInner, TKey> innerKeyNullAwareSelector,
+ final @Nullable Function1<TSource, TNsKey> outerNullSafeKeySelector,
+ final @Nullable Function1<TInner, TNsKey> innerNullSafeKeySelector,
+ final Function2<TSource, @Nullable Boolean, TResult> resultSelector,
+ final @Nullable EqualityComparer<TKey> comparer,
+ final @Nullable EqualityComparer<TNsKey> nullSafeComparer,
+ final @Nullable NullablePredicate2<TSource, TInner> nonEquiPredicate) {
+ return new AbstractEnumerable<TResult>() {
+ @Override public Enumerator<TResult> enumerator() {
+ HashTableWithNullSafeKeySet<TKey, TNsKey, TInner> ht =
+ HashTableWithNullSafeKeySet.build(inner, innerKeyNullAwareSelector,
+ innerNullSafeKeySelector, comparer, nullSafeComparer);
+
+ return new Enumerator<TResult>() {
+ Enumerator<TSource> outers = outer.enumerator();
+ @Nullable Boolean marker = false;
+
+ @Override public TResult current() {
+ return resultSelector.apply(outers.current(), marker);
+ }
+
+ @Override public boolean moveNext() {
+ if (!outers.moveNext()) {
+ return false;
+ }
+ marker = false;
+ final TSource outerRow = outers.current();
+ final TKey outerKey = outerKeyNullAwareSelector.apply(outerRow);
+ if (outerNullSafeKeySelector != null
+ &&
!ht.containsNullSafeKey(outerNullSafeKeySelector.apply(outerRow))) {
+ // there are null-safe keys, but there is no match in the hash
table of null-safe
+ // keys. The marker is FALSE
+ return true;
+ }
+
+ if (outerKey == null) {
+ // outerRow has a NULL value on the unique not null-safe key.
The marker is NULL
+ marker = null;
+ } else {
+ Enumerable<TInner> innerEnumerable = ht.lookup.get(outerKey);
+ if (innerEnumerable == null) {
+ // no match found in the hash table. If there are any NULL
values on the unique
+ // not null-safe key, the marker is NULL.
+ marker = ht.lookup.containsKey(null) ? null : false;
+ } else {
+ if (nonEquiPredicate == null) {
+ marker = true;
+ } else {
+ try (Enumerator<TInner> innerEnumerator =
innerEnumerable.enumerator()) {
+ while (innerEnumerator.moveNext()) {
+ final TInner innerRow = innerEnumerator.current();
+ Boolean predicateMatched =
nonEquiPredicate.apply(outerRow, innerRow);
+ if (predicateMatched == null) {
+ marker = null;
+ } else if (predicateMatched) {
+ marker = true;
+ break;
+ }
+ }
+ }
+ }
+ }
+ }
+ // if the inner is empty set, convert the NULL marker to FALSE
+ if (marker == null && ht.buildSideIsEmpty) {
+ marker = false;
+ }
+ return true;
+ }
+
+ @Override public void reset() {
+ outers.reset();
+ }
+
+ @Override public void close() {
+ outers.close();
+ }
+ };
+ }
+ };
+ }
+
+ /**
+ * As described in {@link #leftMarkHashJoinOptimized}, if there are multiple
not null-safe join
+ * keys, the hash table matching alone is insufficient to determine the
comparison result. For
+ * left rows that fail to match in the hash table, we need to apply the
equi-predicate against
+ * the right-side rows one by one to determine whether the
<code>marker</code> should be
+ * FALSE or NULL.
+ */
+ static <TSource, TInner, TKey, TNsKey, TResult> Enumerable<TResult>
leftMarkHashJoinGeneral(
+ final Enumerable<TSource> outer, final Enumerable<TInner> inner,
+ final Function1<TSource, TKey> outerKeyNullAwareSelector,
+ final Function1<TInner, TKey> innerKeyNullAwareSelector,
+ final @Nullable Function1<TSource, TNsKey> outerNullSafeKeySelector,
+ final @Nullable Function1<TInner, TNsKey> innerNullSafeKeySelector,
+ final Function2<TSource, @Nullable Boolean, TResult> resultSelector,
+ final @Nullable EqualityComparer<TKey> comparer,
+ final @Nullable EqualityComparer<TNsKey> nullSafeComparer,
+ final @Nullable NullablePredicate2<TSource, TInner> nonEquiPredicate,
+ final NullablePredicate2<TSource, TInner> equiPredicate) {
+ return new AbstractEnumerable<TResult>() {
+ @Override public Enumerator<TResult> enumerator() {
+ HashTableWithNullSafeKeySet<TKey, TNsKey, TInner> ht =
+ HashTableWithNullSafeKeySet.build(inner, innerKeyNullAwareSelector,
+ innerNullSafeKeySelector, comparer, nullSafeComparer);
+
+ return new Enumerator<TResult>() {
+ Enumerator<TSource> outers = outer.enumerator();
+ @Nullable Boolean marker = false;
+
+ @Override public TResult current() {
+ return resultSelector.apply(outers.current(), marker);
+ }
+
+ @Override public boolean moveNext() {
+ if (!outers.moveNext()) {
+ return false;
+ }
+ marker = false;
+ final TSource outerRow = outers.current();
+ final TKey outerKey = outerKeyNullAwareSelector.apply(outerRow);
+ if (outerNullSafeKeySelector != null
+ &&
!ht.containsNullSafeKey(outerNullSafeKeySelector.apply(outerRow))) {
+ // there are null-safe keys, but there is no match in the hash
table of null-safe
+ // keys. The marker is FALSE
+ return true;
+ }
+
+ if (outerKey == null) {
+ // outerRow has NULL values on at least one not null-safe key.
Need to apply the
+ // equi-predicate against all right-side rows to determine the
marker is FALSE or NULL
+ flag:
+ for (Enumerable<TInner> eachInnerEnumerable :
ht.lookup.values()) {
+ try (Enumerator<TInner> eachInnerEnumerator =
eachInnerEnumerable.enumerator()) {
+ while (eachInnerEnumerator.moveNext()) {
+ TInner eachInnerRow = eachInnerEnumerator.current();
+ Boolean equiPredicateMatched =
equiPredicate.apply(outerRow, eachInnerRow);
+ if (equiPredicateMatched == null) {
+ marker = null;
+ break flag;
+ }
+ }
+ }
+ }
+ } else {
+ Enumerable<TInner> innerEnumerable = ht.lookup.get(outerKey);
+ if (innerEnumerable == null) {
+ // no match found in the hash table. If there are any NULL
values on not
+ // null-safe keys, need to apply the equi-predicate against
those rows to determine
+ // the marker is FALSE or NULL.
+ Enumerable<TInner> nullValueOnNotNullSafeKey =
ht.lookup.get(null);
+ if (nullValueOnNotNullSafeKey != null) {
+ try (Enumerator<TInner> enumerator =
nullValueOnNotNullSafeKey.enumerator()) {
+ while (enumerator.moveNext()) {
+ TInner nullValueOnNotNullSafeKeyRow =
enumerator.current();
+ Boolean equiPredicateMatched =
+ equiPredicate.apply(outerRow,
nullValueOnNotNullSafeKeyRow);
+ if (equiPredicateMatched == null) {
+ marker = null;
+ break;
+ }
+ }
+ }
+ }
+ } else {
+ if (nonEquiPredicate == null) {
+ marker = true;
+ } else {
+ try (Enumerator<TInner> innerEnumerator =
innerEnumerable.enumerator()) {
+ while (innerEnumerator.moveNext()) {
+ final TInner innerRow = innerEnumerator.current();
+ Boolean predicateMatched =
nonEquiPredicate.apply(outerRow, innerRow);
+ if (predicateMatched == null) {
+ marker = null;
+ } else if (predicateMatched) {
+ marker = true;
+ break;
+ }
+ }
+ }
+ }
+ }
+ }
+ if (marker == null && ht.buildSideIsEmpty) {
+ marker = false;
+ }
+ return true;
+ }
+
+ @Override public void reset() {
+ outers.reset();
+ }
+
+ @Override public void close() {
+ outers.close();
+ }
+ };
+ }
+ };
+ }
+
/**
* For each row of the {@code outer} enumerable returns the correlated rows
* from the {@code inner} enumerable.
@@ -3801,21 +4081,8 @@ static <TSource, TKey, TElement> LookupImpl<TKey,
TElement> toLookup_(
while (os.moveNext()) {
TSource o = os.current();
final TKey key = keySelector.apply(o);
- @SuppressWarnings("nullness")
- List<TElement> list = map.get(key);
- if (list == null) {
- // for first entry, use a singleton list to save space
- list = Collections.singletonList(elementSelector.apply(o));
- } else {
- if (list.size() == 1) {
- // when we go from 1 to 2 elements, switch to array list
- TElement element = list.get(0);
- list = new ArrayList<>();
- list.add(element);
- }
- list.add(elementSelector.apply(o));
- }
- map.put(key, list);
+ final TElement data = elementSelector.apply(o);
+ appendDataForKey(map, key, data);
}
}
return new LookupImpl<>(map);
@@ -3840,6 +4107,24 @@ public static <TSource, TKey, TElement> Lookup<TKey,
TElement> toLookup(
elementSelector);
}
+ private static <TKey, TElement> void appendDataForKey(
+ Map<TKey, List<TElement>> map, TKey key, TElement row) {
+ List<TElement> list = map.get(key);
+ if (list == null) {
+ // for first entry, use a singleton list to save space
+ list = Collections.singletonList(row);
+ } else {
+ if (list.size() == 1) {
+ // when we go from 1 to 2 elements, switch to array list
+ TElement element = list.get(0);
+ list = new ArrayList<>();
+ list.add(element);
+ }
+ list.add(row);
+ }
+ map.put(key, list);
+ }
+
/**
* Produces the set union of two sequences by using
* the default equality comparer.
@@ -4024,6 +4309,63 @@ public static <T, C extends Collection<? super T>> C
remove(
return sink;
}
+ /**
+ * Hash table with null-safe key set.
+ *
+ * @param <TKey> key type
+ * @param <TNsKey> null-safe key type
+ * @param <TData> build side row type
+ */
+ static class HashTableWithNullSafeKeySet<TKey, TNsKey, TData> {
+ final Lookup<TKey, TData> lookup;
+ final Set<Wrapped<TNsKey>> nullSafeKeySet;
+ final EqualityComparer<TNsKey> nullSafeComparer;
+ // whether the build side is empty set
+ final boolean buildSideIsEmpty;
+
+ private HashTableWithNullSafeKeySet(
+ Lookup<TKey, TData> lookup,
+ Set<Wrapped<TNsKey>> nullSafeKeySet,
+ EqualityComparer<TNsKey> nullSafeComparer) {
+ this.lookup = lookup;
+ this.nullSafeKeySet = nullSafeKeySet;
+ this.nullSafeComparer = nullSafeComparer;
+ this.buildSideIsEmpty = lookup.isEmpty();
+ }
+
+ static <TKey, TNsKey, TData> HashTableWithNullSafeKeySet<TKey, TNsKey,
TData> build(
+ Enumerable<TData> data,
+ Function1<TData, TKey> keyNullAwareSelector,
+ @Nullable Function1<TData, TNsKey> nullSafeKeySelector,
+ @Nullable EqualityComparer<TKey> comparer,
+ @Nullable EqualityComparer<TNsKey> nullSafeComparer) {
+ Map<TKey, List<TData>> map =
+ comparer == null
+ ? new HashMap<>()
+ : new WrapMap<>(() -> new HashMap<Wrapped<TKey>, List<TData>>(),
comparer);
+ Set<Wrapped<TNsKey>> nullSafeKeySet = new HashSet<>();
+ nullSafeComparer = nullSafeComparer == null ?
Functions.identityComparer() : nullSafeComparer;
+
+ try (Enumerator<TData> enumerator = data.enumerator()) {
+ while (enumerator.moveNext()) {
+ TData row = enumerator.current();
+ TKey nullAwareKey = keyNullAwareSelector.apply(row);
+ appendDataForKey(map, nullAwareKey, row);
+ if (nullSafeKeySelector != null) {
+ TNsKey nullSafeKey = nullSafeKeySelector.apply(row);
+ nullSafeKeySet.add(Wrapped.upAs(nullSafeComparer, nullSafeKey));
+ }
+ }
+ }
+ Lookup<TKey, TData> nullAwareLookup = new LookupImpl<>(map);
+ return new HashTableWithNullSafeKeySet<>(nullAwareLookup,
nullSafeKeySet, nullSafeComparer);
+ }
+
+ public boolean containsNullSafeKey(TNsKey key) {
+ return nullSafeKeySet.contains(Wrapped.upAs(nullSafeComparer, key));
+ }
+ }
+
/** Enumerable that implements take-while.
*
* @param <TSource> element type */
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
index 9c11f83dd7..67887e7d08 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
@@ -30,6 +30,7 @@
import org.apache.calcite.linq4j.function.NullableFloatFunction1;
import org.apache.calcite.linq4j.function.NullableIntegerFunction1;
import org.apache.calcite.linq4j.function.NullableLongFunction1;
+import org.apache.calcite.linq4j.function.NullablePredicate2;
import org.apache.calcite.linq4j.function.Predicate1;
import org.apache.calcite.linq4j.function.Predicate2;
@@ -646,6 +647,45 @@ <TInner, TKey, TResult> Enumerable<TResult>
hashJoin(Enumerable<TInner> inner,
boolean generateNullsOnLeft, boolean generateNullsOnRight,
Predicate2<TSource, TInner> predicate);
+ /**
+ * Mark each row of the current enumerable to see if it has a join partner
in the
+ * <code>inner</code>. Whether a join partner exists depends on:
+ * - matching keys
+ * - non-equi predicate (if provided)
+ *
+ * <p> Refer to <a
href="https://dl.gi.de/items/c5f7c49f-1572-490e-976a-cc4292519bdd">
+ * The Complete Story of Joins (in HyPer)</a>.
+ *
+ * @param inner Inner enumerable
+ * @param outerKeyNullAwareSelector Function that extracts keys from the
current enumerable
+ * (return NULL when a not null-safe key
has a NULL value)
+ * @param innerKeyNullAwareSelector Function that extracts keys from the
inner enumerable
+ * (return NULL when a not null-safe key
has a NULL value)
+ * @param outerNullSafeKeySelector Function that extracts null-safe keys
from the current
+ * enumerable
+ * @param innerNullSafeKeySelector Function that extracts null-safe keys
from the inner
+ * enumerable
+ * @param atMostOneNotNullSafeKey True when there is at most one not
null-safe key in join
+ * keys
+ * @param resultSelector Function that concat the row of the
current enumerable and
+ * marker
+ * @param comparer Function that compares the keys
+ * @param nullSafeComparer Function that compares the null-safe
keys
+ * @param nonEquiPredicate Non-equi predicate that can return NULL
+ * @param equiPredicate Equi predicate that can return NULL
+ */
+ <TInner, TKey, TNsKey, TResult> Enumerable<TResult>
leftMarkHashJoin(Enumerable<TInner> inner,
+ Function1<TSource, TKey> outerKeyNullAwareSelector,
+ Function1<TInner, TKey> innerKeyNullAwareSelector,
+ Function1<TSource, TNsKey> outerNullSafeKeySelector,
+ Function1<TInner, TNsKey> innerNullSafeKeySelector,
+ boolean atMostOneNotNullSafeKey,
+ Function2<TSource, @Nullable Boolean, TResult> resultSelector,
+ EqualityComparer<TKey> comparer,
+ EqualityComparer<TNsKey> nullSafeComparer,
+ NullablePredicate2<TSource, TInner> nonEquiPredicate,
+ NullablePredicate2<TSource, TInner> equiPredicate);
+
/**
* For each row of the current enumerable returns the correlated rows
* from the {@code inner} enumerable (nested loops join).
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/JoinType.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/JoinType.java
index 4b38f145ea..d5c15ed224 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/JoinType.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/JoinType.java
@@ -87,7 +87,32 @@ public enum JoinType {
/**
* The left version of an ASOF join, where each row from the left table is
part of the output.
*/
- LEFT_ASOF;
+ LEFT_ASOF,
+
+ /**
+ * An LEFT MARK JOIN will keep all rows from the left side and creates a new
attribute to mark a
+ * tuple as having join partners from right side or not. Refer to
+ * <a href="https://dl.gi.de/items/c5f7c49f-1572-490e-976a-cc4292519bdd">
+ * The Complete Story of Joins (in HyPer)</a>.
+ *
+ * <p>Example:
+ * <blockquote><pre>
+ * SELECT EMPNO FROM EMP
+ * WHERE EXISTS (SELECT 1 FROM DEPT
+ * WHERE DEPT.DEPTNO = EMP.DEPTNO)
+ * OR EMPNO > 1
+ *
+ * LogicalProject(EMPNO=[$0])
+ * LogicalFilter(condition=[OR($9, >($0, 1))])
+ * LogicalJoin(condition=[IS NOT DISTINCT FROM($7, $9)],
joinType=[left_mark])
+ * LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ * LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+ * </pre></blockquote>
+ *
+ * <p> If the marker is used on only conjunctive predicates the optimizer
will try to translate
+ * the mark join into semi or anti join.
+ */
+ LEFT_MARK;
/**
* Returns whether a join of this type may generate NULL values on the
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/function/NullablePredicate2.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/function/NullablePredicate2.java
new file mode 100644
index 0000000000..f413db4635
--- /dev/null
+++
b/linq4j/src/main/java/org/apache/calcite/linq4j/function/NullablePredicate2.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.linq4j.function;
+
+/**
+ * Function with two parameters returning a {@link Boolean} value that may be
null.
+ *
+ * @param <T0> Type of argument #0
+ * @param <T1> Type of argument #1
+ */
+public interface NullablePredicate2<T0, T1> extends Function2<T0, T1, Boolean>
{
+ @Override Boolean apply(T0 v0, T1 v1);
+}