This is an automated email from the ASF dual-hosted git repository.

zhenchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 92cfb30175 [CALCITE-5733] Simplify "a = ARRAY[1,2] AND a = ARRAY[2,3]" 
to "false"
92cfb30175 is described below

commit 92cfb301750d9bf55387e5c6b3650c07407225fb
Author: Zhen Chen <[email protected]>
AuthorDate: Wed Dec 31 17:54:41 2025 +0800

    [CALCITE-5733] Simplify "a = ARRAY[1,2] AND a = ARRAY[2,3]" to "false"
---
 .../java/org/apache/calcite/rex/RexAnalyzer.java   |   4 +
 .../java/org/apache/calcite/rex/RexSimplify.java   |  96 +++++++++++--
 .../org/apache/calcite/rex/RexProgramTest.java     | 153 +++++++++++++++++++++
 3 files changed, 239 insertions(+), 14 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/rex/RexAnalyzer.java 
b/core/src/main/java/org/apache/calcite/rex/RexAnalyzer.java
index 5dcad7aa9e..a124168137 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexAnalyzer.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexAnalyzer.java
@@ -96,6 +96,10 @@ private static List<Comparable> getComparables(RexNode 
variable) {
       values.add(0); // 00:00:00.000
       values.add(86_399_000); // 23:59:59.000
       break;
+    case ARRAY:
+    case MAP:
+    case MULTISET:
+      break;
     default:
       throw new AssertionError("don't know values for " + variable
           + " of type " + variable.getType());
diff --git a/core/src/main/java/org/apache/calcite/rex/RexSimplify.java 
b/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
index 7065c25c5c..764100889e 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
@@ -42,6 +42,7 @@
 
 import com.google.common.collect.ArrayListMultimap;
 import com.google.common.collect.BoundType;
+import com.google.common.collect.HashMultiset;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableRangeSet;
 import com.google.common.collect.ImmutableSet;
@@ -65,6 +66,7 @@
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
@@ -1853,7 +1855,7 @@ private <C extends Comparable<C>> RexNode 
simplifyAnd2ForUnknownAsFalse(
         ArrayListMultimap.create();
     final Map<RexNode, Pair<Range<C>, List<RexNode>>> rangeTerms =
         new HashMap<>();
-    final Map<RexNode, RexLiteral> equalityConstantTerms = new HashMap<>();
+    final Map<RexNode, RexNode> equalityConstantTerms = new HashMap<>();
     final Set<RexNode> negatedTerms = new HashSet<>();
     final Set<RexNode> nullOperands = new HashSet<>();
     final Set<RexNode> notNullOperands = new LinkedHashSet<>();
@@ -1915,14 +1917,15 @@ private <C extends Comparable<C>> RexNode 
simplifyAnd2ForUnknownAsFalse(
         // is equal to different constants, this condition cannot be satisfied,
         // and hence it can be evaluated to FALSE
         if (term.getKind() == SqlKind.EQUALS) {
-          if (comparison != null) {
-            final RexLiteral literal = comparison.literal;
-            final RexLiteral prevLiteral =
-                equalityConstantTerms.put(comparison.ref, literal);
-
-            if (prevLiteral != null
-                && literal.getType().equals(prevLiteral.getType())
-                && !literal.equals(prevLiteral)) {
+          final Pair<RexNode, RexNode> constantEquality = 
constantEquality(call);
+          if (constantEquality != null) {
+            final RexNode constant = constantEquality.right;
+            final RexNode prevConstant =
+                equalityConstantTerms.put(constantEquality.left, constant);
+
+            if (prevConstant != null
+                && constant.getType().equals(prevConstant.getType())
+                && !constantsEquivalent(constant, prevConstant)) {
               return rexBuilder.makeLiteral(false);
             }
           } else if (RexUtil.isReferenceOrAccess(left, true)
@@ -1983,17 +1986,18 @@ private <C extends Comparable<C>> RexNode 
simplifyAnd2ForUnknownAsFalse(
     // Example #1. x=5 AND y=5 AND x=y : x=5 AND y=5
     // Example #2. x=5 AND y=6 AND x=y - not satisfiable
     for (RexNode ref1 : equalityTerms.keySet()) {
-      final RexLiteral literal1 = equalityConstantTerms.get(ref1);
-      if (literal1 == null) {
+      final RexNode constant1 = equalityConstantTerms.get(ref1);
+      if (constant1 == null) {
         continue;
       }
       Collection<Pair<RexNode, RexNode>> references = equalityTerms.get(ref1);
       for (Pair<RexNode, RexNode> ref2 : references) {
-        final RexLiteral literal2 = equalityConstantTerms.get(ref2.left);
-        if (literal2 == null) {
+        final RexNode constant2 = equalityConstantTerms.get(ref2.left);
+        if (constant2 == null) {
           continue;
         }
-        if (literal1.getType().equals(literal2.getType()) && 
!literal1.equals(literal2)) {
+        if (constant1.getType().equals(constant2.getType())
+            && !constantsEquivalent(constant1, constant2)) {
           // If an expression is equal to two different constants,
           // it is not satisfiable
           return rexBuilder.makeLiteral(false);
@@ -3034,6 +3038,70 @@ private static class VariableCollector extends 
RexVisitorImpl<Void> {
     }
   }
 
+  private static final Set<SqlKind> CONSTANT_VALUE_CONSTRUCTOR_KINDS =
+      EnumSet.of(
+          SqlKind.ARRAY_VALUE_CONSTRUCTOR,
+          SqlKind.MULTISET_VALUE_CONSTRUCTOR);
+
+  private static @Nullable Pair<RexNode, RexNode> constantEquality(RexCall 
call) {
+    final RexNode o0 = call.getOperands().get(0);
+    final RexNode o1 = call.getOperands().get(1);
+    if (RexUtil.isReferenceOrAccess(o0, true) && isConstant(o1)) {
+      return Pair.of(o0, o1);
+    }
+    if (RexUtil.isReferenceOrAccess(o1, true) && isConstant(o0)) {
+      return Pair.of(o1, o0);
+    }
+    return null;
+  }
+
+  private static boolean constantsEquivalent(RexNode node1, RexNode node2) {
+    if (Objects.equals(node1, node2)) {
+      return true;
+    }
+    if (!(node1 instanceof RexCall) || !(node2 instanceof RexCall)) {
+      return false;
+    }
+    final RexCall call1 = (RexCall) node1;
+    final RexCall call2 = (RexCall) node2;
+    if (call1.getKind() != call2.getKind()) {
+      return false;
+    }
+    switch (call1.getKind()) {
+    case MULTISET_VALUE_CONSTRUCTOR:
+      return multisetLiteralEquals(call1, call2);
+    default:
+      return false;
+    }
+  }
+
+  private static boolean multisetLiteralEquals(RexCall left, RexCall right) {
+    return 
canonicalMultisetLiteral(left).equals(canonicalMultisetLiteral(right));
+  }
+
+  private static HashMultiset<Object> canonicalMultisetLiteral(RexCall call) {
+    final HashMultiset<Object> canonical = HashMultiset.create();
+    for (RexNode operand : call.getOperands()) {
+      canonical.add(canonicalMultisetOperand(operand));
+    }
+    return canonical;
+  }
+
+  private static Object canonicalMultisetOperand(RexNode operand) {
+    if (operand instanceof RexCall
+        && operand.getKind() == SqlKind.MULTISET_VALUE_CONSTRUCTOR) {
+      return canonicalMultisetLiteral((RexCall) operand);
+    }
+    return operand;
+  }
+
+  private static boolean isConstant(RexNode node) {
+    return node instanceof RexLiteral
+        || (node instanceof RexCall
+            && CONSTANT_VALUE_CONSTRUCTOR_KINDS.contains(node.getKind())
+            && RexUtil.isConstant(node));
+  }
+
   /** Represents a simple Comparison.
    *
    * <p>Left hand side is a {@link RexNode}, right hand side is a literal.
diff --git a/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java 
b/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
index 3daf4b8962..3c6620d1b7 100644
--- a/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
+++ b/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
@@ -1876,6 +1876,159 @@ private void checkExponentialCnf(int n) {
     checkSimplifyUnchanged(rexBuilder.makeCall(SqlStdOperatorTable.SOME_GT, 
operand1, operand2));
   }
 
+  /** Unit test for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-5733";>[CALCITE-5733]
+   * Simplify 'a = ARRAY[1,2] AND a = ARRAY[2,3]' to 'false'</a>. */
+  @Test void testSimplifyArrayEquality() {
+    final RelDataType arrayType = tArray(tInt());
+    final RexNode aRef = input(arrayType, 0);
+    final RexNode array12 =
+        rexBuilder.makeCall(arrayType, 
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+            ImmutableList.of(literal(1), literal(2)));
+    final RexNode array21 =
+        rexBuilder.makeCall(arrayType, 
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+            ImmutableList.of(literal(2), literal(1)));
+    final RexNode array23 =
+        rexBuilder.makeCall(arrayType, 
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+            ImmutableList.of(literal(2), literal(3)));
+    final RexNode array2Null =
+        rexBuilder.makeCall(arrayType, 
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+            ImmutableList.of(literal(2), nullInt));
+    final RexNode arrayDoubleNull =
+        rexBuilder.makeCall(arrayType, 
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+            ImmutableList.of(nullInt, nullInt));
+
+    // a = ARRAY[1,2] AND a = ARRAY[2,3]
+    final RexNode condition = and(eq(aRef, array12), eq(aRef, array23));
+    checkSimplifyFilter(condition, "false");
+
+    // a = ARRAY[1,2] AND a = ARRAY[2,null]
+    final RexNode condition2 = and(eq(aRef, array12), eq(aRef, array2Null));
+    checkSimplifyFilter(condition2, "false");
+
+    // a = ARRAY[1,2] AND a = ARRAY[2,null]
+    final RexNode condition3 = and(eq(aRef, array12), eq(aRef, 
arrayDoubleNull));
+    checkSimplifyFilter(condition3, "false");
+
+    // a = ARRAY[2,null] AND a = ARRAY[2,null]
+    final RexNode condition4 = and(eq(aRef, arrayDoubleNull), eq(aRef, 
arrayDoubleNull));
+    checkSimplifyFilter(condition4, "=($0, ARRAY(null:INTEGER, 
null:INTEGER))");
+
+    // a = ARRAY[1,2] AND a = ARRAY[2,1]
+    final RexNode condition5 = and(eq(aRef, array12), eq(aRef, array21));
+    checkSimplifyFilter(condition5, "false");
+
+    // Nested type for Array
+    final RelDataType nestedArrayType = tArray(arrayType);
+    final RexNode nestedRef = input(nestedArrayType, 1);
+
+    final RexNode nestedArray1212 =
+        rexBuilder.makeCall(nestedArrayType, 
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+            ImmutableList.of(array12, array12));
+    final RexNode nestedArray1221 =
+        rexBuilder.makeCall(nestedArrayType, 
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+            ImmutableList.of(array12, array21));
+    final RexNode nestedArray232Null =
+        rexBuilder.makeCall(nestedArrayType, 
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+            ImmutableList.of(array23, array2Null));
+    final RexNode nestedArrayNulls =
+        rexBuilder.makeCall(nestedArrayType, 
SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+            ImmutableList.of(arrayDoubleNull, arrayDoubleNull));
+
+    // a = ARRAY[ARRAY[1,2], ARRAY[1,2]] and a = ARRAY[ARRAY[1,2], ARRAY[2,1]]
+    final RexNode nestedCondition =
+        and(eq(nestedRef, nestedArray1212), eq(nestedRef, nestedArray1221));
+    checkSimplifyFilter(nestedCondition, "false");
+
+    // a = ARRAY[ARRAY[1,2], ARRAY[1,2]] and a = ARRAY[ARRAY[2,3], 
ARRAY[2,null]]
+    final RexNode nestedCondition2 =
+        and(eq(nestedRef, nestedArray1212), eq(nestedRef, nestedArray232Null));
+    checkSimplifyFilter(nestedCondition2, "false");
+
+    // a = ARRAY[ARRAY[1,2], ARRAY[1,2]] and a = ARRAY[ARRAY[null,null], 
ARRAY[null,null]]
+    final RexNode nestedCondition3 =
+        and(eq(nestedRef, nestedArray1212), eq(nestedRef, nestedArrayNulls));
+    checkSimplifyFilter(nestedCondition3, "false");
+  }
+
+  /** Unit test for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-5733";>[CALCITE-5733]
+   * Simplify 'a = ARRAY[1,2] AND a = ARRAY[2,3]' to 'false'</a>. */
+  @Test void testSimplifyMultisetEquality() {
+    final RelDataType multisetType = typeFactory.createMultisetType(tInt(), 
-1);
+    final RexNode aRef = input(multisetType, 0);
+    final RexNode multiset12 =
+        rexBuilder.makeCall(multisetType, SqlStdOperatorTable.MULTISET_VALUE,
+            ImmutableList.of(literal(1), literal(2)));
+    final RexNode multiset21 =
+        rexBuilder.makeCall(multisetType, SqlStdOperatorTable.MULTISET_VALUE,
+            ImmutableList.of(literal(2), literal(1)));
+    final RexNode multiset23 =
+        rexBuilder.makeCall(multisetType, SqlStdOperatorTable.MULTISET_VALUE,
+            ImmutableList.of(literal(2), literal(3)));
+    final RexNode multiset2Null =
+        rexBuilder.makeCall(multisetType, SqlStdOperatorTable.MULTISET_VALUE,
+            ImmutableList.of(literal(2), nullInt));
+    final RexNode multisetDoubleNull =
+        rexBuilder.makeCall(multisetType, SqlStdOperatorTable.MULTISET_VALUE,
+            ImmutableList.of(nullInt, nullInt));
+
+    // a = MULTISET[1,2] AND a = MULTISET[2,3]
+    final RexNode condition = and(eq(aRef, multiset12), eq(aRef, multiset23));
+    checkSimplifyFilter(condition, "false");
+
+    // a = MULTISET[1,2] AND a = MULTISET[2,null]
+    final RexNode condition2 = and(eq(aRef, multiset12), eq(aRef, 
multiset2Null));
+    checkSimplifyFilter(condition2, "false");
+
+    // a = MULTISET[1,2] AND a = MULTISET[2,null]
+    final RexNode condition3 = and(eq(aRef, multiset12), eq(aRef, 
multisetDoubleNull));
+    checkSimplifyFilter(condition3, "false");
+
+    // a = MULTISET[2,null] AND a = MULTISET[2,null]
+    final RexNode condition4 = and(eq(aRef, multisetDoubleNull), eq(aRef, 
multisetDoubleNull));
+    checkSimplifyFilter(condition4, "=($0, MULTISET(null:INTEGER, 
null:INTEGER))");
+
+    // a = MULTISET[1,2] AND a = MULTISET[2,1]
+    final RexNode condition5 = and(eq(aRef, multiset12), eq(aRef, multiset21));
+    checkSimplifyFilter(condition5, "AND(=($0, MULTISET(1, 2)), =($0, 
MULTISET(2, 1)))");
+
+    // Nested type for Multiset
+    final RelDataType nestedMultisetType = 
typeFactory.createMultisetType(multisetType, -1);
+    final RexNode nestedRef = input(nestedMultisetType, 1);
+
+    final RexNode nestedMultiset1212 =
+        rexBuilder.makeCall(nestedMultisetType, 
SqlStdOperatorTable.MULTISET_VALUE,
+            ImmutableList.of(multiset12, multiset12));
+    final RexNode nestedMultiset1221 =
+        rexBuilder.makeCall(nestedMultisetType, 
SqlStdOperatorTable.MULTISET_VALUE,
+            ImmutableList.of(multiset12, multiset21));
+    final RexNode nestedMultiset232Null =
+        rexBuilder.makeCall(nestedMultisetType, 
SqlStdOperatorTable.MULTISET_VALUE,
+            ImmutableList.of(multiset23, multiset2Null));
+    final RexNode nestedMultisetNulls =
+        rexBuilder.makeCall(nestedMultisetType, 
SqlStdOperatorTable.MULTISET_VALUE,
+            ImmutableList.of(multisetDoubleNull, multisetDoubleNull));
+
+    // a = MULTISET[MULTISET[1,2], MULTISET[1,2]] and a = 
MULTISET[MULTISET[1,2], MULTISET[2,1]]
+    final RexNode nestedCondition =
+        and(eq(nestedRef, nestedMultiset1212), eq(nestedRef, 
nestedMultiset1221));
+    checkSimplifyFilter(nestedCondition,
+        "AND(=($1, MULTISET(MULTISET(1, 2), MULTISET(1, 2))),"
+            + " =($1, MULTISET(MULTISET(1, 2), MULTISET(2, 1))))");
+
+    // a = MULTISET[MULTISET[1,2], MULTISET[1,2]] and a = 
MULTISET[MULTISET[2,3], MULTISET[2,null]]
+    final RexNode nestedCondition2 =
+        and(eq(nestedRef, nestedMultiset1212), eq(nestedRef, 
nestedMultiset232Null));
+    checkSimplifyFilter(nestedCondition2, "false");
+
+    // a = MULTISET[MULTISET[1,2], MULTISET[1,2]]
+    // and a = MULTISET[MULTISET[null,null], MULTISET[null,null]]
+    final RexNode nestedCondition3 =
+        and(eq(nestedRef, nestedMultiset1212), eq(nestedRef, 
nestedMultisetNulls));
+    checkSimplifyFilter(nestedCondition3, "false");
+  }
+
   @Test void testSimplifyRange() {
     final RexNode aRef = input(tInt(), 0);
     // ((0 < a and a <= 10) or a >= 15) and a <> 6 and a <> 12

Reply via email to