This is an automated email from the ASF dual-hosted git repository.

gabor pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/parquet-java.git


The following commit(s) were added to refs/heads/master by this push:
     new 9275d594c PARQUET-34: Extend Contains support to all 
ColumnFilterPredicate types (#1370)
9275d594c is described below

commit 9275d594c97b7cf4b0e6cf169d005b38cd5a3f24
Author: Claire McGinty <[email protected]>
AuthorDate: Fri Jun 14 08:31:16 2024 +0200

    PARQUET-34: Extend Contains support to all ColumnFilterPredicate types 
(#1370)
---
 .../parquet/filter2/predicate/FilterApi.java       |   3 +-
 .../parquet/filter2/predicate/Operators.java       |  27 +++--
 ...crementallyUpdatedFilterPredicateGenerator.java | 114 ++++++++++++---------
 .../recordlevel/TestRecordLevelFilters.java        |  31 ++++++
 4 files changed, 118 insertions(+), 57 deletions(-)

diff --git 
a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java
 
b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java
index 4126b73e5..3c5168066 100644
--- 
a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java
+++ 
b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java
@@ -39,6 +39,7 @@ import org.apache.parquet.filter2.predicate.Operators.Not;
 import org.apache.parquet.filter2.predicate.Operators.NotEq;
 import org.apache.parquet.filter2.predicate.Operators.NotIn;
 import org.apache.parquet.filter2.predicate.Operators.Or;
+import 
org.apache.parquet.filter2.predicate.Operators.SingleColumnFilterPredicate;
 import org.apache.parquet.filter2.predicate.Operators.SupportsEqNotEq;
 import org.apache.parquet.filter2.predicate.Operators.SupportsLtGt;
 import org.apache.parquet.filter2.predicate.Operators.UserDefined;
@@ -258,7 +259,7 @@ public final class FilterApi {
     return new NotIn<>(column, values);
   }
 
-  public static <T extends Comparable<T>> Contains<T> contains(Eq<T> pred) {
+  public static <T extends Comparable<T>, P extends 
SingleColumnFilterPredicate<T>> Contains<T> contains(P pred) {
     return Contains.of(pred);
   }
 
diff --git 
a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java
 
b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java
index b86a5ef09..474862d02 100644
--- 
a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java
+++ 
b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java
@@ -85,8 +85,6 @@ public final class Operators {
   public static interface SupportsLtGt
       extends SupportsEqNotEq {} // marker for columns that can be used with 
lt(), ltEq(), gt(), gtEq()
 
-  public static interface SupportsContains {}
-
   public static final class IntColumn extends Column<Integer> implements 
SupportsLtGt {
     IntColumn(ColumnPath columnPath) {
       super(columnPath, Integer.class);
@@ -123,8 +121,13 @@ public final class Operators {
     }
   }
 
+  abstract static class SingleColumnFilterPredicate<T extends Comparable<T>>
+      implements FilterPredicate, Serializable {
+    abstract Column<T> getColumn();
+  }
+
   // base class for Eq, NotEq, Lt, Gt, LtEq, GtEq
-  abstract static class ColumnFilterPredicate<T extends Comparable<T>> 
implements FilterPredicate, Serializable {
+  abstract static class ColumnFilterPredicate<T extends Comparable<T>> extends 
SingleColumnFilterPredicate<T> {
     private final Column<T> column;
     private final T value;
 
@@ -136,6 +139,7 @@ public final class Operators {
       this.value = value;
     }
 
+    @Override
     public Column<T> getColumn() {
       return column;
     }
@@ -172,7 +176,7 @@ public final class Operators {
     }
   }
 
-  public static final class Eq<T extends Comparable<T>> extends 
ColumnFilterPredicate<T> implements SupportsContains {
+  public static final class Eq<T extends Comparable<T>> extends 
ColumnFilterPredicate<T> {
 
     // value can be null
     public Eq(Column<T> column, T value) {
@@ -255,7 +259,7 @@ public final class Operators {
    * {@link NotIn} is used to filter data that are not in the list of values.
    */
   public abstract static class SetColumnFilterPredicate<T extends 
Comparable<T>>
-      implements FilterPredicate, Serializable {
+      extends SingleColumnFilterPredicate<T> {
     private final Column<T> column;
     private final Set<T> values;
 
@@ -265,6 +269,7 @@ public final class Operators {
       checkArgument(!values.isEmpty(), "values in SetColumnFilterPredicate 
shouldn't be empty!");
     }
 
+    @Override
     public Column<T> getColumn() {
       return column;
     }
@@ -325,7 +330,7 @@ public final class Operators {
       this.column = Objects.requireNonNull(column, "column cannot be null");
     }
 
-    static <ColumnT extends Comparable<ColumnT>, C extends 
ColumnFilterPredicate<ColumnT> & SupportsContains>
+    static <ColumnT extends Comparable<ColumnT>, C extends 
SingleColumnFilterPredicate<ColumnT>>
         Contains<ColumnT> of(C pred) {
       return new ContainsColumnPredicate<>(pred);
     }
@@ -415,14 +420,18 @@ public final class Operators {
     }
   }
 
-  private static class ContainsColumnPredicate<T extends Comparable<T>, U 
extends ColumnFilterPredicate<T>>
+  private static class ContainsColumnPredicate<T extends Comparable<T>, U 
extends SingleColumnFilterPredicate<T>>
       extends Contains<T> {
     private final U underlying;
 
     ContainsColumnPredicate(U underlying) {
       super(underlying.getColumn());
-      if (underlying.getValue() == null) {
-        throw new IllegalArgumentException("Contains predicate does not 
support null element value");
+      if ((underlying instanceof ColumnFilterPredicate && 
((ColumnFilterPredicate) underlying).getValue() == null)
+          || (underlying instanceof SetColumnFilterPredicate
+              && ((SetColumnFilterPredicate) underlying)
+                  .getValues()
+                  .contains(null))) {
+        throw new IllegalArgumentException("Contains predicate does not 
support null element value(s)");
       }
       this.underlying = underlying;
     }
diff --git 
a/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java
 
b/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java
index 1a2f5e54e..b356c0ba9 100644
--- 
a/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java
+++ 
b/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java
@@ -115,13 +115,13 @@ public class IncrementallyUpdatedFilterPredicateGenerator 
{
 
     addVisitBegin("In");
     for (TypeInfo info : TYPES) {
-      addInNotInCase(info, true);
+      addInNotInCase(info, true, false);
     }
     addVisitEnd();
 
     addVisitBegin("NotIn");
     for (TypeInfo info : TYPES) {
-      addInNotInCase(info, false);
+      addInNotInCase(info, false, false);
     }
     addVisitEnd();
 
@@ -133,25 +133,25 @@ public class IncrementallyUpdatedFilterPredicateGenerator 
{
 
     addVisitBegin("Lt");
     for (TypeInfo info : TYPES) {
-      addInequalityCase(info, "<");
+      addInequalityCase(info, "<", false);
     }
     addVisitEnd();
 
     addVisitBegin("LtEq");
     for (TypeInfo info : TYPES) {
-      addInequalityCase(info, "<=");
+      addInequalityCase(info, "<=", false);
     }
     addVisitEnd();
 
     addVisitBegin("Gt");
     for (TypeInfo info : TYPES) {
-      addInequalityCase(info, ">");
+      addInequalityCase(info, ">", false);
     }
     addVisitEnd();
 
     addVisitBegin("GtEq");
     for (TypeInfo info : TYPES) {
-      addInequalityCase(info, ">=");
+      addInequalityCase(info, ">=", false);
     }
     addVisitEnd();
 
@@ -245,7 +245,7 @@ public class IncrementallyUpdatedFilterPredicateGenerator {
     add("    }\n\n");
   }
 
-  private void addInequalityCase(TypeInfo info, String op) throws IOException {
+  private void addInequalityCase(TypeInfo info, String op, boolean 
expectMultipleResults) throws IOException {
     if (!info.supportsInequality) {
       add("    if (clazz.equals(" + info.className + ".class)) {\n");
       add("      throw new IllegalArgumentException(\"Operator " + op + " not 
supported for " + info.className
@@ -268,12 +268,17 @@ public class IncrementallyUpdatedFilterPredicateGenerator 
{
         + "        public void update("
         + info.primitiveName + " value) {\n");
 
-    add("          setResult(comparator.compare(value, target) " + op + " 
0);\n");
+    if (!expectMultipleResults) {
+      add("          setResult(comparator.compare(value, target) " + op + " 
0);\n");
+    } else {
+      add("            if (!isKnown() && comparator.compare(value, target) " + 
op + " 0)"
+          + " { setResult(true); }\n");
+    }
 
     add("        }\n" + "      };\n" + "    }\n\n");
   }
 
-  private void addInNotInCase(TypeInfo info, boolean isEq) throws IOException {
+  private void addInNotInCase(TypeInfo info, boolean isEq, boolean 
expectMultipleResults) throws IOException {
     add("    if (clazz.equals(" + info.className + ".class)) {\n" + "      if 
(pred.getValues().contains(null)) {\n"
         + "        valueInspector = new ValueInspector() {\n"
         + "          @Override\n"
@@ -299,22 +304,23 @@ public class IncrementallyUpdatedFilterPredicateGenerator 
{
         + "\n"
         + "          @Override\n"
         + "          public void update("
-        + info.primitiveName + " value) {\n" + "            boolean set = 
false;\n");
+        + info.primitiveName + " value) {\n");
 
+    if (expectMultipleResults) {
+      add("            if (isKnown()) return;\n");
+    }
     add("            for (" + info.primitiveName + " i : target) {\n");
 
     add("              if(" + compareEquality("value", "i", isEq) + ") {\n");
 
-    add("                 setResult(true);\n");
-
-    add("                 set = true;\n");
-
-    add("                 break;\n");
+    add("                 setResult(true);\n                 return;\n");
 
     add("               }\n");
 
     add("             }\n");
-    add("             if (!set) setResult(false);\n");
+    if (!expectMultipleResults) {
+      add("             setResult(false);\n");
+    }
     add("           }\n");
 
     add("         };\n" + "       }\n" + "    }\n\n");
@@ -338,33 +344,45 @@ public class IncrementallyUpdatedFilterPredicateGenerator 
{
     add("      checkSatisfied();\n" + "    }\n");
   }
 
-  private void addContainsInspectorVisitor(String op, boolean isSupported) 
throws IOException {
-    if (isSupported) {
-      add("    @Override\n"
-          + "    public <T extends Comparable<T>> ValueInspector visit(" + op 
+ "<T> pred) {\n"
-          + "      ColumnPath columnPath = pred.getColumn().getColumnPath();\n"
-          + "      Class<T> clazz = pred.getColumn().getColumnType();\n"
-          + "      ValueInspector valueInspector = null;\n");
-
-      for (TypeInfo info : TYPES) {
-        switch (op) {
-          case "Eq":
-            addEqNotEqCase(info, true, true);
-            break;
-          default:
-            throw new UnsupportedOperationException("Op " + op + " not 
implemented for Contains filter");
-        }
-      }
+  private void addContainsInspectorVisitor(String op) throws IOException {
+    add("    @Override\n"
+        + "    public <T extends Comparable<T>> ValueInspector visit(" + op + 
"<T> pred) {\n"
+        + "      ColumnPath columnPath = pred.getColumn().getColumnPath();\n"
+        + "      Class<T> clazz = pred.getColumn().getColumnType();\n"
+        + "      ValueInspector valueInspector = null;\n");
 
-      add("      return valueInspector;" + "    }\n");
-    } else {
-      add("    @Override\n"
-          + "    public <T extends Comparable<T>> ValueInspector visit(" + op 
+ "<T> pred) {\n"
-          + "      throw new UnsupportedOperationException(\"" + op
-          + " not supported for Contains predicate\");\n"
-          + "    }\n"
-          + "\n");
+    for (TypeInfo info : TYPES) {
+      switch (op) {
+        case "Eq":
+          addEqNotEqCase(info, true, true);
+          break;
+        case "NotEq":
+          addEqNotEqCase(info, false, true);
+          break;
+        case "Lt":
+          addInequalityCase(info, "<", true);
+          break;
+        case "LtEq":
+          addInequalityCase(info, "<=", true);
+          break;
+        case "Gt":
+          addInequalityCase(info, ">", true);
+          break;
+        case "GtEq":
+          addInequalityCase(info, ">=", true);
+          break;
+        case "In":
+          addInNotInCase(info, true, true);
+          break;
+        case "NotIn":
+          addInNotInCase(info, false, true);
+          break;
+        default:
+          throw new UnsupportedOperationException("Op " + op + " not 
implemented for Contains filter");
+      }
     }
+
+    add("      return valueInspector;" + "    }\n");
   }
 
   private void addContainsBegin() throws IOException {
@@ -476,12 +494,14 @@ public class IncrementallyUpdatedFilterPredicateGenerator 
{
         + "      );\n"
         + "    }\n");
 
-    addContainsInspectorVisitor("Eq", true);
-    addContainsInspectorVisitor("NotEq", false);
-    addContainsInspectorVisitor("Lt", false);
-    addContainsInspectorVisitor("LtEq", false);
-    addContainsInspectorVisitor("Gt", false);
-    addContainsInspectorVisitor("GtEq", false);
+    addContainsInspectorVisitor("Eq");
+    addContainsInspectorVisitor("NotEq");
+    addContainsInspectorVisitor("Lt");
+    addContainsInspectorVisitor("LtEq");
+    addContainsInspectorVisitor("Gt");
+    addContainsInspectorVisitor("GtEq");
+    addContainsInspectorVisitor("In");
+    addContainsInspectorVisitor("NotIn");
 
     add("    @Override\n"
         + "    public ValueInspector visit(Operators.And pred) {\n"
diff --git 
a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java
 
b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java
index dedec409c..888f2d052 100644
--- 
a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java
+++ 
b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java
@@ -24,15 +24,20 @@ import static 
org.apache.parquet.filter2.predicate.FilterApi.contains;
 import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn;
 import static org.apache.parquet.filter2.predicate.FilterApi.eq;
 import static org.apache.parquet.filter2.predicate.FilterApi.gt;
+import static org.apache.parquet.filter2.predicate.FilterApi.gtEq;
 import static org.apache.parquet.filter2.predicate.FilterApi.in;
 import static org.apache.parquet.filter2.predicate.FilterApi.longColumn;
+import static org.apache.parquet.filter2.predicate.FilterApi.lt;
+import static org.apache.parquet.filter2.predicate.FilterApi.ltEq;
 import static org.apache.parquet.filter2.predicate.FilterApi.not;
 import static org.apache.parquet.filter2.predicate.FilterApi.notEq;
+import static org.apache.parquet.filter2.predicate.FilterApi.notIn;
 import static org.apache.parquet.filter2.predicate.FilterApi.or;
 import static org.apache.parquet.filter2.predicate.FilterApi.userDefined;
 import static org.junit.Assert.assertEquals;
 
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import java.io.File;
 import java.io.IOException;
 import java.io.Serializable;
@@ -215,6 +220,32 @@ public class TestRecordLevelFilters {
   public void testArrayContains() throws Exception {
     assertPredicate(
         contains(eq(binaryColumn("phoneNumbers.phone.kind"), 
Binary.fromString("home"))), 27L, 28L, 30L);
+
+    assertPredicate(
+        contains(notEq(binaryColumn("phoneNumbers.phone.kind"), 
Binary.fromString("cell"))), 27L, 28L, 30L);
+
+    assertPredicate(contains(gt(longColumn("phoneNumbers.phone.number"), 
1111111111L)), 20L, 27L, 28L);
+
+    assertPredicate(contains(gtEq(longColumn("phoneNumbers.phone.number"), 
1111111111L)), 20L, 27L, 28L, 30L);
+
+    assertPredicate(contains(lt(longColumn("phoneNumbers.phone.number"), 
105L)), 100L, 101L, 102L, 103L, 104L);
+
+    assertPredicate(
+        contains(ltEq(longColumn("phoneNumbers.phone.number"), 105L)), 100L, 
101L, 102L, 103L, 104L, 105L);
+
+    assertPredicate(
+        contains(in(
+            binaryColumn("phoneNumbers.phone.kind"),
+            ImmutableSet.of(Binary.fromString("apartment"), 
Binary.fromString("home")))),
+        27L,
+        28L,
+        30L);
+
+    assertPredicate(
+        contains(notIn(binaryColumn("phoneNumbers.phone.kind"), 
ImmutableSet.of(Binary.fromString("cell")))),
+        27L,
+        28L,
+        30L);
   }
 
   @Test

Reply via email to