This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new b8102d596c Spark 3.4: Support NOT_EQ for V2 filters (#7898)
b8102d596c is described below

commit b8102d596c75e8b5c805c1ee6ad3c38bb37a0d0c
Author: Xianyang Liu <[email protected]>
AuthorDate: Wed Jul 5 05:03:21 2023 +0800

    Spark 3.4: Support NOT_EQ for V2 filters (#7898)
    
    Co-authored-by: xianyangliu <[email protected]>
---
 .../org/apache/iceberg/spark/SparkV2Filters.java   |  78 +++++++++++----
 .../apache/iceberg/spark/TestSparkV2Filters.java   | 109 +++++++++++++++++++++
 2 files changed, 167 insertions(+), 20 deletions(-)

diff --git 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java
index 6d564bbd62..cbedc4d568 100644
--- 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java
+++ 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java
@@ -28,9 +28,12 @@ import static 
org.apache.iceberg.expressions.Expressions.isNull;
 import static org.apache.iceberg.expressions.Expressions.lessThan;
 import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
 import static org.apache.iceberg.expressions.Expressions.not;
+import static org.apache.iceberg.expressions.Expressions.notEqual;
 import static org.apache.iceberg.expressions.Expressions.notIn;
+import static org.apache.iceberg.expressions.Expressions.notNaN;
 import static org.apache.iceberg.expressions.Expressions.notNull;
 import static org.apache.iceberg.expressions.Expressions.or;
+import static org.apache.iceberg.expressions.Expressions.ref;
 import static org.apache.iceberg.expressions.Expressions.startsWith;
 
 import java.util.Arrays;
@@ -40,9 +43,12 @@ import java.util.stream.Collectors;
 import org.apache.iceberg.expressions.Expression;
 import org.apache.iceberg.expressions.Expression.Operation;
 import org.apache.iceberg.expressions.Expressions;
+import org.apache.iceberg.expressions.UnboundPredicate;
+import org.apache.iceberg.expressions.UnboundTerm;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.iceberg.util.NaNUtil;
+import org.apache.iceberg.util.Pair;
 import org.apache.spark.sql.connector.expressions.Literal;
 import org.apache.spark.sql.connector.expressions.NamedReference;
 import org.apache.spark.sql.connector.expressions.filter.And;
@@ -57,6 +63,7 @@ public class SparkV2Filters {
   private static final String FALSE = "ALWAYS_FALSE";
   private static final String EQ = "=";
   private static final String EQ_NULL_SAFE = "<=>";
+  private static final String NOT_EQ = "<>";
   private static final String GT = ">";
   private static final String GT_EQ = ">=";
   private static final String LT = "<";
@@ -75,6 +82,7 @@ public class SparkV2Filters {
           .put(FALSE, Operation.FALSE)
           .put(EQ, Operation.EQ)
           .put(EQ_NULL_SAFE, Operation.EQ)
+          .put(NOT_EQ, Operation.NOT_EQ)
           .put(GT, Operation.GT)
           .put(GT_EQ, Operation.GT_EQ)
           .put(LT, Operation.LT)
@@ -152,31 +160,35 @@ public class SparkV2Filters {
           }
 
         case EQ: // used for both eq and null-safe-eq
-          Object value;
-          String columnName;
-          if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) 
{
-            columnName = SparkUtil.toColumnName(leftChild(predicate));
-            value = convertLiteral(rightChild(predicate));
-          } else if (isRef(rightChild(predicate)) && 
isLiteral(leftChild(predicate))) {
-            columnName = SparkUtil.toColumnName(rightChild(predicate));
-            value = convertLiteral(leftChild(predicate));
-          } else {
+          Pair<UnboundTerm<Object>, Object> eqChildren = 
predicateChildren(predicate);
+          if (eqChildren == null) {
             return null;
           }
 
           if (predicate.name().equals(EQ)) {
             // comparison with null in normal equality is always null. this is 
probably a mistake.
             Preconditions.checkNotNull(
-                value, "Expression is always false (eq is not null-safe): %s", 
predicate);
-            return handleEqual(columnName, value);
-          } else { // "<=>"
-            if (value == null) {
-              return isNull(columnName);
-            } else {
-              return handleEqual(columnName, value);
-            }
+                eqChildren.second(),
+                "Expression is always false (eq is not null-safe): %s",
+                predicate);
+          }
+
+          return handleEqual(eqChildren.first(), eqChildren.second());
+
+        case NOT_EQ:
+          Pair<UnboundTerm<Object>, Object> notEqChildren = 
predicateChildren(predicate);
+          if (notEqChildren == null) {
+            return null;
           }
 
+          // comparison with null in normal equality is always null. this is 
probably a mistake.
+          Preconditions.checkNotNull(
+              notEqChildren.second(),
+              "Expression is always false (notEq is not null-safe): %s",
+              predicate);
+
+          return handleNotEqual(notEqChildren.first(), notEqChildren.second());
+
         case IN:
           if (isSupportedInPredicate(predicate)) {
             return in(
@@ -245,6 +257,22 @@ public class SparkV2Filters {
     return null;
   }
 
+  private static Pair<UnboundTerm<Object>, Object> predicateChildren(Predicate 
predicate) {
+    if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
+      UnboundTerm<Object> term = 
ref(SparkUtil.toColumnName(leftChild(predicate)));
+      Object value = convertLiteral(rightChild(predicate));
+      return Pair.of(term, value);
+
+    } else if (isRef(rightChild(predicate)) && 
isLiteral(leftChild(predicate))) {
+      UnboundTerm<Object> term = 
ref(SparkUtil.toColumnName(rightChild(predicate)));
+      Object value = convertLiteral(leftChild(predicate));
+      return Pair.of(term, value);
+
+    } else {
+      return null;
+    }
+  }
+
   @SuppressWarnings("unchecked")
   private static <T> T child(Predicate predicate) {
     org.apache.spark.sql.connector.expressions.Expression[] children = 
predicate.children();
@@ -289,11 +317,21 @@ public class SparkV2Filters {
     return literal.value();
   }
 
-  private static Expression handleEqual(String attribute, Object value) {
+  private static UnboundPredicate<Object> handleEqual(UnboundTerm<Object> 
term, Object value) {
+    if (value == null) {
+      return isNull(term);
+    } else if (NaNUtil.isNaN(value)) {
+      return isNaN(term);
+    } else {
+      return equal(term, value);
+    }
+  }
+
+  private static UnboundPredicate<Object> handleNotEqual(UnboundTerm<Object> 
term, Object value) {
     if (NaNUtil.isNaN(value)) {
-      return isNaN(attribute);
+      return notNaN(term);
     } else {
-      return equal(attribute, value);
+      return notEqual(term, value);
     }
   }
 
diff --git 
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java
 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java
index 4c8a32fa41..dabd5d991b 100644
--- 
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java
+++ 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java
@@ -34,11 +34,13 @@ import org.apache.spark.sql.connector.expressions.filter.Or;
 import org.apache.spark.sql.connector.expressions.filter.Predicate;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.unsafe.types.UTF8String;
+import org.assertj.core.api.Assertions;
 import org.junit.Assert;
 import org.junit.Test;
 
 public class TestSparkV2Filters {
 
+  @SuppressWarnings("checkstyle:MethodLength")
   @Test
   public void testV2Filters() {
     Map<String, String> attrMap = Maps.newHashMap();
@@ -128,6 +130,18 @@ public class TestSparkV2Filters {
           Expression actualEq2 = SparkV2Filters.convert(eq2);
           Assert.assertEquals("EqualTo must match", expectedEq2.toString(), 
actualEq2.toString());
 
+          Predicate notEq1 = new Predicate("<>", attrAndValue);
+          Expression expectedNotEq1 = Expressions.notEqual(unquoted, 1);
+          Expression actualNotEq1 = SparkV2Filters.convert(notEq1);
+          Assert.assertEquals(
+              "NotEqualTo must match", expectedNotEq1.toString(), 
actualNotEq1.toString());
+
+          Predicate notEq2 = new Predicate("<>", valueAndAttr);
+          Expression expectedNotEq2 = Expressions.notEqual(unquoted, 1);
+          Expression actualNotEq2 = SparkV2Filters.convert(notEq2);
+          Assert.assertEquals(
+              "NotEqualTo must match", expectedNotEq2.toString(), 
actualNotEq2.toString());
+
           Predicate eqNullSafe1 = new Predicate("<=>", attrAndValue);
           Expression expectedEqNullSafe1 = Expressions.equal(unquoted, 1);
           Expression actualEqNullSafe1 = SparkV2Filters.convert(eqNullSafe1);
@@ -189,6 +203,101 @@ public class TestSparkV2Filters {
         });
   }
 
+  @Test
+  public void testEqualToNull() {
+    String col = "col";
+    NamedReference namedReference = FieldReference.apply(col);
+    LiteralValue value = new LiteralValue(null, DataTypes.IntegerType);
+
+    org.apache.spark.sql.connector.expressions.Expression[] attrAndValue =
+        new org.apache.spark.sql.connector.expressions.Expression[] 
{namedReference, value};
+    org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr =
+        new org.apache.spark.sql.connector.expressions.Expression[] {value, 
namedReference};
+
+    Predicate eq1 = new Predicate("=", attrAndValue);
+    Assertions.assertThatThrownBy(() -> SparkV2Filters.convert(eq1))
+        .isInstanceOf(NullPointerException.class)
+        .hasMessageContaining("Expression is always false");
+
+    Predicate eq2 = new Predicate("=", valueAndAttr);
+    Assertions.assertThatThrownBy(() -> SparkV2Filters.convert(eq2))
+        .isInstanceOf(NullPointerException.class)
+        .hasMessageContaining("Expression is always false");
+
+    Predicate eqNullSafe1 = new Predicate("<=>", attrAndValue);
+    Expression expectedEqNullSafe = Expressions.isNull(col);
+    Expression actualEqNullSafe1 = SparkV2Filters.convert(eqNullSafe1);
+    
Assertions.assertThat(actualEqNullSafe1.toString()).isEqualTo(expectedEqNullSafe.toString());
+
+    Predicate eqNullSafe2 = new Predicate("<=>", valueAndAttr);
+    Expression actualEqNullSafe2 = SparkV2Filters.convert(eqNullSafe2);
+    
Assertions.assertThat(actualEqNullSafe2.toString()).isEqualTo(expectedEqNullSafe.toString());
+  }
+
+  @Test
+  public void testEqualToNaN() {
+    String col = "col";
+    NamedReference namedReference = FieldReference.apply(col);
+    LiteralValue value = new LiteralValue(Float.NaN, DataTypes.FloatType);
+
+    org.apache.spark.sql.connector.expressions.Expression[] attrAndValue =
+        new org.apache.spark.sql.connector.expressions.Expression[] 
{namedReference, value};
+    org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr =
+        new org.apache.spark.sql.connector.expressions.Expression[] {value, 
namedReference};
+
+    Predicate eqNaN1 = new Predicate("=", attrAndValue);
+    Expression expectedEqNaN = Expressions.isNaN(col);
+    Expression actualEqNaN1 = SparkV2Filters.convert(eqNaN1);
+    
Assertions.assertThat(actualEqNaN1.toString()).isEqualTo(expectedEqNaN.toString());
+
+    Predicate eqNaN2 = new Predicate("=", valueAndAttr);
+    Expression actualEqNaN2 = SparkV2Filters.convert(eqNaN2);
+    
Assertions.assertThat(actualEqNaN2.toString()).isEqualTo(expectedEqNaN.toString());
+  }
+
+  @Test
+  public void testNotEqualToNull() {
+    String col = "col";
+    NamedReference namedReference = FieldReference.apply(col);
+    LiteralValue value = new LiteralValue(null, DataTypes.IntegerType);
+
+    org.apache.spark.sql.connector.expressions.Expression[] attrAndValue =
+        new org.apache.spark.sql.connector.expressions.Expression[] 
{namedReference, value};
+    org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr =
+        new org.apache.spark.sql.connector.expressions.Expression[] {value, 
namedReference};
+
+    Predicate notEq1 = new Predicate("<>", attrAndValue);
+    Assertions.assertThatThrownBy(() -> SparkV2Filters.convert(notEq1))
+        .isInstanceOf(NullPointerException.class)
+        .hasMessageContaining("Expression is always false");
+
+    Predicate notEq2 = new Predicate("<>", valueAndAttr);
+    Assertions.assertThatThrownBy(() -> SparkV2Filters.convert(notEq2))
+        .isInstanceOf(NullPointerException.class)
+        .hasMessageContaining("Expression is always false");
+  }
+
+  @Test
+  public void testNotEqualToNaN() {
+    String col = "col";
+    NamedReference namedReference = FieldReference.apply(col);
+    LiteralValue value = new LiteralValue(Float.NaN, DataTypes.FloatType);
+
+    org.apache.spark.sql.connector.expressions.Expression[] attrAndValue =
+        new org.apache.spark.sql.connector.expressions.Expression[] 
{namedReference, value};
+    org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr =
+        new org.apache.spark.sql.connector.expressions.Expression[] {value, 
namedReference};
+
+    Predicate notEqNaN1 = new Predicate("<>", attrAndValue);
+    Expression expectedNotEqNaN = Expressions.notNaN(col);
+    Expression actualNotEqNaN1 = SparkV2Filters.convert(notEqNaN1);
+    
Assertions.assertThat(actualNotEqNaN1.toString()).isEqualTo(expectedNotEqNaN.toString());
+
+    Predicate notEqNaN2 = new Predicate("<>", valueAndAttr);
+    Expression actualNotEqNaN2 = SparkV2Filters.convert(notEqNaN2);
+    
Assertions.assertThat(actualNotEqNaN2.toString()).isEqualTo(expectedNotEqNaN.toString());
+  }
+
   @Test
   public void testTimestampFilterConversion() {
     Instant instant = Instant.parse("2018-10-18T00:00:57.907Z");

Reply via email to