http://git-wip-us.apache.org/repos/asf/calcite/blob/505a9064/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java index 2f1d6b9..2812851 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java @@ -31,12 +31,16 @@ import org.apache.calcite.rel.BiRel; import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelShuttleImpl; -import org.apache.calcite.rel.RelVisitor; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Correlate; import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rel.core.Values; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.logical.LogicalCorrelate; import org.apache.calcite.rel.logical.LogicalFilter; @@ -58,6 +62,7 @@ import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexSubQuery; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.rex.RexVisitorImpl; import org.apache.calcite.sql.SqlExplainLevel; @@ -67,20 +72,25 @@ import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlCountAggFunction; import org.apache.calcite.sql.fun.SqlSingleValueAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.Bug; import org.apache.calcite.util.Holder; import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.Litmus; import org.apache.calcite.util.Pair; import org.apache.calcite.util.ReflectUtil; -import org.apache.calcite.util.ReflectiveVisitDispatcher; import org.apache.calcite.util.ReflectiveVisitor; +import org.apache.calcite.util.Stacks; import org.apache.calcite.util.Util; import org.apache.calcite.util.mapping.Mappings; import org.apache.calcite.util.trace.CalciteTrace; +import com.google.common.base.Preconditions; import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableSortedMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Multimap; @@ -96,6 +106,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.NavigableMap; import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; @@ -127,10 +138,14 @@ public class RelDecorrelator implements ReflectiveVisitor { //~ Instance fields -------------------------------------------------------- + private final RelBuilder relBuilder; + // map built during translation private CorelMap cm; - private final DecorrelateRelVisitor decorrelateVisitor; + private final ReflectUtil.MethodDispatcher<Frame> dispatcher = + ReflectUtil.createMethodDispatcher(Frame.class, this, "decorrelateRel", + RelNode.class); private final RexBuilder rexBuilder; @@ -139,31 +154,24 @@ public class RelDecorrelator implements ReflectiveVisitor { private final Context context; - // maps built during decorrelation - private final Map<RelNode, RelNode> mapOldToNewRel = Maps.newHashMap(); - - // map rel to all the newly created correlated variables in its output - private final Map<RelNode, SortedMap<Correlation, Integer>> - mapNewRelToMapCorVarToOutputPos = Maps.newHashMap(); - - // another map to map old input positions to new input positions - // this is from the view point of the parent rel of a new rel. - private final Map<RelNode, Map<Integer, Integer>> - mapNewRelToMapOldToNewOutputPos = Maps.newHashMap(); + /** Built during decorrelation, of rel to all the newly created correlated + * variables in its output, and to map old input positions to new input + * positions. This is from the view point of the parent rel of a new rel. */ + private final Map<RelNode, Frame> map = new HashMap<>(); private final HashSet<LogicalCorrelate> generatedCorRels = Sets.newHashSet(); //~ Constructors ----------------------------------------------------------- private RelDecorrelator( - RexBuilder rexBuilder, + RelOptCluster cluster, CorelMap cm, Context context) { this.cm = cm; - this.rexBuilder = rexBuilder; + this.rexBuilder = cluster.getRexBuilder(); this.context = context; + relBuilder = RelFactories.LOGICAL_BUILDER.create(cluster, null); - decorrelateVisitor = new DecorrelateRelVisitor(); } //~ Methods ---------------------------------------------------------------- @@ -178,18 +186,16 @@ public class RelDecorrelator implements ReflectiveVisitor { * {@link org.apache.calcite.rel.logical.LogicalCorrelate} instances removed */ public static RelNode decorrelateQuery(RelNode rootRel) { - final CorelMap corelMap = CorelMap.build(rootRel); + final CorelMap corelMap = new CorelMapBuilder().build(rootRel); if (!corelMap.hasCorrelation()) { return rootRel; } final RelOptCluster cluster = rootRel.getCluster(); - final RexBuilder rexBuilder = cluster.getRexBuilder(); final RelDecorrelator decorrelator = - new RelDecorrelator(rexBuilder, corelMap, + new RelDecorrelator(cluster, corelMap, cluster.getPlanner().getContext()); - RelNode newRootRel = decorrelator.removeCorrelationViaRule(rootRel); if (SQL2REL_LOGGER.isLoggable(Level.FINE)) { @@ -211,7 +217,7 @@ public class RelDecorrelator implements ReflectiveVisitor { private void setCurrent(RelNode root, LogicalCorrelate corRel) { currentRel = corRel; if (corRel != null) { - cm = CorelMap.build(Util.first(root, corRel)); + cm = new CorelMapBuilder().build(Util.first(root, corRel)); } } @@ -231,13 +237,10 @@ public class RelDecorrelator implements ReflectiveVisitor { root = planner.findBestExp(); // Perform decorrelation. - mapOldToNewRel.clear(); - mapNewRelToMapCorVarToOutputPos.clear(); - mapNewRelToMapOldToNewOutputPos.clear(); - - decorrelateVisitor.visit(root, 0, null); + map.clear(); - if (mapOldToNewRel.containsKey(root)) { + final Frame frame = getInvoke(root, null); + if (frame != null) { // has been rewritten; apply rules post-decorrelation final HepProgram program2 = HepProgram.builder() .addRuleInstance(FilterJoinRule.FILTER_ON_JOIN) @@ -245,7 +248,7 @@ public class RelDecorrelator implements ReflectiveVisitor { .build(); final HepPlanner planner2 = createPlanner(program2); - final RelNode newRoot = mapOldToNewRel.get(root); + final RelNode newRoot = frame.r; planner2.setRoot(newRoot); return planner2.findBestExp(); } @@ -265,7 +268,7 @@ public class RelDecorrelator implements ReflectiveVisitor { LogicalCorrelate oldCor = (LogicalCorrelate) oldNode; CorrelationId c = oldCor.getCorrelationId(); if (cm.mapCorVarToCorRel.get(c) == oldNode) { - cm.mapCorVarToCorRel.put(c, (LogicalCorrelate) newNode); + cm.mapCorVarToCorRel.put(c, newNode); } if (generatedCorRels.contains(oldNode)) { @@ -298,9 +301,7 @@ public class RelDecorrelator implements ReflectiveVisitor { HepPlanner planner = createPlanner(program); planner.setRoot(root); - RelNode newRootRel = planner.findBestExp(); - - return newRootRel; + return planner.findBestExp(); } protected RexNode decorrelateExpr(RexNode exp) { @@ -312,9 +313,8 @@ public class RelDecorrelator implements ReflectiveVisitor { RexNode exp, boolean projectPulledAboveLeftCorrelator) { RemoveCorrelationRexShuttle shuttle = - new RemoveCorrelationRexShuttle( - rexBuilder, - projectPulledAboveLeftCorrelator); + new RemoveCorrelationRexShuttle(rexBuilder, + projectPulledAboveLeftCorrelator, null, ImmutableSet.<Integer>of()); return exp.accept(shuttle); } @@ -323,10 +323,9 @@ public class RelDecorrelator implements ReflectiveVisitor { boolean projectPulledAboveLeftCorrelator, RexInputRef nullIndicator) { RemoveCorrelationRexShuttle shuttle = - new RemoveCorrelationRexShuttle( - rexBuilder, - projectPulledAboveLeftCorrelator, - nullIndicator); + new RemoveCorrelationRexShuttle(rexBuilder, + projectPulledAboveLeftCorrelator, nullIndicator, + ImmutableSet.<Integer>of()); return exp.accept(shuttle); } @@ -335,30 +334,27 @@ public class RelDecorrelator implements ReflectiveVisitor { boolean projectPulledAboveLeftCorrelator, Set<Integer> isCount) { RemoveCorrelationRexShuttle shuttle = - new RemoveCorrelationRexShuttle( - rexBuilder, - projectPulledAboveLeftCorrelator, - isCount); + new RemoveCorrelationRexShuttle(rexBuilder, + projectPulledAboveLeftCorrelator, null, isCount); return exp.accept(shuttle); } - public void decorrelateRelGeneric(RelNode rel) { + /** Fallback if none of the other {@code decorrelateRel} methods match. */ + public Frame decorrelateRel(RelNode rel) { RelNode newRel = rel.copy(rel.getTraitSet(), rel.getInputs()); if (rel.getInputs().size() > 0) { List<RelNode> oldInputs = rel.getInputs(); List<RelNode> newInputs = Lists.newArrayList(); for (int i = 0; i < oldInputs.size(); ++i) { - RelNode newInputRel = mapOldToNewRel.get(oldInputs.get(i)); - if ((newInputRel == null) - || mapNewRelToMapCorVarToOutputPos.containsKey(newInputRel)) { - // if child is not rewritten, or if it produces correlated + final Frame frame = getInvoke(oldInputs.get(i), rel); + if (frame == null || !frame.corVarOutputPos.isEmpty()) { + // if input is not rewritten, or if it produces correlated // variables, terminate rewrite - return; - } else { - newInputs.add(newInputRel); - newRel.replaceInput(i, newInputRel); + return null; } + newInputs.add(frame.r); + newRel.replaceInput(i, frame.r); } if (!Util.equalShallow(oldInputs, newInputs)) { @@ -368,12 +364,8 @@ public class RelDecorrelator implements ReflectiveVisitor { // the output position should not change since there are no corVars // coming from below. - Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap(); - for (int i = 0; i < rel.getRowType().getFieldCount(); i++) { - mapOldToNewOutputPos.put(i, i); - } - mapOldToNewRel.put(rel, newRel); - mapNewRelToMapOldToNewOutputPos.put(newRel, mapOldToNewOutputPos); + return register(rel, newRel, identityMap(rel.getRowType().getFieldCount()), + ImmutableSortedMap.<Correlation, Integer>of()); } /** @@ -381,7 +373,7 @@ public class RelDecorrelator implements ReflectiveVisitor { * * @param rel Sort to be rewritten */ - public void decorrelateRel(Sort rel) { + public Frame decorrelateRel(Sort rel) { // // Rewrite logic: // @@ -397,33 +389,39 @@ public class RelDecorrelator implements ReflectiveVisitor { // Its output does not change the input ordering, so there's no // need to call propagateExpr. - RelNode oldChildRel = rel.getInput(); - - RelNode newChildRel = mapOldToNewRel.get(oldChildRel); - if (newChildRel == null) { - // If child has not been rewritten, do not rewrite this rel. - return; + final RelNode oldInput = rel.getInput(); + final Frame frame = getInvoke(oldInput, rel); + if (frame == null) { + // If input has not been rewritten, do not rewrite this rel. + return null; } + final RelNode newInput = frame.r; - Map<Integer, Integer> childMapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newChildRel); - assert childMapOldToNewOutputPos != null; Mappings.TargetMapping mapping = Mappings.target( - childMapOldToNewOutputPos, - oldChildRel.getRowType().getFieldCount(), - newChildRel.getRowType().getFieldCount()); + frame.oldToNewOutputPos, + oldInput.getRowType().getFieldCount(), + newInput.getRowType().getFieldCount()); RelCollation oldCollation = rel.getCollation(); RelCollation newCollation = RexUtil.apply(mapping, oldCollation); - final Sort newRel = - LogicalSort.create(newChildRel, newCollation, rel.offset, rel.fetch); - - mapOldToNewRel.put(rel, newRel); + final Sort newSort = + LogicalSort.create(newInput, newCollation, rel.offset, rel.fetch); // Sort does not change input ordering - mapNewRelToMapOldToNewOutputPos.put(newRel, childMapOldToNewOutputPos); + return register(rel, newSort, frame.oldToNewOutputPos, + frame.corVarOutputPos); + } + + /** + * Rewrites a {@link Values}. + * + * @param rel Values to be rewritten + */ + public Frame decorrelateRel(Values rel) { + // There are no inputs, so rel does not need to be changed. + return null; } /** @@ -431,7 +429,7 @@ public class RelDecorrelator implements ReflectiveVisitor { * * @param rel Aggregate to rewrite */ - public void decorrelateRel(LogicalAggregate rel) { + public Frame decorrelateRel(LogicalAggregate rel) { if (rel.getGroupType() != Aggregate.Group.SIMPLE) { throw new AssertionError(Bug.CALCITE_461_FIXED); } @@ -439,7 +437,7 @@ public class RelDecorrelator implements ReflectiveVisitor { // Rewrite logic: // // 1. Permute the group by keys to the front. - // 2. If the child of an aggregate produces correlated variables, + // 2. If the input of an aggregate produces correlated variables, // add them to the group list. // 3. Change aggCalls to reference the new project. // @@ -447,117 +445,107 @@ public class RelDecorrelator implements ReflectiveVisitor { // Aggregate itself should not reference cor vars. assert !cm.mapRefRelToCorVar.containsKey(rel); - RelNode oldChildRel = rel.getInput(); - - RelNode newChildRel = mapOldToNewRel.get(oldChildRel); - if (newChildRel == null) { - // If child has not been rewritten, do not rewrite this rel. - return; + final RelNode oldInput = rel.getInput(); + final Frame frame = getInvoke(oldInput, rel); + if (frame == null) { + // If input has not been rewritten, do not rewrite this rel. + return null; } + assert !frame.corVarOutputPos.isEmpty(); + final RelNode newInput = frame.r; - Map<Integer, Integer> childMapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newChildRel); - assert childMapOldToNewOutputPos != null; - - // map from newChildRel - Map<Integer, Integer> mapNewChildToProjOutputPos = Maps.newHashMap(); + // map from newInput + Map<Integer, Integer> mapNewInputToProjOutputPos = Maps.newHashMap(); final int oldGroupKeyCount = rel.getGroupSet().cardinality(); - // LogicalProject projects the original expressions, - // plus any correlated variables the child wants to pass along. + // Project projects the original expressions, + // plus any correlated variables the input wants to pass along. final List<Pair<RexNode, String>> projects = Lists.newArrayList(); - List<RelDataTypeField> newChildOutput = - newChildRel.getRowType().getFieldList(); + List<RelDataTypeField> newInputOutput = + newInput.getRowType().getFieldList(); - int newPos; + int newPos = 0; - // oldChildRel has the original group by keys in the front. - for (newPos = 0; newPos < oldGroupKeyCount; newPos++) { - int newChildPos = childMapOldToNewOutputPos.get(newPos); - projects.add(RexInputRef.of2(newChildPos, newChildOutput)); - mapNewChildToProjOutputPos.put(newChildPos, newPos); + // oldInput has the original group by keys in the front. + final NavigableMap<Integer, RexLiteral> omittedConstants = new TreeMap<>(); + for (int i = 0; i < oldGroupKeyCount; i++) { + final RexLiteral constant = projectedLiteral(newInput, i); + if (constant != null) { + // Exclude constants. Aggregate({true}) occurs because Aggregate({}) + // would generate 1 row even when applied to an empty table. + omittedConstants.put(i, constant); + continue; + } + int newInputPos = frame.oldToNewOutputPos.get(i); + projects.add(RexInputRef.of2(newInputPos, newInputOutput)); + mapNewInputToProjOutputPos.put(newInputPos, newPos); + newPos++; } - SortedMap<Correlation, Integer> mapCorVarToOutputPos = Maps.newTreeMap(); - - boolean produceCorVar = - mapNewRelToMapCorVarToOutputPos.containsKey(newChildRel); - if (produceCorVar) { - // If child produces correlated variables, move them to the front, - // right after any existing groupby fields. + final SortedMap<Correlation, Integer> mapCorVarToOutputPos = new TreeMap<>(); + if (!frame.corVarOutputPos.isEmpty()) { + // If input produces correlated variables, move them to the front, + // right after any existing GROUP BY fields. - SortedMap<Correlation, Integer> childMapCorVarToOutputPos = - mapNewRelToMapCorVarToOutputPos.get(newChildRel); - - // Now add the corVars from the child, starting from + // Now add the corVars from the input, starting from // position oldGroupKeyCount. - for (Correlation corVar - : childMapCorVarToOutputPos.keySet()) { - int newChildPos = childMapCorVarToOutputPos.get(corVar); - projects.add(RexInputRef.of2(newChildPos, newChildOutput)); + for (Map.Entry<Correlation, Integer> entry + : frame.corVarOutputPos.entrySet()) { + projects.add(RexInputRef.of2(entry.getValue(), newInputOutput)); - mapCorVarToOutputPos.put(corVar, newPos); - mapNewChildToProjOutputPos.put(newChildPos, newPos); + mapCorVarToOutputPos.put(entry.getKey(), newPos); + mapNewInputToProjOutputPos.put(entry.getValue(), newPos); newPos++; } } // add the remaining fields final int newGroupKeyCount = newPos; - for (int i = 0; i < newChildOutput.size(); i++) { - if (!mapNewChildToProjOutputPos.containsKey(i)) { - projects.add(RexInputRef.of2(i, newChildOutput)); - mapNewChildToProjOutputPos.put(i, newPos); + for (int i = 0; i < newInputOutput.size(); i++) { + if (!mapNewInputToProjOutputPos.containsKey(i)) { + projects.add(RexInputRef.of2(i, newInputOutput)); + mapNewInputToProjOutputPos.put(i, newPos); newPos++; } } - assert newPos == newChildOutput.size(); + assert newPos == newInputOutput.size(); - // This LogicalProject will be what the old child maps to, - // replacing any previous mapping from old child). - RelNode newProjectRel = - RelOptUtil.createProject(newChildRel, projects, false); + // This Project will be what the old input maps to, + // replacing any previous mapping from old input). + RelNode newProject = + RelOptUtil.createProject(newInput, projects, false); // update mappings: - // oldChildRel ----> newChildRel + // oldInput ----> newInput // - // newProjectRel - // | - // oldChildRel ----> newChildRel + // newProject + // | + // oldInput ----> newInput // // is transformed to // - // oldChildRel ----> newProjectRel - // | - // newChildRel + // oldInput ----> newProject + // | + // newInput Map<Integer, Integer> combinedMap = Maps.newHashMap(); - for (Integer oldChildPos : childMapOldToNewOutputPos.keySet()) { - combinedMap.put( - oldChildPos, - mapNewChildToProjOutputPos.get( - childMapOldToNewOutputPos.get(oldChildPos))); + for (Integer oldInputPos : frame.oldToNewOutputPos.keySet()) { + combinedMap.put(oldInputPos, + mapNewInputToProjOutputPos.get( + frame.oldToNewOutputPos.get(oldInputPos))); } - mapOldToNewRel.put(oldChildRel, newProjectRel); - mapNewRelToMapOldToNewOutputPos.put(newProjectRel, combinedMap); + register(oldInput, newProject, combinedMap, mapCorVarToOutputPos); - if (produceCorVar) { - mapNewRelToMapCorVarToOutputPos.put( - newProjectRel, - mapCorVarToOutputPos); - } - - // now it's time to rewrite LogicalAggregate + // now it's time to rewrite the Aggregate + final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount); List<AggregateCall> newAggCalls = Lists.newArrayList(); List<AggregateCall> oldAggCalls = rel.getAggCallList(); - // LogicalAggregate.Call oldAggCall; - int oldChildOutputFieldCount = oldChildRel.getRowType().getFieldCount(); - int newChildOutputFieldCount = - newProjectRel.getRowType().getFieldCount(); + int oldInputOutputFieldCount = rel.getGroupSet().cardinality(); + int newInputOutputFieldCount = newGroupSet.cardinality(); int i = -1; for (AggregateCall oldAggCall : oldAggCalls) { @@ -567,7 +555,7 @@ public class RelDecorrelator implements ReflectiveVisitor { List<Integer> aggArgs = Lists.newArrayList(); // Adjust the aggregator argument positions. - // Note aggregator does not change input ordering, so the child + // Note aggregator does not change input ordering, so the input // output position mapping can be used to derive the new positions // for the argument. for (int oldPos : oldAggArgs) { @@ -577,34 +565,57 @@ public class RelDecorrelator implements ReflectiveVisitor { : combinedMap.get(oldAggCall.filterArg); newAggCalls.add( - oldAggCall.adaptTo(newProjectRel, aggArgs, filterArg, + oldAggCall.adaptTo(newProject, aggArgs, filterArg, oldGroupKeyCount, newGroupKeyCount)); // The old to new output position mapping will be the same as that - // of newProjectRel, plus any aggregates that the oldAgg produces. + // of newProject, plus any aggregates that the oldAgg produces. combinedMap.put( - oldChildOutputFieldCount + i, - newChildOutputFieldCount + i); + oldInputOutputFieldCount + i, + newInputOutputFieldCount + i); } - LogicalAggregate newAggregate = - LogicalAggregate.create(newProjectRel, + relBuilder.push( + LogicalAggregate.create(newProject, false, - ImmutableBitSet.range(newGroupKeyCount), + newGroupSet, null, - newAggCalls); + newAggCalls)); + + if (!omittedConstants.isEmpty()) { + final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields()); + for (Map.Entry<Integer, RexLiteral> entry + : omittedConstants.descendingMap().entrySet()) { + postProjects.add(entry.getKey() + frame.corVarOutputPos.size(), + entry.getValue()); + } + relBuilder.project(postProjects); + } - mapOldToNewRel.put(rel, newAggregate); + // Aggregate does not change input ordering so corVars will be + // located at the same position as the input newProject. + return register(rel, relBuilder.build(), combinedMap, mapCorVarToOutputPos); + } - mapNewRelToMapOldToNewOutputPos.put(newAggregate, combinedMap); + public Frame getInvoke(RelNode r, RelNode parent) { + final Frame frame = dispatcher.invoke(r); + if (frame != null) { + map.put(r, frame); + } + currentRel = parent; + return frame; + } - if (produceCorVar) { - // LogicalAggregate does not change input ordering so corVars will be - // located at the same position as the input newProjectRel. - mapNewRelToMapCorVarToOutputPos.put( - newAggregate, - mapCorVarToOutputPos); + /** Returns a literal output field, or null if it is not literal. */ + private static RexLiteral projectedLiteral(RelNode rel, int i) { + if (rel instanceof Project) { + final Project project = (Project) rel; + final RexNode node = project.getProjects().get(i); + if (node instanceof RexLiteral) { + return (RexLiteral) node; + } } + return null; } /** @@ -612,34 +623,24 @@ public class RelDecorrelator implements ReflectiveVisitor { * * @param rel the project rel to rewrite */ - public void decorrelateRel(LogicalProject rel) { + public Frame decorrelateRel(LogicalProject rel) { // // Rewrite logic: // - // 1. Pass along any correlated variables coming from the child. + // 1. Pass along any correlated variables coming from the input. // - RelNode oldChildRel = rel.getInput(); - - RelNode newChildRel = mapOldToNewRel.get(oldChildRel); - if (newChildRel == null) { - // If child has not been rewritten, do not rewrite this rel. - return; + final RelNode oldInput = rel.getInput(); + Frame frame = getInvoke(oldInput, rel); + if (frame == null) { + // If input has not been rewritten, do not rewrite this rel. + return null; } - List<RexNode> oldProj = rel.getProjects(); - List<RelDataTypeField> relOutput = rel.getRowType().getFieldList(); - - Map<Integer, Integer> childMapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newChildRel); - assert childMapOldToNewOutputPos != null; - - Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap(); - - boolean produceCorVar = - mapNewRelToMapCorVarToOutputPos.containsKey(newChildRel); + final List<RexNode> oldProjects = rel.getProjects(); + final List<RelDataTypeField> relOutput = rel.getRowType().getFieldList(); // LogicalProject projects the original expressions, - // plus any correlated variables the child wants to pass along. + // plus any correlated variables the input wants to pass along. final List<Pair<RexNode, String>> projects = Lists.newArrayList(); // If this LogicalProject has correlated reference, create value generator @@ -647,55 +648,38 @@ public class RelDecorrelator implements ReflectiveVisitor { if (cm.mapRefRelToCorVar.containsKey(rel)) { decorrelateInputWithValueGenerator(rel); - // The old child should be mapped to the LogicalJoin created by + // The old input should be mapped to the LogicalJoin created by // rewriteInputWithValueGenerator(). - newChildRel = mapOldToNewRel.get(oldChildRel); - produceCorVar = true; + frame = map.get(oldInput); } // LogicalProject projects the original expressions + final Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap(); int newPos; - for (newPos = 0; newPos < oldProj.size(); newPos++) { + for (newPos = 0; newPos < oldProjects.size(); newPos++) { projects.add( newPos, Pair.of( - decorrelateExpr(oldProj.get(newPos)), + decorrelateExpr(oldProjects.get(newPos)), relOutput.get(newPos).getName())); mapOldToNewOutputPos.put(newPos, newPos); } - SortedMap<Correlation, Integer> mapCorVarToOutputPos = Maps.newTreeMap(); - - // Project any correlated variables the child wants to pass along. - if (produceCorVar) { - SortedMap<Correlation, Integer> childMapCorVarToOutputPos = - mapNewRelToMapCorVarToOutputPos.get(newChildRel); - - // propagate cor vars from the new child - List<RelDataTypeField> newChildOutput = - newChildRel.getRowType().getFieldList(); - for (Correlation corVar - : childMapCorVarToOutputPos.keySet()) { - int corVarPos = childMapCorVarToOutputPos.get(corVar); - projects.add(RexInputRef.of2(corVarPos, newChildOutput)); - mapCorVarToOutputPos.put(corVar, newPos); - newPos++; - } + // Project any correlated variables the input wants to pass along. + final SortedMap<Correlation, Integer> mapCorVarToOutputPos = new TreeMap<>(); + for (Map.Entry<Correlation, Integer> entry : frame.corVarOutputPos.entrySet()) { + projects.add( + RexInputRef.of2(entry.getValue(), + frame.r.getRowType().getFieldList())); + mapCorVarToOutputPos.put(entry.getKey(), newPos); + newPos++; } - RelNode newProjectRel = - RelOptUtil.createProject(newChildRel, projects, false); - - mapOldToNewRel.put(rel, newProjectRel); - mapNewRelToMapOldToNewOutputPos.put( - newProjectRel, - mapOldToNewOutputPos); + RelNode newProject = + RelOptUtil.createProject(frame.r, projects, false); - if (produceCorVar) { - mapNewRelToMapCorVarToOutputPos.put( - newProjectRel, - mapCorVarToOutputPos); - } + return register(rel, newProject, mapOldToNewOutputPos, + mapCorVarToOutputPos); } /** @@ -712,44 +696,37 @@ public class RelDecorrelator implements ReflectiveVisitor { Iterable<Correlation> correlations, int valueGenFieldOffset, SortedMap<Correlation, Integer> mapCorVarToOutputPos) { - RelNode resultRel = null; + final Map<RelNode, List<Integer>> mapNewInputToOutputPos = + new HashMap<>(); - Map<RelNode, List<Integer>> mapNewInputRelToOutputPos = Maps.newHashMap(); - - Map<RelNode, Integer> mapNewInputRelToNewOffset = Maps.newHashMap(); - - RelNode oldInputRel; - RelNode newInputRel; - List<Integer> newLocalOutputPosList; + final Map<RelNode, Integer> mapNewInputToNewOffset = new HashMap<>(); // inputRel provides the definition of a correlated variable. // Add to map all the referenced positions(relative to each input rel) for (Correlation corVar : correlations) { - int oldCorVarOffset = corVar.field; + final int oldCorVarOffset = corVar.field; - oldInputRel = cm.mapCorVarToCorRel.get(corVar.corr).getInput(0); - assert oldInputRel != null; - newInputRel = mapOldToNewRel.get(oldInputRel); - assert newInputRel != null; + final RelNode oldInput = getCorRel(corVar); + assert oldInput != null; + final Frame frame = map.get(oldInput); + assert frame != null; + final RelNode newInput = frame.r; - if (!mapNewInputRelToOutputPos.containsKey(newInputRel)) { + final List<Integer> newLocalOutputPosList; + if (!mapNewInputToOutputPos.containsKey(newInput)) { newLocalOutputPosList = Lists.newArrayList(); } else { newLocalOutputPosList = - mapNewInputRelToOutputPos.get(newInputRel); + mapNewInputToOutputPos.get(newInput); } - Map<Integer, Integer> mapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newInputRel); - assert mapOldToNewOutputPos != null; - - int newCorVarOffset = mapOldToNewOutputPos.get(oldCorVarOffset); + final int newCorVarOffset = frame.oldToNewOutputPos.get(oldCorVarOffset); // Add all unique positions referenced. if (!newLocalOutputPosList.contains(newCorVarOffset)) { newLocalOutputPosList.add(newCorVarOffset); } - mapNewInputRelToOutputPos.put(newInputRel, newLocalOutputPosList); + mapNewInputToOutputPos.put(newInput, newLocalOutputPosList); } int offset = 0; @@ -759,33 +736,34 @@ public class RelDecorrelator implements ReflectiveVisitor { // To make sure the plan does not change in terms of join order, // join these rels based on their occurrence in cor var list which // is sorted. - Set<RelNode> joinedInputRelSet = Sets.newHashSet(); + final Set<RelNode> joinedInputRelSet = Sets.newHashSet(); + RelNode r = null; for (Correlation corVar : correlations) { - oldInputRel = cm.mapCorVarToCorRel.get(corVar.corr).getInput(0); - assert oldInputRel != null; - newInputRel = mapOldToNewRel.get(oldInputRel); - assert newInputRel != null; + final RelNode oldInput = getCorRel(corVar); + assert oldInput != null; + final RelNode newInput = map.get(oldInput).r; + assert newInput != null; - if (!joinedInputRelSet.contains(newInputRel)) { - RelNode projectRel = + if (!joinedInputRelSet.contains(newInput)) { + RelNode project = RelOptUtil.createProject( - newInputRel, - mapNewInputRelToOutputPos.get(newInputRel)); - RelNode distinctRel = RelOptUtil.createDistinctRel(projectRel); - RelOptCluster cluster = distinctRel.getCluster(); + newInput, + mapNewInputToOutputPos.get(newInput)); + RelNode distinct = RelOptUtil.createDistinctRel(project); + RelOptCluster cluster = distinct.getCluster(); - joinedInputRelSet.add(newInputRel); - mapNewInputRelToNewOffset.put(newInputRel, offset); - offset += distinctRel.getRowType().getFieldCount(); + joinedInputRelSet.add(newInput); + mapNewInputToNewOffset.put(newInput, offset); + offset += distinct.getRowType().getFieldCount(); - if (resultRel == null) { - resultRel = distinctRel; + if (r == null) { + r = distinct; } else { - resultRel = - LogicalJoin.create(resultRel, distinctRel, + r = + LogicalJoin.create(r, distinct, cluster.getRexBuilder().makeLiteral(true), - JoinRelType.INNER, ImmutableSet.<String>of()); + ImmutableSet.<CorrelationId>of(), JoinRelType.INNER); } } } @@ -794,27 +772,26 @@ public class RelDecorrelator implements ReflectiveVisitor { // the join output, leaving room for valueGenFieldOffset because // valueGenerators are joined with the original left input of the rel // referencing correlated variables. - int newOutputPos; - int newLocalOutputPos; for (Correlation corVar : correlations) { - // The first child of a correlatorRel is always the rel defining + // The first input of a Correlator is always the rel defining // the correlated variables. - newInputRel = - mapOldToNewRel.get(cm.mapCorVarToCorRel.get(corVar.corr).getInput(0)); - newLocalOutputPosList = mapNewInputRelToOutputPos.get(newInputRel); + final RelNode oldInput = getCorRel(corVar); + assert oldInput != null; + final Frame frame = map.get(oldInput); + final RelNode newInput = frame.r; + assert newInput != null; - Map<Integer, Integer> mapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newInputRel); - assert mapOldToNewOutputPos != null; + final List<Integer> newLocalOutputPosList = + mapNewInputToOutputPos.get(newInput); - newLocalOutputPos = mapOldToNewOutputPos.get(corVar.field); + final int newLocalOutputPos = frame.oldToNewOutputPos.get(corVar.field); // newOutputPos is the index of the cor var in the referenced // position list plus the offset of referenced position list of - // each newInputRel. - newOutputPos = + // each newInput. + final int newOutputPos = newLocalOutputPosList.indexOf(newLocalOutputPos) - + mapNewInputRelToNewOffset.get(newInputRel) + + mapNewInputToNewOffset.get(newInput) + valueGenFieldOffset; if (mapCorVarToOutputPos.containsKey(corVar)) { @@ -823,53 +800,47 @@ public class RelDecorrelator implements ReflectiveVisitor { mapCorVarToOutputPos.put(corVar, newOutputPos); } - return resultRel; + return r; } - private void decorrelateInputWithValueGenerator( - RelNode rel) { - // currently only handles one child input - assert rel.getInputs().size() == 1; - RelNode oldChildRel = rel.getInput(0); - RelNode newChildRel = mapOldToNewRel.get(oldChildRel); - - Map<Integer, Integer> childMapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newChildRel); - assert childMapOldToNewOutputPos != null; + private RelNode getCorRel(Correlation corVar) { + final RelNode r = cm.mapCorVarToCorRel.get(corVar.corr); + RelNode r2 = r.getInput(0); + if (r2 instanceof Join) { + r2 = r2.getInput(0); + } + return r2; + } - SortedMap<Correlation, Integer> mapCorVarToOutputPos = Maps.newTreeMap(); + private void decorrelateInputWithValueGenerator(RelNode rel) { + // currently only handles one input input + assert rel.getInputs().size() == 1; + RelNode oldInput = rel.getInput(0); + final Frame frame = map.get(oldInput); - if (mapNewRelToMapCorVarToOutputPos.containsKey(newChildRel)) { - mapCorVarToOutputPos.putAll( - mapNewRelToMapCorVarToOutputPos.get(newChildRel)); - } + final SortedMap<Correlation, Integer> mapCorVarToOutputPos = + new TreeMap<>(frame.corVarOutputPos); final Collection<Correlation> corVarList = cm.mapRefRelToCorVar.get(rel); - RelNode newLeftChildRel = newChildRel; - - int leftChildOutputCount = newLeftChildRel.getRowType().getFieldCount(); + int leftInputOutputCount = frame.r.getRowType().getFieldCount(); // can directly add positions into mapCorVarToOutputPos since join - // does not change the output ordering from the children. - RelNode valueGenRel = + // does not change the output ordering from the inputs. + RelNode valueGen = createValueGenerator( corVarList, - leftChildOutputCount, + leftInputOutputCount, mapCorVarToOutputPos); - final Set<String> variablesStopped = Collections.emptySet(); - RelNode joinRel = - LogicalJoin.create(newLeftChildRel, valueGenRel, - rexBuilder.makeLiteral(true), JoinRelType.INNER, variablesStopped); - - mapOldToNewRel.put(oldChildRel, joinRel); - mapNewRelToMapCorVarToOutputPos.put(joinRel, mapCorVarToOutputPos); + RelNode join = + LogicalJoin.create(frame.r, valueGen, rexBuilder.makeLiteral(true), + ImmutableSet.<CorrelationId>of(), JoinRelType.INNER); // LogicalJoin or LogicalFilter does not change the old input ordering. All // input fields from newLeftInput(i.e. the original input to the old // LogicalFilter) are in the output and in the same position. - mapNewRelToMapOldToNewOutputPos.put(joinRel, childMapOldToNewOutputPos); + register(oldInput, join, frame.oldToNewOutputPos, mapCorVarToOutputPos); } /** @@ -877,7 +848,7 @@ public class RelDecorrelator implements ReflectiveVisitor { * * @param rel the filter rel to rewrite */ - public void decorrelateRel(LogicalFilter rel) { + public Frame decorrelateRel(LogicalFilter rel) { // // Rewrite logic: // @@ -894,53 +865,36 @@ public class RelDecorrelator implements ReflectiveVisitor { // rewrite the filter condition using new input. // - RelNode oldChildRel = rel.getInput(); - - RelNode newChildRel = mapOldToNewRel.get(oldChildRel); - if (newChildRel == null) { - // If child has not been rewritten, do not rewrite this rel. - return; + final RelNode oldInput = rel.getInput(); + Frame frame = getInvoke(oldInput, rel); + if (frame == null) { + // If input has not been rewritten, do not rewrite this rel. + return null; } - Map<Integer, Integer> childMapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newChildRel); - assert childMapOldToNewOutputPos != null; - - boolean produceCorVar = - mapNewRelToMapCorVarToOutputPos.containsKey(newChildRel); - // If this LogicalFilter has correlated reference, create value generator // and produce the correlated variables in the new output. if (cm.mapRefRelToCorVar.containsKey(rel)) { decorrelateInputWithValueGenerator(rel); - // The old child should be mapped to the newly created LogicalJoin by + // The old input should be mapped to the newly created LogicalJoin by // rewriteInputWithValueGenerator(). - newChildRel = mapOldToNewRel.get(oldChildRel); - produceCorVar = true; + frame = map.get(oldInput); } // Replace the filter expression to reference output of the join // Map filter to the new filter over join - RelNode newFilterRel = + RelNode newFilter = RelOptUtil.createFilter( - newChildRel, + frame.r, decorrelateExpr(rel.getCondition())); - mapOldToNewRel.put(rel, newFilterRel); - // Filter does not change the input ordering. - mapNewRelToMapOldToNewOutputPos.put( - newFilterRel, - childMapOldToNewOutputPos); - - if (produceCorVar) { - // filter rel does not permute the input all corvars produced by - // filter will have the same output positions in the child rel. - mapNewRelToMapCorVarToOutputPos.put( - newFilterRel, - mapNewRelToMapCorVarToOutputPos.get(newChildRel)); - } + // Filter rel does not permute the input. + // All corvars produced by filter will have the same output positions in the + // input rel. + return register(rel, newFilter, frame.oldToNewOutputPos, + frame.corVarOutputPos); } /** @@ -948,7 +902,7 @@ public class RelDecorrelator implements ReflectiveVisitor { * * @param rel Correlator */ - public void decorrelateRel(LogicalCorrelate rel) { + public Frame decorrelateRel(LogicalCorrelate rel) { // // Rewrite logic: // @@ -959,126 +913,93 @@ public class RelDecorrelator implements ReflectiveVisitor { // // the right input to Correlator should produce correlated variables - RelNode oldLeftRel = rel.getInputs().get(0); - RelNode oldRightRel = rel.getInputs().get(1); + final RelNode oldLeft = rel.getInput(0); + final RelNode oldRight = rel.getInput(1); - RelNode newLeftRel = mapOldToNewRel.get(oldLeftRel); - RelNode newRightRel = mapOldToNewRel.get(oldRightRel); + final Frame leftFrame = getInvoke(oldLeft, rel); + final Frame rightFrame = getInvoke(oldRight, rel); - if ((newLeftRel == null) || (newRightRel == null)) { - // If any child has not been rewritten, do not rewrite this rel. - return; + if (leftFrame == null || rightFrame == null) { + // If any input has not been rewritten, do not rewrite this rel. + return null; } - SortedMap<Correlation, Integer> rightChildMapCorVarToOutputPos = - mapNewRelToMapCorVarToOutputPos.get(newRightRel); - - if (rightChildMapCorVarToOutputPos == null) { - return; + if (rightFrame.corVarOutputPos.isEmpty()) { + return null; } - Map<Integer, Integer> leftChildMapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newLeftRel); - assert leftChildMapOldToNewOutputPos != null; - - Map<Integer, Integer> rightChildMapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newRightRel); - - assert rightChildMapOldToNewOutputPos != null; - - SortedMap<Correlation, Integer> mapCorVarToOutputPos = - rightChildMapCorVarToOutputPos; - assert rel.getRequiredColumns().cardinality() - <= rightChildMapCorVarToOutputPos.keySet().size(); + <= rightFrame.corVarOutputPos.keySet().size(); // Change correlator rel into a join. // Join all the correlated variables produced by this correlator rel // with the values generated and propagated from the right input - RexNode condition = rexBuilder.makeLiteral(true); + final SortedMap<Correlation, Integer> corVarOutputPos = + new TreeMap<>(rightFrame.corVarOutputPos); + final List<RexNode> conditions = new ArrayList<>(); final List<RelDataTypeField> newLeftOutput = - newLeftRel.getRowType().getFieldList(); + leftFrame.r.getRowType().getFieldList(); int newLeftFieldCount = newLeftOutput.size(); final List<RelDataTypeField> newRightOutput = - newRightRel.getRowType().getFieldList(); + rightFrame.r.getRowType().getFieldList(); - int newLeftPos; - int newRightPos; for (Map.Entry<Correlation, Integer> rightOutputPos - : Lists.newArrayList(rightChildMapCorVarToOutputPos.entrySet())) { - Correlation corVar = rightOutputPos.getKey(); + : Lists.newArrayList(corVarOutputPos.entrySet())) { + final Correlation corVar = rightOutputPos.getKey(); if (!corVar.corr.equals(rel.getCorrelationId())) { continue; } - newLeftPos = leftChildMapOldToNewOutputPos.get(corVar.field); - newRightPos = rightChildMapCorVarToOutputPos.get(corVar); - RexNode equi = - rexBuilder.makeCall( - SqlStdOperatorTable.EQUALS, + final int newLeftPos = leftFrame.oldToNewOutputPos.get(corVar.field); + final int newRightPos = rightOutputPos.getValue(); + conditions.add( + rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, RexInputRef.of(newLeftPos, newLeftOutput), - new RexInputRef( - newLeftFieldCount + newRightPos, - newRightOutput.get(newRightPos).getType())); - if (condition == rexBuilder.makeLiteral(true)) { - condition = equi; - } else { - condition = - rexBuilder.makeCall( - SqlStdOperatorTable.AND, - condition, - equi); - } + new RexInputRef(newLeftFieldCount + newRightPos, + newRightOutput.get(newRightPos).getType()))); // remove this cor var from output position mapping - mapCorVarToOutputPos.remove(corVar); + corVarOutputPos.remove(corVar); } // Update the output position for the cor vars: only pass on the cor // vars that are not used in the join key. - for (Correlation corVar : mapCorVarToOutputPos.keySet()) { - int newPos = mapCorVarToOutputPos.get(corVar) + newLeftFieldCount; - mapCorVarToOutputPos.put(corVar, newPos); + for (Correlation corVar : corVarOutputPos.keySet()) { + int newPos = corVarOutputPos.get(corVar) + newLeftFieldCount; + corVarOutputPos.put(corVar, newPos); } // then add any cor var from the left input. Do not need to change // output positions. - if (mapNewRelToMapCorVarToOutputPos.containsKey(newLeftRel)) { - mapCorVarToOutputPos.putAll( - mapNewRelToMapCorVarToOutputPos.get(newLeftRel)); - } + corVarOutputPos.putAll(leftFrame.corVarOutputPos); // Create the mapping between the output of the old correlation rel // and the new join rel - Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap(); + final Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap(); - int oldLeftFieldCount = oldLeftRel.getRowType().getFieldCount(); + int oldLeftFieldCount = oldLeft.getRowType().getFieldCount(); - int oldRightFieldCount = oldRightRel.getRowType().getFieldCount(); + int oldRightFieldCount = oldRight.getRowType().getFieldCount(); assert rel.getRowType().getFieldCount() == oldLeftFieldCount + oldRightFieldCount; // Left input positions are not changed. - mapOldToNewOutputPos.putAll(leftChildMapOldToNewOutputPos); + mapOldToNewOutputPos.putAll(leftFrame.oldToNewOutputPos); // Right input positions are shifted by newLeftFieldCount. for (int i = 0; i < oldRightFieldCount; i++) { mapOldToNewOutputPos.put( i + oldLeftFieldCount, - rightChildMapOldToNewOutputPos.get(i) + newLeftFieldCount); + rightFrame.oldToNewOutputPos.get(i) + newLeftFieldCount); } - final Set<String> variablesStopped = Collections.emptySet(); - RelNode newRel = - LogicalJoin.create(newLeftRel, newRightRel, condition, - rel.getJoinType().toJoinType(), variablesStopped); - - mapOldToNewRel.put(rel, newRel); - mapNewRelToMapOldToNewOutputPos.put(newRel, mapOldToNewOutputPos); + final RexNode condition = + RexUtil.composeConjunction(rexBuilder, conditions, false); + RelNode newJoin = + LogicalJoin.create(leftFrame.r, rightFrame.r, condition, + ImmutableSet.<CorrelationId>of(), rel.getJoinType().toJoinType()); - if (!mapCorVarToOutputPos.isEmpty()) { - mapNewRelToMapCorVarToOutputPos.put(newRel, mapCorVarToOutputPos); - } + return register(rel, newJoin, mapOldToNewOutputPos, corVarOutputPos); } /** @@ -1086,7 +1007,7 @@ public class RelDecorrelator implements ReflectiveVisitor { * * @param rel LogicalJoin */ - public void decorrelateRel(LogicalJoin rel) { + public Frame decorrelateRel(LogicalJoin rel) { // // Rewrite logic: // @@ -1094,77 +1015,52 @@ public class RelDecorrelator implements ReflectiveVisitor { // 2. map output positions and produce cor vars if any. // - RelNode oldLeftRel = rel.getInputs().get(0); - RelNode oldRightRel = rel.getInputs().get(1); + final RelNode oldLeft = rel.getInput(0); + final RelNode oldRight = rel.getInput(1); - RelNode newLeftRel = mapOldToNewRel.get(oldLeftRel); - RelNode newRightRel = mapOldToNewRel.get(oldRightRel); + final Frame leftFrame = getInvoke(oldLeft, rel); + final Frame rightFrame = getInvoke(oldRight, rel); - if ((newLeftRel == null) || (newRightRel == null)) { - // If any child has not been rewritten, do not rewrite this rel. - return; + if (leftFrame == null || rightFrame == null) { + // If any input has not been rewritten, do not rewrite this rel. + return null; } - Map<Integer, Integer> leftChildMapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newLeftRel); - assert leftChildMapOldToNewOutputPos != null; - - Map<Integer, Integer> rightChildMapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newRightRel); - assert rightChildMapOldToNewOutputPos != null; - - SortedMap<Correlation, Integer> mapCorVarToOutputPos = Maps.newTreeMap(); - - final Set<String> variablesStopped = Collections.emptySet(); - RelNode newRel = - LogicalJoin.create(newLeftRel, newRightRel, - decorrelateExpr(rel.getCondition()), rel.getJoinType(), - variablesStopped); + final RelNode newJoin = + LogicalJoin.create(leftFrame.r, rightFrame.r, + decorrelateExpr(rel.getCondition()), + ImmutableSet.<CorrelationId>of(), rel.getJoinType()); // Create the mapping between the output of the old correlation rel // and the new join rel Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap(); - int oldLeftFieldCount = oldLeftRel.getRowType().getFieldCount(); - int newLeftFieldCount = newLeftRel.getRowType().getFieldCount(); + int oldLeftFieldCount = oldLeft.getRowType().getFieldCount(); + int newLeftFieldCount = leftFrame.r.getRowType().getFieldCount(); - int oldRightFieldCount = oldRightRel.getRowType().getFieldCount(); + int oldRightFieldCount = oldRight.getRowType().getFieldCount(); assert rel.getRowType().getFieldCount() == oldLeftFieldCount + oldRightFieldCount; // Left input positions are not changed. - mapOldToNewOutputPos.putAll(leftChildMapOldToNewOutputPos); + mapOldToNewOutputPos.putAll(leftFrame.oldToNewOutputPos); // Right input positions are shifted by newLeftFieldCount. for (int i = 0; i < oldRightFieldCount; i++) { - mapOldToNewOutputPos.put( - i + oldLeftFieldCount, - rightChildMapOldToNewOutputPos.get(i) + newLeftFieldCount); + mapOldToNewOutputPos.put(i + oldLeftFieldCount, + rightFrame.oldToNewOutputPos.get(i) + newLeftFieldCount); } - if (mapNewRelToMapCorVarToOutputPos.containsKey(newLeftRel)) { - mapCorVarToOutputPos.putAll( - mapNewRelToMapCorVarToOutputPos.get(newLeftRel)); - } + final SortedMap<Correlation, Integer> mapCorVarToOutputPos = + new TreeMap<>(leftFrame.corVarOutputPos); // Right input positions are shifted by newLeftFieldCount. - int oldRightPos; - if (mapNewRelToMapCorVarToOutputPos.containsKey(newRightRel)) { - SortedMap<Correlation, Integer> rightChildMapCorVarToOutputPos = - mapNewRelToMapCorVarToOutputPos.get(newRightRel); - for (Correlation corVar : rightChildMapCorVarToOutputPos.keySet()) { - oldRightPos = rightChildMapCorVarToOutputPos.get(corVar); - mapCorVarToOutputPos.put( - corVar, - oldRightPos + newLeftFieldCount); - } - } - mapOldToNewRel.put(rel, newRel); - mapNewRelToMapOldToNewOutputPos.put(newRel, mapOldToNewOutputPos); - - if (!mapCorVarToOutputPos.isEmpty()) { - mapNewRelToMapCorVarToOutputPos.put(newRel, mapCorVarToOutputPos); + for (Map.Entry<Correlation, Integer> entry + : rightFrame.corVarOutputPos.entrySet()) { + mapCorVarToOutputPos.put(entry.getKey(), + entry.getValue() + newLeftFieldCount); } + return register(rel, newJoin, mapOldToNewOutputPos, mapCorVarToOutputPos); } private RexInputRef getNewForOldInputRef(RexInputRef oldInputRef) { @@ -1175,61 +1071,57 @@ public class RelDecorrelator implements ReflectiveVisitor { // determine which input rel oldOrdinal references, and adjust // oldOrdinal to be relative to that input rel - List<RelNode> oldInputRels = currentRel.getInputs(); - RelNode oldInputRel = null; + RelNode oldInput = null; - for (RelNode oldInputRel0 : oldInputRels) { - RelDataType oldInputType = oldInputRel0.getRowType(); + for (RelNode oldInput0 : currentRel.getInputs()) { + RelDataType oldInputType = oldInput0.getRowType(); int n = oldInputType.getFieldCount(); if (oldOrdinal < n) { - oldInputRel = oldInputRel0; + oldInput = oldInput0; break; } - RelNode newInput = mapOldToNewRel.get(oldInputRel0); + RelNode newInput = map.get(oldInput0).r; newOrdinal += newInput.getRowType().getFieldCount(); oldOrdinal -= n; } - assert oldInputRel != null; + assert oldInput != null; - RelNode newInputRel = mapOldToNewRel.get(oldInputRel); - assert newInputRel != null; + final Frame frame = map.get(oldInput); + assert frame != null; - // now oldOrdinal is relative to oldInputRel + // now oldOrdinal is relative to oldInput int oldLocalOrdinal = oldOrdinal; - // figure out the newLocalOrdinal, relative to the newInputRel. + // figure out the newLocalOrdinal, relative to the newInput. int newLocalOrdinal = oldLocalOrdinal; - Map<Integer, Integer> mapOldToNewOutputPos = - mapNewRelToMapOldToNewOutputPos.get(newInputRel); - - if (mapOldToNewOutputPos != null) { - newLocalOrdinal = mapOldToNewOutputPos.get(oldLocalOrdinal); + if (!frame.oldToNewOutputPos.isEmpty()) { + newLocalOrdinal = frame.oldToNewOutputPos.get(oldLocalOrdinal); } newOrdinal += newLocalOrdinal; return new RexInputRef(newOrdinal, - newInputRel.getRowType().getFieldList().get(newLocalOrdinal).getType()); + frame.r.getRowType().getFieldList().get(newLocalOrdinal).getType()); } /** - * Pull projRel above the join from its RHS input. Enforce nullability + * Pulls project above the join from its RHS input. Enforces nullability * for join output. * * @param join Join - * @param projRel the original projRel as the RHS input of the join. + * @param project Original project as the right-hand input of the join * @param nullIndicatorPos Position of null indicator * @return the subtree with the new LogicalProject at the root */ private RelNode projectJoinOutputWithNullability( LogicalJoin join, - LogicalProject projRel, + LogicalProject project, int nullIndicatorPos) { - RelDataTypeFactory typeFactory = join.getCluster().getTypeFactory(); - RelNode leftInputRel = join.getLeft(); - JoinRelType joinType = join.getJoinType(); + final RelDataTypeFactory typeFactory = join.getCluster().getTypeFactory(); + final RelNode left = join.getLeft(); + final JoinRelType joinType = join.getJoinType(); RexInputRef nullIndicator = new RexInputRef( @@ -1245,7 +1137,7 @@ public class RelDecorrelator implements ReflectiveVisitor { // project everything from the LHS and then those from the original // projRel List<RelDataTypeField> leftInputFields = - leftInputRel.getRowType().getFieldList(); + left.getRowType().getFieldList(); for (int i = 0; i < leftInputFields.size(); i++) { newProjExprs.add(RexInputRef.of2(i, leftInputFields)); @@ -1257,7 +1149,7 @@ public class RelDecorrelator implements ReflectiveVisitor { boolean projectPulledAboveLeftCorrelator = joinType.generatesNullsOnRight(); - for (Pair<RexNode, String> pair : projRel.getNamedProjects()) { + for (Pair<RexNode, String> pair : project.getNamedProjects()) { RexNode newProjExpr = removeCorrelationExpr( pair.left, @@ -1267,36 +1159,33 @@ public class RelDecorrelator implements ReflectiveVisitor { newProjExprs.add(Pair.of(newProjExpr, pair.right)); } - RelNode newProjRel = - RelOptUtil.createProject(join, newProjExprs, false); - - return newProjRel; + return RelOptUtil.createProject(join, newProjExprs, false); } /** - * Pulls projRel above the joinRel from its RHS input. Enforces nullability - * for join output. + * Pulls a {@link Project} above a {@link Correlate} from its RHS input. + * Enforces nullability for join output. * - * @param corRel Correlator - * @param projRel the original LogicalProject as the RHS input of the join + * @param correlate Correlate + * @param project the original project as the RHS input of the join * @param isCount Positions which are calls to the <code>COUNT</code> * aggregation function * @return the subtree with the new LogicalProject at the root */ private RelNode aggregateCorrelatorOutput( - LogicalCorrelate corRel, - LogicalProject projRel, + Correlate correlate, + LogicalProject project, Set<Integer> isCount) { - RelNode leftInputRel = corRel.getLeft(); - JoinRelType joinType = corRel.getJoinType().toJoinType(); + final RelNode left = correlate.getLeft(); + final JoinRelType joinType = correlate.getJoinType().toJoinType(); // now create the new project - List<Pair<RexNode, String>> newProjects = Lists.newArrayList(); + final List<Pair<RexNode, String>> newProjects = Lists.newArrayList(); - // project everything from the LHS and then those from the original - // projRel - List<RelDataTypeField> leftInputFields = - leftInputRel.getRowType().getFieldList(); + // Project everything from the LHS and then those from the original + // project + final List<RelDataTypeField> leftInputFields = + left.getRowType().getFieldList(); for (int i = 0; i < leftInputFields.size(); i++) { newProjects.add(RexInputRef.of2(i, leftInputFields)); @@ -1308,7 +1197,7 @@ public class RelDecorrelator implements ReflectiveVisitor { boolean projectPulledAboveLeftCorrelator = joinType.generatesNullsOnRight(); - for (Pair<RexNode, String> pair : projRel.getNamedProjects()) { + for (Pair<RexNode, String> pair : project.getNamedProjects()) { RexNode newProjExpr = removeCorrelationExpr( pair.left, @@ -1317,22 +1206,22 @@ public class RelDecorrelator implements ReflectiveVisitor { newProjects.add(Pair.of(newProjExpr, pair.right)); } - return RelOptUtil.createProject(corRel, newProjects, false); + return RelOptUtil.createProject(correlate, newProjects, false); } /** * Checks whether the correlations in projRel and filter are related to * the correlated variables provided by corRel. * - * @param corRel Correlator - * @param projRel The original Project as the RHS input of the join + * @param correlate Correlate + * @param project The original Project as the RHS input of the join * @param filter Filter * @param correlatedJoinKeys Correlated join keys * @return true if filter and proj only references corVar provided by corRel */ private boolean checkCorVars( - LogicalCorrelate corRel, - LogicalProject projRel, + LogicalCorrelate correlate, + LogicalProject project, LogicalFilter filter, List<RexFieldAccess> correlatedJoinKeys) { if (filter != null) { @@ -1344,8 +1233,7 @@ public class RelDecorrelator implements ReflectiveVisitor { Sets.newHashSet(cm.mapRefRelToCorVar.get(filter)); for (RexFieldAccess correlatedJoinKey : correlatedJoinKeys) { - corVarInFilter.remove( - cm.mapFieldAccessToCorVar.get(correlatedJoinKey)); + corVarInFilter.remove(cm.mapFieldAccessToCorVar.get(correlatedJoinKey)); } if (!corVarInFilter.isEmpty()) { @@ -1357,18 +1245,18 @@ public class RelDecorrelator implements ReflectiveVisitor { corVarInFilter.addAll(cm.mapRefRelToCorVar.get(filter)); for (Correlation corVar : corVarInFilter) { - if (cm.mapCorVarToCorRel.get(corVar.corr) != corRel) { + if (cm.mapCorVarToCorRel.get(corVar.corr) != correlate) { return false; } } } - // if projRel has any correlated reference, make sure they are also - // provided by the current corRel. They will be projected out of the LHS - // of the corRel. - if ((projRel != null) && cm.mapRefRelToCorVar.containsKey(projRel)) { - for (Correlation corVar : cm.mapRefRelToCorVar.get(projRel)) { - if (cm.mapCorVarToCorRel.get(corVar.corr) != corRel) { + // if project has any correlated reference, make sure they are also + // provided by the current correlate. They will be projected out of the LHS + // of the correlate. + if ((project != null) && cm.mapRefRelToCorVar.containsKey(project)) { + for (Correlation corVar : cm.mapRefRelToCorVar.get(project)) { + if (cm.mapCorVarToCorRel.get(corVar.corr) != correlate) { return false; } } @@ -1380,26 +1268,26 @@ public class RelDecorrelator implements ReflectiveVisitor { /** * Remove correlated variables from the tree at root corRel * - * @param corRel Correlator + * @param correlate Correlator */ - private void removeCorVarFromTree(LogicalCorrelate corRel) { - if (cm.mapCorVarToCorRel.get(corRel.getCorrelationId()) == corRel) { - cm.mapCorVarToCorRel.remove(corRel.getCorrelationId()); + private void removeCorVarFromTree(LogicalCorrelate correlate) { + if (cm.mapCorVarToCorRel.get(correlate.getCorrelationId()) == correlate) { + cm.mapCorVarToCorRel.remove(correlate.getCorrelationId()); } } /** - * Project all childRel output fields plus the additional expressions. + * Projects all {@code input} output fields plus the additional expressions. * - * @param childRel Child relational expression + * @param input Input relational expression * @param additionalExprs Additional expressions and names * @return the new LogicalProject */ private RelNode createProjectWithAdditionalExprs( - RelNode childRel, + RelNode input, List<Pair<RexNode, String>> additionalExprs) { final List<RelDataTypeField> fieldList = - childRel.getRowType().getFieldList(); + input.getRowType().getFieldList(); List<Pair<RexNode, String>> projects = Lists.newArrayList(); for (Ord<RelDataTypeField> field : Ord.zip(fieldList)) { projects.add( @@ -1409,140 +1297,93 @@ public class RelDecorrelator implements ReflectiveVisitor { field.e.getName())); } projects.addAll(additionalExprs); - return RelOptUtil.createProject(childRel, projects, false); + return RelOptUtil.createProject(input, projects, false); } - //~ Inner Classes ---------------------------------------------------------- + /* Returns an immutable map with the identity [0: 0, .., count-1: count-1]. */ + static Map<Integer, Integer> identityMap(int count) { + ImmutableMap.Builder<Integer, Integer> builder = ImmutableMap.builder(); + for (int i = 0; i < count; i++) { + builder.put(i, i); + } + return builder.build(); + } + + /** Registers a relational expression and the relational expression it became + * after decorrelation. */ + Frame register(RelNode rel, RelNode newRel, + Map<Integer, Integer> oldToNewOutputPos, + SortedMap<Correlation, Integer> corVarToOutputPos) { + assert allLessThan(oldToNewOutputPos.keySet(), + newRel.getRowType().getFieldCount(), Litmus.THROW); + final Frame frame = new Frame(newRel, corVarToOutputPos, oldToNewOutputPos); + map.put(rel, frame); + return frame; + } - /** Visitor that decorrelates. */ - private class DecorrelateRelVisitor extends RelVisitor { - private final ReflectiveVisitDispatcher<RelDecorrelator, RelNode> - dispatcher = - ReflectUtil.createDispatcher( - RelDecorrelator.class, - RelNode.class); - - // implement RelVisitor - public void visit(RelNode p, int ordinal, RelNode parent) { - // rewrite children first (from left to right) - super.visit(p, ordinal, parent); - - currentRel = p; - - final String visitMethodName = "decorrelateRel"; - boolean found = - dispatcher.invokeVisitor( - RelDecorrelator.this, - currentRel, - visitMethodName); - setCurrent(null, null); - - if (!found) { - decorrelateRelGeneric(p); + static boolean allLessThan(Collection<Integer> integers, int limit, + Litmus ret) { + for (int value : integers) { + if (value >= limit) { + return ret.fail("out of range; value: " + value + ", limit: " + limit); } - // else no rewrite will occur. This will terminate the bottom-up - // rewrite. If root node of a RelNode tree is not rewritten, the - // original tree will be returned. See decorrelate() method. } + return ret.succeed(); } + private static RelNode stripHep(RelNode rel) { + if (rel instanceof HepRelVertex) { + HepRelVertex hepRelVertex = (HepRelVertex) rel; + rel = hepRelVertex.getCurrentRel(); + } + return rel; + } + + //~ Inner Classes ---------------------------------------------------------- + /** Shuttle that decorrelates. */ private class DecorrelateRexShuttle extends RexShuttle { - // override RexShuttle - public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { - int newInputRelOutputOffset = 0; - RelNode oldInputRel; - RelNode newInputRel; - Integer newInputPos; - - List<RelNode> inputs = currentRel.getInputs(); - for (int i = 0; i < inputs.size(); i++) { - oldInputRel = inputs.get(i); - newInputRel = mapOldToNewRel.get(oldInputRel); - - if ((newInputRel != null) - && mapNewRelToMapCorVarToOutputPos.containsKey(newInputRel)) { - SortedMap<Correlation, Integer> childMapCorVarToOutputPos = - mapNewRelToMapCorVarToOutputPos.get(newInputRel); - - if (childMapCorVarToOutputPos != null) { - // try to find in this input rel the position of cor var - Correlation corVar = cm.mapFieldAccessToCorVar.get(fieldAccess); - - if (corVar != null) { - newInputPos = childMapCorVarToOutputPos.get(corVar); - if (newInputPos != null) { - // this input rel does produce the cor var - // referenced - newInputPos += newInputRelOutputOffset; - - // fieldAccess is assumed to have the correct - // type info. - RexInputRef newInput = - new RexInputRef( - newInputPos, - fieldAccess.getType()); - return newInput; - } + @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + int newInputOutputOffset = 0; + for (RelNode input : currentRel.getInputs()) { + final Frame frame = map.get(input); + + if (frame != null) { + // try to find in this input rel the position of cor var + final Correlation corVar = cm.mapFieldAccessToCorVar.get(fieldAccess); + + if (corVar != null) { + Integer newInputPos = frame.corVarOutputPos.get(corVar); + if (newInputPos != null) { + // This input rel does produce the cor var referenced. + // Assume fieldAccess has the correct type info. + return new RexInputRef(newInputPos + newInputOutputOffset, + fieldAccess.getType()); } } // this input rel does not produce the cor var needed - newInputRelOutputOffset += - newInputRel.getRowType().getFieldCount(); + newInputOutputOffset += frame.r.getRowType().getFieldCount(); } else { // this input rel is not rewritten - newInputRelOutputOffset += - oldInputRel.getRowType().getFieldCount(); + newInputOutputOffset += input.getRowType().getFieldCount(); } } return fieldAccess; } - // override RexShuttle - public RexNode visitInputRef(RexInputRef inputRef) { - RexInputRef newInputRef = getNewForOldInputRef(inputRef); - return newInputRef; + @Override public RexNode visitInputRef(RexInputRef inputRef) { + return getNewForOldInputRef(inputRef); } } /** Shuttle that removes correlations. */ private class RemoveCorrelationRexShuttle extends RexShuttle { - RexBuilder rexBuilder; - RelDataTypeFactory typeFactory; - boolean projectPulledAboveLeftCorrelator; - RexInputRef nullIndicator; - Set<Integer> isCount; - - public RemoveCorrelationRexShuttle( - RexBuilder rexBuilder, - boolean projectPulledAboveLeftCorrelator) { - this( - rexBuilder, - projectPulledAboveLeftCorrelator, - null, null); - } - - public RemoveCorrelationRexShuttle( - RexBuilder rexBuilder, - boolean projectPulledAboveLeftCorrelator, - RexInputRef nullIndicator) { - this( - rexBuilder, - projectPulledAboveLeftCorrelator, - nullIndicator, - null); - } - - public RemoveCorrelationRexShuttle( - RexBuilder rexBuilder, - boolean projectPulledAboveLeftCorrelator, - Set<Integer> isCount) { - this( - rexBuilder, - projectPulledAboveLeftCorrelator, - null, isCount); - } + final RexBuilder rexBuilder; + final RelDataTypeFactory typeFactory; + final boolean projectPulledAboveLeftCorrelator; + final RexInputRef nullIndicator; + final ImmutableSet<Integer> isCount; public RemoveCorrelationRexShuttle( RexBuilder rexBuilder, @@ -1551,8 +1392,8 @@ public class RelDecorrelator implements ReflectiveVisitor { Set<Integer> isCount) { this.projectPulledAboveLeftCorrelator = projectPulledAboveLeftCorrelator; - this.nullIndicator = nullIndicator; - this.isCount = isCount; + this.nullIndicator = nullIndicator; // may be null + this.isCount = ImmutableSet.copyOf(isCount); this.rexBuilder = rexBuilder; this.typeFactory = rexBuilder.getTypeFactory(); } @@ -1603,8 +1444,7 @@ public class RelDecorrelator implements ReflectiveVisitor { caseOperands); } - // override RexShuttle - public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { if (cm.mapFieldAccessToCorVar.containsKey(fieldAccess)) { // if it is a corVar, change it to be input ref. Correlation corVar = cm.mapFieldAccessToCorVar.get(fieldAccess); @@ -1629,15 +1469,14 @@ public class RelDecorrelator implements ReflectiveVisitor { return fieldAccess; } - // override RexShuttle - public RexNode visitInputRef(RexInputRef inputRef) { - if ((currentRel != null) && (currentRel instanceof LogicalCorrelate)) { + @Override public RexNode visitInputRef(RexInputRef inputRef) { + if (currentRel instanceof LogicalCorrelate) { // if this rel references corVar // and now it needs to be rewritten // it must have been pulled above the Correlator // replace the input ref to account for the LHS of the // Correlator - int leftInputFieldCount = + final int leftInputFieldCount = ((LogicalCorrelate) currentRel).getLeft().getRowType() .getFieldCount(); RelDataType newType = inputRef.getType(); @@ -1663,8 +1502,7 @@ public class RelDecorrelator implements ReflectiveVisitor { return inputRef; } - // override RexLiteral - public RexNode visitLiteral(RexLiteral literal) { + @Override public RexNode visitLiteral(RexLiteral literal) { // Use nullIndicator to decide whether to project null. // Do nothing if the literal is null. if (!RexUtil.isNull(literal) @@ -1678,7 +1516,7 @@ public class RelDecorrelator implements ReflectiveVisitor { return literal; } - public RexNode visitCall(final RexCall call) { + @Override public RexNode visitCall(final RexCall call) { RexNode newCall; boolean[] update = {false}; @@ -1752,14 +1590,14 @@ public class RelDecorrelator implements ReflectiveVisitor { } public void onMatch(RelOptRuleCall call) { - LogicalAggregate singleAggRel = call.rel(0); - LogicalProject projRel = call.rel(1); - LogicalAggregate aggRel = call.rel(2); + LogicalAggregate singleAggregate = call.rel(0); + LogicalProject project = call.rel(1); + LogicalAggregate aggregate = call.rel(2); // check singleAggRel is single_value agg - if ((!singleAggRel.getGroupSet().isEmpty()) - || (singleAggRel.getAggCallList().size() != 1) - || !(singleAggRel.getAggCallList().get(0).getAggregation() + if ((!singleAggregate.getGroupSet().isEmpty()) + || (singleAggregate.getAggCallList().size() != 1) + || !(singleAggregate.getAggCallList().get(0).getAggregation() instanceof SqlSingleValueAggFunction)) { return; } @@ -1767,21 +1605,21 @@ public class RelDecorrelator implements ReflectiveVisitor { // check projRel only projects one expression // check this project only projects one expression, i.e. scalar // subqueries. - List<RexNode> projExprs = projRel.getProjects(); + List<RexNode> projExprs = project.getProjects(); if (projExprs.size() != 1) { return; } // check the input to projRel is an aggregate on the entire input - if (!aggRel.getGroupSet().isEmpty()) { + if (!aggregate.getGroupSet().isEmpty()) { return; } // singleAggRel produces a nullable type, so create the new // projection that casts proj expr to a nullable type. - final RelOptCluster cluster = projRel.getCluster(); - RelNode newProjRel = - RelOptUtil.createProject(aggRel, + final RelOptCluster cluster = project.getCluster(); + RelNode newProject = + RelOptUtil.createProject(aggregate, ImmutableList.of( rexBuilder.makeCast( cluster.getTypeFactory().createTypeWithNullability( @@ -1789,7 +1627,7 @@ public class RelDecorrelator implements ReflectiveVisitor { true), projExprs.get(0))), null); - call.transformTo(newProjRel); + call.transformTo(newProject); } } @@ -1805,14 +1643,14 @@ public class RelDecorrelator implements ReflectiveVisitor { } public void onMatch(RelOptRuleCall call) { - LogicalCorrelate corRel = call.rel(0); - RelNode leftInputRel = call.rel(1); - LogicalAggregate aggRel = call.rel(2); - LogicalProject projRel = call.rel(3); - RelNode rightInputRel = call.rel(4); - RelOptCluster cluster = corRel.getCluster(); + final LogicalCorrelate correlate = call.rel(0); + final RelNode left = call.rel(1); + final LogicalAggregate aggregate = call.rel(2); + final LogicalProject project = call.rel(3); + RelNode right = call.rel(4); + final RelOptCluster cluster = correlate.getCluster(); - setCurrent(call.getPlanner().getRoot(), corRel); + setCurrent(call.getPlanner().getRoot(), correlate); // Check for this pattern. // The pattern matching could be simplified if rules can be applied @@ -1823,7 +1661,7 @@ public class RelDecorrelator implements ReflectiveVisitor { // LogicalAggregate (groupby (0) single_value()) // LogicalProject-A (may reference coVar) // RightInputRel - JoinRelType joinType = corRel.getJoinType().toJoinType(); + final JoinRelType joinType = correlate.getJoinType().toJoinType(); // corRel.getCondition was here, however Correlate was updated so it // never includes a join condition. The code was not modified for brevity. @@ -1835,23 +1673,23 @@ public class RelDecorrelator implements ReflectiveVisitor { // check that the agg is of the following type: // doing a single_value() on the entire input - if ((!aggRel.getGroupSet().isEmpty()) - || (aggRel.getAggCallList().size() != 1) - || !(aggRel.getAggCallList().get(0).getAggregation() + if ((!aggregate.getGroupSet().isEmpty()) + || (aggregate.getAggCallList().size() != 1) + || !(aggregate.getAggCallList().get(0).getAggregation() instanceof SqlSingleValueAggFunction)) { return; } // check this project only projects one expression, i.e. scalar // subqueries. - if (projRel.getProjects().size() != 1) { + if (project.getProjects().size() != 1) { return; } int nullIndicatorPos; - if ((rightInputRel instanceof LogicalFilter) - && cm.mapRefRelToCorVar.containsKey(rightInputRel)) { + if ((right instanceof LogicalFilter) + && cm.mapRefRelToCorVar.containsKey(right)) { // rightInputRel has this shape: // // LogicalFilter (references corvar) @@ -1861,14 +1699,14 @@ public class RelDecorrelator implements ReflectiveVisitor { // reference, make sure the correlated keys in the filter // condition forms a unique key of the RHS. - LogicalFilter filter = (LogicalFilter) rightInputRel; - rightInputRel = filter.getInput(); + LogicalFilter filter = (LogicalFilter) right; + right = filter.getInput(); - assert rightInputRel instanceof HepRelVertex; - rightInputRel = ((HepRelVertex) rightInputRel).getCurrentRel(); + assert right instanceof HepRelVertex; + right = ((HepRelVertex) right).getCurrentRel(); // check filter input contains no correlation - if (RelOptUtil.getVariablesUsed(rightInputRel).size() > 0) { + if (RelOptUtil.getVariablesUsed(right).size() > 0) { return; } @@ -1889,7 +1727,7 @@ public class RelDecorrelator implements ReflectiveVisitor { // check that the columns referenced in these comparisons form // an unique key of the filterInputRel - List<RexInputRef> rightJoinKeys = new ArrayList<RexInputRef>(); + final List<RexInputRef> rightJoinKeys = new ArrayList<>(); for (RexNode key : tmpRightJoinKeys) { assert key instanceof RexInputRef; rightJoinKeys.add((RexInputRef) key); @@ -1904,11 +1742,11 @@ public class RelDecorrelator implements ReflectiveVisitor { // The join filters out the nulls. So, it's ok if there are // nulls in the join keys. if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered( - rightInputRel, + right, rightJoinKeys)) { SQL2REL_LOGGER.fine(rightJoinKeys.toString() + "are not unique keys for " - + rightInputRel.toString()); + + right.toString()); return; } @@ -1918,7 +1756,7 @@ public class RelDecorrelator implements ReflectiveVisitor { List<RexFieldAccess> correlatedKeyList = visitor.getFieldAccessList(); - if (!checkCorVars(corRel, projRel, filter, correlatedKeyList)) { + if (!checkCorVars(correlate, project, filter, correlatedKeyList)) { return; } @@ -1935,15 +1773,15 @@ public class RelDecorrelator implements ReflectiveVisitor { removeCorrelationExpr(filter.getCondition(), false); nullIndicatorPos = - leftInputRel.getRowType().getFieldCount() + left.getRowType().getFieldCount() + rightJoinKeys.get(0).getIndex(); - } else if (cm.mapRefRelToCorVar.containsKey(projRel)) { + } else if (cm.mapRefRelToCorVar.containsKey(project)) { // check filter input contains no correlation - if (RelOptUtil.getVariablesUsed(rightInputRel).size() > 0) { + if (RelOptUtil.getVariablesUsed(right).size() > 0) { return; } - if (!checkCorVars(corRel, projRel, null, null)) { + if (!checkCorVars(correlate, project, null, null)) { return; } @@ -1957,37 +1795,37 @@ public class RelDecorrelator implements ReflectiveVisitor { // ProjInputRel // make the new projRel to provide a null indicator - rightInputRel = - createProjectWithAdditionalExprs(rightInputRel, + right = + createProjectWithAdditionalExprs(right, ImmutableList.of( Pair.<RexNode, String>of( rexBuilder.makeLiteral(true), "nullIndicator"))); // make the new aggRel - rightInputRel = - RelOptUtil.createSingleValueAggRel(cluster, rightInputRel); + right = + RelOptUtil.createSingleValueAggRel(cluster, right); // The last field: // single_value(true) // is the nullIndicator nullIndicatorPos = - leftInputRel.getRowType().getFieldCount() - + rightInputRel.getRowType().getFieldCount() - 1; + left.getRowType().getFieldCount() + + right.getRowType().getFieldCount() - 1; } else { return; } // make the new join rel LogicalJoin join = - LogicalJoin.create(leftInputRel, rightInputRel, joinCond, joinType, - ImmutableSet.<String>of()); + LogicalJoin.create(left, right, joinCond, + ImmutableSet.<CorrelationId>of(), joinType); - RelNode newProjRel = - projectJoinOutputWithNullability(join, projRel, nullIndicatorPos); + RelNode newProject = + projectJoinOutputWithNullability(join, project, nullIndicatorPos); - call.transformTo(newProjRel); + call.transformTo(newProject); - removeCorVarFromTree(corRel); + removeCorVarFromTree(correlate); } } @@ -2005,15 +1843,15 @@ public class RelDecorrelator implements ReflectiveVisitor { } public void onMatch(RelOptRuleCall call) { - LogicalCorrelate corRel = call.rel(0); - RelNode leftInputRel = call.rel(1); - LogicalProject aggOutputProjRel = call.rel(2); - LogicalAggregate aggRel = call.rel(3); - LogicalProject aggInputProjRel = call.rel(4); - RelNode rightInputRel = call.rel(5); - RelOptCluster cluster = corRel.getCluster(); + final LogicalCorrelate correlate = call.rel(0); + final RelNode left = call.rel(1); + final LogicalProject aggOutputProject = call.rel(2); + final LogicalAggregate aggregate = call.rel(3); + final LogicalProject aggInputProject = call.rel(4); + RelNode right = call.rel(5); + final RelOptCluster cluster = correlate.getCluster(); - setCurrent(call.getPlanner().getRoot(), corRel); + setCurrent(call.getPlanner().getRoot(), correlate); // check for this pattern // The pattern matching could be simplified if rules can be applied @@ -2026,13 +1864,13 @@ public class RelDecorrelator implements ReflectiveVisitor { // LogicalProject-B (references coVar) // rightInputRel - // check aggOutputProj projects only one expression - List<RexNode> aggOutputProjExprs = aggOutputProjRel.getProjects(); - if (aggOutputProjExprs.size() != 1) { + // check aggOutputProject projects only one expression + final List<RexNode> aggOutputProjects = aggOutputProject.getProjects(); + if (aggOutputProjects.size() != 1) { return; } - JoinRelType joinType = corRel.getJoinType().toJoinType(); + final JoinRelType joinType = correlate.getJoinType().toJoinType(); // corRel.getCondition was here, however Correlate was updated so it // never includes a join condition. The code was not modified for brevity. RexNode joinCond = rexBuilder.makeLiteral(true); @@ -2042,14 +1880,14 @@ public class RelDecorrelator implements ReflectiveVisitor { } // check that the agg is on the entire input - if (!aggRel.getGroupSet().isEmpty()) { + if (!aggregate.getGroupSet().isEmpty()) { return; } - List<RexNode> aggInputProjExprs = aggInputProjRel.getProjects(); + final List<RexNode> aggInputProjects = aggInputProject.getProjects(); - List<AggregateCall> aggCalls = aggRel.getAggCallList(); - Set<Integer> isCountStar = Sets.newHashSet(); + final List<AggregateCall> aggCalls = aggregate.getAggCallList(); + final Set<Integer> isCountStar = Sets.newHashSet(); // mark if agg produces count(*) which needs to reference the // nullIndicator after the transformation. @@ -2062,20 +1900,20 @@ public class RelDecorrelator implements ReflectiveVisitor { } } - if ((rightInputRel instanceof LogicalFilter) - && cm.mapRefRelToCorVar.containsKey(rightInputRel)) { + if ((right instanceof LogicalFilter) + && cm.mapRefRelToCorVar.containsKey(right)) { // rightInputRel has this shape: // // LogicalFilter (references corvar) // FilterInputRel - LogicalFilter filter = (LogicalFilter) rightInputRel; - rightInputRel = filter.getInput(); + LogicalFilter filter = (LogicalFilter) right; + right = filter.getInput(); - assert rightInputRel instanceof HepRelVertex; - rightInputRel = ((HepRelVertex) rightInputRel).getCurrentRel(); + assert right instanceof HepRelVertex; + right = ((HepRelVertex) right).getCurrentRel(); // check filter input contains no correlation - if (RelOptUtil.getVariablesUsed(rightInputRel).size() > 0) { + if (RelOptUtil.getVariablesUsed(right).size() > 0) { return; } @@ -2119,17 +1957,17 @@ public class RelDecorrelator implements ReflectiveVisitor { // The join filters out the nulls. So, it's ok if there are // nulls in the join keys. if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered( - leftInputRel, + left, correlatedInputRefJoinKeys)) { SQL2REL_LOGGER.fine(correlatedJoinKeys.toString() + "are not unique keys for " - + leftInputRel.toString()); + + left.toString()); return; } // check cor var references are valid - if (!checkCorVars(corRel, - aggInputProjRel, + if (!checkCorVars(correlate, + aggInputProject, filter, correlatedJoinKeys)) { return; @@ -2180,27 +2018,27
<TRUNCATED>
