[CALCITE-1456] Change SubstitutionVisitor to use generic RelBuilder instead of Logical instances of the operators when possible
Project: http://git-wip-us.apache.org/repos/asf/calcite/repo Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/e9d0ca67 Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/e9d0ca67 Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/e9d0ca67 Branch: refs/heads/master Commit: e9d0ca6731b2f5ec33b9270b5ffcaaf9e4eb8537 Parents: 9a691a7 Author: Jesus Camacho Rodriguez <[email protected]> Authored: Wed Apr 26 19:09:19 2017 +0100 Committer: Jesus Camacho Rodriguez <[email protected]> Committed: Wed Apr 26 20:05:03 2017 +0100 ---------------------------------------------------------------------- .../MaterializedViewSubstitutionVisitor.java | 6 ++ .../calcite/plan/SubstitutionVisitor.java | 32 +++++--- .../apache/calcite/rel/mutable/MutableRels.java | 80 +++++++++++--------- 3 files changed, 75 insertions(+), 43 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/calcite/blob/e9d0ca67/core/src/main/java/org/apache/calcite/plan/MaterializedViewSubstitutionVisitor.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/plan/MaterializedViewSubstitutionVisitor.java b/core/src/main/java/org/apache/calcite/plan/MaterializedViewSubstitutionVisitor.java index c1e0e37..a2ff5f4 100644 --- a/core/src/main/java/org/apache/calcite/plan/MaterializedViewSubstitutionVisitor.java +++ b/core/src/main/java/org/apache/calcite/plan/MaterializedViewSubstitutionVisitor.java @@ -25,6 +25,7 @@ import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.tools.RelBuilderFactory; import com.google.common.collect.ImmutableList; @@ -46,6 +47,11 @@ public class MaterializedViewSubstitutionVisitor extends SubstitutionVisitor { super(target_, query_, EXTENDED_RULES); } + public MaterializedViewSubstitutionVisitor(RelNode target_, RelNode query_, + RelBuilderFactory relBuilderFactory) { + super(target_, query_, EXTENDED_RULES, relBuilderFactory); + } + public List<RelNode> go(RelNode replacement_) { return super.go(replacement_); } http://git-wip-us.apache.org/repos/asf/calcite/blob/e9d0ca67/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java b/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java index f7bf106..fca0c56 100644 --- a/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java +++ b/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java @@ -21,6 +21,7 @@ import org.apache.calcite.prepare.CalcitePrepareImpl; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.logical.LogicalFilter; import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.mutable.Holder; @@ -46,6 +47,8 @@ import org.apache.calcite.rex.RexUtil; import org.apache.calcite.runtime.PredicateImpl; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilderFactory; import org.apache.calcite.util.Bug; import org.apache.calcite.util.ControlFlowException; import org.apache.calcite.util.ImmutableBitSet; @@ -127,6 +130,11 @@ public class SubstitutionVisitor { AggregateToAggregateUnifyRule.INSTANCE, AggregateOnProjectToAggregateUnifyRule.INSTANCE); + /** + * Factory for a builder for relational expressions. + */ + protected final RelBuilder relBuilder; + private final ImmutableList<UnifyRule> rules; private final Map<Pair<Class, Class>, List<UnifyRule>> ruleMap = new HashMap<>(); @@ -157,12 +165,17 @@ public class SubstitutionVisitor { /** Creates a SubstitutionVisitor with the default rule set. */ public SubstitutionVisitor(RelNode target_, RelNode query_) { - this(target_, query_, DEFAULT_RULES); + this(target_, query_, DEFAULT_RULES, RelFactories.LOGICAL_BUILDER); } - /** Creates a SubstitutionVisitor. */ + /** Creates a SubstitutionVisitor with the default logical builder. */ public SubstitutionVisitor(RelNode target_, RelNode query_, ImmutableList<UnifyRule> rules) { + this(target_, query_, rules, RelFactories.LOGICAL_BUILDER); + } + + public SubstitutionVisitor(RelNode target_, RelNode query_, + ImmutableList<UnifyRule> rules, RelBuilderFactory relBuilderFactory) { this.cluster = target_.getCluster(); final RexExecutor executor = Util.first(cluster.getPlanner().getExecutor(), RexUtil.EXECUTOR); @@ -170,6 +183,7 @@ public class SubstitutionVisitor { this.rules = rules; this.query = Holder.of(MutableRels.toMutable(query_)); this.target = MutableRels.toMutable(target_); + this.relBuilder = relBuilderFactory.create(cluster, null); final Set<MutableRel> parents = Sets.newIdentityHashSet(); final List<MutableRel> allNodes = new ArrayList<>(); final MutableRelVisitor visitor = @@ -395,7 +409,7 @@ public class SubstitutionVisitor { + "\nnode:\n" + node.deep()); } - return MutableRels.fromMutable(node); + return MutableRels.fromMutable(node, relBuilder); } /** @@ -412,8 +426,8 @@ public class SubstitutionVisitor { return ImmutableList.of(); } List<RelNode> sub = Lists.newArrayList(); - sub.add(MutableRels.fromMutable(query.getInput())); - reverseSubstitute(query, matches, sub, 0, matches.size()); + sub.add(MutableRels.fromMutable(query.getInput(), relBuilder)); + reverseSubstitute(relBuilder, query, matches, sub, 0, matches.size()); return sub; } @@ -594,19 +608,19 @@ public class SubstitutionVisitor { } } - private static void reverseSubstitute(Holder query, + private static void reverseSubstitute(RelBuilder relBuilder, Holder query, List<List<Replacement>> matches, List<RelNode> sub, int replaceCount, int maxCount) { if (matches.isEmpty()) { return; } final List<List<Replacement>> rem = matches.subList(1, matches.size()); - reverseSubstitute(query, rem, sub, replaceCount, maxCount); + reverseSubstitute(relBuilder, query, rem, sub, replaceCount, maxCount); undoReplacement(matches.get(0)); if (++replaceCount < maxCount) { - sub.add(MutableRels.fromMutable(query.getInput())); + sub.add(MutableRels.fromMutable(query.getInput(), relBuilder)); } - reverseSubstitute(query, rem, sub, replaceCount, maxCount); + reverseSubstitute(relBuilder, query, rem, sub, replaceCount, maxCount); redoReplacement(matches.get(0)); } http://git-wip-us.apache.org/repos/asf/calcite/blob/e9d0ca67/core/src/main/java/org/apache/calcite/rel/mutable/MutableRels.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/rel/mutable/MutableRels.java b/core/src/main/java/org/apache/calcite/rel/mutable/MutableRels.java index ed1f1b1..b07da0e 100644 --- a/core/src/main/java/org/apache/calcite/rel/mutable/MutableRels.java +++ b/core/src/main/java/org/apache/calcite/rel/mutable/MutableRels.java @@ -30,6 +30,7 @@ import org.apache.calcite.rel.core.Intersect; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.Minus; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.core.Sample; import org.apache.calcite.rel.core.SemiJoin; import org.apache.calcite.rel.core.Sort; @@ -40,24 +41,18 @@ import org.apache.calcite.rel.core.Uncollect; import org.apache.calcite.rel.core.Union; import org.apache.calcite.rel.core.Values; import org.apache.calcite.rel.core.Window; -import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.logical.LogicalCalc; import org.apache.calcite.rel.logical.LogicalCorrelate; import org.apache.calcite.rel.logical.LogicalExchange; -import org.apache.calcite.rel.logical.LogicalFilter; -import org.apache.calcite.rel.logical.LogicalIntersect; -import org.apache.calcite.rel.logical.LogicalJoin; -import org.apache.calcite.rel.logical.LogicalMinus; -import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.logical.LogicalSort; import org.apache.calcite.rel.logical.LogicalTableFunctionScan; import org.apache.calcite.rel.logical.LogicalTableModify; -import org.apache.calcite.rel.logical.LogicalUnion; import org.apache.calcite.rel.logical.LogicalWindow; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.Util; import org.apache.calcite.util.mapping.Mappings; @@ -177,97 +172,114 @@ public abstract class MutableRels { } public static RelNode fromMutable(MutableRel node) { + return fromMutable(node, RelFactories.LOGICAL_BUILDER.create(node.cluster, null)); + } + + public static RelNode fromMutable(MutableRel node, RelBuilder relBuilder) { switch (node.type) { case TABLE_SCAN: case VALUES: return ((MutableLeafRel) node).rel; case PROJECT: final MutableProject project = (MutableProject) node; - return LogicalProject.create( - fromMutable(project.input), project.projects, project.rowType); + relBuilder.push(fromMutable(project.input, relBuilder)); + relBuilder.project(project.projects, project.rowType.getFieldNames(), true); + return relBuilder.build(); case FILTER: final MutableFilter filter = (MutableFilter) node; - return LogicalFilter.create(fromMutable(filter.input), - filter.condition); + relBuilder.push(fromMutable(filter.input, relBuilder)); + relBuilder.filter(filter.condition); + return relBuilder.build(); case AGGREGATE: final MutableAggregate aggregate = (MutableAggregate) node; - return LogicalAggregate.create(fromMutable(aggregate.input), - aggregate.indicator, aggregate.groupSet, aggregate.groupSets, + relBuilder.push(fromMutable(aggregate.input, relBuilder)); + relBuilder.aggregate( + relBuilder.groupKey(aggregate.groupSet, aggregate.indicator, aggregate.groupSets), aggregate.aggCalls); + return relBuilder.build(); case SORT: final MutableSort sort = (MutableSort) node; - return LogicalSort.create(fromMutable(sort.input), sort.collation, + return LogicalSort.create(fromMutable(sort.input, relBuilder), sort.collation, sort.offset, sort.fetch); case CALC: final MutableCalc calc = (MutableCalc) node; - return LogicalCalc.create(fromMutable(calc.input), calc.program); + return LogicalCalc.create(fromMutable(calc.input, relBuilder), calc.program); case EXCHANGE: final MutableExchange exchange = (MutableExchange) node; return LogicalExchange.create( - fromMutable(exchange.getInput()), exchange.distribution); + fromMutable(exchange.getInput(), relBuilder), exchange.distribution); case COLLECT: { final MutableCollect collect = (MutableCollect) node; - final RelNode child = fromMutable(collect.getInput()); + final RelNode child = fromMutable(collect.getInput(), relBuilder); return new Collect(collect.cluster, child.getTraitSet(), child, collect.fieldName); } case UNCOLLECT: { final MutableUncollect uncollect = (MutableUncollect) node; - final RelNode child = fromMutable(uncollect.getInput()); + final RelNode child = fromMutable(uncollect.getInput(), relBuilder); return Uncollect.create(child.getTraitSet(), child, uncollect.withOrdinality); } case WINDOW: { final MutableWindow window = (MutableWindow) node; - final RelNode child = fromMutable(window.getInput()); + final RelNode child = fromMutable(window.getInput(), relBuilder); return LogicalWindow.create(child.getTraitSet(), child, window.constants, window.rowType, window.groups); } case TABLE_MODIFY: final MutableTableModify modify = (MutableTableModify) node; return LogicalTableModify.create(modify.table, modify.catalogReader, - fromMutable(modify.getInput()), modify.operation, modify.updateColumnList, + fromMutable(modify.getInput(), relBuilder), modify.operation, modify.updateColumnList, modify.sourceExpressionList, modify.flattened); case SAMPLE: final MutableSample sample = (MutableSample) node; - return new Sample(sample.cluster, fromMutable(sample.getInput()), sample.params); + return new Sample(sample.cluster, fromMutable(sample.getInput(), relBuilder), sample.params); case TABLE_FUNCTION_SCAN: final MutableTableFunctionScan tableFunctionScan = (MutableTableFunctionScan) node; return LogicalTableFunctionScan.create(tableFunctionScan.cluster, - fromMutables(tableFunctionScan.getInputs()), tableFunctionScan.rexCall, + fromMutables(tableFunctionScan.getInputs(), relBuilder), tableFunctionScan.rexCall, tableFunctionScan.elementType, tableFunctionScan.rowType, tableFunctionScan.columnMappings); case JOIN: final MutableJoin join = (MutableJoin) node; - return LogicalJoin.create(fromMutable(join.getLeft()), fromMutable(join.getRight()), - join.condition, join.variablesSet, join.joinType); + relBuilder.push(fromMutable(join.getLeft(), relBuilder)); + relBuilder.push(fromMutable(join.getRight(), relBuilder)); + relBuilder.join(join.joinType, join.condition, join.variablesSet); + return relBuilder.build(); case SEMIJOIN: final MutableSemiJoin semiJoin = (MutableSemiJoin) node; - return SemiJoin.create(fromMutable(semiJoin.getLeft()), - fromMutable(semiJoin.getRight()), semiJoin.condition, - semiJoin.leftKeys, semiJoin.rightKeys); + relBuilder.push(fromMutable(semiJoin.getLeft(), relBuilder)); + relBuilder.push(fromMutable(semiJoin.getRight(), relBuilder)); + relBuilder.semiJoin(semiJoin.condition); + return relBuilder.build(); case CORRELATE: final MutableCorrelate correlate = (MutableCorrelate) node; - return LogicalCorrelate.create(fromMutable(correlate.getLeft()), - fromMutable(correlate.getRight()), correlate.correlationId, + return LogicalCorrelate.create(fromMutable(correlate.getLeft(), relBuilder), + fromMutable(correlate.getRight(), relBuilder), correlate.correlationId, correlate.requiredColumns, correlate.joinType); case UNION: final MutableUnion union = (MutableUnion) node; - return LogicalUnion.create(MutableRels.fromMutables(union.inputs), union.all); + relBuilder.pushAll(MutableRels.fromMutables(union.inputs, relBuilder)); + relBuilder.union(union.all, union.inputs.size()); + return relBuilder.build(); case MINUS: final MutableMinus minus = (MutableMinus) node; - return LogicalMinus.create(MutableRels.fromMutables(minus.inputs), minus.all); + relBuilder.pushAll(MutableRels.fromMutables(minus.inputs, relBuilder)); + relBuilder.minus(minus.all, minus.inputs.size()); + return relBuilder.build(); case INTERSECT: final MutableIntersect intersect = (MutableIntersect) node; - return LogicalIntersect.create(MutableRels.fromMutables(intersect.inputs), intersect.all); + relBuilder.pushAll(MutableRels.fromMutables(intersect.inputs, relBuilder)); + relBuilder.intersect(intersect.all, intersect.inputs.size()); + return relBuilder.build(); default: throw new AssertionError(node.deep()); } } - private static List<RelNode> fromMutables(List<MutableRel> nodes) { + private static List<RelNode> fromMutables(List<MutableRel> nodes, final RelBuilder relBuilder) { return Lists.transform(nodes, new Function<MutableRel, RelNode>() { public RelNode apply(MutableRel mutableRel) { - return fromMutable(mutableRel); + return fromMutable(mutableRel, relBuilder); } }); }
