This is an automated email from the ASF dual-hosted git repository. mbudiu pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/calcite.git
commit fb37dc3705dd8dafb037310ee9b114f4020fd1e5 Author: Mihai Budiu <[email protected]> AuthorDate: Wed Aug 14 11:47:30 2024 -0700 [CALCITE-6372] Add ASOF join to the Calcite enumerable Signed-off-by: Mihai Budiu <[email protected]> --- .../calcite/adapter/enumerable/EnumUtils.java | 4 + .../adapter/enumerable/EnumerableAsofJoin.java | 243 +++++++++++++++++++++ .../adapter/enumerable/EnumerableAsofJoinRule.java | 68 ++++++ .../adapter/enumerable/EnumerableRules.java | 7 + .../org/apache/calcite/util/BuiltInMethod.java | 7 + .../test/enumerable/EnumerableJoinTest.java | 50 +++++ core/src/test/resources/sql/asof.iq | 206 +++++++++++++++++ .../apache/calcite/linq4j/DefaultEnumerable.java | 13 ++ .../apache/calcite/linq4j/EnumerableDefaults.java | 187 +++++++++++++++- .../apache/calcite/linq4j/ExtendedEnumerable.java | 22 ++ .../java/org/apache/calcite/linq4j/JoinType.java | 23 +- .../org/apache/calcite/linq4j/test/Linq4jTest.java | 18 ++ 12 files changed, 846 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumUtils.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumUtils.java index 224784b405..1020838443 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumUtils.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumUtils.java @@ -842,6 +842,10 @@ public class EnumUtils { return JoinType.SEMI; case ANTI: return JoinType.ANTI; + case ASOF: + return JoinType.ASOF; + case LEFT_ASOF: + return JoinType.LEFT_ASOF; default: break; } diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAsofJoin.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAsofJoin.java new file mode 100644 index 0000000000..f59ea1e409 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAsofJoin.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.adapter.enumerable; + +import org.apache.calcite.linq4j.tree.BlockBuilder; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptCost; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelCollationTraitDef; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.AsofJoin; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalAsofJoin; +import org.apache.calcite.rel.metadata.RelMdCollation; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.util.BuiltInMethod; +import org.apache.calcite.util.Pair; + +import com.google.common.collect.ImmutableList; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +/** Implementation of {@link LogicalAsofJoin} in + * {@link EnumerableConvention enumerable calling convention}. */ +public class EnumerableAsofJoin extends AsofJoin implements EnumerableRel { + /** Creates an EnumerableAsofJoin. + * + * <p>Use {@link #create} unless you know what you're doing. */ + protected EnumerableAsofJoin( + RelOptCluster cluster, + RelTraitSet traits, + RelNode left, + RelNode right, + RexNode condition, + RexNode matchCondition, + Set<CorrelationId> variablesSet, + JoinRelType joinType) { + super( + cluster, + traits, + ImmutableList.of(), + left, + right, + condition, + matchCondition, + variablesSet, + joinType); + } + + /** Creates an EnumerableAsofJoin. */ + public static EnumerableAsofJoin create( + RelNode left, + RelNode right, + RexNode condition, + RexNode matchCondition, + Set<CorrelationId> variablesSet, + JoinRelType joinType) { + final RelOptCluster cluster = left.getCluster(); + final RelMetadataQuery mq = cluster.getMetadataQuery(); + final RelTraitSet traitSet = + cluster.traitSetOf(EnumerableConvention.INSTANCE) + .replaceIfs(RelCollationTraitDef.INSTANCE, + () -> RelMdCollation.enumerableHashJoin(mq, left, right, joinType)); + return new EnumerableAsofJoin(cluster, traitSet, left, right, condition, matchCondition, + variablesSet, joinType); + } + + @Override public EnumerableAsofJoin copy(RelTraitSet traitSet, RexNode condition, + RelNode left, RelNode right, JoinRelType joinType, + boolean semiJoinDone) { + // This method does not know about the matchCondition, so it should not be called + throw new RuntimeException("This method should not be called"); + } + + @Override public Join copy(RelTraitSet traitSet, List<RelNode> inputs) { + assert inputs.size() == 2; + return new EnumerableAsofJoin( + getCluster(), traitSet, inputs.get(0), inputs.get(1), + getCondition(), matchCondition, variablesSet, joinType); + } + + @Override public @Nullable Pair<RelTraitSet, List<RelTraitSet>> passThroughTraits( + final RelTraitSet required) { + return EnumerableTraitsUtils.passThroughTraitsForJoin( + required, joinType, left.getRowType().getFieldCount(), getTraitSet()); + } + + @Override public @Nullable Pair<RelTraitSet, List<RelTraitSet>> deriveTraits( + final RelTraitSet childTraits, final int childId) { + // should only derive traits (limited to collation for now) from left join input. + return EnumerableTraitsUtils.deriveTraitsForJoin( + childTraits, childId, joinType, getTraitSet(), right.getTraitSet()); + } + + @Override public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, + RelMetadataQuery mq) { + double rowCount = mq.getRowCount(this); + return planner.getCostFactory().makeCost(rowCount, 0, 0); + } + + /** Generate the function that compares two rows from the right collection on their + * timestamp field. + * + * @param rightCollectionType Type of data in right collection. + * @param kind Comparison kind. + * @param timestampFieldIndex Index of the field that is the timestamp field. + */ + private Expression generateTimestampComparator( + PhysType rightCollectionType, SqlKind kind, int timestampFieldIndex) { + RelFieldCollation.Direction direction; + switch (kind) { + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + direction = RelFieldCollation.Direction.ASCENDING; + break; + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + direction = RelFieldCollation.Direction.DESCENDING; + break; + default: + throw new RuntimeException("Unexpected timestamp comparison in ASOF join " + kind); + } + + final List<RelFieldCollation> fieldCollations = new ArrayList<>(1); + fieldCollations.add( + new RelFieldCollation(timestampFieldIndex, direction, + RelFieldCollation.NullDirection.FIRST)); + final RelCollation collation = RelCollations.of(fieldCollations); + return rightCollectionType.generateComparator(collation); + } + + /** Extract from a comparison 'call' the index of the field from + * the inner collection that is used in the comparison. */ + private int getTimestampFieldIndex(RexCall call) { + int timestampFieldIndex; + int leftFieldCount = left.getRowType().getFieldCount(); + List<RexNode> operands = call.getOperands(); + assert operands.size() == 2; + RexNode compareLeft = operands.get(0); + RexNode compareRight = operands.get(1); + assert compareLeft instanceof RexInputRef; + assert compareRight instanceof RexInputRef; + RexInputRef leftInputRef = (RexInputRef) compareLeft; + RexInputRef rightInputRef = (RexInputRef) compareRight; + // We know for sure that these two come from the inner and outer collection respectively, + // but we don't know which is which + if (leftInputRef.getIndex() < leftFieldCount) { + // Left input comes from the left collection + timestampFieldIndex = rightInputRef.getIndex() - leftFieldCount; + } else { + // Left input comes from the right collection + timestampFieldIndex = leftInputRef.getIndex() - leftFieldCount; + } + return timestampFieldIndex; + } + + @Override public Result implement(EnumerableRelImplementor implementor, Prefer pref) { + BlockBuilder builder = new BlockBuilder(); + final Result leftResult = + implementor.visitChild(this, 0, (EnumerableRel) left, pref); + Expression leftExpression = + builder.append( + "left", leftResult.block); + final Result rightResult = + implementor.visitChild(this, 1, (EnumerableRel) right, pref); + Expression rightExpression = + builder.append( + "right", rightResult.block); + final PhysType physType = + PhysTypeImpl.of( + implementor.getTypeFactory(), getRowType(), pref.preferArray()); + // ASOF joins conditions are restricted to equalities + assert joinInfo.nonEquiConditions.isEmpty(); + + // From the match condition we need to find out the kind of comparison performed + // and the timestamp field in the right collection. + assert matchCondition instanceof RexCall; + RexCall call = (RexCall) matchCondition; + SqlKind kind = call.getKind(); + int timestampFieldIndex = getTimestampFieldIndex(call); + + Expression timestampComparator = + generateTimestampComparator(rightResult.physType, kind, timestampFieldIndex); + + Expression matchPredicate = + EnumUtils.generatePredicate(implementor, getCluster().getRexBuilder(), + left, right, leftResult.physType, rightResult.physType, matchCondition); + return implementor.result( + physType, + builder.append( + Expressions.call( + leftExpression, + BuiltInMethod.ASOF_JOIN.method, + Expressions.list( + rightExpression, + // outer key selector + leftResult.physType.generateAccessorWithoutNulls(joinInfo.leftKeys), + // inner key selector + rightResult.physType.generateAccessorWithoutNulls(joinInfo.rightKeys), + // result selector + EnumUtils.joinSelector(joinType, + physType, + ImmutableList.of( + leftResult.physType, rightResult.physType))) + // match comparator + .append(matchPredicate) + // comparator for the columns used as "timestamps" + .append(timestampComparator) + // generatesNullOnRight + .append(Expressions.constant(joinType.generatesNullsOnRight())))) + .toBlock()); + } +} diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAsofJoinRule.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAsofJoinRule.java new file mode 100644 index 0000000000..4e975e601d --- /dev/null +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAsofJoinRule.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.adapter.enumerable; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.logical.LogicalAsofJoin; + +import java.util.ArrayList; +import java.util.List; + +/** Planner rule that converts a + * {@link LogicalAsofJoin} relational expression + * {@link EnumerableConvention enumerable calling convention}. + * + * @see EnumerableRules#ENUMERABLE_JOIN_RULE */ +class EnumerableAsofJoinRule extends ConverterRule { + /** Default configuration. */ + public static final Config DEFAULT_CONFIG = Config.INSTANCE + .withConversion(LogicalAsofJoin.class, Convention.NONE, + EnumerableConvention.INSTANCE, "EnumerableAsofJoinRule") + .withRuleFactory(EnumerableAsofJoinRule::new); + + /** Called from the Config. */ + protected EnumerableAsofJoinRule(Config config) { + super(config); + } + + @Override public RelNode convert(RelNode rel) { + LogicalAsofJoin join = (LogicalAsofJoin) rel; + List<RelNode> newInputs = new ArrayList<>(); + for (RelNode input : join.getInputs()) { + if (!(input.getConvention() instanceof EnumerableConvention)) { + input = + convert( + input, + input.getTraitSet() + .replace(EnumerableConvention.INSTANCE)); + } + newInputs.add(input); + } + final RelNode left = newInputs.get(0); + final RelNode right = newInputs.get(1); + + return EnumerableAsofJoin.create( + left, + right, + join.getCondition(), + join.getMatchCondition(), + join.getVariablesSet(), + join.getJoinType()); + } +} diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java index a597108964..a575d18943 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java @@ -48,6 +48,12 @@ public class EnumerableRules { public static final RelOptRule ENUMERABLE_JOIN_RULE = EnumerableJoinRule.DEFAULT_CONFIG.toRule(EnumerableJoinRule.class); + /** Rule that converts a + * {@link org.apache.calcite.rel.logical.LogicalAsofJoin} to + * {@link EnumerableConvention enumerable calling convention}. */ + public static final RelOptRule ENUMERABLE_ASOFJOIN_RULE = + EnumerableAsofJoinRule.DEFAULT_CONFIG.toRule(EnumerableAsofJoinRule.class); + /** Rule that converts a * {@link org.apache.calcite.rel.logical.LogicalJoin} to * {@link EnumerableConvention enumerable calling convention}. */ @@ -205,6 +211,7 @@ public class EnumerableRules { public static final List<RelOptRule> ENUMERABLE_RULES = ImmutableList.of(EnumerableRules.ENUMERABLE_JOIN_RULE, + EnumerableRules.ENUMERABLE_ASOFJOIN_RULE, EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE, EnumerableRules.ENUMERABLE_CORRELATE_RULE, EnumerableRules.ENUMERABLE_PROJECT_RULE, diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 57d1f25a11..7d74ccd37a 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -198,6 +198,13 @@ public enum BuiltInMethod { Function1.class, Function1.class, Function2.class, EqualityComparer.class, boolean.class, boolean.class, Predicate2.class), + ASOF_JOIN(ExtendedEnumerable.class, "asofJoin", Enumerable.class, + Function1.class, // outer key selector + Function1.class, // inner key selector + Function2.class, // result selector + Predicate2.class, // timestamp comparator + Comparator.class, // match comparator + boolean.class), // generateNullsOnRight MATCH(Enumerables.class, "match", Enumerable.class, Function1.class, Matcher.class, Enumerables.Emitter.class, int.class, int.class), PATTERN_BUILDER(Utilities.class, "patternBuilder"), diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java index a15af886d6..a93bd94692 100644 --- a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java +++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java @@ -66,6 +66,56 @@ class EnumerableJoinTest { "deptno=40; name=HR"); } + @Test void asofJoinTest() { + tester(false, new HrSchema()) + .withRel( + // select d.deptno, e.empid from emps e left asof join depts d + // match_condition d.name <= e.name + // on d.deptno = e.deptno + builder -> builder + .scan("s", "depts").as("d") + .scan("s", "emps").as("e") + .asofJoin(JoinRelType.LEFT_ASOF, + builder.equals( + builder.field(2, "d", "deptno"), + builder.field(2, "e", "deptno")), + builder.lessThan( + builder.field(2, "d", "name"), + builder.field(2, "e", "name"))) + .project( + builder.field("deptno"), + builder.field("e", "name"), + builder.field("empid")) + .build()) + .returnsUnordered( + "deptno=10; name=Theodore; empid=110", + "deptno=30; name=null; empid=null", + "deptno=40; name=null; empid=null"); + + tester(false, new HrSchema()) + .withRel( + // select d.deptno, e.empid from emps e asof join depts d + // match_condition e.name <= d.name + // on d.deptno = e.deptno + builder -> builder + .scan("s", "depts").as("d") + .scan("s", "emps").as("e") + .asofJoin(JoinRelType.ASOF, + builder.equals( + builder.field(2, "d", "deptno"), + builder.field(2, "e", "deptno")), + builder.lessThan( + builder.field(2, "e", "name"), + builder.field(2, "d", "name"))) + .project( + builder.field("deptno"), + builder.field("e", "name"), + builder.field("empid")) + .build()) + .returnsUnordered( + "deptno=10; name=Bill; empid=100"); + } + /** Test case for * <a href="https://issues.apache.org/jira/browse/CALCITE-2968">[CALCITE-2968] * New AntiJoin relational expression</a>. */ diff --git a/core/src/test/resources/sql/asof.iq b/core/src/test/resources/sql/asof.iq new file mode 100644 index 0000000000..287681ab20 --- /dev/null +++ b/core/src/test/resources/sql/asof.iq @@ -0,0 +1,206 @@ +# asof.iq - ASOF Join query tests +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to you under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +!use post +!set outputformat mysql + +# These results have been validated against DuckDB +# Note that DuckDB has a slightly different syntax for ASOF joins, +# so the queries have to be rewritten. +# Also, DuckDB compares nulls in keys as larger than any value, +# so the results differ for tuples with null keys. We believe that +# the behavior of DuckDB is wrong, since logically the result of an +# ASOF JOIN should always be a subset of the result produced by +# the corresponding normal JOIN. + +SELECT * +FROM (VALUES (NULL, 0), (1, NULL), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (2, 3), (3, 4)) AS t1(k, t) +ASOF JOIN (VALUES (1, NULL), (1, 2), (1, 3), (2, 10), (2, 0)) AS t2(k, t) +MATCH_CONDITION t2.t < t1.t +ON t1.k = t2.k; ++---+---+----+----+ +| K | T | K0 | T0 | ++---+---+----+----+ +| 1 | 3 | 1 | 2 | +| 1 | 4 | 1 | 3 | +| 2 | 3 | 2 | 0 | ++---+---+----+----+ +(3 rows) + +!ok + +# Same test, no explicit table references + +SELECT * +FROM (VALUES (NULL, 0), (1, NULL), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (2, 3), (3, 4)) AS t1(k1, ts1) +ASOF JOIN (VALUES (1, NULL), (1, 2), (1, 3), (2, 10), (2, 0)) AS t2(k2, ts2) +MATCH_CONDITION ts2 < ts1 +ON k1 = k2; ++----+-----+----+-----+ +| K1 | TS1 | K2 | TS2 | ++----+-----+----+-----+ +| 1 | 3 | 1 | 2 | +| 1 | 4 | 1 | 3 | +| 2 | 3 | 2 | 0 | ++----+-----+----+-----+ +(3 rows) + +!ok + +SELECT * +FROM (VALUES (NULL, 0), (1, NULL), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (2, 3), (3, 4)) AS t1(k, t) +ASOF JOIN (VALUES (1, NULL), (1, 2), (1, 3), (2, 10), (2, 0)) AS t2(k, t) +MATCH_CONDITION t2.t > t1.t +ON t1.k = t2.k; ++---+---+----+----+ +| K | T | K0 | T0 | ++---+---+----+----+ +| 1 | 0 | 1 | 2 | +| 1 | 1 | 1 | 2 | +| 1 | 2 | 1 | 3 | +| 2 | 3 | 2 | 10 | ++---+---+----+----+ +(4 rows) + +!ok + +SELECT * +FROM (VALUES (NULL, 0), (1, NULL), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (2, 3), (3, 4)) AS t1(k, t) +ASOF JOIN (VALUES (1, NULL), (1, 2), (1, 3), (2, 10), (2, 0)) AS t2(k, t) +MATCH_CONDITION t2.t >= t1.t +ON t1.k = t2.k; ++---+---+----+----+ +| K | T | K0 | T0 | ++---+---+----+----+ +| 1 | 0 | 1 | 2 | +| 1 | 1 | 1 | 2 | +| 1 | 2 | 1 | 2 | +| 1 | 3 | 1 | 3 | +| 2 | 3 | 2 | 10 | ++---+---+----+----+ +(5 rows) + +!ok + +SELECT * +FROM (VALUES (NULL, 0), (1, NULL), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (2, 3), (3, 4)) AS t1(k, t) +ASOF JOIN (VALUES (1, NULL), (1, 2), (1, 3), (2, 10), (2, 0)) AS t2(k, t) +MATCH_CONDITION t2.t <= t1.t +ON t1.k = t2.k; ++---+---+----+----+ +| K | T | K0 | T0 | ++---+---+----+----+ +| 1 | 2 | 1 | 2 | +| 1 | 3 | 1 | 3 | +| 1 | 4 | 1 | 3 | +| 2 | 3 | 2 | 0 | ++---+---+----+----+ +(4 rows) + +!ok + +# Same tests with LEFT ASOF JOIN + +SELECT * +FROM (VALUES (NULL, 0), (1, NULL), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (2, 3), (3, 4)) AS t1(k, t) +LEFT ASOF JOIN (VALUES (1, NULL), (1, 2), (1, 3), (2, 10), (2, 0)) AS t2(k, t) +MATCH_CONDITION t2.t < t1.t +ON t1.k = t2.k; ++---+---+----+----+ +| K | T | K0 | T0 | ++---+---+----+----+ +| | 0 | | | +| 1 | | | | +| 1 | 0 | | | +| 1 | 1 | | | +| 1 | 2 | | | +| 1 | 3 | 1 | 2 | +| 1 | 4 | 1 | 3 | +| 2 | 3 | 2 | 0 | +| 3 | 4 | | | ++---+---+----+----+ +(9 rows) + +!ok + +SELECT * +FROM (VALUES (NULL, 0), (1, NULL), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (2, 3), (3, 4)) AS t1(k, t) +LEFT ASOF JOIN (VALUES (1, NULL), (1, 2), (1, 3), (2, 10), (2, 0)) AS t2(k, t) +MATCH_CONDITION t2.t > t1.t +ON t1.k = t2.k; ++---+---+----+----+ +| K | T | K0 | T0 | ++---+---+----+----+ +| | 0 | | | +| 1 | | | | +| 1 | 0 | 1 | 2 | +| 1 | 1 | 1 | 2 | +| 1 | 2 | 1 | 3 | +| 1 | 3 | | | +| 1 | 4 | | | +| 2 | 3 | 2 | 10 | +| 3 | 4 | | | ++---+---+----+----+ +(9 rows) + +!ok + +SELECT * +FROM (VALUES (NULL, 0), (1, NULL), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (2, 3), (3, 4)) AS t1(k, t) +LEFT ASOF JOIN (VALUES (1, NULL), (1, 2), (1, 3), (2, 10), (2, 0)) AS t2(k, t) +MATCH_CONDITION t2.t >= t1.t +ON t1.k = t2.k; ++---+---+----+----+ +| K | T | K0 | T0 | ++---+---+----+----+ +| | 0 | | | +| 1 | | | | +| 1 | 0 | 1 | 2 | +| 1 | 1 | 1 | 2 | +| 1 | 2 | 1 | 2 | +| 1 | 3 | 1 | 3 | +| 1 | 4 | | | +| 2 | 3 | 2 | 10 | +| 3 | 4 | | | ++---+---+----+----+ +(9 rows) + +!ok + +SELECT * +FROM (VALUES (NULL, 0), (1, NULL), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (2, 3), (3, 4)) AS t1(k, t) +LEFT ASOF JOIN (VALUES (1, NULL), (1, 2), (1, 3), (2, 10), (2, 0)) AS t2(k, t) +MATCH_CONDITION t2.t <= t1.t +ON t1.k = t2.k; ++---+---+----+----+ +| K | T | K0 | T0 | ++---+---+----+----+ +| | 0 | | | +| 1 | | | | +| 1 | 0 | | | +| 1 | 1 | | | +| 1 | 2 | 1 | 2 | +| 1 | 3 | 1 | 3 | +| 1 | 4 | 1 | 3 | +| 2 | 3 | 2 | 0 | +| 3 | 4 | | | ++---+---+----+----+ +(9 rows) + +!ok + +# End asof.iq diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java index 50b14b9ca9..3dd31fb502 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java @@ -393,6 +393,19 @@ public abstract class DefaultEnumerable<T> implements OrderedEnumerable<T> { innerKeySelector, resultSelector, comparer); } + @Override public <TInner, TKey, TResult> Enumerable<TResult> asofJoin( + Enumerable<TInner> inner, + Function1<T, TKey> outerKeySelector, + Function1<TInner, TKey> innerKeySelector, + Function2<T, @Nullable TInner, TResult> resultSelector, + Predicate2<T, TInner> matchComparator, + Comparator<TInner> timestampComparator, + boolean generateNullsOnRight) { + return EnumerableDefaults.asofJoin(getThis(), inner, outerKeySelector, + innerKeySelector, resultSelector, matchComparator, + timestampComparator, generateNullsOnRight); + } + @Override public <TInner, TKey, TResult> Enumerable<TResult> hashJoin( Enumerable<TInner> inner, Function1<T, TKey> outerKeySelector, Function1<TInner, TKey> innerKeySelector, diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java index 34819551d8..5afa372e27 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java @@ -858,6 +858,192 @@ public abstract class EnumerableDefaults { }; } + /** + * ASOF join implementation. For each row in the source enumerable produce exactly one + * result, similar to a LEFT join. The row on the right is the one that is the "largest" + * according to the timestampComparer that matches in key and satisfies the timestampComparator. + * + * @param outer Left input + * @param inner Right input + * @param outerKeySelector Selects a key from a left input record + * @param innerKeySelector Selects a key from the right input record + * @param resultSelector Produces the result from a pair (left, right) + * @param matchComparator Compares an element from the left input with one from the right + * input and returns 'true' if the timestamp are appropriate + * @param timestampComparator Compares two elements from the right input and returns + * true if the second is in the right order with respect + * to the ASOF comparison. + * @param emitNullsOnRight If true this is a left join. + */ + public static <TResult, TSource, TInner, TKey> Enumerable<TResult> asofJoin( + Enumerable<TSource> outer, Enumerable<TInner> inner, + Function1<TSource, TKey> outerKeySelector, + Function1<TInner, TKey> innerKeySelector, + Function2<TSource, @Nullable TInner, TResult> resultSelector, + Predicate2<TSource, TInner> matchComparator, + Comparator<TInner> timestampComparator, + boolean emitNullsOnRight) { + + // The basic algorithm is simple: + // - scan and index left collection by key + // - for each left record keep the best right record, initialized to 'null' + // - scan the right collection and for each record + // - match it against all left collection records with the same key + // - if the timestamp is closer, update the right record + // - emit all items in the index + Map<TKey, List<TSource>> leftIndex = new HashMap<>(); + // For each left element the corresponding best right element + Map<TKey, List<@Nullable TInner>> rightIndex = new HashMap<>(); + // Outer elements that have null keys. Will remain empty if !emitNullsOnRight. + List<TSource> outerWithNullKeys = new ArrayList<>(); + try (Enumerator<TSource> os = outer.enumerator()) { + while (os.moveNext()) { + TSource l = os.current(); + TKey key = outerKeySelector.apply(l); + if (key == null) { + // key contains null fields (result of key selector is null) + if (emitNullsOnRight) { + outerWithNullKeys.add(l); + } + } else { + List<TSource> left; + List<@Nullable TInner> right; + if (!leftIndex.containsKey(key)) { + left = new ArrayList<>(); + right = new ArrayList<>(); + leftIndex.put(key, left); + rightIndex.put(key, right); + } else { + left = leftIndex.get(key); + right = rightIndex.get(key); + } + left.add(l); + Objects.requireNonNull(right, "right").add(null); + } + } + } + // Scan right collection + try (Enumerator<TInner> is = inner.enumerator()) { + while (is.moveNext()) { + TInner r = is.current(); + TKey key = innerKeySelector.apply(r); + if (key == null) { + // key contains null fields (result of key selector is null) + continue; + } + List<TSource> left = leftIndex.get(key); + if (left == null) { + continue; + } + assert !left.isEmpty(); + List<@Nullable TInner> best = Objects.requireNonNull(rightIndex.get(key)); + assert left.size() == best.size(); + for (int i = 0; i < left.size(); i++) { + TSource leftElement = left.get(i); + boolean matches = matchComparator.apply(leftElement, r); + if (!matches) { + continue; + } + @Nullable TInner bestElement = best.get(i); + if (bestElement == null) { + best.set(i, r); + } else { + boolean isCloser = timestampComparator.compare(bestElement, r) < 0; + if (isCloser) { + best.set(i, r); + } + } + } + } + } + + return new AbstractEnumerable<TResult>() { + @Override public Enumerator<TResult> enumerator() { + return new Enumerator<TResult>() { + final Enumerator<Map.Entry<TKey, List<TSource>>> enumerator = + new Linq4j.IterableEnumerator<>(leftIndex.entrySet()); + + boolean emittingNullKeys = false; // True when we emit the records with null keys + @Nullable Enumerator<TSource> left = null; // Iterates over values with same key + @Nullable Enumerator<@Nullable TInner> right = null; + final Enumerator<TSource> leftNullKeys = // not used for inner ASOF joins + new Linq4j.IterableEnumerator<>(outerWithNullKeys); + + // This is a small state machine + // if (emittingNullKeys) { + // we are iterating over 'outerWithNullKeys' using 'leftNullKeys' + // } else { + // we are iterating over the 'leftIndex' using 'enumerator' + // for each value of the key we iterate advancing + // concurrently using 'left' and 'right' + // when finished set emittingNullKeys = true + // } + + @Override public TResult current() { + if (emittingNullKeys) { + TSource l = leftNullKeys.current(); + return resultSelector.apply(l, null); + } + + TSource l = Objects.requireNonNull(left, "left").current(); + @Nullable TInner r = Objects.requireNonNull(right, "right").current(); + return resultSelector.apply(l, r); + } + + @Override public boolean moveNext() { + while (true) { + boolean hasNext = false; + if (emittingNullKeys) { + return leftNullKeys.moveNext(); + } else { + if (left != null) { + // Advance left, right + hasNext = left.moveNext(); + boolean rightHasNext = Objects.requireNonNull(right, "right").moveNext(); + assert hasNext == rightHasNext; + } + if (hasNext) { + if (!emitNullsOnRight) { + @Nullable TInner r = Objects.requireNonNull(right, "right").current(); + if (r == null) { + continue; + } + } + return true; + } + // Advance enumerator + hasNext = enumerator.moveNext(); + if (hasNext) { + Map.Entry<TKey, List<TSource>> current = enumerator.current(); + TKey key = current.getKey(); + List<TSource> value = current.getValue(); + left = new Linq4j.IterableEnumerator<>(value); + List<@Nullable TInner> rightList = Objects.requireNonNull(rightIndex.get(key)); + right = new Linq4j.IterableEnumerator<>(rightList); + } else { + // Done with the data, start emitting records with null keys + emittingNullKeys = true; + } + } + } + } + + @Override public void reset() { + enumerator.reset(); + left = null; + right = null; + } + + @Override public void close() { + enumerator.close(); + left = null; + right = null; + } + }; + } + }; + } + /** Enumerator that evaluates aggregate functions over an input that is sorted * by the group key. * @@ -4778,5 +4964,4 @@ public abstract class EnumerableDefaults { } }; } - } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java index a375ff250b..b1c271ccea 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java @@ -567,6 +567,28 @@ public interface ExtendedEnumerable<TSource> { Function2<TSource, TInner, TResult> resultSelector, EqualityComparer<TKey> comparer); + /** + * Correlates elements of two sequences based on + * - matching keys + * - a comparator for timestamps. + * + * @param inner Inner sequence + * @param outerKeySelector Function that extracts a key from the outer collection + * @param innerKeySelector Function that extracts a key from the inner collection + * @param resultSelector Function that computes the join result + * @param matchComparator Function that compares an outer row and an inner row for timestamp + * @param timestampComparator Function that compares two inner rows for timestamp + * @param generateNullsOnRight If true, this a left join + */ + <TInner, TKey, TResult> Enumerable<TResult> asofJoin( + Enumerable<TInner> inner, + Function1<TSource, TKey> outerKeySelector, + Function1<TInner, TKey> innerKeySelector, + Function2<TSource, @Nullable TInner, TResult> resultSelector, + Predicate2<TSource, TInner> matchComparator, + Comparator<TInner> timestampComparator, + boolean generateNullsOnRight); + /** * Correlates the elements of two sequences based on matching keys, with * optional outer join semantics. A specified diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/JoinType.java b/linq4j/src/main/java/org/apache/calcite/linq4j/JoinType.java index f694dd78cb..4b38f145ea 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/JoinType.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/JoinType.java @@ -66,7 +66,28 @@ public enum JoinType { * WHERE DEPT.DEPTNO = EMP.DEPTNO)</pre> * </blockquote> */ - ANTI; + ANTI, + + /** + * An ASOF JOIN operation combines rows from two tables based on comparable timestamp values. + * For each row in the left table, the join finds at most one row in the right table that has the + * "closest" timestamp value. The matched row on the right side is the closest match, + * which could less than or equal or greater than or equal in the timestamp column, + * as specified by the comparison operator. + * + * <p>Example: + * <blockquote><pre> + * FROM left_table ASOF JOIN right_table + * MATCH_CONDITION ( left_table.timecol ≤ right_table.timecol ) + * ON left_table.col = right_table.col</pre> + * </blockquote> + */ + ASOF, + + /** + * The left version of an ASOF join, where each row from the left table is part of the output. + */ + LEFT_ASOF; /** * Returns whether a join of this type may generate NULL values on the diff --git a/linq4j/src/test/java/org/apache/calcite/linq4j/test/Linq4jTest.java b/linq4j/src/test/java/org/apache/calcite/linq4j/test/Linq4jTest.java index db4dce7f35..57eccf23a0 100644 --- a/linq4j/src/test/java/org/apache/calcite/linq4j/test/Linq4jTest.java +++ b/linq4j/src/test/java/org/apache/calcite/linq4j/test/Linq4jTest.java @@ -48,6 +48,7 @@ import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -1408,6 +1409,23 @@ public class Linq4jTest { assertEquals(depts[1], deptList.get(1)); } + @Test void testAsofJoin() { + // TODO: improve this test + Enumerable<Employee> employees = Linq4j.asEnumerable(emps); + Enumerable<Department> departments = Linq4j.asEnumerable(depts); + employees.iterator().forEachRemaining(System.out::println); + departments.iterator().forEachRemaining(System.out::println); + Enumerable<String> result = + employees.asofJoin(departments, // inner + e -> e.deptno, // outerKeySelector + d -> d.deptno, // innerKeySelector + (e, d) -> e.name + ":" + (d != null ? d.name : "null"), // resultSelector + (e, d) -> e.name.charAt(1) <= d.name.charAt(1), // matchComparator + Comparator.comparing(d0 -> d0.name), // timestampComparator + true); + result.iterator().forEachRemaining(System.out::println); + } + @Test void testTakeWhileNNoMatch() { final Queryable<Department> queryableDepts = Linq4j.asEnumerable(depts).asQueryable();
