This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new b8102d596c Spark 3.4: Support NOT_EQ for V2 filters (#7898)
b8102d596c is described below
commit b8102d596c75e8b5c805c1ee6ad3c38bb37a0d0c
Author: Xianyang Liu <[email protected]>
AuthorDate: Wed Jul 5 05:03:21 2023 +0800
Spark 3.4: Support NOT_EQ for V2 filters (#7898)
Co-authored-by: xianyangliu <[email protected]>
---
.../org/apache/iceberg/spark/SparkV2Filters.java | 78 +++++++++++----
.../apache/iceberg/spark/TestSparkV2Filters.java | 109 +++++++++++++++++++++
2 files changed, 167 insertions(+), 20 deletions(-)
diff --git
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java
index 6d564bbd62..cbedc4d568 100644
---
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java
+++
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java
@@ -28,9 +28,12 @@ import static
org.apache.iceberg.expressions.Expressions.isNull;
import static org.apache.iceberg.expressions.Expressions.lessThan;
import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.not;
+import static org.apache.iceberg.expressions.Expressions.notEqual;
import static org.apache.iceberg.expressions.Expressions.notIn;
+import static org.apache.iceberg.expressions.Expressions.notNaN;
import static org.apache.iceberg.expressions.Expressions.notNull;
import static org.apache.iceberg.expressions.Expressions.or;
+import static org.apache.iceberg.expressions.Expressions.ref;
import static org.apache.iceberg.expressions.Expressions.startsWith;
import java.util.Arrays;
@@ -40,9 +43,12 @@ import java.util.stream.Collectors;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.Expression.Operation;
import org.apache.iceberg.expressions.Expressions;
+import org.apache.iceberg.expressions.UnboundPredicate;
+import org.apache.iceberg.expressions.UnboundTerm;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.util.NaNUtil;
+import org.apache.iceberg.util.Pair;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.filter.And;
@@ -57,6 +63,7 @@ public class SparkV2Filters {
private static final String FALSE = "ALWAYS_FALSE";
private static final String EQ = "=";
private static final String EQ_NULL_SAFE = "<=>";
+ private static final String NOT_EQ = "<>";
private static final String GT = ">";
private static final String GT_EQ = ">=";
private static final String LT = "<";
@@ -75,6 +82,7 @@ public class SparkV2Filters {
.put(FALSE, Operation.FALSE)
.put(EQ, Operation.EQ)
.put(EQ_NULL_SAFE, Operation.EQ)
+ .put(NOT_EQ, Operation.NOT_EQ)
.put(GT, Operation.GT)
.put(GT_EQ, Operation.GT_EQ)
.put(LT, Operation.LT)
@@ -152,31 +160,35 @@ public class SparkV2Filters {
}
case EQ: // used for both eq and null-safe-eq
- Object value;
- String columnName;
- if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate)))
{
- columnName = SparkUtil.toColumnName(leftChild(predicate));
- value = convertLiteral(rightChild(predicate));
- } else if (isRef(rightChild(predicate)) &&
isLiteral(leftChild(predicate))) {
- columnName = SparkUtil.toColumnName(rightChild(predicate));
- value = convertLiteral(leftChild(predicate));
- } else {
+ Pair<UnboundTerm<Object>, Object> eqChildren =
predicateChildren(predicate);
+ if (eqChildren == null) {
return null;
}
if (predicate.name().equals(EQ)) {
// comparison with null in normal equality is always null. this is
probably a mistake.
Preconditions.checkNotNull(
- value, "Expression is always false (eq is not null-safe): %s",
predicate);
- return handleEqual(columnName, value);
- } else { // "<=>"
- if (value == null) {
- return isNull(columnName);
- } else {
- return handleEqual(columnName, value);
- }
+ eqChildren.second(),
+ "Expression is always false (eq is not null-safe): %s",
+ predicate);
+ }
+
+ return handleEqual(eqChildren.first(), eqChildren.second());
+
+ case NOT_EQ:
+ Pair<UnboundTerm<Object>, Object> notEqChildren =
predicateChildren(predicate);
+ if (notEqChildren == null) {
+ return null;
}
+ // comparison with null in normal equality is always null. this is
probably a mistake.
+ Preconditions.checkNotNull(
+ notEqChildren.second(),
+ "Expression is always false (notEq is not null-safe): %s",
+ predicate);
+
+ return handleNotEqual(notEqChildren.first(), notEqChildren.second());
+
case IN:
if (isSupportedInPredicate(predicate)) {
return in(
@@ -245,6 +257,22 @@ public class SparkV2Filters {
return null;
}
+ private static Pair<UnboundTerm<Object>, Object> predicateChildren(Predicate
predicate) {
+ if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
+ UnboundTerm<Object> term =
ref(SparkUtil.toColumnName(leftChild(predicate)));
+ Object value = convertLiteral(rightChild(predicate));
+ return Pair.of(term, value);
+
+ } else if (isRef(rightChild(predicate)) &&
isLiteral(leftChild(predicate))) {
+ UnboundTerm<Object> term =
ref(SparkUtil.toColumnName(rightChild(predicate)));
+ Object value = convertLiteral(leftChild(predicate));
+ return Pair.of(term, value);
+
+ } else {
+ return null;
+ }
+ }
+
@SuppressWarnings("unchecked")
private static <T> T child(Predicate predicate) {
org.apache.spark.sql.connector.expressions.Expression[] children =
predicate.children();
@@ -289,11 +317,21 @@ public class SparkV2Filters {
return literal.value();
}
- private static Expression handleEqual(String attribute, Object value) {
+ private static UnboundPredicate<Object> handleEqual(UnboundTerm<Object>
term, Object value) {
+ if (value == null) {
+ return isNull(term);
+ } else if (NaNUtil.isNaN(value)) {
+ return isNaN(term);
+ } else {
+ return equal(term, value);
+ }
+ }
+
+ private static UnboundPredicate<Object> handleNotEqual(UnboundTerm<Object>
term, Object value) {
if (NaNUtil.isNaN(value)) {
- return isNaN(attribute);
+ return notNaN(term);
} else {
- return equal(attribute, value);
+ return notEqual(term, value);
}
}
diff --git
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java
index 4c8a32fa41..dabd5d991b 100644
---
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java
+++
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java
@@ -34,11 +34,13 @@ import org.apache.spark.sql.connector.expressions.filter.Or;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.unsafe.types.UTF8String;
+import org.assertj.core.api.Assertions;
import org.junit.Assert;
import org.junit.Test;
public class TestSparkV2Filters {
+ @SuppressWarnings("checkstyle:MethodLength")
@Test
public void testV2Filters() {
Map<String, String> attrMap = Maps.newHashMap();
@@ -128,6 +130,18 @@ public class TestSparkV2Filters {
Expression actualEq2 = SparkV2Filters.convert(eq2);
Assert.assertEquals("EqualTo must match", expectedEq2.toString(),
actualEq2.toString());
+ Predicate notEq1 = new Predicate("<>", attrAndValue);
+ Expression expectedNotEq1 = Expressions.notEqual(unquoted, 1);
+ Expression actualNotEq1 = SparkV2Filters.convert(notEq1);
+ Assert.assertEquals(
+ "NotEqualTo must match", expectedNotEq1.toString(),
actualNotEq1.toString());
+
+ Predicate notEq2 = new Predicate("<>", valueAndAttr);
+ Expression expectedNotEq2 = Expressions.notEqual(unquoted, 1);
+ Expression actualNotEq2 = SparkV2Filters.convert(notEq2);
+ Assert.assertEquals(
+ "NotEqualTo must match", expectedNotEq2.toString(),
actualNotEq2.toString());
+
Predicate eqNullSafe1 = new Predicate("<=>", attrAndValue);
Expression expectedEqNullSafe1 = Expressions.equal(unquoted, 1);
Expression actualEqNullSafe1 = SparkV2Filters.convert(eqNullSafe1);
@@ -189,6 +203,101 @@ public class TestSparkV2Filters {
});
}
+ @Test
+ public void testEqualToNull() {
+ String col = "col";
+ NamedReference namedReference = FieldReference.apply(col);
+ LiteralValue value = new LiteralValue(null, DataTypes.IntegerType);
+
+ org.apache.spark.sql.connector.expressions.Expression[] attrAndValue =
+ new org.apache.spark.sql.connector.expressions.Expression[]
{namedReference, value};
+ org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr =
+ new org.apache.spark.sql.connector.expressions.Expression[] {value,
namedReference};
+
+ Predicate eq1 = new Predicate("=", attrAndValue);
+ Assertions.assertThatThrownBy(() -> SparkV2Filters.convert(eq1))
+ .isInstanceOf(NullPointerException.class)
+ .hasMessageContaining("Expression is always false");
+
+ Predicate eq2 = new Predicate("=", valueAndAttr);
+ Assertions.assertThatThrownBy(() -> SparkV2Filters.convert(eq2))
+ .isInstanceOf(NullPointerException.class)
+ .hasMessageContaining("Expression is always false");
+
+ Predicate eqNullSafe1 = new Predicate("<=>", attrAndValue);
+ Expression expectedEqNullSafe = Expressions.isNull(col);
+ Expression actualEqNullSafe1 = SparkV2Filters.convert(eqNullSafe1);
+
Assertions.assertThat(actualEqNullSafe1.toString()).isEqualTo(expectedEqNullSafe.toString());
+
+ Predicate eqNullSafe2 = new Predicate("<=>", valueAndAttr);
+ Expression actualEqNullSafe2 = SparkV2Filters.convert(eqNullSafe2);
+
Assertions.assertThat(actualEqNullSafe2.toString()).isEqualTo(expectedEqNullSafe.toString());
+ }
+
+ @Test
+ public void testEqualToNaN() {
+ String col = "col";
+ NamedReference namedReference = FieldReference.apply(col);
+ LiteralValue value = new LiteralValue(Float.NaN, DataTypes.FloatType);
+
+ org.apache.spark.sql.connector.expressions.Expression[] attrAndValue =
+ new org.apache.spark.sql.connector.expressions.Expression[]
{namedReference, value};
+ org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr =
+ new org.apache.spark.sql.connector.expressions.Expression[] {value,
namedReference};
+
+ Predicate eqNaN1 = new Predicate("=", attrAndValue);
+ Expression expectedEqNaN = Expressions.isNaN(col);
+ Expression actualEqNaN1 = SparkV2Filters.convert(eqNaN1);
+
Assertions.assertThat(actualEqNaN1.toString()).isEqualTo(expectedEqNaN.toString());
+
+ Predicate eqNaN2 = new Predicate("=", valueAndAttr);
+ Expression actualEqNaN2 = SparkV2Filters.convert(eqNaN2);
+
Assertions.assertThat(actualEqNaN2.toString()).isEqualTo(expectedEqNaN.toString());
+ }
+
+ @Test
+ public void testNotEqualToNull() {
+ String col = "col";
+ NamedReference namedReference = FieldReference.apply(col);
+ LiteralValue value = new LiteralValue(null, DataTypes.IntegerType);
+
+ org.apache.spark.sql.connector.expressions.Expression[] attrAndValue =
+ new org.apache.spark.sql.connector.expressions.Expression[]
{namedReference, value};
+ org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr =
+ new org.apache.spark.sql.connector.expressions.Expression[] {value,
namedReference};
+
+ Predicate notEq1 = new Predicate("<>", attrAndValue);
+ Assertions.assertThatThrownBy(() -> SparkV2Filters.convert(notEq1))
+ .isInstanceOf(NullPointerException.class)
+ .hasMessageContaining("Expression is always false");
+
+ Predicate notEq2 = new Predicate("<>", valueAndAttr);
+ Assertions.assertThatThrownBy(() -> SparkV2Filters.convert(notEq2))
+ .isInstanceOf(NullPointerException.class)
+ .hasMessageContaining("Expression is always false");
+ }
+
+ @Test
+ public void testNotEqualToNaN() {
+ String col = "col";
+ NamedReference namedReference = FieldReference.apply(col);
+ LiteralValue value = new LiteralValue(Float.NaN, DataTypes.FloatType);
+
+ org.apache.spark.sql.connector.expressions.Expression[] attrAndValue =
+ new org.apache.spark.sql.connector.expressions.Expression[]
{namedReference, value};
+ org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr =
+ new org.apache.spark.sql.connector.expressions.Expression[] {value,
namedReference};
+
+ Predicate notEqNaN1 = new Predicate("<>", attrAndValue);
+ Expression expectedNotEqNaN = Expressions.notNaN(col);
+ Expression actualNotEqNaN1 = SparkV2Filters.convert(notEqNaN1);
+
Assertions.assertThat(actualNotEqNaN1.toString()).isEqualTo(expectedNotEqNaN.toString());
+
+ Predicate notEqNaN2 = new Predicate("<>", valueAndAttr);
+ Expression actualNotEqNaN2 = SparkV2Filters.convert(notEqNaN2);
+
Assertions.assertThat(actualNotEqNaN2.toString()).isEqualTo(expectedNotEqNaN.toString());
+ }
+
@Test
public void testTimestampFilterConversion() {
Instant instant = Instant.parse("2018-10-18T00:00:57.907Z");