This is an automated email from the ASF dual-hosted git repository. jhyde pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/calcite.git
commit 11116f6d03bca7cd97d151033c5d82f24394e229 Author: Juhwan Kim <[email protected]> AuthorDate: Wed Aug 14 22:47:07 2019 -0700 [CALCITE-3111] Add RelBuilder.correlate method, and allow custom implementations of Correlate in RelDecorrelator (Juhwan Kim) In RelDecorrelator, refactor all Logical rels into corresponding abstract rels, and use given RelBuilder when creating a new rel. Add abstract rel visitors in CorelMapBuilder, and change access levels to allow extending RelDecorrelator. Close apache/calcite#1334 --- .../java/org/apache/calcite/plan/RelOptUtil.java | 28 ++ .../apache/calcite/sql2rel/RelDecorrelator.java | 340 ++++++++++----------- .../java/org/apache/calcite/tools/RelBuilder.java | 36 +++ .../org/apache/calcite/test/RelBuilderTest.java | 45 +++ .../org/apache/calcite/test/RelOptRulesTest.java | 43 +++ .../org/apache/calcite/test/SqlToRelTestBase.java | 32 ++ site/_docs/algebra.md | 9 +- site/_docs/history.md | 4 +- 8 files changed, 358 insertions(+), 179 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java index 4645e9c..bccb0f5 100644 --- a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java +++ b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java @@ -1019,6 +1019,18 @@ public abstract class RelOptUtil { List<RexNode> joinKeys, List<RexNode> correlatedJoinKeys, boolean extractCorrelatedFieldAccess) { + return splitCorrelatedFilterCondition( + (Filter) filter, + joinKeys, + correlatedJoinKeys, + extractCorrelatedFieldAccess); + } + + public static RexNode splitCorrelatedFilterCondition( + Filter filter, + List<RexNode> joinKeys, + List<RexNode> correlatedJoinKeys, + boolean extractCorrelatedFieldAccess) { final List<RexNode> nonEquiList = new ArrayList<>(); splitCorrelatedFilterCondition( @@ -1371,6 +1383,22 @@ public abstract class RelOptUtil { List<RexNode> correlatedJoinKeys, List<RexNode> nonEquiList, boolean extractCorrelatedFieldAccess) { + splitCorrelatedFilterCondition( + (Filter) filter, + condition, + joinKeys, + correlatedJoinKeys, + nonEquiList, + extractCorrelatedFieldAccess); + } + + private static void splitCorrelatedFilterCondition( + Filter filter, + RexNode condition, + List<RexNode> joinKeys, + List<RexNode> correlatedJoinKeys, + List<RexNode> nonEquiList, + boolean extractCorrelatedFieldAccess) { if (condition instanceof RexCall) { RexCall call = (RexCall) condition; if (call.getOperator().getKind() == SqlKind.AND) { diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java index 4facf48..01ed669 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java @@ -29,13 +29,14 @@ import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepRelVertex; import org.apache.calcite.rel.BiRel; import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelHomogeneousShuttle; import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.RelShuttleImpl; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Correlate; import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.RelFactories; @@ -47,7 +48,6 @@ import org.apache.calcite.rel.logical.LogicalFilter; import org.apache.calcite.rel.logical.LogicalJoin; import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.logical.LogicalSnapshot; -import org.apache.calcite.rel.logical.LogicalSort; import org.apache.calcite.rel.metadata.RelMdUtil; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.rules.FilterCorrelateRule; @@ -113,6 +113,7 @@ import java.util.Objects; import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; +import java.util.stream.Collectors; import javax.annotation.Nonnull; /** @@ -142,7 +143,7 @@ public class RelDecorrelator implements ReflectiveVisitor { private final RelBuilder relBuilder; // map built during translation - private CorelMap cm; + protected CorelMap cm; private final ReflectUtil.MethodDispatcher<Frame> dispatcher = ReflectUtil.createMethodDispatcher(Frame.class, this, "decorrelateRel", @@ -158,11 +159,11 @@ public class RelDecorrelator implements ReflectiveVisitor { * positions. This is from the view point of the parent rel of a new rel. */ private final Map<RelNode, Frame> map = new HashMap<>(); - private final HashSet<LogicalCorrelate> generatedCorRels = new HashSet<>(); + private final HashSet<Correlate> generatedCorRels = new HashSet<>(); //~ Constructors ----------------------------------------------------------- - private RelDecorrelator( + protected RelDecorrelator( CorelMap cm, Context context, RelBuilder relBuilder) { @@ -188,7 +189,7 @@ public class RelDecorrelator implements ReflectiveVisitor { * @param relBuilder Builder for relational expressions * * @return Equivalent query with all - * {@link org.apache.calcite.rel.logical.LogicalCorrelate} instances removed + * {@link org.apache.calcite.rel.core.Correlate} instances removed */ public static RelNode decorrelateQuery(RelNode rootRel, RelBuilder relBuilder) { @@ -217,18 +218,18 @@ public class RelDecorrelator implements ReflectiveVisitor { return newRootRel; } - private void setCurrent(RelNode root, LogicalCorrelate corRel) { + private void setCurrent(RelNode root, Correlate corRel) { currentRel = corRel; if (corRel != null) { cm = new CorelMapBuilder().build(Util.first(root, corRel)); } } - private RelBuilderFactory relBuilderFactory() { + protected RelBuilderFactory relBuilderFactory() { return RelBuilder.proto(relBuilder); } - private RelNode decorrelate(RelNode root) { + protected RelNode decorrelate(RelNode root) { // first adjust count() expression if any final RelBuilderFactory f = relBuilderFactory(); HepProgram program = HepProgram.builder() @@ -280,16 +281,16 @@ public class RelDecorrelator implements ReflectiveVisitor { cm.mapRefRelToCorRef.putAll(newNode, cm.mapRefRelToCorRef.get(oldNode)); } - if (oldNode instanceof LogicalCorrelate - && newNode instanceof LogicalCorrelate) { - LogicalCorrelate oldCor = (LogicalCorrelate) oldNode; + if (oldNode instanceof Correlate + && newNode instanceof Correlate) { + Correlate oldCor = (Correlate) oldNode; CorrelationId c = oldCor.getCorrelationId(); if (cm.mapCorToCorRel.get(c) == oldNode) { cm.mapCorToCorRel.put(c, newNode); } if (generatedCorRels.contains(oldNode)) { - generatedCorRels.add((LogicalCorrelate) newNode); + generatedCorRels.add((Correlate) newNode); } } return null; @@ -387,11 +388,6 @@ public class RelDecorrelator implements ReflectiveVisitor { ImmutableSortedMap.of()); } - /** - * Rewrite Sort. - * - * @param rel Sort to be rewritten - */ public Frame decorrelateRel(Sort rel) { // // Rewrite logic: @@ -424,29 +420,28 @@ public class RelDecorrelator implements ReflectiveVisitor { RelCollation oldCollation = rel.getCollation(); RelCollation newCollation = RexUtil.apply(mapping, oldCollation); - final Sort newSort = - LogicalSort.create(newInput, newCollation, rel.offset, rel.fetch); + final int offset = rel.offset == null ? 0 : RexLiteral.intValue(rel.offset); + final int fetch = rel.fetch == null ? 0 : RexLiteral.intValue(rel.fetch); + + final RelNode newSort = relBuilder + .push(newInput) + .sortLimit(offset, fetch, relBuilder.fields(newCollation)) + .build(); // Sort does not change input ordering return register(rel, newSort, frame.oldToNewOutputs, frame.corDefOutputs); } - /** - * Rewrites a {@link Values}. - * - * @param rel Values to be rewritten - */ public Frame decorrelateRel(Values rel) { // There are no inputs, so rel does not need to be changed. return null; } - /** - * Rewrites a {@link LogicalAggregate}. - * - * @param rel Aggregate to rewrite - */ public Frame decorrelateRel(LogicalAggregate rel) { + return decorrelateRel((Aggregate) rel); + } + + public Frame decorrelateRel(Aggregate rel) { // // Rewrite logic: // @@ -604,8 +599,8 @@ public class RelDecorrelator implements ReflectiveVisitor { newInputOutputFieldCount + i); } - relBuilder.push( - LogicalAggregate.create(newProject, newGroupSet, newGroupSets, newAggCalls)); + relBuilder.push(newProject).aggregate( + relBuilder.groupKey(newGroupSet, newGroupSets), newAggCalls); if (!omittedConstants.isEmpty()) { final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields()); @@ -665,12 +660,11 @@ public class RelDecorrelator implements ReflectiveVisitor { return null; } - /** - * Rewrite LogicalProject. - * - * @param rel the project rel to rewrite - */ public Frame decorrelateRel(LogicalProject rel) { + return decorrelateRel((Project) rel); + } + + public Frame decorrelateRel(Project rel) { // // Rewrite logic: // @@ -787,10 +781,11 @@ public class RelDecorrelator implements ReflectiveVisitor { assert newInput != null; if (!joinedInputs.contains(newInput)) { - RelNode project = - RelOptUtil.createProject(newInput, - mapNewInputToOutputs.get(newInput)); - RelNode distinct = relBuilder.push(project) + final List<Integer> positions = mapNewInputToOutputs.get(newInput); + final List<String> fieldNames = newInput.getRowType().getFieldNames(); + + RelNode distinct = relBuilder.push(newInput) + .project(relBuilder.fields(positions)) .distinct() .build(); RelOptCluster cluster = distinct.getCluster(); @@ -802,10 +797,8 @@ public class RelDecorrelator implements ReflectiveVisitor { if (r == null) { r = distinct; } else { - r = - LogicalJoin.create(r, distinct, - cluster.getRexBuilder().makeLiteral(true), - ImmutableSet.of(), JoinRelType.INNER); + r = relBuilder.push(r).push(distinct) + .join(JoinRelType.INNER, cluster.getRexBuilder().makeLiteral(true)).build(); } } } @@ -950,9 +943,9 @@ public class RelDecorrelator implements ReflectiveVisitor { RelNode valueGen = createValueGenerator(corVarList, leftInputOutputCount, corDefOutputs); - RelNode join = - LogicalJoin.create(frame.r, valueGen, relBuilder.literal(true), - ImmutableSet.of(), JoinRelType.INNER); + RelNode join = relBuilder.push(frame.r).push(valueGen) + .join(JoinRelType.INNER, relBuilder.literal(true), + ImmutableSet.of()).build(); // Join or Filter does not change the old input ordering. All // input fields from newLeftInput (i.e. the original input to the old @@ -1017,11 +1010,6 @@ public class RelDecorrelator implements ReflectiveVisitor { && type.getPrecision() >= type1.getPrecision(); } - /** - * Rewrite LogicalSnapshot. - * - * @param rel the snapshot rel to rewrite - */ public Frame decorrelateRel(LogicalSnapshot rel) { if (RexUtil.containsCorrelation(rel.getPeriod())) { return null; @@ -1029,12 +1017,11 @@ public class RelDecorrelator implements ReflectiveVisitor { return decorrelateRel((RelNode) rel); } - /** - * Rewrite LogicalFilter. - * - * @param rel the filter rel to rewrite - */ public Frame decorrelateRel(LogicalFilter rel) { + return decorrelateRel((Filter) rel); + } + + public Frame decorrelateRel(Filter rel) { // // Rewrite logic: // @@ -1083,12 +1070,11 @@ public class RelDecorrelator implements ReflectiveVisitor { frame.corDefOutputs); } - /** - * Rewrite Correlate into a left outer join. - * - * @param rel Correlator - */ public Frame decorrelateRel(LogicalCorrelate rel) { + return decorrelateRel((Correlate) rel); + } + + public Frame decorrelateRel(Correlate rel) { // // Rewrite logic: // @@ -1181,26 +1167,23 @@ public class RelDecorrelator implements ReflectiveVisitor { final RexNode condition = RexUtil.composeConjunction(relBuilder.getRexBuilder(), conditions); - RelNode newJoin = - LogicalJoin.create(leftFrame.r, rightFrame.r, condition, - ImmutableSet.of(), rel.getJoinType()); + RelNode newJoin = relBuilder.push(leftFrame.r).push(rightFrame.r) + .join(rel.getJoinType(), condition, ImmutableSet.of()).build(); return register(rel, newJoin, mapOldToNewOutputs, corDefOutputs); } - /** - * Rewrite LogicalJoin. - * - * @param rel Join - */ public Frame decorrelateRel(LogicalJoin rel) { + return decorrelateRel((Join) rel); + } + + public Frame decorrelateRel(Join rel) { // For SEMI/ANTI join decorrelate it's input directly, // because the correlate variables can only be propagated from // the left side, which is not supported yet. if (!rel.getJoinType().projectsRight()) { return decorrelateRel((RelNode) rel); } - // // Rewrite logic: // @@ -1219,10 +1202,12 @@ public class RelDecorrelator implements ReflectiveVisitor { return null; } - final RelNode newJoin = - LogicalJoin.create(leftFrame.r, rightFrame.r, - decorrelateExpr(currentRel, map, cm, rel.getCondition()), - ImmutableSet.of(), rel.getJoinType()); + final RelNode newJoin = relBuilder + .push(leftFrame.r) + .push(rightFrame.r) + .join(rel.getJoinType(), decorrelateExpr(currentRel, map, cm, rel.getCondition()), + ImmutableSet.of()) + .build(); // Create the mapping between the output of the old correlation rel // and the new join rel @@ -1311,8 +1296,8 @@ public class RelDecorrelator implements ReflectiveVisitor { * @return the subtree with the new Project at the root */ private RelNode projectJoinOutputWithNullability( - LogicalJoin join, - LogicalProject project, + Join join, + Project project, int nullIndicatorPos) { final RelDataTypeFactory typeFactory = join.getCluster().getTypeFactory(); final RelNode left = join.getLeft(); @@ -1371,7 +1356,7 @@ public class RelDecorrelator implements ReflectiveVisitor { */ private RelNode aggregateCorrelatorOutput( Correlate correlate, - LogicalProject project, + Project project, Set<Integer> isCount) { final RelNode left = correlate.getLeft(); final JoinRelType joinType = correlate.getJoinType(); @@ -1419,9 +1404,9 @@ public class RelDecorrelator implements ReflectiveVisitor { * @return true if filter and proj only references corVar provided by corRel */ private boolean checkCorVars( - LogicalCorrelate correlate, - LogicalProject project, - LogicalFilter filter, + Correlate correlate, + Project project, + Filter filter, List<RexFieldAccess> correlatedJoinKeys) { if (filter != null) { assert correlatedJoinKeys != null; @@ -1469,7 +1454,7 @@ public class RelDecorrelator implements ReflectiveVisitor { * * @param correlate Correlate */ - private void removeCorVarFromTree(LogicalCorrelate correlate) { + private void removeCorVarFromTree(Correlate correlate) { if (cm.mapCorToCorRel.get(correlate.getCorrelationId()) == correlate) { cm.mapCorToCorRel.remove(correlate.getCorrelationId()); } @@ -1680,14 +1665,14 @@ public class RelDecorrelator implements ReflectiveVisitor { } @Override public RexNode visitInputRef(RexInputRef inputRef) { - if (currentRel instanceof LogicalCorrelate) { + if (currentRel instanceof Correlate) { // if this rel references corVar // and now it needs to be rewritten // it must have been pulled above the Correlate // replace the input ref to account for the LHS of the // Correlate final int leftInputFieldCount = - ((LogicalCorrelate) currentRel).getLeft().getRowType() + ((Correlate) currentRel).getLeft().getRowType() .getFieldCount(); RelDataType newType = inputRef.getType(); @@ -1787,17 +1772,17 @@ public class RelDecorrelator implements ReflectiveVisitor { RemoveSingleAggregateRule(RelBuilderFactory relBuilderFactory) { super( operand( - LogicalAggregate.class, + Aggregate.class, operand( - LogicalProject.class, - operand(LogicalAggregate.class, any()))), + Project.class, + operand(Aggregate.class, any()))), relBuilderFactory, null); } public void onMatch(RelOptRuleCall call) { - LogicalAggregate singleAggregate = call.rel(0); - LogicalProject project = call.rel(1); - LogicalAggregate aggregate = call.rel(2); + Aggregate singleAggregate = call.rel(0); + Project project = call.rel(1); + Aggregate aggregate = call.rel(2); // check singleAggRel is single_value agg if ((!singleAggregate.getGroupSet().isEmpty()) @@ -1838,19 +1823,19 @@ public class RelDecorrelator implements ReflectiveVisitor { private final class RemoveCorrelationForScalarProjectRule extends RelOptRule { RemoveCorrelationForScalarProjectRule(RelBuilderFactory relBuilderFactory) { super( - operand(LogicalCorrelate.class, + operand(Correlate.class, operand(RelNode.class, any()), - operand(LogicalAggregate.class, - operand(LogicalProject.class, + operand(Aggregate.class, + operand(Project.class, operand(RelNode.class, any())))), relBuilderFactory, null); } public void onMatch(RelOptRuleCall call) { - final LogicalCorrelate correlate = call.rel(0); + final Correlate correlate = call.rel(0); final RelNode left = call.rel(1); - final LogicalAggregate aggregate = call.rel(2); - final LogicalProject project = call.rel(3); + final Aggregate aggregate = call.rel(2); + final Project project = call.rel(3); RelNode right = call.rel(4); final RelOptCluster cluster = correlate.getCluster(); @@ -1892,7 +1877,7 @@ public class RelDecorrelator implements ReflectiveVisitor { int nullIndicatorPos; - if ((right instanceof LogicalFilter) + if ((right instanceof Filter) && cm.mapRefRelToCorRef.containsKey(right)) { // rightInput has this shape: // @@ -1903,7 +1888,7 @@ public class RelDecorrelator implements ReflectiveVisitor { // reference, make sure the correlated keys in the filter // condition forms a unique key of the RHS. - LogicalFilter filter = (LogicalFilter) right; + Filter filter = (Filter) right; right = filter.getInput(); assert right instanceof HepRelVertex; @@ -2018,9 +2003,9 @@ public class RelDecorrelator implements ReflectiveVisitor { } // make the new join rel - LogicalJoin join = - LogicalJoin.create(left, right, joinCond, - ImmutableSet.of(), joinType); + Join join = + (Join) relBuilder.push(left).push(right) + .join(joinType, joinCond).build(); RelNode newProject = projectJoinOutputWithNullability(join, project, nullIndicatorPos); @@ -2036,21 +2021,21 @@ public class RelDecorrelator implements ReflectiveVisitor { extends RelOptRule { RemoveCorrelationForScalarAggregateRule(RelBuilderFactory relBuilderFactory) { super( - operand(LogicalCorrelate.class, + operand(Correlate.class, operand(RelNode.class, any()), - operand(LogicalProject.class, - operandJ(LogicalAggregate.class, null, Aggregate::isSimple, - operand(LogicalProject.class, + operand(Project.class, + operandJ(Aggregate.class, null, Aggregate::isSimple, + operand(Project.class, operand(RelNode.class, any()))))), relBuilderFactory, null); } public void onMatch(RelOptRuleCall call) { - final LogicalCorrelate correlate = call.rel(0); + final Correlate correlate = call.rel(0); final RelNode left = call.rel(1); - final LogicalProject aggOutputProject = call.rel(2); - final LogicalAggregate aggregate = call.rel(3); - final LogicalProject aggInputProject = call.rel(4); + final Project aggOutputProject = call.rel(2); + final Aggregate aggregate = call.rel(3); + final Project aggInputProject = call.rel(4); RelNode right = call.rel(5); final RelBuilder builder = call.builder(); final RexBuilder rexBuilder = builder.getRexBuilder(); @@ -2105,13 +2090,13 @@ public class RelDecorrelator implements ReflectiveVisitor { } } - if ((right instanceof LogicalFilter) + if ((right instanceof Filter) && cm.mapRefRelToCorRef.containsKey(right)) { // rightInput has this shape: // // Filter (references corVar) // filterInput - LogicalFilter filter = (LogicalFilter) right; + Filter filter = (Filter) right; right = filter.getInput(); assert right instanceof HepRelVertex; @@ -2299,9 +2284,9 @@ public class RelDecorrelator implements ReflectiveVisitor { Pair.of(rexBuilder.makeLiteral(true), "nullIndicator"))); - LogicalJoin join = - LogicalJoin.create(left, right, joinCond, - ImmutableSet.of(), joinType); + Join join = + (Join) relBuilder.push(left).push(right) + .join(joinType, joinCond, ImmutableSet.of()).build(); // To the consumer of joinOutputProjRel, nullIndicator is located // at the end @@ -2372,13 +2357,11 @@ public class RelDecorrelator implements ReflectiveVisitor { ImmutableBitSet groupSet = ImmutableBitSet.range(groupCount); - LogicalAggregate newAggregate = - LogicalAggregate.create(joinOutputProject, groupSet, null, - newAggCalls); + builder.push(joinOutputProject).aggregate(builder.groupKey(groupSet, null), newAggCalls); List<RexNode> newAggOutputProjectList = new ArrayList<>(); for (int i : groupSet) { newAggOutputProjectList.add( - rexBuilder.makeInputRef(newAggregate, i)); + rexBuilder.makeInputRef(builder.peek(), i)); } RexNode newAggOutputProjects = @@ -2390,8 +2373,7 @@ public class RelDecorrelator implements ReflectiveVisitor { true), newAggOutputProjects)); - builder.push(newAggregate) - .project(newAggOutputProjectList); + builder.project(newAggOutputProjectList); call.transformTo(builder.build()); removeCorVarFromTree(correlate); @@ -2414,22 +2396,22 @@ public class RelDecorrelator implements ReflectiveVisitor { RelBuilderFactory relBuilderFactory) { super( flavor - ? operand(LogicalCorrelate.class, + ? operand(Correlate.class, operand(RelNode.class, any()), - operand(LogicalProject.class, - operand(LogicalAggregate.class, any()))) - : operand(LogicalCorrelate.class, + operand(Project.class, + operand(Aggregate.class, any()))) + : operand(Correlate.class, operand(RelNode.class, any()), - operand(LogicalAggregate.class, any())), + operand(Aggregate.class, any())), relBuilderFactory, null); this.flavor = flavor; } public void onMatch(RelOptRuleCall call) { - final LogicalCorrelate correlate = call.rel(0); + final Correlate correlate = call.rel(0); final RelNode left = call.rel(1); - final LogicalProject aggOutputProject; - final LogicalAggregate aggregate; + final Project aggOutputProject; + final Aggregate aggregate; if (flavor) { aggOutputProject = call.rel(2); aggregate = call.rel(3); @@ -2446,17 +2428,17 @@ public class RelDecorrelator implements ReflectiveVisitor { final RelBuilder relBuilder = call.builder(); relBuilder.push(aggregate) .projectNamed(Pair.left(projects), Pair.right(projects), true); - aggOutputProject = (LogicalProject) relBuilder.build(); + aggOutputProject = (Project) relBuilder.build(); } onMatch2(call, correlate, left, aggOutputProject, aggregate); } private void onMatch2( RelOptRuleCall call, - LogicalCorrelate correlate, + Correlate correlate, RelNode leftInput, - LogicalProject aggOutputProject, - LogicalAggregate aggregate) { + Project aggOutputProject, + Aggregate aggregate) { if (generatedCorRels.contains(correlate)) { // This Correlate was generated by a previous invocation of // this rule. No further work to do. @@ -2514,10 +2496,15 @@ public class RelDecorrelator implements ReflectiveVisitor { // leftInput // Aggregate(groupby (0), agg0(), agg1()...) // - LogicalCorrelate newCorrelate = - LogicalCorrelate.create(leftInput, aggregate, - correlate.getCorrelationId(), correlate.getRequiredColumns(), - correlate.getJoinType()); + List<RexNode> requiredNodes = + correlate.getRequiredColumns().asList().stream() + .map(ord -> relBuilder.getRexBuilder().makeInputRef(correlate, ord)) + .collect(Collectors.toList()); + Correlate newCorrelate = (Correlate) relBuilder.push(leftInput) + .push(aggregate).correlate(correlate.getJoinType(), + correlate.getCorrelationId(), + requiredNodes).build(); + // remember this rel so we don't fire rule on it again // REVIEW jhyde 29-Oct-2007: rules should not save state; rule @@ -2625,7 +2612,7 @@ public class RelDecorrelator implements ReflectiveVisitor { } /** A map of the locations of - * {@link org.apache.calcite.rel.logical.LogicalCorrelate} + * {@link org.apache.calcite.rel.core.Correlate} * in a tree of {@link RelNode}s. * * <p>It is used to drive the decorrelation process. @@ -2645,7 +2632,7 @@ public class RelDecorrelator implements ReflectiveVisitor { * updated. * * </ol> */ - private static class CorelMap { + protected static class CorelMap { private final Multimap<RelNode, CorRef> mapRefRelToCorRef; private final SortedMap<CorrelationId, RelNode> mapCorToCorRel; private final Map<RexFieldAccess, CorRef> mapFieldAccessToCorRef; @@ -2689,6 +2676,10 @@ public class RelDecorrelator implements ReflectiveVisitor { mapFieldAccessToCorVar); } + public SortedMap<CorrelationId, RelNode> getMapCorToCorRel() { + return mapCorToCorRel; + } + /** * Returns whether there are any correlating variables in this statement. * @@ -2700,7 +2691,7 @@ public class RelDecorrelator implements ReflectiveVisitor { } /** Builds a {@link org.apache.calcite.sql2rel.RelDecorrelator.CorelMap}. */ - private static class CorelMapBuilder extends RelShuttleImpl { + public static class CorelMapBuilder extends RelHomogeneousShuttle { final SortedMap<CorrelationId, RelNode> mapCorToCorRel = new TreeMap<>(); @@ -2715,7 +2706,7 @@ public class RelDecorrelator implements ReflectiveVisitor { int corrIdGenerator = 0; /** Creates a CorelMap by iterating over a {@link RelNode} tree. */ - CorelMap build(RelNode... rels) { + public CorelMap build(RelNode... rels) { for (RelNode rel : rels) { stripHep(rel).accept(this); } @@ -2723,14 +2714,40 @@ public class RelDecorrelator implements ReflectiveVisitor { mapFieldAccessToCorVar); } - @Override public RelNode visit(LogicalJoin join) { - try { - stack.push(join); - join.getCondition().accept(rexVisitor(join)); - } finally { - stack.pop(); + @Override public RelNode visit(RelNode other) { + if (other instanceof Join) { + Join join = (Join) other; + try { + stack.push(join); + join.getCondition().accept(rexVisitor(join)); + } finally { + stack.pop(); + } + return visitJoin(join); + } else if (other instanceof Correlate) { + Correlate correlate = (Correlate) other; + mapCorToCorRel.put(correlate.getCorrelationId(), correlate); + return visitJoin(correlate); + } else if (other instanceof Filter) { + Filter filter = (Filter) other; + try { + stack.push(filter); + filter.getCondition().accept(rexVisitor(filter)); + } finally { + stack.pop(); + } + } else if (other instanceof Project) { + Project project = (Project) other; + try { + stack.push(project); + for (RexNode node : project.getProjects()) { + node.accept(rexVisitor(project)); + } + } finally { + stack.pop(); + } } - return visitJoin(join); + return super.visit(other); } @Override protected RelNode visitChild(RelNode parent, int i, @@ -2738,11 +2755,6 @@ public class RelDecorrelator implements ReflectiveVisitor { return super.visitChild(parent, i, stripHep(input)); } - @Override public RelNode visit(LogicalCorrelate correlate) { - mapCorToCorRel.put(correlate.getCorrelationId(), correlate); - return visitJoin(correlate); - } - private RelNode visitJoin(BiRel join) { final int x = offset.get(); visitChild(join, 0, join.getLeft()); @@ -2752,28 +2764,6 @@ public class RelDecorrelator implements ReflectiveVisitor { return join; } - @Override public RelNode visit(final LogicalFilter filter) { - try { - stack.push(filter); - filter.getCondition().accept(rexVisitor(filter)); - } finally { - stack.pop(); - } - return super.visit(filter); - } - - @Override public RelNode visit(LogicalProject project) { - try { - stack.push(project); - for (RexNode node : project.getProjects()) { - node.accept(rexVisitor(project)); - } - } finally { - stack.pop(); - } - return super.visit(project); - } - private RexVisitorImpl<Void> rexVisitor(final RelNode rel) { return new RexVisitorImpl<Void>(true) { @Override public Void visitFieldAccess(RexFieldAccess fieldAccess) { diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java index aa19724..c0fda03 100644 --- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java +++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java @@ -34,6 +34,7 @@ import org.apache.calcite.rel.RelHomogeneousShuttle; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Correlate; import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Intersect; @@ -1988,6 +1989,41 @@ public class RelBuilder { return this; } + /** Creates a {@link Correlate} + * with a {@link CorrelationId} and an array of fields that are used by correlation. */ + public RelBuilder correlate(JoinRelType joinType, + CorrelationId correlationId, RexNode... requiredFields) { + return correlate(joinType, correlationId, ImmutableList.copyOf(requiredFields)); + } + + /** Creates a {@link Correlate} + * with a {@link CorrelationId} and a list of fields that are used by correlation. */ + public RelBuilder correlate(JoinRelType joinType, + CorrelationId correlationId, Iterable<? extends RexNode> requiredFields) { + Frame right = stack.pop(); + + final Registrar registrar = + new Registrar(fields(), peek().getRowType().getFieldNames()); + + List<Integer> requiredOrdinals = + registrar.registerExpressions(ImmutableList.copyOf(requiredFields)); + + project(registrar.extraNodes); + rename(registrar.names); + Frame left = stack.pop(); + + final RelNode correlate = correlateFactory + .createCorrelate(left.rel, right.rel, correlationId, + ImmutableBitSet.of(requiredOrdinals), joinType); + + final ImmutableList.Builder<Field> fields = ImmutableList.builder(); + fields.addAll(left.fields); + fields.addAll(right.fields); + stack.push(new Frame(correlate, fields.build())); + + return this; + } + /** Creates a {@link Join} using USING syntax. * * <p>For each of the field names, both left and right inputs must have a diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java index 9e9095f..bcf5265 100644 --- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java @@ -2886,6 +2886,51 @@ public class RelBuilderTest { + " LogicalTableScan(table=[[scott, EMP]])\n"; assertThat(root, hasTree(expected)); } + + @Test public void testCorrelate() { + final RelBuilder builder = RelBuilder.create(config().build()); + final Holder<RexCorrelVariable> v = Holder.of(null); + RelNode root = builder.scan("EMP") + .variable(v) + .scan("DEPT") + .filter( + builder.equals(builder.field(0), + builder.field(v.get(), "DEPTNO"))) + .correlate(JoinRelType.LEFT, v.get().id, builder.field(2, 0, "DEPTNO")) + .build(); + + final String expected = "" + + "LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{7}])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($0, $cor0.DEPTNO)])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + assertThat(root, hasTree(expected)); + } + + @Test public void testCorrelateWithComplexFields() { + final RelBuilder builder = RelBuilder.create(config().build()); + final Holder<RexCorrelVariable> v = Holder.of(null); + RelNode root = builder.scan("EMP") + .variable(v) + .scan("DEPT") + .filter( + builder.equals(builder.field(0), + builder.field(v.get(), "DEPTNO"))) + .correlate(JoinRelType.LEFT, v.get().id, + builder.field(2, 0, "DEPTNO"), + builder.getRexBuilder().makeCall(SqlStdOperatorTable.AS, + builder.field(2, 0, "EMPNO"), + builder.literal("RENAMED_EMPNO"))) + .build(); + + final String expected = "" + + "LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0, 7}])\n" + + " LogicalProject(RENAMED_EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($0, $cor0.DEPTNO)])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + assertThat(root, hasTree(expected)); + } } // End RelBuilderTest.java diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index eb0effd..a003f7f 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -147,6 +147,7 @@ import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlMonotonicity; import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql2rel.RelDecorrelator; import org.apache.calcite.sql2rel.SqlToRelConverter; import org.apache.calcite.test.catalog.MockCatalogReader; import org.apache.calcite.tools.Program; @@ -3738,6 +3739,48 @@ public class RelOptRulesTest extends RelOptTestBase { .checkUnchanged(); } + /** Test case for + * <a href="https://issues.apache.org/jira/browse/CALCITE-3111">[CALCITE-3111] + * Allow custom implementations of Correlate in RelDecorrelator </a> + */ + @Test public void testCustomDecorrelate() { + final String sql = "SELECT e1.empno\n" + + "FROM emp e1, dept d1 where e1.deptno = d1.deptno\n" + + "and e1.deptno < 10 and d1.deptno < 15\n" + + "and e1.sal > (select avg(sal) from emp e2 where e1.empno = e2.empno)"; + + // Convert sql to rel + RelRoot root = tester.convertSqlToRel(sql); + + // Create a duplicate rel tree with a custom correlate instead of logical correlate + LogicalCorrelate logicalCorrelate = (LogicalCorrelate) root.rel.getInput(0).getInput(0); + CustomCorrelate customCorrelate = new CustomCorrelate( + logicalCorrelate.getCluster(), + logicalCorrelate.getTraitSet(), + logicalCorrelate.getLeft(), + logicalCorrelate.getRight(), + logicalCorrelate.getCorrelationId(), + logicalCorrelate.getRequiredColumns(), + logicalCorrelate.getJoinType()); + RelNode newRoot = root.rel.copy( + root.rel.getTraitSet(), + ImmutableList.of( + root.rel.getInput(0).copy( + root.rel.getInput(0).getTraitSet(), + ImmutableList.<RelNode>of(customCorrelate)))); + + // Decorrelate both trees using the same relBuilder + final RelBuilder relBuilder = RelBuilder.create(RelBuilderTest.config().build()); + RelNode logicalDecorrelated = RelDecorrelator.decorrelateQuery(root.rel, relBuilder); + RelNode customDecorrelated = RelDecorrelator.decorrelateQuery(newRoot, relBuilder); + String logicalDecorrelatedPlan = NL + RelOptUtil.toString(logicalDecorrelated); + String customDecorrelatedPlan = NL + RelOptUtil.toString(customDecorrelated); + + // Ensure that the plans are equal + getDiffRepos().assertEquals("Comparing Plans from LogicalCorrelate and CustomCorrelate", + logicalDecorrelatedPlan, customDecorrelatedPlan); + } + @Test public void testProjectWindowTransposeRule() { HepProgram program = new HepProgramBuilder() .addRuleInstance(ProjectToWindowRule.PROJECT) diff --git a/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java b/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java index 00c3a20..111f896 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java +++ b/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java @@ -26,6 +26,7 @@ import org.apache.calcite.plan.RelOptSchema; import org.apache.calcite.plan.RelOptSchemaWithSampling; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollations; @@ -35,6 +36,10 @@ import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelReferentialConstraint; import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.RelShuttle; +import org.apache.calcite.rel.core.Correlate; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.logical.LogicalTableScan; import org.apache.calcite.rel.type.RelDataType; @@ -883,6 +888,33 @@ public abstract class SqlToRelTestBase { } } } + + /** + * Custom implementation of Correlate for testing. + */ + public static class CustomCorrelate extends Correlate { + public CustomCorrelate( + RelOptCluster cluster, + RelTraitSet traits, + RelNode left, + RelNode right, + CorrelationId correlationId, + ImmutableBitSet requiredColumns, + JoinRelType joinType) { + super(cluster, traits, left, right, correlationId, requiredColumns, joinType); + } + + @Override public Correlate copy(RelTraitSet traitSet, + RelNode left, RelNode right, CorrelationId correlationId, + ImmutableBitSet requiredColumns, JoinRelType joinType) { + return new CustomCorrelate(getCluster(), traitSet, left, right, + correlationId, requiredColumns, joinType); + } + + @Override public RelNode accept(RelShuttle shuttle) { + return shuttle.visit(this); + } + } } // End SqlToRelTestBase.java diff --git a/site/_docs/algebra.md b/site/_docs/algebra.md index f8294d7..70301a1 100644 --- a/site/_docs/algebra.md +++ b/site/_docs/algebra.md @@ -321,6 +321,7 @@ return the `RelBuilder`. | `limit(offset, fetch)` | Creates a [Sort]({{ site.apiRoot }}/org/apache/calcite/rel/core/Sort.html) that does not sort, only applies with offset and limit. | `exchange(distribution)` | Creates an [Exchange]({{ site.apiRoot }}/org/apache/calcite/rel/core/Exchange.html). | `sortExchange(distribution, collation)` | Creates a [SortExchange]({{ site.apiRoot }}/org/apache/calcite/rel/core/SortExchange.html). +| `correlate(joinType, correlationId, requiredField...)`<br/>`correlate(joinType, correlationId, requiredFieldList)` | Creates a [Correlate]({{ site.apiRoot }}/org/apache/calcite/rel/core/Correlate.html) of the two most recent relational expressions, with a variable name and required field expressions for the left relation. | `join(joinType, expr...)`<br/>`join(joinType, exprList)`<br/>`join(joinType, fieldName...)` | Creates a [Join]({{ site.apiRoot }}/org/apache/calcite/rel/core/Join.html) of the two most recent relational expressions.<br/><br/>The first form joins on a boolean expression (multiple conditions are combined using AND).<br/><br/>The last form joins on named fields; each side must have a field of each name. | `semiJoin(expr)` | Creates a [Join]({{ site.apiRoot }}/org/apache/calcite/rel/core/Join.html) with SEMI join type of the two most recent relational expressions. | `antiJoin(expr)` | Creates a [Join]({{ site.apiRoot }}/org/apache/calcite/rel/core/Join.html) with ANTI join type of the two most recent relational expressions. @@ -334,8 +335,10 @@ return the `RelBuilder`. Argument types: * `expr`, `interval` [RexNode]({{ site.apiRoot }}/org/apache/calcite/rex/RexNode.html) -* `expr...` Array of [RexNode]({{ site.apiRoot }}/org/apache/calcite/rex/RexNode.html) -* `exprList`, `measureList`, `partitionKeys`, `orderKeys` Iterable of +* `expr...`, `requiredField...` Array of + [RexNode]({{ site.apiRoot }}/org/apache/calcite/rex/RexNode.html) +* `exprList`, `measureList`, `partitionKeys`, `orderKeys`, + `requiredFieldList` Iterable of [RexNode]({{ site.apiRoot }}/org/apache/calcite/rex/RexNode.html) * `fieldOrdinal` Ordinal of a field within its row (starting from 0) * `fieldName` Name of a field, unique within its row @@ -350,6 +353,7 @@ Argument types: * `tupleList` Iterable of List of [RexLiteral]({{ site.apiRoot }}/org/apache/calcite/rex/RexLiteral.html) * `all`, `distinct`, `strictStart`, `strictEnd`, `allRows` boolean * `alias` String +* `correlationId` [CorrelationId]({{ site.apiRoot }}/org/apache/calcite/rel/core/CorrelationId.html) * `variablesSet` Iterable of [CorrelationId]({{ site.apiRoot }}/org/apache/calcite/rel/core/CorrelationId.html) * `varHolder` [Holder]({{ site.apiRoot }}/org/apache/calcite/util/Holder.html) of [RexCorrelVariable]({{ site.apiRoot }}/org/apache/calcite/rex/RexCorrelVariable.html) @@ -358,6 +362,7 @@ Argument types: * `distribution` [RelDistribution]({{ site.apiRoot }}/org/apache/calcite/rel/RelDistribution.html) * `collation` [RelCollation]({{ site.apiRoot }}/org/apache/calcite/rel/RelCollation.html) * `operator` [SqlOperator]({{ site.apiRoot }}/org/apache/calcite/sql/SqlOperator.html) +* `joinType` [JoinRelType]({{ site.apiRoot }}/org/apache/calcite/rel/core/JoinRelType.html) The builder methods perform various optimizations, including: diff --git a/site/_docs/history.md b/site/_docs/history.md index 09797b4..db5bc6f 100644 --- a/site/_docs/history.md +++ b/site/_docs/history.md @@ -32,8 +32,8 @@ Downloads are available on the #### Breaking Changes -* core parser config.fmpp#dataTypeParserMethods should return "SqlTypeNameSpec" -instead of "SqlIdentifier". +* core parser config.fmpp#dataTypeParserMethods should return `SqlTypeNameSpec` +instead of `SqlIdentifier`. ## <a href="https://github.com/apache/calcite/releases/tag/calcite-1.20.0">1.20.0</a> / 2019-06-24 {: #v1-20-0}
