This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new a7806e159 [VL] Support fallback processing of velox_bloom_filter_agg 
(#5477)
a7806e159 is described below

commit a7806e159f75315e6b8127fce10caaac6a21d25e
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Apr 23 10:01:46 2024 +0800

    [VL] Support fallback processing of velox_bloom_filter_agg (#5477)
---
 .../apache/spark/util/sketch/VeloxBloomFilter.java |  38 +++++-
 .../util/sketch/VeloxBloomFilterJniWrapper.java    |   8 ++
 .../BloomFilterMightContainJointRewriteRule.scala  |   5 +-
 .../aggregate/VeloxBloomFilterAggregate.scala      |  63 +++++++---
 .../spark/util/sketch/VeloxBloomFilterTest.java    | 136 ++++++++++++++++++++-
 cpp/velox/jni/VeloxJniWrapper.cc                   |  73 ++++++++++-
 .../java/org/apache/gluten/exec/RuntimeAware.java  |   4 +
 .../sql/GlutenBloomFilterAggregateQuerySuite.scala |   2 +-
 .../sql/GlutenBloomFilterAggregateQuerySuite.scala |  35 +++++-
 .../sql/GlutenBloomFilterAggregateQuerySuite.scala |   4 +-
 10 files changed, 339 insertions(+), 29 deletions(-)

diff --git 
a/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilter.java
 
b/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilter.java
index 71bbbb051..59716ed79 100644
--- 
a/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilter.java
+++ 
b/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilter.java
@@ -19,6 +19,7 @@ package org.apache.spark.util.sketch;
 import org.apache.commons.io.IOUtils;
 
 import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
@@ -33,6 +34,15 @@ public class VeloxBloomFilter extends BloomFilter {
     handle = jni.init(data);
   }
 
+  private VeloxBloomFilter(int capacity) {
+    jni = VeloxBloomFilterJniWrapper.create();
+    handle = jni.empty(capacity);
+  }
+
+  public static VeloxBloomFilter empty(int capacity) {
+    return new VeloxBloomFilter(capacity);
+  }
+
   public static VeloxBloomFilter readFrom(InputStream in) {
     try {
       byte[] all = IOUtils.toByteArray(in);
@@ -50,6 +60,15 @@ public class VeloxBloomFilter extends BloomFilter {
     }
   }
 
+  public byte[] serialize() {
+    try (ByteArrayOutputStream o = new ByteArrayOutputStream()) {
+      writeTo(o);
+      return o.toByteArray();
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
   @Override
   public double expectedFpp() {
     throw new UnsupportedOperationException("Not yet implemented");
@@ -72,7 +91,8 @@ public class VeloxBloomFilter extends BloomFilter {
 
   @Override
   public boolean putLong(long item) {
-    throw new UnsupportedOperationException("Not yet implemented");
+    jni.insertLong(handle, item);
+    return true;
   }
 
   @Override
@@ -87,7 +107,18 @@ public class VeloxBloomFilter extends BloomFilter {
 
   @Override
   public BloomFilter mergeInPlace(BloomFilter other) throws 
IncompatibleMergeException {
-    throw new UnsupportedOperationException("Not yet implemented");
+    if (!(other instanceof VeloxBloomFilter)) {
+      throw new IncompatibleMergeException(
+          "Cannot merge Velox bloom-filter with non-Velox bloom-filter");
+    }
+    final VeloxBloomFilter from = (VeloxBloomFilter) other;
+
+    if (!jni.isCompatibleWith(from.jni)) {
+      throw new IncompatibleMergeException(
+          "Cannot merge Velox bloom-filters with different Velox runtimes");
+    }
+    jni.mergeFrom(handle, from.handle);
+    return this;
   }
 
   @Override
@@ -117,6 +148,7 @@ public class VeloxBloomFilter extends BloomFilter {
 
   @Override
   public void writeTo(OutputStream out) throws IOException {
-    throw new UnsupportedOperationException("Not yet implemented");
+    byte[] data = jni.serialize(handle);
+    out.write(data);
   }
 }
diff --git 
a/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilterJniWrapper.java
 
b/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilterJniWrapper.java
index a369c8a30..572e2c7ac 100644
--- 
a/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilterJniWrapper.java
+++ 
b/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilterJniWrapper.java
@@ -36,7 +36,15 @@ public class VeloxBloomFilterJniWrapper implements 
RuntimeAware {
     return runtime.getHandle();
   }
 
+  public native long empty(int capacity);
+
   public native long init(byte[] data);
 
+  public native void insertLong(long handle, long item);
+
   public native boolean mightContainLong(long handle, long item);
+
+  public native void mergeFrom(long handle, long other);
+
+  public native byte[] serialize(long handle);
 }
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/BloomFilterMightContainJointRewriteRule.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/BloomFilterMightContainJointRewriteRule.scala
index bbec3ee01..7d15e32b3 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/BloomFilterMightContainJointRewriteRule.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/BloomFilterMightContainJointRewriteRule.scala
@@ -27,10 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
 
 case class BloomFilterMightContainJointRewriteRule(spark: SparkSession) 
extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
-    if (
-      !(GlutenConfig.getConf.enableNativeBloomFilter &&
-        GlutenConfig.getConf.enableColumnarHashAgg)
-    ) {
+    if (!(GlutenConfig.getConf.enableNativeBloomFilter)) {
       return plan
     }
     val out = plan.transformWithSubqueries {
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VeloxBloomFilterAggregate.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VeloxBloomFilterAggregate.scala
index 720261fd2..da545aa47 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VeloxBloomFilterAggregate.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VeloxBloomFilterAggregate.scala
@@ -22,8 +22,10 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.sketch.BloomFilter
+import org.apache.spark.util.TaskResources
+import org.apache.spark.util.sketch.{BloomFilter, VeloxBloomFilter}
 
 /**
  * Velox's bloom-filter implementation uses different algorithms internally 
comparing to vanilla
@@ -48,6 +50,15 @@ case class VeloxBloomFilterAggregate(
 
   override def prettyName: String = "velox_bloom_filter_agg"
 
+  // Mark as lazy so that `estimatedNumItems` is not evaluated during tree 
transformation.
+  private lazy val estimatedNumItems: Long =
+    Math.min(
+      estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue,
+      SQLConf.get
+        .getConfString("spark.sql.optimizer.runtime.bloomFilter.maxNumItems", 
"4000000")
+        .toLong
+    )
+
   override def first: Expression = child
 
   override def second: Expression = estimatedNumItemsExpression
@@ -70,26 +81,48 @@ case class VeloxBloomFilterAggregate(
       numBitsExpression = newNumBitsExpression)
   }
 
-  override def createAggregationBuffer(): BloomFilter = throw new 
UnsupportedOperationException()
+  override def createAggregationBuffer(): BloomFilter = {
+    if (!TaskResources.inSparkTask()) {
+      throw new UnsupportedOperationException("velox_bloom_filter_agg is not 
evaluable on Driver")
+    }
+    VeloxBloomFilter.empty(Math.toIntExact(estimatedNumItems))
+  }
 
-  override def update(buffer: BloomFilter, input: InternalRow): BloomFilter =
-    throw new UnsupportedOperationException()
+  override def update(buffer: BloomFilter, input: InternalRow): BloomFilter = {
+    assert(buffer.isInstanceOf[VeloxBloomFilter])
+    val value = child.eval(input)
+    // Ignore null values.
+    if (value == null) {
+      return buffer
+    }
+    buffer.putLong(value.asInstanceOf[Long])
+    buffer
+  }
 
-  override def merge(buffer: BloomFilter, input: BloomFilter): BloomFilter =
-    throw new UnsupportedOperationException()
+  override def merge(buffer: BloomFilter, input: BloomFilter): BloomFilter = {
+    assert(buffer.isInstanceOf[VeloxBloomFilter])
+    assert(input.isInstanceOf[VeloxBloomFilter])
+    buffer.asInstanceOf[VeloxBloomFilter].mergeInPlace(input)
+  }
 
-  override def eval(buffer: BloomFilter): Any = throw new 
UnsupportedOperationException()
+  override def eval(buffer: BloomFilter): Any = {
+    assert(buffer.isInstanceOf[VeloxBloomFilter])
+    serialize(buffer)
+  }
 
-  override def serialize(buffer: BloomFilter): Array[Byte] =
-    throw new UnsupportedOperationException()
+  override def serialize(buffer: BloomFilter): Array[Byte] = {
+    assert(buffer.isInstanceOf[VeloxBloomFilter])
+    buffer.asInstanceOf[VeloxBloomFilter].serialize()
+  }
 
-  override def deserialize(storageFormat: Array[Byte]): BloomFilter =
-    throw new UnsupportedOperationException()
+  override def deserialize(bytes: Array[Byte]): BloomFilter = {
+    VeloxBloomFilter.readFrom(bytes)
+  }
 
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
-    throw new UnsupportedOperationException()
+  override def withNewMutableAggBufferOffset(newOffset: Int): 
VeloxBloomFilterAggregate =
+    copy(mutableAggBufferOffset = newOffset)
 
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
-    throw new UnsupportedOperationException()
+  override def withNewInputAggBufferOffset(newOffset: Int): 
VeloxBloomFilterAggregate =
+    copy(inputAggBufferOffset = newOffset)
 
 }
diff --git 
a/backends-velox/src/test/java/org/apache/spark/util/sketch/VeloxBloomFilterTest.java
 
b/backends-velox/src/test/java/org/apache/spark/util/sketch/VeloxBloomFilterTest.java
index 81213fab2..ce439b864 100644
--- 
a/backends-velox/src/test/java/org/apache/spark/util/sketch/VeloxBloomFilterTest.java
+++ 
b/backends-velox/src/test/java/org/apache/spark/util/sketch/VeloxBloomFilterTest.java
@@ -25,6 +25,7 @@ import org.apache.spark.util.TaskResources$;
 import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Test;
+import org.junit.function.ThrowingRunnable;
 
 import java.nio.ByteBuffer;
 
@@ -39,16 +40,149 @@ public class VeloxBloomFilterTest {
 
   @Test
   public void testEmpty() {
+    TaskResources$.MODULE$.runUnsafe(
+        () -> {
+          final BloomFilter filter = VeloxBloomFilter.empty(10000);
+          for (int i = 0; i < 1000; i++) {
+            Assert.assertFalse(filter.mightContainLong(i));
+          }
+          return null;
+        });
+  }
+
+  @Test
+  public void testMalformed() {
     final ByteBuffer buf = ByteBuffer.allocate(5);
     buf.put((byte) 1); // kBloomFilterV1
     buf.putInt(0); // size
     TaskResources$.MODULE$.runUnsafe(
         () -> {
           final BloomFilter filter = VeloxBloomFilter.readFrom(buf.array());
-          for (int i = 0; i < 1000; i++) {
+          Assert.assertThrows(
+              "Bloom-filter is not initialized",
+              RuntimeException.class,
+              new ThrowingRunnable() {
+                @Override
+                public void run() throws Throwable {
+                  filter.mightContainLong(0);
+                }
+              });
+          return null;
+        });
+  }
+
+  @Test
+  public void testSanity() {
+    TaskResources$.MODULE$.runUnsafe(
+        () -> {
+          final BloomFilter filter = VeloxBloomFilter.empty(10000);
+          final int numItems = 2000;
+          final int halfNumItems = numItems / 2;
+          for (int i = -halfNumItems; i < halfNumItems; i++) {
             Assert.assertFalse(filter.mightContainLong(i));
           }
+          for (int i = -halfNumItems; i < halfNumItems; i++) {
+            filter.putLong(i);
+            Assert.assertTrue(filter.mightContainLong(i));
+          }
+          for (int i = -halfNumItems; i < halfNumItems; i++) {
+            Assert.assertTrue(filter.mightContainLong(i));
+          }
+
+          // Check false positives.
+          checkFalsePositives(filter, halfNumItems);
+
           return null;
         });
   }
+
+  @Test
+  public void testMerge() {
+    TaskResources$.MODULE$.runUnsafe(
+        () -> {
+          final BloomFilter filter1 = VeloxBloomFilter.empty(10000);
+          final int start1 = 0;
+          final int end1 = 2000;
+          for (int i = start1; i < end1; i++) {
+            Assert.assertFalse(filter1.mightContainLong(i));
+          }
+          for (int i = start1; i < end1; i++) {
+            filter1.putLong(i);
+            Assert.assertTrue(filter1.mightContainLong(i));
+          }
+          for (int i = start1; i < end1; i++) {
+            Assert.assertTrue(filter1.mightContainLong(i));
+          }
+
+          final BloomFilter filter2 = VeloxBloomFilter.empty(10000);
+          final int start2 = 1000;
+          final int end2 = 3000;
+          for (int i = start2; i < end2; i++) {
+            Assert.assertFalse(filter2.mightContainLong(i));
+          }
+          for (int i = start2; i < end2; i++) {
+            filter2.putLong(i);
+            Assert.assertTrue(filter2.mightContainLong(i));
+          }
+          for (int i = start2; i < end2; i++) {
+            Assert.assertTrue(filter2.mightContainLong(i));
+          }
+
+          try {
+            filter1.mergeInPlace(filter2);
+          } catch (IncompatibleMergeException e) {
+            throw new RuntimeException(e);
+          }
+
+          for (int i = start1; i < end2; i++) {
+            Assert.assertTrue(filter1.mightContainLong(i));
+          }
+
+          // Check false positives.
+          checkFalsePositives(filter1, end2);
+
+          return null;
+        });
+  }
+
+  @Test
+  public void testSerde() {
+    TaskResources$.MODULE$.runUnsafe(
+        () -> {
+          final VeloxBloomFilter filter = VeloxBloomFilter.empty(10000);
+          for (int i = 0; i < 1000; i++) {
+            filter.putLong(i);
+          }
+
+          byte[] data1 = filter.serialize();
+
+          final VeloxBloomFilter filter2 = VeloxBloomFilter.readFrom(data1);
+          byte[] data2 = filter2.serialize();
+
+          Assert.assertArrayEquals(data2, data1);
+          return null;
+        });
+  }
+
+  private static void checkFalsePositives(BloomFilter filter, int start) {
+    final int attemptStart = start;
+    final int attemptCount = 5000000;
+
+    int falsePositives = 0;
+    int negativeFalsePositives = 0;
+
+    for (int i = attemptStart; i < attemptStart + attemptCount; i++) {
+      if (filter.mightContainLong(i)) {
+        falsePositives++;
+      }
+      if (filter.mightContainLong(-i)) {
+        negativeFalsePositives++;
+      }
+    }
+
+    Assert.assertTrue(falsePositives > 0);
+    Assert.assertTrue(falsePositives < attemptCount);
+    Assert.assertTrue(negativeFalsePositives > 0);
+    Assert.assertTrue(negativeFalsePositives < attemptCount);
+  }
 }
diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc
index 1d8c09fa4..a3c51f64a 100644
--- a/cpp/velox/jni/VeloxJniWrapper.cc
+++ b/cpp/velox/jni/VeloxJniWrapper.cc
@@ -141,12 +141,24 @@ 
Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateWithFail
   JNI_METHOD_END(nullptr)
 }
 
+JNIEXPORT jlong JNICALL 
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_empty( // NOLINT
+    JNIEnv* env,
+    jobject wrapper,
+    jint capacity) {
+  JNI_METHOD_START
+  auto ctx = gluten::getRuntime(env, wrapper);
+  auto filter = 
std::make_shared<velox::BloomFilter<std::allocator<uint64_t>>>();
+  filter->reset(capacity);
+  GLUTEN_CHECK(filter->isSet(), "Bloom-filter is not initialized");
+  return ctx->objectStore()->save(filter);
+  JNI_METHOD_END(gluten::kInvalidResourceHandle)
+}
+
 JNIEXPORT jlong JNICALL 
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_init( // NOLINT
     JNIEnv* env,
     jobject wrapper,
     jbyteArray data) {
   JNI_METHOD_START
-
   auto len = env->GetArrayLength(data);
   auto safeArray = gluten::getByteArrayElementsSafe(env, data);
   auto ctx = gluten::getRuntime(env, wrapper);
@@ -157,6 +169,19 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWra
   JNI_METHOD_END(gluten::kInvalidResourceHandle)
 }
 
+JNIEXPORT void JNICALL 
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_insertLong( // 
NOLINT
+    JNIEnv* env,
+    jobject wrapper,
+    jlong handle,
+    jlong item) {
+  JNI_METHOD_START
+  auto ctx = gluten::getRuntime(env, wrapper);
+  auto filter = 
ctx->objectStore()->retrieve<velox::BloomFilter<std::allocator<uint64_t>>>(handle);
+  GLUTEN_CHECK(filter->isSet(), "Bloom-filter is not initialized");
+  filter->insert(folly::hasher<int64_t>()(item));
+  JNI_METHOD_END()
+}
+
 JNIEXPORT jboolean JNICALL 
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_mightContainLong( 
// NOLINT
     JNIEnv* env,
     jobject wrapper,
@@ -165,11 +190,55 @@ JNIEXPORT jboolean JNICALL 
Java_org_apache_spark_util_sketch_VeloxBloomFilterJni
   JNI_METHOD_START
   auto ctx = gluten::getRuntime(env, wrapper);
   auto filter = 
ctx->objectStore()->retrieve<velox::BloomFilter<std::allocator<uint64_t>>>(handle);
-  bool out = filter->isSet() && 
filter->mayContain(folly::hasher<int64_t>()(item));
+  GLUTEN_CHECK(filter->isSet(), "Bloom-filter is not initialized");
+  bool out = filter->mayContain(folly::hasher<int64_t>()(item));
   return out;
   JNI_METHOD_END(false)
 }
 
+namespace {
+static std::vector<char> serialize(BloomFilter<std::allocator<uint64_t>>* bf) {
+  uint32_t size = bf->serializedSize();
+  std::vector<char> buffer;
+  buffer.reserve(size);
+  char* data = buffer.data();
+  bf->serialize(data);
+  return buffer;
+}
+} // namespace
+
+JNIEXPORT void JNICALL 
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_mergeFrom( // 
NOLINT
+    JNIEnv* env,
+    jobject wrapper,
+    jlong handle,
+    jlong other) {
+  JNI_METHOD_START
+  auto ctx = gluten::getRuntime(env, wrapper);
+  auto to = 
ctx->objectStore()->retrieve<velox::BloomFilter<std::allocator<uint64_t>>>(handle);
+  auto from = 
ctx->objectStore()->retrieve<velox::BloomFilter<std::allocator<uint64_t>>>(other);
+  GLUTEN_CHECK(to->isSet(), "Bloom-filter is not initialized");
+  GLUTEN_CHECK(from->isSet(), "Bloom-filter is not initialized");
+  std::vector<char> serialized = serialize(from.get());
+  to->merge(serialized.data());
+  JNI_METHOD_END()
+}
+
+JNIEXPORT jbyteArray JNICALL 
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_serialize( // 
NOLINT
+    JNIEnv* env,
+    jobject wrapper,
+    jlong handle) {
+  JNI_METHOD_START
+  auto ctx = gluten::getRuntime(env, wrapper);
+  auto filter = 
ctx->objectStore()->retrieve<velox::BloomFilter<std::allocator<uint64_t>>>(handle);
+  GLUTEN_CHECK(filter->isSet(), "Bloom-filter is not initialized");
+  std::vector<char> buffer = serialize(filter.get());
+  auto size = buffer.capacity();
+  jbyteArray out = env->NewByteArray(size);
+  env->SetByteArrayRegion(out, 0, size, 
reinterpret_cast<jbyte*>(buffer.data()));
+  return out;
+  JNI_METHOD_END(nullptr)
+}
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/gluten-data/src/main/java/org/apache/gluten/exec/RuntimeAware.java 
b/gluten-data/src/main/java/org/apache/gluten/exec/RuntimeAware.java
index 8a3f15b40..ca96ace64 100644
--- a/gluten-data/src/main/java/org/apache/gluten/exec/RuntimeAware.java
+++ b/gluten-data/src/main/java/org/apache/gluten/exec/RuntimeAware.java
@@ -21,5 +21,9 @@ package org.apache.gluten.exec;
  * for further native processing.
  */
 public interface RuntimeAware {
+  default boolean isCompatibleWith(RuntimeAware other) {
+    return handle() == other.handle();
+  }
+
   long handle();
 }
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
index eb691e9e7..ddd4cf1d4 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
@@ -115,7 +115,7 @@ class GlutenBloomFilterAggregateQuerySuite
     }
   }
 
-  testGluten("Test bloom_filter_agg fallback with might_contain offloaded") {
+  testGluten("Test bloom_filter_agg agg fallback") {
     val table = "bloom_filter_test"
     val numEstimatedItems = 5000000L
     val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits
diff --git 
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
 
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
index 89302fa37..0c75db830 100644
--- 
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
+++ 
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
@@ -68,7 +68,7 @@ class GlutenBloomFilterAggregateQuerySuite
       Row(null))
   }
 
-  testGluten("Test bloom_filter_agg fallback") {
+  testGluten("Test bloom_filter_agg filter fallback") {
     val table = "bloom_filter_test"
     val numEstimatedItems = 5000000L
     val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits
@@ -113,4 +113,37 @@ class GlutenBloomFilterAggregateQuerySuite
       }
     }
   }
+
+  testGluten("Test bloom_filter_agg agg fallback") {
+    val table = "bloom_filter_test"
+    val numEstimatedItems = 5000000L
+    val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits
+    val sqlString = s"""
+                       |SELECT col positive_membership_test
+                       |FROM $table
+                       |WHERE might_contain(
+                       |            (SELECT bloom_filter_agg(col,
+                       |              cast($numEstimatedItems as long),
+                       |              cast($numBits as long))
+                       |             FROM $table), col)
+                      """.stripMargin
+
+    withTempView(table) {
+      (Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
+        .toDF("col")
+        .createOrReplaceTempView(table)
+      withSQLConf(
+        GlutenConfig.COLUMNAR_HASHAGG_ENABLED.key -> "false"
+      ) {
+        val df = spark.sql(sqlString)
+        df.collect
+        assert(
+          collectWithSubqueries(df.queryExecution.executedPlan) {
+            case h if h.isInstanceOf[HashAggregateExecBaseTransformer] => h
+          }.isEmpty,
+          df.queryExecution.executedPlan
+        )
+      }
+    }
+  }
 }
diff --git 
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
 
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
index 730ce6062..ddd4cf1d4 100644
--- 
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
+++ 
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
@@ -69,7 +69,7 @@ class GlutenBloomFilterAggregateQuerySuite
       Row(null))
   }
 
-  testGluten("Test bloom_filter_agg fallback") {
+  testGluten("Test bloom_filter_agg filter fallback") {
     val table = "bloom_filter_test"
     val numEstimatedItems = 5000000L
     val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits
@@ -115,7 +115,7 @@ class GlutenBloomFilterAggregateQuerySuite
     }
   }
 
-  testGluten("Test bloom_filter_agg fallback with might_contain offloaded") {
+  testGluten("Test bloom_filter_agg agg fallback") {
     val table = "bloom_filter_test"
     val numEstimatedItems = 5000000L
     val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to