This is an automated email from the ASF dual-hosted git repository.
zhenchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 92cfb30175 [CALCITE-5733] Simplify "a = ARRAY[1,2] AND a = ARRAY[2,3]"
to "false"
92cfb30175 is described below
commit 92cfb301750d9bf55387e5c6b3650c07407225fb
Author: Zhen Chen <[email protected]>
AuthorDate: Wed Dec 31 17:54:41 2025 +0800
[CALCITE-5733] Simplify "a = ARRAY[1,2] AND a = ARRAY[2,3]" to "false"
---
.../java/org/apache/calcite/rex/RexAnalyzer.java | 4 +
.../java/org/apache/calcite/rex/RexSimplify.java | 96 +++++++++++--
.../org/apache/calcite/rex/RexProgramTest.java | 153 +++++++++++++++++++++
3 files changed, 239 insertions(+), 14 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/rex/RexAnalyzer.java
b/core/src/main/java/org/apache/calcite/rex/RexAnalyzer.java
index 5dcad7aa9e..a124168137 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexAnalyzer.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexAnalyzer.java
@@ -96,6 +96,10 @@ private static List<Comparable> getComparables(RexNode
variable) {
values.add(0); // 00:00:00.000
values.add(86_399_000); // 23:59:59.000
break;
+ case ARRAY:
+ case MAP:
+ case MULTISET:
+ break;
default:
throw new AssertionError("don't know values for " + variable
+ " of type " + variable.getType());
diff --git a/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
b/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
index 7065c25c5c..764100889e 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
@@ -42,6 +42,7 @@
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BoundType;
+import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableRangeSet;
import com.google.common.collect.ImmutableSet;
@@ -65,6 +66,7 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -1853,7 +1855,7 @@ private <C extends Comparable<C>> RexNode
simplifyAnd2ForUnknownAsFalse(
ArrayListMultimap.create();
final Map<RexNode, Pair<Range<C>, List<RexNode>>> rangeTerms =
new HashMap<>();
- final Map<RexNode, RexLiteral> equalityConstantTerms = new HashMap<>();
+ final Map<RexNode, RexNode> equalityConstantTerms = new HashMap<>();
final Set<RexNode> negatedTerms = new HashSet<>();
final Set<RexNode> nullOperands = new HashSet<>();
final Set<RexNode> notNullOperands = new LinkedHashSet<>();
@@ -1915,14 +1917,15 @@ private <C extends Comparable<C>> RexNode
simplifyAnd2ForUnknownAsFalse(
// is equal to different constants, this condition cannot be satisfied,
// and hence it can be evaluated to FALSE
if (term.getKind() == SqlKind.EQUALS) {
- if (comparison != null) {
- final RexLiteral literal = comparison.literal;
- final RexLiteral prevLiteral =
- equalityConstantTerms.put(comparison.ref, literal);
-
- if (prevLiteral != null
- && literal.getType().equals(prevLiteral.getType())
- && !literal.equals(prevLiteral)) {
+ final Pair<RexNode, RexNode> constantEquality =
constantEquality(call);
+ if (constantEquality != null) {
+ final RexNode constant = constantEquality.right;
+ final RexNode prevConstant =
+ equalityConstantTerms.put(constantEquality.left, constant);
+
+ if (prevConstant != null
+ && constant.getType().equals(prevConstant.getType())
+ && !constantsEquivalent(constant, prevConstant)) {
return rexBuilder.makeLiteral(false);
}
} else if (RexUtil.isReferenceOrAccess(left, true)
@@ -1983,17 +1986,18 @@ private <C extends Comparable<C>> RexNode
simplifyAnd2ForUnknownAsFalse(
// Example #1. x=5 AND y=5 AND x=y : x=5 AND y=5
// Example #2. x=5 AND y=6 AND x=y - not satisfiable
for (RexNode ref1 : equalityTerms.keySet()) {
- final RexLiteral literal1 = equalityConstantTerms.get(ref1);
- if (literal1 == null) {
+ final RexNode constant1 = equalityConstantTerms.get(ref1);
+ if (constant1 == null) {
continue;
}
Collection<Pair<RexNode, RexNode>> references = equalityTerms.get(ref1);
for (Pair<RexNode, RexNode> ref2 : references) {
- final RexLiteral literal2 = equalityConstantTerms.get(ref2.left);
- if (literal2 == null) {
+ final RexNode constant2 = equalityConstantTerms.get(ref2.left);
+ if (constant2 == null) {
continue;
}
- if (literal1.getType().equals(literal2.getType()) &&
!literal1.equals(literal2)) {
+ if (constant1.getType().equals(constant2.getType())
+ && !constantsEquivalent(constant1, constant2)) {
// If an expression is equal to two different constants,
// it is not satisfiable
return rexBuilder.makeLiteral(false);
@@ -3034,6 +3038,70 @@ private static class VariableCollector extends
RexVisitorImpl<Void> {
}
}
+ private static final Set<SqlKind> CONSTANT_VALUE_CONSTRUCTOR_KINDS =
+ EnumSet.of(
+ SqlKind.ARRAY_VALUE_CONSTRUCTOR,
+ SqlKind.MULTISET_VALUE_CONSTRUCTOR);
+
+ private static @Nullable Pair<RexNode, RexNode> constantEquality(RexCall
call) {
+ final RexNode o0 = call.getOperands().get(0);
+ final RexNode o1 = call.getOperands().get(1);
+ if (RexUtil.isReferenceOrAccess(o0, true) && isConstant(o1)) {
+ return Pair.of(o0, o1);
+ }
+ if (RexUtil.isReferenceOrAccess(o1, true) && isConstant(o0)) {
+ return Pair.of(o1, o0);
+ }
+ return null;
+ }
+
+ private static boolean constantsEquivalent(RexNode node1, RexNode node2) {
+ if (Objects.equals(node1, node2)) {
+ return true;
+ }
+ if (!(node1 instanceof RexCall) || !(node2 instanceof RexCall)) {
+ return false;
+ }
+ final RexCall call1 = (RexCall) node1;
+ final RexCall call2 = (RexCall) node2;
+ if (call1.getKind() != call2.getKind()) {
+ return false;
+ }
+ switch (call1.getKind()) {
+ case MULTISET_VALUE_CONSTRUCTOR:
+ return multisetLiteralEquals(call1, call2);
+ default:
+ return false;
+ }
+ }
+
+ private static boolean multisetLiteralEquals(RexCall left, RexCall right) {
+ return
canonicalMultisetLiteral(left).equals(canonicalMultisetLiteral(right));
+ }
+
+ private static HashMultiset<Object> canonicalMultisetLiteral(RexCall call) {
+ final HashMultiset<Object> canonical = HashMultiset.create();
+ for (RexNode operand : call.getOperands()) {
+ canonical.add(canonicalMultisetOperand(operand));
+ }
+ return canonical;
+ }
+
+ private static Object canonicalMultisetOperand(RexNode operand) {
+ if (operand instanceof RexCall
+ && operand.getKind() == SqlKind.MULTISET_VALUE_CONSTRUCTOR) {
+ return canonicalMultisetLiteral((RexCall) operand);
+ }
+ return operand;
+ }
+
+ private static boolean isConstant(RexNode node) {
+ return node instanceof RexLiteral
+ || (node instanceof RexCall
+ && CONSTANT_VALUE_CONSTRUCTOR_KINDS.contains(node.getKind())
+ && RexUtil.isConstant(node));
+ }
+
/** Represents a simple Comparison.
*
* <p>Left hand side is a {@link RexNode}, right hand side is a literal.
diff --git a/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
b/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
index 3daf4b8962..3c6620d1b7 100644
--- a/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
+++ b/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
@@ -1876,6 +1876,159 @@ private void checkExponentialCnf(int n) {
checkSimplifyUnchanged(rexBuilder.makeCall(SqlStdOperatorTable.SOME_GT,
operand1, operand2));
}
+ /** Unit test for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-5733">[CALCITE-5733]
+ * Simplify 'a = ARRAY[1,2] AND a = ARRAY[2,3]' to 'false'</a>. */
+ @Test void testSimplifyArrayEquality() {
+ final RelDataType arrayType = tArray(tInt());
+ final RexNode aRef = input(arrayType, 0);
+ final RexNode array12 =
+ rexBuilder.makeCall(arrayType,
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+ ImmutableList.of(literal(1), literal(2)));
+ final RexNode array21 =
+ rexBuilder.makeCall(arrayType,
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+ ImmutableList.of(literal(2), literal(1)));
+ final RexNode array23 =
+ rexBuilder.makeCall(arrayType,
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+ ImmutableList.of(literal(2), literal(3)));
+ final RexNode array2Null =
+ rexBuilder.makeCall(arrayType,
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+ ImmutableList.of(literal(2), nullInt));
+ final RexNode arrayDoubleNull =
+ rexBuilder.makeCall(arrayType,
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+ ImmutableList.of(nullInt, nullInt));
+
+ // a = ARRAY[1,2] AND a = ARRAY[2,3]
+ final RexNode condition = and(eq(aRef, array12), eq(aRef, array23));
+ checkSimplifyFilter(condition, "false");
+
+ // a = ARRAY[1,2] AND a = ARRAY[2,null]
+ final RexNode condition2 = and(eq(aRef, array12), eq(aRef, array2Null));
+ checkSimplifyFilter(condition2, "false");
+
+ // a = ARRAY[1,2] AND a = ARRAY[2,null]
+ final RexNode condition3 = and(eq(aRef, array12), eq(aRef,
arrayDoubleNull));
+ checkSimplifyFilter(condition3, "false");
+
+ // a = ARRAY[2,null] AND a = ARRAY[2,null]
+ final RexNode condition4 = and(eq(aRef, arrayDoubleNull), eq(aRef,
arrayDoubleNull));
+ checkSimplifyFilter(condition4, "=($0, ARRAY(null:INTEGER,
null:INTEGER))");
+
+ // a = ARRAY[1,2] AND a = ARRAY[2,1]
+ final RexNode condition5 = and(eq(aRef, array12), eq(aRef, array21));
+ checkSimplifyFilter(condition5, "false");
+
+ // Nested type for Array
+ final RelDataType nestedArrayType = tArray(arrayType);
+ final RexNode nestedRef = input(nestedArrayType, 1);
+
+ final RexNode nestedArray1212 =
+ rexBuilder.makeCall(nestedArrayType,
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+ ImmutableList.of(array12, array12));
+ final RexNode nestedArray1221 =
+ rexBuilder.makeCall(nestedArrayType,
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+ ImmutableList.of(array12, array21));
+ final RexNode nestedArray232Null =
+ rexBuilder.makeCall(nestedArrayType,
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+ ImmutableList.of(array23, array2Null));
+ final RexNode nestedArrayNulls =
+ rexBuilder.makeCall(nestedArrayType,
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+ ImmutableList.of(arrayDoubleNull, arrayDoubleNull));
+
+ // a = ARRAY[ARRAY[1,2], ARRAY[1,2]] and a = ARRAY[ARRAY[1,2], ARRAY[2,1]]
+ final RexNode nestedCondition =
+ and(eq(nestedRef, nestedArray1212), eq(nestedRef, nestedArray1221));
+ checkSimplifyFilter(nestedCondition, "false");
+
+ // a = ARRAY[ARRAY[1,2], ARRAY[1,2]] and a = ARRAY[ARRAY[2,3],
ARRAY[2,null]]
+ final RexNode nestedCondition2 =
+ and(eq(nestedRef, nestedArray1212), eq(nestedRef, nestedArray232Null));
+ checkSimplifyFilter(nestedCondition2, "false");
+
+ // a = ARRAY[ARRAY[1,2], ARRAY[1,2]] and a = ARRAY[ARRAY[null,null],
ARRAY[null,null]]
+ final RexNode nestedCondition3 =
+ and(eq(nestedRef, nestedArray1212), eq(nestedRef, nestedArrayNulls));
+ checkSimplifyFilter(nestedCondition3, "false");
+ }
+
+ /** Unit test for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-5733">[CALCITE-5733]
+ * Simplify 'a = ARRAY[1,2] AND a = ARRAY[2,3]' to 'false'</a>. */
+ @Test void testSimplifyMultisetEquality() {
+ final RelDataType multisetType = typeFactory.createMultisetType(tInt(),
-1);
+ final RexNode aRef = input(multisetType, 0);
+ final RexNode multiset12 =
+ rexBuilder.makeCall(multisetType, SqlStdOperatorTable.MULTISET_VALUE,
+ ImmutableList.of(literal(1), literal(2)));
+ final RexNode multiset21 =
+ rexBuilder.makeCall(multisetType, SqlStdOperatorTable.MULTISET_VALUE,
+ ImmutableList.of(literal(2), literal(1)));
+ final RexNode multiset23 =
+ rexBuilder.makeCall(multisetType, SqlStdOperatorTable.MULTISET_VALUE,
+ ImmutableList.of(literal(2), literal(3)));
+ final RexNode multiset2Null =
+ rexBuilder.makeCall(multisetType, SqlStdOperatorTable.MULTISET_VALUE,
+ ImmutableList.of(literal(2), nullInt));
+ final RexNode multisetDoubleNull =
+ rexBuilder.makeCall(multisetType, SqlStdOperatorTable.MULTISET_VALUE,
+ ImmutableList.of(nullInt, nullInt));
+
+ // a = MULTISET[1,2] AND a = MULTISET[2,3]
+ final RexNode condition = and(eq(aRef, multiset12), eq(aRef, multiset23));
+ checkSimplifyFilter(condition, "false");
+
+ // a = MULTISET[1,2] AND a = MULTISET[2,null]
+ final RexNode condition2 = and(eq(aRef, multiset12), eq(aRef,
multiset2Null));
+ checkSimplifyFilter(condition2, "false");
+
+ // a = MULTISET[1,2] AND a = MULTISET[2,null]
+ final RexNode condition3 = and(eq(aRef, multiset12), eq(aRef,
multisetDoubleNull));
+ checkSimplifyFilter(condition3, "false");
+
+ // a = MULTISET[2,null] AND a = MULTISET[2,null]
+ final RexNode condition4 = and(eq(aRef, multisetDoubleNull), eq(aRef,
multisetDoubleNull));
+ checkSimplifyFilter(condition4, "=($0, MULTISET(null:INTEGER,
null:INTEGER))");
+
+ // a = MULTISET[1,2] AND a = MULTISET[2,1]
+ final RexNode condition5 = and(eq(aRef, multiset12), eq(aRef, multiset21));
+ checkSimplifyFilter(condition5, "AND(=($0, MULTISET(1, 2)), =($0,
MULTISET(2, 1)))");
+
+ // Nested type for Multiset
+ final RelDataType nestedMultisetType =
typeFactory.createMultisetType(multisetType, -1);
+ final RexNode nestedRef = input(nestedMultisetType, 1);
+
+ final RexNode nestedMultiset1212 =
+ rexBuilder.makeCall(nestedMultisetType,
SqlStdOperatorTable.MULTISET_VALUE,
+ ImmutableList.of(multiset12, multiset12));
+ final RexNode nestedMultiset1221 =
+ rexBuilder.makeCall(nestedMultisetType,
SqlStdOperatorTable.MULTISET_VALUE,
+ ImmutableList.of(multiset12, multiset21));
+ final RexNode nestedMultiset232Null =
+ rexBuilder.makeCall(nestedMultisetType,
SqlStdOperatorTable.MULTISET_VALUE,
+ ImmutableList.of(multiset23, multiset2Null));
+ final RexNode nestedMultisetNulls =
+ rexBuilder.makeCall(nestedMultisetType,
SqlStdOperatorTable.MULTISET_VALUE,
+ ImmutableList.of(multisetDoubleNull, multisetDoubleNull));
+
+ // a = MULTISET[MULTISET[1,2], MULTISET[1,2]] and a =
MULTISET[MULTISET[1,2], MULTISET[2,1]]
+ final RexNode nestedCondition =
+ and(eq(nestedRef, nestedMultiset1212), eq(nestedRef,
nestedMultiset1221));
+ checkSimplifyFilter(nestedCondition,
+ "AND(=($1, MULTISET(MULTISET(1, 2), MULTISET(1, 2))),"
+ + " =($1, MULTISET(MULTISET(1, 2), MULTISET(2, 1))))");
+
+ // a = MULTISET[MULTISET[1,2], MULTISET[1,2]] and a =
MULTISET[MULTISET[2,3], MULTISET[2,null]]
+ final RexNode nestedCondition2 =
+ and(eq(nestedRef, nestedMultiset1212), eq(nestedRef,
nestedMultiset232Null));
+ checkSimplifyFilter(nestedCondition2, "false");
+
+ // a = MULTISET[MULTISET[1,2], MULTISET[1,2]]
+ // and a = MULTISET[MULTISET[null,null], MULTISET[null,null]]
+ final RexNode nestedCondition3 =
+ and(eq(nestedRef, nestedMultiset1212), eq(nestedRef,
nestedMultisetNulls));
+ checkSimplifyFilter(nestedCondition3, "false");
+ }
+
@Test void testSimplifyRange() {
final RexNode aRef = input(tInt(), 0);
// ((0 < a and a <= 10) or a >= 15) and a <> 6 and a <> 12