Simplify RexProgram, in particular "(NOT CASE ... END) IS TRUE", which occurs in NOT IN
Project: http://git-wip-us.apache.org/repos/asf/calcite/repo Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/7837e64c Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/7837e64c Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/7837e64c Branch: refs/heads/master Commit: 7837e64c3294aa776ca38cba3f756550af3494ab Parents: 4762b88 Author: Julian Hyde <[email protected]> Authored: Wed Aug 19 16:25:30 2015 -0700 Committer: Julian Hyde <[email protected]> Committed: Sun Jan 10 00:51:24 2016 -0800 ---------------------------------------------------------------------- .../adapter/enumerable/EnumerableCalc.java | 3 + .../calcite/adapter/enumerable/RexImpTable.java | 2 +- .../calcite/rel/rules/ProjectToWindowRule.java | 4 +- .../rel/rules/ReduceExpressionsRule.java | 2 +- .../java/org/apache/calcite/rex/RexCall.java | 10 ++ .../java/org/apache/calcite/rex/RexProgram.java | 41 ++++--- .../apache/calcite/rex/RexProgramBuilder.java | 111 +++++++++-------- .../java/org/apache/calcite/rex/RexUtil.java | 118 ++++++++++++++++--- .../java/org/apache/calcite/sql/SqlKind.java | 26 ++++ .../java/org/apache/calcite/sql/SqlMerge.java | 9 +- .../main/java/org/apache/calcite/util/Pair.java | 8 +- .../org/apache/calcite/test/RexProgramTest.java | 117 ++++++++++++++++-- 12 files changed, 340 insertions(+), 111 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableCalc.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableCalc.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableCalc.java index 603c7b1..ce1f642 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableCalc.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableCalc.java @@ -144,6 +144,9 @@ public class EnumerableCalc extends Calc implements EnumerableRel { BuiltInMethod.ENUMERATOR_CURRENT.method), inputJavaType); + final RexProgram program = + this.program.normalize(getCluster().getRexBuilder(), true); + BlockStatement moveNextBody; if (program.getCondition() == null) { moveNextBody = http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java index 20872b6..511584b 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java @@ -1887,7 +1887,7 @@ public class RexImpTable { negate == seek, translator.translate( operands.get(0), - seek ? NullAs.FALSE : NullAs.TRUE)); + negate == seek ? NullAs.TRUE : NullAs.FALSE)); } } } http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/main/java/org/apache/calcite/rel/rules/ProjectToWindowRule.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/rel/rules/ProjectToWindowRule.java b/core/src/main/java/org/apache/calcite/rel/rules/ProjectToWindowRule.java index fe6334e..ed4f610 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/ProjectToWindowRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/ProjectToWindowRule.java @@ -35,7 +35,6 @@ import org.apache.calcite.rex.RexLocalRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexOver; import org.apache.calcite.rex.RexProgram; -import org.apache.calcite.rex.RexProgramBuilder; import org.apache.calcite.rex.RexVisitorImpl; import org.apache.calcite.rex.RexWindow; import org.apache.calcite.tools.RelBuilder; @@ -199,8 +198,7 @@ public abstract class ProjectToWindowRule extends RelOptRule { protected RelNode makeRel(RelOptCluster cluster, RelTraitSet traitSet, RelBuilder relBuilder, RelNode input, RexProgram program) { assert !program.containsAggs(); - program = RexProgramBuilder.normalize(cluster.getRexBuilder(), - program); + program = program.normalize(cluster.getRexBuilder(), false); return super.makeRel(cluster, traitSet, relBuilder, input, program); } }, http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/main/java/org/apache/calcite/rel/rules/ReduceExpressionsRule.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/rel/rules/ReduceExpressionsRule.java b/core/src/main/java/org/apache/calcite/rel/rules/ReduceExpressionsRule.java index 17d5cb0..759748c 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/ReduceExpressionsRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/ReduceExpressionsRule.java @@ -663,7 +663,7 @@ public abstract class ReduceExpressionsRule extends RelOptRule { * <p>We have a loose definition of 'predicate': any boolean expression will * do, except CASE. For example '(CASE ...) = 5' or '(CASE ...) IS NULL'. */ - protected static RexCall pushPredicateIntoCase(RexCall call) { + public static RexCall pushPredicateIntoCase(RexCall call) { if (call.getType().getSqlTypeName() != SqlTypeName.BOOLEAN) { return call; } http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/main/java/org/apache/calcite/rex/RexCall.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/rex/RexCall.java b/core/src/main/java/org/apache/calcite/rex/RexCall.java index 8271005..7cf8255 100644 --- a/core/src/main/java/org/apache/calcite/rex/RexCall.java +++ b/core/src/main/java/org/apache/calcite/rex/RexCall.java @@ -118,6 +118,11 @@ public class RexCall extends RexNode { switch (getKind()) { case IS_NOT_NULL: return !operands.get(0).getType().isNullable(); + case IS_NOT_FALSE: + case NOT: + return operands.get(0).isAlwaysFalse(); + case IS_NOT_TRUE: + case IS_FALSE: case CAST: return operands.get(0).isAlwaysTrue(); default: @@ -129,6 +134,11 @@ public class RexCall extends RexNode { switch (getKind()) { case IS_NULL: return !operands.get(0).getType().isNullable(); + case IS_NOT_TRUE: + case NOT: + return operands.get(0).isAlwaysTrue(); + case IS_NOT_FALSE: + case IS_TRUE: case CAST: return operands.get(0).isAlwaysFalse(); default: http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/main/java/org/apache/calcite/rex/RexProgram.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/rex/RexProgram.java b/core/src/main/java/org/apache/calcite/rex/RexProgram.java index 58cdc40..664c092 100644 --- a/core/src/main/java/org/apache/calcite/rex/RexProgram.java +++ b/core/src/main/java/org/apache/calcite/rex/RexProgram.java @@ -42,6 +42,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Set; /** * A collection of expressions which read inputs, compute output expressions, @@ -351,8 +352,8 @@ public class RexProgram { "field type mismatch: " + rowType + " vs. " + outputRowType); } final List<RelDataTypeField> fields = rowType.getFieldList(); - final List<RexLocalRef> projectRefs = new ArrayList<RexLocalRef>(); - final List<RexInputRef> refs = new ArrayList<RexInputRef>(); + final List<RexLocalRef> projectRefs = new ArrayList<>(); + final List<RexInputRef> refs = new ArrayList<>(); for (int i = 0; i < fields.size(); i++) { final RexInputRef ref = RexInputRef.of(i, fields); refs.add(ref); @@ -462,14 +463,14 @@ public class RexProgram { return litmus.fail(null); } } - for (int i = 0; i < projects.size(); i++) { - projects.get(i).accept(checker); + for (RexLocalRef project : projects) { + project.accept(checker); if (checker.failCount > 0) { return litmus.fail(null); } } - for (int i = 0; i < exprs.size(); i++) { - exprs.get(i).accept(checker); + for (RexNode expr : exprs) { + expr.accept(checker); if (checker.failCount > 0) { return litmus.fail(null); } @@ -511,10 +512,7 @@ public class RexProgram { * @return expanded form */ public RexNode expandLocalRef(RexLocalRef ref) { - // TODO jvs 19-Apr-2006: assert that ref is part of - // this program - ExpansionShuttle shuttle = new ExpansionShuttle(); - return ref.accept(shuttle); + return ref.accept(new ExpansionShuttle(exprs)); } /** Splits this program into a list of project expressions and a list of @@ -540,7 +538,7 @@ public class RexProgram { * mutable. */ public List<RelCollation> getCollations(List<RelCollation> inputCollations) { - List<RelCollation> outputCollations = new ArrayList<RelCollation>(1); + List<RelCollation> outputCollations = new ArrayList<>(1); deduceCollations( outputCollations, inputRowType.getFieldCount(), projects, @@ -568,8 +566,7 @@ public class RexProgram { } loop: for (RelCollation collation : inputCollations) { - final ArrayList<RelFieldCollation> fieldCollations = - new ArrayList<RelFieldCollation>(0); + final List<RelFieldCollation> fieldCollations = new ArrayList<>(0); for (RelFieldCollation fieldCollation : collation.getFieldCollations()) { final int source = fieldCollation.getFieldIndex(); final int target = targets[source]; @@ -741,8 +738,8 @@ public class RexProgram { * * @return set of correlation variable names */ - public HashSet<String> getCorrelVariableNames() { - final HashSet<String> paramIdSet = new HashSet<String>(); + public Set<String> getCorrelVariableNames() { + final Set<String> paramIdSet = new HashSet<>(); RexUtil.apply( new RexVisitorImpl<Void>(true) { public Void visitCorrelVariable( @@ -790,7 +787,7 @@ public class RexProgram { assert isValid(Litmus.THROW); final RexProgramBuilder builder = RexProgramBuilder.create(rexBuilder, inputRowType, exprs, projects, - condition, outputRowType, simplify); + condition, outputRowType, true, simplify); return builder.getProgram(false); } @@ -840,9 +837,15 @@ public class RexProgram { * A RexShuttle used in the implementation of * {@link RexProgram#expandLocalRef}. */ - private class ExpansionShuttle extends RexShuttle { + static class ExpansionShuttle extends RexShuttle { + private final List<RexNode> exprs; + + public ExpansionShuttle(List<RexNode> exprs) { + this.exprs = exprs; + } + public RexNode visitLocalRef(RexLocalRef localRef) { - RexNode tree = getExprList().get(localRef.getIndex()); + RexNode tree = exprs.get(localRef.getIndex()); return tree.accept(this); } } @@ -886,7 +889,7 @@ public class RexProgram { } public RexNode visitCall(RexCall call) { - final List<RexNode> newOperands = new ArrayList<RexNode>(); + final List<RexNode> newOperands = new ArrayList<>(); for (RexNode operand : call.getOperands()) { newOperands.add(operand.accept(this)); } http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/main/java/org/apache/calcite/rex/RexProgramBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/rex/RexProgramBuilder.java b/core/src/main/java/org/apache/calcite/rex/RexProgramBuilder.java index c292e56..5f7d3ef 100644 --- a/core/src/main/java/org/apache/calcite/rex/RexProgramBuilder.java +++ b/core/src/main/java/org/apache/calcite/rex/RexProgramBuilder.java @@ -41,13 +41,12 @@ public class RexProgramBuilder { private final RexBuilder rexBuilder; private final RelDataType inputRowType; - private final List<RexNode> exprList = new ArrayList<RexNode>(); + private final List<RexNode> exprList = new ArrayList<>(); private final Map<Pair<String, String>, RexLocalRef> exprMap = - new HashMap<Pair<String, String>, RexLocalRef>(); - private final List<RexLocalRef> localRefList = new ArrayList<RexLocalRef>(); - private final List<RexLocalRef> projectRefList = - new ArrayList<RexLocalRef>(); - private final List<String> projectNameList = new ArrayList<String>(); + new HashMap<>(); + private final List<RexLocalRef> localRefList = new ArrayList<>(); + private final List<RexLocalRef> projectRefList = new ArrayList<>(); + private final List<String> projectNameList = new ArrayList<>(); private RexLocalRef conditionRef = null; private boolean validating; @@ -78,19 +77,21 @@ public class RexProgramBuilder { * @param rexBuilder Rex builder * @param inputRowType Input row type * @param exprList Common expressions - * @param projectRefList Projections - * @param conditionRef Condition, or null + * @param projectList Projections + * @param condition Condition, or null * @param outputRowType Output row type * @param normalize Whether to normalize + * @param simplify Whether to simplify */ private RexProgramBuilder( RexBuilder rexBuilder, final RelDataType inputRowType, final List<RexNode> exprList, - final List<RexLocalRef> projectRefList, - final RexLocalRef conditionRef, + final Iterable<? extends RexNode> projectList, + RexNode condition, final RelDataType outputRowType, - boolean normalize) { + boolean normalize, + boolean simplify) { this(inputRowType, rexBuilder); // Create a shuttle for registering input expressions. @@ -106,24 +107,38 @@ public class RexProgramBuilder { } } + final RexShuttle expander = new RexProgram.ExpansionShuttle(exprList); + // Register project expressions // and create a named project item. final List<RelDataTypeField> fieldList = outputRowType.getFieldList(); - for (Pair<RexLocalRef, RelDataTypeField> pair - : Pair.zip(projectRefList, fieldList)) { - final RexLocalRef projectRef = pair.left; + for (Pair<? extends RexNode, RelDataTypeField> pair + : Pair.zip(projectList, fieldList)) { + final RexNode project; + if (simplify) { + project = RexUtil.simplify(rexBuilder, pair.left.accept(expander)); + } else { + project = pair.left; + } final String name = pair.right.getName(); - final int oldIndex = projectRef.getIndex(); - final RexNode expr = exprList.get(oldIndex); - final RexLocalRef ref = (RexLocalRef) expr.accept(shuttle); + final RexLocalRef ref = (RexLocalRef) project.accept(shuttle); addProject(ref.getIndex(), name); } // Register the condition, if there is one. - if (conditionRef != null) { - final RexNode expr = exprList.get(conditionRef.getIndex()); - final RexLocalRef ref = (RexLocalRef) expr.accept(shuttle); - addCondition(ref); + if (condition != null) { + if (simplify) { + condition = RexUtil.simplify(rexBuilder, + rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE, + condition.accept(expander))); + if (condition.isAlwaysTrue()) { + condition = null; + } + } + if (condition != null) { + final RexLocalRef ref = (RexLocalRef) condition.accept(shuttle); + addCondition(ref); + } } } @@ -463,7 +478,8 @@ public class RexProgramBuilder { projectRefs, conditionRef, outputRowType, - normalize); + normalize, + false); } /** @@ -494,28 +510,37 @@ public class RexProgramBuilder { * @param rexBuilder Rex builder * @param inputRowType Input row type * @param exprList Common expressions - * @param projectRefList Projections - * @param conditionRef Condition, or null + * @param projectList Projections + * @param condition Condition, or null * @param outputRowType Output row type * @param normalize Whether to normalize + * @param simplify Whether to simplify expressions * @return A program builder */ public static RexProgramBuilder create( RexBuilder rexBuilder, final RelDataType inputRowType, final List<RexNode> exprList, - final List<RexLocalRef> projectRefList, - final RexLocalRef conditionRef, + final List<? extends RexNode> projectList, + final RexNode condition, + final RelDataType outputRowType, + boolean normalize, + boolean simplify) { + return new RexProgramBuilder(rexBuilder, inputRowType, exprList, + projectList, condition, outputRowType, normalize, simplify); + } + + @Deprecated // to be removed before 2.0 + public static RexProgramBuilder create( + RexBuilder rexBuilder, + final RelDataType inputRowType, + final List<RexNode> exprList, + final List<? extends RexNode> projectList, + final RexNode condition, final RelDataType outputRowType, boolean normalize) { - return new RexProgramBuilder( - rexBuilder, - inputRowType, - exprList, - projectRefList, - conditionRef, - outputRowType, - normalize); + return create(rexBuilder, inputRowType, exprList, projectList, condition, + outputRowType, normalize, false); } /** @@ -557,21 +582,11 @@ public class RexProgramBuilder { return progBuilder; } - /** - * Normalizes a program. - * - * @param rexBuilder Rex builder - * @param program Program - * @return Normalized program - */ + @Deprecated // to be removed before 2.0 public static RexProgram normalize( RexBuilder rexBuilder, RexProgram program) { - // Normalize program by creating program builder from the program, then - // converting to a program. getProgram does not need to normalize - // because the builder was normalized on creation. - return forProgram(program, rexBuilder, true) - .getProgram(false); + return program.normalize(rexBuilder, false); } /** @@ -601,7 +616,7 @@ public class RexProgramBuilder { // register the result. // REVIEW jpham 28-Apr-2006: if the user shuttle rewrites an input // expression, then input references may change - List<RexLocalRef> newRefs = new ArrayList<RexLocalRef>(exprList.size()); + List<RexLocalRef> newRefs = new ArrayList<>(exprList.size()); RexShuttle refShuttle = new UpdateRefShuttle(newRefs); int i = 0; for (RexNode expr : exprList) { @@ -753,7 +768,7 @@ public class RexProgramBuilder { private List<RexLocalRef> registerProjectsAndCondition(RexProgram program) { final List<RexNode> exprList = program.getExprList(); - final List<RexLocalRef> projectRefList = new ArrayList<RexLocalRef>(); + final List<RexLocalRef> projectRefList = new ArrayList<>(); final RexShuttle shuttle = new RegisterOutputShuttle(exprList); // For each project, lookup the expr and expand it so it is in terms of http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/main/java/org/apache/calcite/rex/RexUtil.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/rex/RexUtil.java b/core/src/main/java/org/apache/calcite/rex/RexUtil.java index 8e8a79c..b2de640 100644 --- a/core/src/main/java/org/apache/calcite/rex/RexUtil.java +++ b/core/src/main/java/org/apache/calcite/rex/RexUtil.java @@ -1087,9 +1087,10 @@ public class RexUtil { /** Converts an expression to disjunctive normal form (DNF). * - * <p>DNF: It is a form of logical formula which is disjunction of conjunctive clauses</p> + * <p>DNF: It is a form of logical formula which is disjunction of conjunctive + * clauses. * - * <p>All logicl formulas can be converted into DNF.</p> + * <p>All logical formulas can be converted into DNF. * * <p>The following expression is in DNF: * @@ -1280,19 +1281,100 @@ public class RexUtil { return simplifyAnd(rexBuilder, (RexCall) e); case OR: return simplifyOr(rexBuilder, (RexCall) e); + case NOT: + return simplifyNot(rexBuilder, (RexCall) e); case CASE: return simplifyCase(rexBuilder, (RexCall) e); + } + switch (e.getKind()) { case IS_NULL: - return ((RexCall) e).getOperands().get(0).getType().isNullable() - ? e : rexBuilder.makeLiteral(false); case IS_NOT_NULL: - return ((RexCall) e).getOperands().get(0).getType().isNullable() - ? e : rexBuilder.makeLiteral(true); + case IS_TRUE: + case IS_NOT_TRUE: + case IS_FALSE: + case IS_NOT_FALSE: + assert e instanceof RexCall; + return simplifyIs(rexBuilder, (RexCall) e); default: return e; } } + private static RexNode simplifyNot(RexBuilder rexBuilder, RexCall call) { + final RexNode a = call.getOperands().get(0); + switch (a.getKind()) { + case NOT: + // NOT NOT x ==> x + return simplify(rexBuilder, ((RexCall) a).getOperands().get(0)); + } + final SqlKind negateKind = a.getKind().negate(); + if (a.getKind() != negateKind) { + return simplify(rexBuilder, + rexBuilder.makeCall(op(negateKind), + ImmutableList.of(((RexCall) a).getOperands().get(0)))); + } + return call; + } + + private static RexNode simplifyIs(RexBuilder rexBuilder, RexCall call) { + final SqlKind kind = call.getKind(); + final RexNode a = call.getOperands().get(0); + if (!a.getType().isNullable()) { + switch (kind) { + case IS_NULL: + case IS_NOT_NULL: + // x IS NULL ==> FALSE (if x is not nullable) + // x IS NOT NULL ==> TRUE (if x is not nullable) + return rexBuilder.makeLiteral(kind == SqlKind.IS_NOT_NULL); + case IS_TRUE: + case IS_NOT_FALSE: + // x IS TRUE ==> x (if x is not nullable) + // x IS NOT FALSE ==> x (if x is not nullable) + return simplify(rexBuilder, a); + case IS_FALSE: + case IS_NOT_TRUE: + // x IS NOT TRUE ==> NOT x (if x is not nullable) + // x IS FALSE ==> NOT x (if x is not nullable) + return simplify(rexBuilder, + rexBuilder.makeCall(SqlStdOperatorTable.NOT, a)); + } + } + switch (a.getKind()) { + case NOT: + // NOT x IS TRUE ==> x IS NOT TRUE + // Similarly for IS NOT TRUE, IS FALSE, etc. + return simplify(rexBuilder, + rexBuilder.makeCall(op(kind.negate()), + ((RexCall) a).getOperands().get(0))); + } + RexNode a2 = simplify(rexBuilder, a); + if (a != a2) { + return rexBuilder.makeCall(op(kind), ImmutableList.of(a2)); + } + return call; + } + + private static SqlOperator op(SqlKind kind) { + switch (kind) { + case IS_FALSE: + return SqlStdOperatorTable.IS_FALSE; + case IS_TRUE: + return SqlStdOperatorTable.IS_TRUE; + case IS_UNKNOWN: + return SqlStdOperatorTable.IS_UNKNOWN; + case IS_NULL: + return SqlStdOperatorTable.IS_NULL; + case IS_NOT_FALSE: + return SqlStdOperatorTable.IS_NOT_FALSE; + case IS_NOT_TRUE: + return SqlStdOperatorTable.IS_NOT_TRUE; + case IS_NOT_NULL: + return SqlStdOperatorTable.IS_NOT_NULL; + default: + throw new AssertionError(kind); + } + } + private static RexNode simplifyCase(RexBuilder rexBuilder, RexCall call) { final List<RexNode> operands = call.getOperands(); final List<RexNode> newOperands = new ArrayList<>(); @@ -1379,11 +1461,13 @@ public class RexUtil { --i; break; case LITERAL: - if (!RexLiteral.booleanValue(term)) { - return term; // false - } else { - terms.remove(i); - --i; + if (!RexLiteral.isNullLiteral(term)) { + if (!RexLiteral.booleanValue(term)) { + return term; // false + } else { + terms.remove(i); + --i; + } } } } @@ -1420,11 +1504,13 @@ public class RexUtil { final RexNode term = terms.get(i); switch (term.getKind()) { case LITERAL: - if (RexLiteral.booleanValue(term)) { - return term; // true - } else { - terms.remove(i); - --i; + if (!RexLiteral.isNullLiteral(term)) { + if (RexLiteral.booleanValue(term)) { + return term; // true + } else { + terms.remove(i); + --i; + } } } } http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/main/java/org/apache/calcite/sql/SqlKind.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java b/core/src/main/java/org/apache/calcite/sql/SqlKind.java index 78a9d7d..37a757d 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java @@ -907,6 +907,32 @@ public enum SqlKind { } } + /** Returns the kind that you get if you apply NOT to this kind. + * + * <p>For example, {@code IS_NOT_NULL.negate()} returns {@link #IS_NULL}. */ + public SqlKind negate() { + switch (this) { + case IS_TRUE: + return IS_NOT_TRUE; + case IS_FALSE: + return IS_NOT_FALSE; + case IS_NULL: + return IS_NOT_NULL; + case IS_NOT_TRUE: + return IS_TRUE; + case IS_NOT_FALSE: + return IS_FALSE; + case IS_NOT_NULL: + return IS_NULL; + case IS_DISTINCT_FROM: + return IS_NOT_DISTINCT_FROM; + case IS_NOT_DISTINCT_FROM: + return IS_DISTINCT_FROM; + default: + return this; + } + } + /** * Returns whether this {@code SqlKind} belongs to a given category. * http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/main/java/org/apache/calcite/sql/SqlMerge.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql/SqlMerge.java b/core/src/main/java/org/apache/calcite/sql/SqlMerge.java index 77be4e7..af38c39 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlMerge.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlMerge.java @@ -23,7 +23,6 @@ import org.apache.calcite.sql.validate.SqlValidatorScope; import org.apache.calcite.util.ImmutableNullableList; import org.apache.calcite.util.Pair; -import java.util.Iterator; import java.util.List; /** @@ -80,7 +79,8 @@ public class SqlMerge extends SqlCall { @Override public void setOperand(int i, SqlNode operand) { switch (i) { case 0: - targetTable = (SqlIdentifier) operand; + assert operand instanceof SqlIdentifier; + targetTable = operand; break; case 1: condition = operand; @@ -194,11 +194,6 @@ public class SqlMerge extends SqlCall { "SET", ""); - Iterator targetColumnIter = - updateCall.getTargetColumnList().getList().iterator(); - Iterator sourceExpressionIter = - updateCall.getSourceExpressionList().getList().iterator(); - for (Pair<SqlNode, SqlNode> pair : Pair.zip( updateCall.targetColumnList, updateCall.sourceExpressionList)) { writer.sep(","); http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/main/java/org/apache/calcite/util/Pair.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/util/Pair.java b/core/src/main/java/org/apache/calcite/util/Pair.java index a01e9e7..f4f5b8d 100644 --- a/core/src/main/java/org/apache/calcite/util/Pair.java +++ b/core/src/main/java/org/apache/calcite/util/Pair.java @@ -216,12 +216,12 @@ public class Pair<T1, T2> * @return Iterable over pairs */ public static <K, V> Iterable<Pair<K, V>> zip( - final Iterable<K> ks, - final Iterable<V> vs) { + final Iterable<? extends K> ks, + final Iterable<? extends V> vs) { return new Iterable<Pair<K, V>>() { public Iterator<Pair<K, V>> iterator() { - final Iterator<K> kIterator = ks.iterator(); - final Iterator<V> vIterator = vs.iterator(); + final Iterator<? extends K> kIterator = ks.iterator(); + final Iterator<? extends V> vIterator = vs.iterator(); return new Iterator<Pair<K, V>>() { public boolean hasNext() { http://git-wip-us.apache.org/repos/asf/calcite/blob/7837e64c/core/src/test/java/org/apache/calcite/test/RexProgramTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/calcite/test/RexProgramTest.java b/core/src/test/java/org/apache/calcite/test/RexProgramTest.java index 20162dd..7b50427 100644 --- a/core/src/test/java/org/apache/calcite/test/RexProgramTest.java +++ b/core/src/test/java/org/apache/calcite/test/RexProgramTest.java @@ -153,10 +153,7 @@ public class RexProgramTest { // Normalize the program using the RexProgramBuilder.normalize API. // Note that unused expression '77' is eliminated, input refs (e.g. $0) // become local refs (e.g. $t0), and constants are assigned to locals. - final RexProgram normalizedProgram = - RexProgramBuilder.normalize( - rexBuilder, - program); + final RexProgram normalizedProgram = program.normalize(rexBuilder, false); final String normalizedProgramString = normalizedProgram.toString(); TestUtil.assertEqualsVerbose( "(expr#0..1=[{inputs}], expr#2=[+($t0, $t1)], expr#3=[1], " @@ -201,6 +198,47 @@ public class RexProgramTest { } /** + * Tests how the condition is simplified. + */ + @Test public void testSimplifyCondition() { + final RexProgram program = createProg(3).getProgram(false); + assertThat(program.toString(), + is("(expr#0..1=[{inputs}], expr#2=[+($0, 1)], expr#3=[77], " + + "expr#4=[+($0, $1)], expr#5=[+($0, 1)], expr#6=[+($0, $t5)], " + + "expr#7=[+($t4, $t2)], expr#8=[5], expr#9=[>($t2, $t8)], " + + "expr#10=[true], expr#11=[IS NOT NULL($t5)], expr#12=[false], " + + "expr#13=[null], expr#14=[CASE($t9, $t10, $t11, $t12, $t13)], " + + "expr#15=[NOT($t14)], a=[$t7], b=[$t6], $condition=[$t15])")); + + assertThat(program.normalize(rexBuilder, true).toString(), + is("(expr#0..1=[{inputs}], expr#2=[+($t0, $t1)], expr#3=[1], " + + "expr#4=[+($t0, $t3)], expr#5=[+($t2, $t4)], " + + "expr#6=[+($t0, $t4)], expr#7=[5], expr#8=[>($t4, $t7)], " + + "expr#9=[NOT($t8)], a=[$t5], b=[$t6], $condition=[$t9])")); + } + + /** + * Tests how the condition is simplified. + */ + @Test public void testSimplifyCondition2() { + final RexProgram program = createProg(4).getProgram(false); + assertThat(program.toString(), + is("(expr#0..1=[{inputs}], expr#2=[+($0, 1)], expr#3=[77], " + + "expr#4=[+($0, $1)], expr#5=[+($0, 1)], expr#6=[+($0, $t5)], " + + "expr#7=[+($t4, $t2)], expr#8=[5], expr#9=[>($t2, $t8)], " + + "expr#10=[true], expr#11=[IS NOT NULL($t5)], expr#12=[false], " + + "expr#13=[null], expr#14=[CASE($t9, $t10, $t11, $t12, $t13)], " + + "expr#15=[NOT($t14)], expr#16=[IS TRUE($t15)], a=[$t7], b=[$t6], " + + "$condition=[$t16])")); + + assertThat(program.normalize(rexBuilder, true).toString(), + is("(expr#0..1=[{inputs}], expr#2=[+($t0, $t1)], expr#3=[1], " + + "expr#4=[+($t0, $t3)], expr#5=[+($t2, $t4)], " + + "expr#6=[+($t0, $t4)], expr#7=[5], expr#8=[>($t4, $t7)], " + + "expr#9=[NOT($t8)], a=[$t5], b=[$t6], $condition=[$t9])")); + } + + /** * Checks translation of AND(x, x). */ @Test public void testDuplicateAnd() { @@ -226,10 +264,16 @@ public class RexProgramTest { * from t(x, y)</code> * <li><code>select (x + y) + (x + 1) as a, (x + x) as b from t(x, y) * where ((x + y) > 1) and ((x + y) > 1)</code> + * <li><code>select (x + y) + (x + 1) as a, (x + x) as b from t(x, y) + * where not case + * when x + 1 > 5 then true + * when y is null then null + * else false + * end</code> * </ol> */ private RexProgramBuilder createProg(int variant) { - assert variant == 0 || variant == 1 || variant == 2; + assert variant >= 0 && variant <= 4; List<RelDataType> types = Arrays.asList( typeFactory.createSqlType(SqlTypeName.INTEGER), @@ -243,8 +287,8 @@ public class RexProgramTest { // $t2 = $t0 + 1 (i.e. x + 1) final RexNode i0 = rexBuilder.makeInputRef( types.get(0), 0); - final RexLiteral c1 = rexBuilder.makeExactLiteral( - BigDecimal.ONE); + final RexLiteral c1 = rexBuilder.makeExactLiteral(BigDecimal.ONE); + final RexLiteral c5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(5L)); RexLocalRef t2 = builder.addExpr( rexBuilder.makeCall( @@ -269,6 +313,7 @@ public class RexProgramTest { i0, i1)); RexLocalRef t5; + final RexLocalRef t1; switch (variant) { case 0: case 2: @@ -278,10 +323,13 @@ public class RexProgramTest { SqlStdOperatorTable.PLUS, i0, i0)); + t1 = null; break; case 1: + case 3: + case 4: // $tx = $t0 + 1 - RexLocalRef tx = + t1 = builder.addExpr( rexBuilder.makeCall( SqlStdOperatorTable.PLUS, @@ -293,7 +341,7 @@ public class RexProgramTest { rexBuilder.makeCall( SqlStdOperatorTable.PLUS, i0, - tx)); + t1)); break; default: throw Util.newInternal("unexpected variant " + variant); @@ -308,16 +356,19 @@ public class RexProgramTest { builder.addProject(t6.getIndex(), "a"); builder.addProject(t5.getIndex(), "b"); - if (variant == 2) { + final RexLocalRef t7; + final RexLocalRef t8; + switch (variant) { + case 2: // $t7 = $t4 > $i0 (i.e. (x + y) > 0) - RexLocalRef t7 = + t7 = builder.addExpr( rexBuilder.makeCall( SqlStdOperatorTable.GREATER_THAN, t4, i0)); // $t8 = $t7 AND $t7 - RexLocalRef t8 = + t8 = builder.addExpr( rexBuilder.makeCall( SqlStdOperatorTable.AND, @@ -325,6 +376,48 @@ public class RexProgramTest { t7)); builder.addCondition(t8); builder.addCondition(t7); + break; + case 3: + case 4: + // $t7 = 5 + t7 = builder.addExpr(c5); + // $t8 = $t2 > $t7 (i.e. (x + 1) > 5) + t8 = + builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t2, t7)); + // $t9 = true + final RexLocalRef t9 = + builder.addExpr(rexBuilder.makeLiteral(true)); + // $t10 = $t1 is not null (i.e. y is not null) + assert t1 != null; + final RexLocalRef t10 = + builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, t1)); + // $t11 = false + final RexLocalRef t11 = + builder.addExpr(rexBuilder.makeLiteral(false)); + // $t12 = unknown + final RexLocalRef t12 = + builder.addExpr(rexBuilder.makeNullLiteral(SqlTypeName.BOOLEAN)); + // $t13 = case when $t8 then $t9 when $t10 then $t11 else $t12 end + final RexLocalRef t13 = + builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.CASE, + t8, t9, t10, t11, t12)); + // $t14 = not $t13 (i.e. not case ... end) + final RexLocalRef t14 = + builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.NOT, t13)); + // don't add 't14 is true' - that is implicit + if (variant == 3) { + builder.addCondition(t14); + } else { + // $t15 = $14 is true + final RexLocalRef t15 = + builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE, t14)); + builder.addCondition(t15); + } } return builder; }
