This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit 6324646db7e4e1d777f7cec5a7f0a5a71c6c609f
Author: Julian Hyde <[email protected]>
AuthorDate: Thu Mar 30 18:06:29 2023 -0500

    In ImmutableBitSet, specialize forEach, and add forEachInt, anyMatch, 
allMatch
---
 .../enumerable/EnumerableSortedAggregate.java      |   3 +-
 .../apache/calcite/plan/SubstitutionVisitor.java   |   2 +-
 .../calcite/rel/rel2sql/RelToSqlConverter.java     |   2 +-
 .../calcite/rel/rules/AggregateMergeRule.java      |   2 +-
 .../org/apache/calcite/util/ImmutableBitSet.java   |  74 +++++++++--
 .../org/apache/calcite/util/ImmutableIntList.java  |  17 +++
 .../apache/calcite/util/ImmutableBitSetTest.java   | 136 +++++++++++++++++++--
 .../java/org/apache/calcite/util/UtilTest.java     |  31 +++++
 8 files changed, 238 insertions(+), 29 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java
 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java
index f28a6bbd87..73a4525112 100644
--- 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java
+++ 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java
@@ -93,7 +93,8 @@ public class EnumerableSortedAggregate extends 
EnumerableAggregateBase implement
     } else if (groupKeys.contains(requiredKeys)) {
       // group by a,b,c order by c,b
       List<RelFieldCollation> list = new 
ArrayList<>(collation.getFieldCollations());
-      groupKeys.except(requiredKeys).forEach(k -> list.add(new 
RelFieldCollation(k)));
+      groupKeys.except(requiredKeys).forEachInt(k ->
+          list.add(new RelFieldCollation(k)));
       RelCollation aggCollation = RelCollations.of(list);
       RelCollation inputCollation = RexUtil.apply(mapping, aggCollation);
       return Pair.of(traitSet.replace(aggCollation),
diff --git 
a/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java 
b/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java
index 9e7cddbd71..2315c5d672 100644
--- a/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java
+++ b/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java
@@ -1951,7 +1951,7 @@ public class SubstitutionVisitor {
     } else if (target.getGroupType() == Aggregate.Group.SIMPLE) {
       // Query is coarser level of aggregation. Generate an aggregate.
       final Map<Integer, Integer> map = new HashMap<>();
-      target.groupSet.forEach(k -> map.put(k, map.size()));
+      target.groupSet.forEachInt(k -> map.put(k, map.size()));
       for (int c : query.groupSet) {
         if (!map.containsKey(c)) {
           return null;
diff --git 
a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java 
b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java
index a1b2f8def4..35dd660ec9 100644
--- a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java
+++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java
@@ -580,7 +580,7 @@ public class RelToSqlConverter extends SqlImplementor
       // it out using a "HAVING GROUPING(groupSets) <> 0".
       // We want to generate the
       final SqlNodeList groupingList = new SqlNodeList(POS);
-      e.getGroupSet().forEach(g ->
+      e.getGroupSet().forEachInt(g ->
           groupingList.add(builder.context.field(g)));
       builder.setHaving(
           SqlStdOperatorTable.NOT_EQUALS.createCall(POS,
diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateMergeRule.java 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateMergeRule.java
index b601f1abb9..36a3bf54d5 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateMergeRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateMergeRule.java
@@ -87,7 +87,7 @@ public class AggregateMergeRule
 
     final ImmutableBitSet bottomGroupSet = bottomAgg.getGroupSet();
     final Map<Integer, Integer> map = new HashMap<>();
-    bottomGroupSet.forEach(v -> map.put(map.size(), v));
+    bottomGroupSet.forEachInt(v -> map.put(map.size(), v));
     for (int k : topAgg.getGroupSet()) {
       if (!map.containsKey(k)) {
         return;
diff --git a/core/src/main/java/org/apache/calcite/util/ImmutableBitSet.java 
b/core/src/main/java/org/apache/calcite/util/ImmutableBitSet.java
index c3ae4ca42e..0a46f13748 100644
--- a/core/src/main/java/org/apache/calcite/util/ImmutableBitSet.java
+++ b/core/src/main/java/org/apache/calcite/util/ImmutableBitSet.java
@@ -44,6 +44,9 @@ import java.util.Map;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.TreeMap;
+import java.util.function.Consumer;
+import java.util.function.IntConsumer;
+import java.util.function.IntPredicate;
 import java.util.stream.Collector;
 
 import static org.apache.calcite.linq4j.Nullness.castNonNull;
@@ -108,9 +111,24 @@ public class ImmutableBitSet
     return EMPTY;
   }
 
+  /** Creates an ImmutableBitSet with the given bit set. */
+  public static ImmutableBitSet of(int bit) {
+    if (bit < 0) {
+      throw new IndexOutOfBoundsException("bit < 0: " + bit);
+    }
+    long[] words = new long[wordIndex(bit) + 1];
+    int wordIndex = wordIndex(bit);
+    words[wordIndex] |= 1L << bit;
+    return new ImmutableBitSet(words);
+  }
+
+  /** Creates an ImmutableBitSet with the given bits set. */
   public static ImmutableBitSet of(int... bits) {
     int max = -1;
     for (int bit : bits) {
+      if (bit < 0) {
+        throw new IndexOutOfBoundsException("bit < 0: " + bit);
+      }
       max = Math.max(bit, max);
     }
     if (max == -1) {
@@ -130,6 +148,9 @@ public class ImmutableBitSet
     }
     int max = -1;
     for (int bit : bits) {
+      if (bit < 0) {
+        throw new IndexOutOfBoundsException("bit < 0: " + bit);
+      }
       max = Math.max(bit, max);
     }
     if (max == -1) {
@@ -621,6 +642,18 @@ public class ImmutableBitSet
     return list;
   }
 
+  @Override public void forEach(Consumer<? super Integer> action) {
+    forEachInt(action::accept);
+  }
+
+  /** As {@link #forEach(Consumer)} but on primitive {@code int} values. */
+  public void forEachInt(IntConsumer action) {
+    requireNonNull(action, "action");
+    for (int i = nextSetBit(0); i >= 0; i = nextSetBit(i + 1)) {
+      action.accept(i);
+    }
+  }
+
   /** Creates a view onto this bit set as a list of integers.
    *
    * <p>The {@code cardinality} and {@code get} methods are both O(n), but
@@ -945,6 +978,30 @@ public class ImmutableBitSet
     return true;
   }
 
+  /** Returns whether a given predicate evaluates to true for all bits in this
+   * set. */
+  public boolean allMatch(IntPredicate predicate) {
+    requireNonNull(predicate, "predicate");
+    for (int i = nextSetBit(0); i >= 0; i = nextSetBit(i + 1)) {
+      if (!predicate.test(i)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /** Returns whether a given predicate evaluates to true for any bit in this
+   * set. */
+  public boolean anyMatch(IntPredicate predicate) {
+    requireNonNull(predicate, "predicate");
+    for (int i = nextSetBit(0); i >= 0; i = nextSetBit(i + 1)) {
+      if (predicate.test(i)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
   /**
    * Setup equivalence Sets for each position. If i and j are equivalent then
    * they will have the same equivalence Set. The algorithm computes the
@@ -1148,34 +1205,25 @@ public class ImmutableBitSet
 
     /** Sets all bits in a given bit set. */
     public Builder addAll(ImmutableBitSet bitSet) {
-      for (Integer bit : bitSet) {
-        set(bit);
-      }
+      bitSet.forEachInt(this::set);
       return this;
     }
 
     /** Sets all bits in a given list of bits. */
     public Builder addAll(Iterable<Integer> integers) {
-      for (Integer integer : integers) {
-        set(integer);
-      }
+      integers.forEach(this::set);
       return this;
     }
 
     /** Sets all bits in a given list of {@code int}s. */
     public Builder addAll(ImmutableIntList integers) {
-      //noinspection ForLoopReplaceableByForEach
-      for (int i = 0; i < integers.size(); i++) {
-        set(integers.get(i));
-      }
+      integers.forEachInt(this::set);
       return this;
     }
 
     /** Clears all bits in a given bit set. */
     public Builder removeAll(ImmutableBitSet bitSet) {
-      for (Integer bit : bitSet) {
-        clear(bit);
-      }
+      bitSet.forEachInt(this::clear);
       return this;
     }
 
diff --git a/core/src/main/java/org/apache/calcite/util/ImmutableIntList.java 
b/core/src/main/java/org/apache/calcite/util/ImmutableIntList.java
index 7dc305b5f7..cb34fe3d53 100644
--- a/core/src/main/java/org/apache/calcite/util/ImmutableIntList.java
+++ b/core/src/main/java/org/apache/calcite/util/ImmutableIntList.java
@@ -36,6 +36,8 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.ListIterator;
 import java.util.NoSuchElementException;
+import java.util.function.Consumer;
+import java.util.function.IntConsumer;
 
 import static org.apache.calcite.linq4j.Nullness.castNonNull;
 
@@ -139,6 +141,21 @@ public class ImmutableIntList extends 
FlatLists.AbstractFlatList<Integer> {
     return ints.length;
   }
 
+  @Override public void forEach(Consumer<? super Integer> action) {
+    requireNonNull(action, "action");
+    for (int i : ints) {
+      action.accept(i);
+    }
+  }
+
+  /** As {@link #forEach(Consumer)} but on primitive {@code int} values. */
+  public void forEachInt(IntConsumer action) {
+    requireNonNull(action, "action");
+    for (int i : ints) {
+      action.accept(i);
+    }
+  }
+
   @Override public Object[] toArray() {
     final Object[] objects = new Object[ints.length];
     for (int i = 0; i < objects.length; i++) {
diff --git 
a/core/src/test/java/org/apache/calcite/util/ImmutableBitSetTest.java 
b/core/src/test/java/org/apache/calcite/util/ImmutableBitSetTest.java
index e9632406fe..fb0332bdb7 100644
--- a/core/src/test/java/org/apache/calcite/util/ImmutableBitSetTest.java
+++ b/core/src/test/java/org/apache/calcite/util/ImmutableBitSetTest.java
@@ -35,6 +35,10 @@ import java.util.Set;
 import java.util.SortedMap;
 import java.util.TreeMap;
 import java.util.TreeSet;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+import java.util.function.IntConsumer;
+import java.util.function.IntPredicate;
 
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.is;
@@ -42,6 +46,7 @@ import static org.hamcrest.CoreMatchers.sameInstance;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.fail;
 
@@ -55,6 +60,49 @@ class ImmutableBitSetTest {
     assertToIterBitSet("0", ImmutableBitSet.of(0));
     assertToIterBitSet("0, 1", ImmutableBitSet.of(0, 1));
     assertToIterBitSet("10", ImmutableBitSet.of(10));
+
+    check((bitSet, list) -> {
+      final List<Integer> list2 = new ArrayList<>();
+      for (Integer integer : bitSet) {
+        list2.add(integer);
+      }
+      assertThat(list2, equalTo(list));
+    });
+  }
+
+  /** Tests the method {@link ImmutableBitSet#of(int)}. */
+  @Test void testSingletonConstructor() {
+    IntConsumer c = i -> {
+      final ImmutableBitSet s0 = ImmutableBitSet.of(i);
+      final ImmutableBitSet s1 = ImmutableBitSet.of(ImmutableIntList.of(i));
+      final ImmutableBitSet s2 = ImmutableBitSet.of(Collections.singleton(i));
+      final ImmutableBitSet s3 =
+          ImmutableBitSet.of(99, 100).set(i).clear(100).clear(99);
+      assertThat(s0.cardinality(), is(1));
+      assertThat(s0, is(s1));
+      assertThat(s0, is(s2));
+      assertThat(s0, is(s3));
+      assertThat(s1, is(s2));
+      assertThat(s1, is(s3));
+      assertThat(s2, is(s3));
+    };
+    c.accept(0);
+    c.accept(1);
+    c.accept(63);
+    c.accept(64);
+  }
+
+  @Test void testNegative() {
+    assertThrows(IndexOutOfBoundsException.class,
+        () -> ImmutableBitSet.of(-1));
+    assertThrows(IndexOutOfBoundsException.class,
+        () -> ImmutableBitSet.of(-2));
+    assertThrows(IndexOutOfBoundsException.class,
+        () -> ImmutableBitSet.of(1, 10, -1, 63));
+    assertThrows(IndexOutOfBoundsException.class,
+        () -> ImmutableBitSet.of(-1, 10));
+    assertThrows(IndexOutOfBoundsException.class,
+        () -> ImmutableBitSet.of(Collections.singleton(-2)));
   }
 
   /**
@@ -80,18 +128,42 @@ class ImmutableBitSetTest {
    * {@link org.apache.calcite.util.ImmutableBitSet#toList()}.
    */
   @Test void testToList() {
-    assertThat(ImmutableBitSet.of().toList(),
-        equalTo(Collections.<Integer>emptyList()));
-    assertThat(ImmutableBitSet.of(5).toList(), equalTo(Arrays.asList(5)));
-    assertThat(ImmutableBitSet.of(3, 5).toList(), equalTo(Arrays.asList(3, 
5)));
-    assertThat(ImmutableBitSet.of(63).toList(), equalTo(Arrays.asList(63)));
-    assertThat(ImmutableBitSet.of(64).toList(), equalTo(Arrays.asList(64)));
-    assertThat(ImmutableBitSet.of(3, 63).toList(),
-        equalTo(Arrays.asList(3, 63)));
-    assertThat(ImmutableBitSet.of(3, 64).toList(),
-        equalTo(Arrays.asList(3, 64)));
-    assertThat(ImmutableBitSet.of(0, 4, 2).toList(),
-        equalTo(Arrays.asList(0, 2, 4)));
+    check((bitSet, list) -> assertThat(bitSet.toList(), equalTo(list)));
+  }
+
+  /**
+   * Tests the method
+   * {@link org.apache.calcite.util.ImmutableBitSet#forEachInt}.
+   */
+  @Test void testForEachInt() {
+    check((bitSet, list) -> {
+      final List<Integer> list2 = new ArrayList<>();
+      bitSet.forEachInt(list2::add);
+      assertThat(list2, equalTo(list));
+    });
+  }
+
+  /**
+   * Tests the method
+   * {@link org.apache.calcite.util.ImmutableBitSet#forEach}.
+   */
+  @Test void testForEachInteger() {
+    check((bitSet, list) -> {
+      final List<Integer> list2 = new ArrayList<>();
+      bitSet.forEach(list2::add);
+      assertThat(list2, equalTo(list));
+    });
+  }
+
+  private void check(BiConsumer<ImmutableBitSet, List<Integer>> consumer) {
+    consumer.accept(ImmutableBitSet.of(), Collections.emptyList());
+    consumer.accept(ImmutableBitSet.of(5), Collections.singletonList(5));
+    consumer.accept(ImmutableBitSet.of(3, 5), Arrays.asList(3, 5));
+    consumer.accept(ImmutableBitSet.of(63), Collections.singletonList(63));
+    consumer.accept(ImmutableBitSet.of(64), Collections.singletonList(64));
+    consumer.accept(ImmutableBitSet.of(3, 63), Arrays.asList(3, 63));
+    consumer.accept(ImmutableBitSet.of(3, 64), Arrays.asList(3, 64));
+    consumer.accept(ImmutableBitSet.of(0, 4, 2), Arrays.asList(0, 2, 4));
   }
 
   /**
@@ -593,6 +665,46 @@ class ImmutableBitSetTest {
     assertFalse(ImmutableBitSet.allContain(collection2, 4));
   }
 
+  /**
+   * Test case for {@link ImmutableBitSet#anyMatch(IntPredicate)}
+   * and {@link ImmutableBitSet#allMatch(IntPredicate)}.
+   *
+   * <p>Checks a variety of predicates (is even, is zero, always true,
+   * always false) and their negations on a variety of bit sets.
+   */
+  @Test void testAnyMatch() {
+    BiConsumer<ImmutableBitSet, IntPredicate> c = (bitSet, predicate) -> {
+      final Set<Integer> integerSet = new HashSet<>(bitSet.asList());
+      assertThat(bitSet.anyMatch(predicate),
+          is(integerSet.stream().anyMatch(predicate::test)));
+      assertThat(bitSet.allMatch(predicate),
+          is(integerSet.stream().allMatch(predicate::test)));
+    };
+
+    BiConsumer<ImmutableBitSet, IntPredicate> c2 = (bitSet, predicate) -> {
+      c.accept(bitSet, predicate);
+      c.accept(bitSet, predicate.negate());
+    };
+
+    final ImmutableBitSet set0 = ImmutableBitSet.of();
+    final ImmutableBitSet set1 = ImmutableBitSet.of(0, 1, 2, 3);
+    final ImmutableBitSet set2 = ImmutableBitSet.of(0, 2, 4, 8);
+    Consumer<IntPredicate> c3 = predicate -> {
+      c2.accept(set0, predicate);
+      c2.accept(set1, predicate);
+      c2.accept(set2, predicate);
+    };
+
+    final IntPredicate isZero = i -> i == 0;
+    final IntPredicate isEven = i -> i % 2 == 0;
+    final IntPredicate alwaysTrue = i -> true;
+    final IntPredicate alwaysFalse = i -> false;
+    c3.accept(isZero);
+    c3.accept(isEven);
+    c3.accept(alwaysTrue);
+    c3.accept(alwaysFalse);
+  }
+
   /** Test case for
    * {@link org.apache.calcite.util.ImmutableBitSet#toImmutableBitSet()}. */
   @Test void testCollector() {
diff --git a/core/src/test/java/org/apache/calcite/util/UtilTest.java 
b/core/src/test/java/org/apache/calcite/util/UtilTest.java
index 176a601f29..c337eca078 100644
--- a/core/src/test/java/org/apache/calcite/util/UtilTest.java
+++ b/core/src/test/java/org/apache/calcite/util/UtilTest.java
@@ -97,6 +97,7 @@ import java.util.TreeSet;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.BiConsumer;
 import java.util.function.BiFunction;
+import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.ObjIntConsumer;
 import java.util.function.Predicate;
@@ -1233,6 +1234,32 @@ class UtilTest {
   }
 
   @Test void testImmutableIntList() {
+    final BiConsumer<ImmutableIntList, List<Integer>> c2 = (intList, list) -> {
+      assertThat(list.size(), is(intList.size()));
+      assertThat(list, is(intList));
+      assertThat(list.toString(), is(intList.toString()));
+      assertThat(list.hashCode(), is(intList.hashCode()));
+    };
+
+    final Consumer<ImmutableIntList> c = list -> {
+      final List<Integer> arrayList = new ArrayList<>(list);
+      c2.accept(list, arrayList);
+
+      final List<Integer> arrayList2 = new ArrayList<>();
+      //noinspection CollectionAddAllCanBeReplacedWithConstructor
+      arrayList2.addAll(list);
+      c2.accept(list, arrayList2);
+
+      final List<Integer> arrayList3 = new ArrayList<>();
+      //noinspection UseBulkOperation
+      list.forEach(arrayList3::add);
+      c2.accept(list, arrayList3);
+
+      final List<Integer> arrayList4 = new ArrayList<>();
+      list.forEachInt(arrayList4::add);
+      c2.accept(list, arrayList4);
+    };
+
     final ImmutableIntList list = ImmutableIntList.of();
     assertEquals(0, list.size());
     assertEquals(list, Collections.<Integer>emptyList());
@@ -1263,6 +1290,10 @@ class UtilTest {
     assertThat(
         Arrays.toString(ImmutableIntList.of(1).toArray(new Integer[]{5, 6, 
7})),
         is("[1, null, 7]"));
+
+    c.accept(list);
+    c.accept(list2);
+    c.accept(ImmutableIntList.of(-2, 10, 1, -2));
   }
 
   /** Unit test for {@link IdPair}. */

Reply via email to