This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new a7806e159 [VL] Support fallback processing of velox_bloom_filter_agg
(#5477)
a7806e159 is described below
commit a7806e159f75315e6b8127fce10caaac6a21d25e
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Apr 23 10:01:46 2024 +0800
[VL] Support fallback processing of velox_bloom_filter_agg (#5477)
---
.../apache/spark/util/sketch/VeloxBloomFilter.java | 38 +++++-
.../util/sketch/VeloxBloomFilterJniWrapper.java | 8 ++
.../BloomFilterMightContainJointRewriteRule.scala | 5 +-
.../aggregate/VeloxBloomFilterAggregate.scala | 63 +++++++---
.../spark/util/sketch/VeloxBloomFilterTest.java | 136 ++++++++++++++++++++-
cpp/velox/jni/VeloxJniWrapper.cc | 73 ++++++++++-
.../java/org/apache/gluten/exec/RuntimeAware.java | 4 +
.../sql/GlutenBloomFilterAggregateQuerySuite.scala | 2 +-
.../sql/GlutenBloomFilterAggregateQuerySuite.scala | 35 +++++-
.../sql/GlutenBloomFilterAggregateQuerySuite.scala | 4 +-
10 files changed, 339 insertions(+), 29 deletions(-)
diff --git
a/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilter.java
b/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilter.java
index 71bbbb051..59716ed79 100644
---
a/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilter.java
+++
b/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilter.java
@@ -19,6 +19,7 @@ package org.apache.spark.util.sketch;
import org.apache.commons.io.IOUtils;
import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@@ -33,6 +34,15 @@ public class VeloxBloomFilter extends BloomFilter {
handle = jni.init(data);
}
+ private VeloxBloomFilter(int capacity) {
+ jni = VeloxBloomFilterJniWrapper.create();
+ handle = jni.empty(capacity);
+ }
+
+ public static VeloxBloomFilter empty(int capacity) {
+ return new VeloxBloomFilter(capacity);
+ }
+
public static VeloxBloomFilter readFrom(InputStream in) {
try {
byte[] all = IOUtils.toByteArray(in);
@@ -50,6 +60,15 @@ public class VeloxBloomFilter extends BloomFilter {
}
}
+ public byte[] serialize() {
+ try (ByteArrayOutputStream o = new ByteArrayOutputStream()) {
+ writeTo(o);
+ return o.toByteArray();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
@Override
public double expectedFpp() {
throw new UnsupportedOperationException("Not yet implemented");
@@ -72,7 +91,8 @@ public class VeloxBloomFilter extends BloomFilter {
@Override
public boolean putLong(long item) {
- throw new UnsupportedOperationException("Not yet implemented");
+ jni.insertLong(handle, item);
+ return true;
}
@Override
@@ -87,7 +107,18 @@ public class VeloxBloomFilter extends BloomFilter {
@Override
public BloomFilter mergeInPlace(BloomFilter other) throws
IncompatibleMergeException {
- throw new UnsupportedOperationException("Not yet implemented");
+ if (!(other instanceof VeloxBloomFilter)) {
+ throw new IncompatibleMergeException(
+ "Cannot merge Velox bloom-filter with non-Velox bloom-filter");
+ }
+ final VeloxBloomFilter from = (VeloxBloomFilter) other;
+
+ if (!jni.isCompatibleWith(from.jni)) {
+ throw new IncompatibleMergeException(
+ "Cannot merge Velox bloom-filters with different Velox runtimes");
+ }
+ jni.mergeFrom(handle, from.handle);
+ return this;
}
@Override
@@ -117,6 +148,7 @@ public class VeloxBloomFilter extends BloomFilter {
@Override
public void writeTo(OutputStream out) throws IOException {
- throw new UnsupportedOperationException("Not yet implemented");
+ byte[] data = jni.serialize(handle);
+ out.write(data);
}
}
diff --git
a/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilterJniWrapper.java
b/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilterJniWrapper.java
index a369c8a30..572e2c7ac 100644
---
a/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilterJniWrapper.java
+++
b/backends-velox/src/main/java/org/apache/spark/util/sketch/VeloxBloomFilterJniWrapper.java
@@ -36,7 +36,15 @@ public class VeloxBloomFilterJniWrapper implements
RuntimeAware {
return runtime.getHandle();
}
+ public native long empty(int capacity);
+
public native long init(byte[] data);
+ public native void insertLong(long handle, long item);
+
public native boolean mightContainLong(long handle, long item);
+
+ public native void mergeFrom(long handle, long other);
+
+ public native byte[] serialize(long handle);
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/BloomFilterMightContainJointRewriteRule.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/BloomFilterMightContainJointRewriteRule.scala
index bbec3ee01..7d15e32b3 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/BloomFilterMightContainJointRewriteRule.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/BloomFilterMightContainJointRewriteRule.scala
@@ -27,10 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
case class BloomFilterMightContainJointRewriteRule(spark: SparkSession)
extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
- if (
- !(GlutenConfig.getConf.enableNativeBloomFilter &&
- GlutenConfig.getConf.enableColumnarHashAgg)
- ) {
+ if (!(GlutenConfig.getConf.enableNativeBloomFilter)) {
return plan
}
val out = plan.transformWithSubqueries {
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VeloxBloomFilterAggregate.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VeloxBloomFilterAggregate.scala
index 720261fd2..da545aa47 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VeloxBloomFilterAggregate.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VeloxBloomFilterAggregate.scala
@@ -22,8 +22,10 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.sketch.BloomFilter
+import org.apache.spark.util.TaskResources
+import org.apache.spark.util.sketch.{BloomFilter, VeloxBloomFilter}
/**
* Velox's bloom-filter implementation uses different algorithms internally
comparing to vanilla
@@ -48,6 +50,15 @@ case class VeloxBloomFilterAggregate(
override def prettyName: String = "velox_bloom_filter_agg"
+ // Mark as lazy so that `estimatedNumItems` is not evaluated during tree
transformation.
+ private lazy val estimatedNumItems: Long =
+ Math.min(
+ estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue,
+ SQLConf.get
+ .getConfString("spark.sql.optimizer.runtime.bloomFilter.maxNumItems",
"4000000")
+ .toLong
+ )
+
override def first: Expression = child
override def second: Expression = estimatedNumItemsExpression
@@ -70,26 +81,48 @@ case class VeloxBloomFilterAggregate(
numBitsExpression = newNumBitsExpression)
}
- override def createAggregationBuffer(): BloomFilter = throw new
UnsupportedOperationException()
+ override def createAggregationBuffer(): BloomFilter = {
+ if (!TaskResources.inSparkTask()) {
+ throw new UnsupportedOperationException("velox_bloom_filter_agg is not
evaluable on Driver")
+ }
+ VeloxBloomFilter.empty(Math.toIntExact(estimatedNumItems))
+ }
- override def update(buffer: BloomFilter, input: InternalRow): BloomFilter =
- throw new UnsupportedOperationException()
+ override def update(buffer: BloomFilter, input: InternalRow): BloomFilter = {
+ assert(buffer.isInstanceOf[VeloxBloomFilter])
+ val value = child.eval(input)
+ // Ignore null values.
+ if (value == null) {
+ return buffer
+ }
+ buffer.putLong(value.asInstanceOf[Long])
+ buffer
+ }
- override def merge(buffer: BloomFilter, input: BloomFilter): BloomFilter =
- throw new UnsupportedOperationException()
+ override def merge(buffer: BloomFilter, input: BloomFilter): BloomFilter = {
+ assert(buffer.isInstanceOf[VeloxBloomFilter])
+ assert(input.isInstanceOf[VeloxBloomFilter])
+ buffer.asInstanceOf[VeloxBloomFilter].mergeInPlace(input)
+ }
- override def eval(buffer: BloomFilter): Any = throw new
UnsupportedOperationException()
+ override def eval(buffer: BloomFilter): Any = {
+ assert(buffer.isInstanceOf[VeloxBloomFilter])
+ serialize(buffer)
+ }
- override def serialize(buffer: BloomFilter): Array[Byte] =
- throw new UnsupportedOperationException()
+ override def serialize(buffer: BloomFilter): Array[Byte] = {
+ assert(buffer.isInstanceOf[VeloxBloomFilter])
+ buffer.asInstanceOf[VeloxBloomFilter].serialize()
+ }
- override def deserialize(storageFormat: Array[Byte]): BloomFilter =
- throw new UnsupportedOperationException()
+ override def deserialize(bytes: Array[Byte]): BloomFilter = {
+ VeloxBloomFilter.readFrom(bytes)
+ }
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ImperativeAggregate =
- throw new UnsupportedOperationException()
+ override def withNewMutableAggBufferOffset(newOffset: Int):
VeloxBloomFilterAggregate =
+ copy(mutableAggBufferOffset = newOffset)
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int):
ImperativeAggregate =
- throw new UnsupportedOperationException()
+ override def withNewInputAggBufferOffset(newOffset: Int):
VeloxBloomFilterAggregate =
+ copy(inputAggBufferOffset = newOffset)
}
diff --git
a/backends-velox/src/test/java/org/apache/spark/util/sketch/VeloxBloomFilterTest.java
b/backends-velox/src/test/java/org/apache/spark/util/sketch/VeloxBloomFilterTest.java
index 81213fab2..ce439b864 100644
---
a/backends-velox/src/test/java/org/apache/spark/util/sketch/VeloxBloomFilterTest.java
+++
b/backends-velox/src/test/java/org/apache/spark/util/sketch/VeloxBloomFilterTest.java
@@ -25,6 +25,7 @@ import org.apache.spark.util.TaskResources$;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
+import org.junit.function.ThrowingRunnable;
import java.nio.ByteBuffer;
@@ -39,16 +40,149 @@ public class VeloxBloomFilterTest {
@Test
public void testEmpty() {
+ TaskResources$.MODULE$.runUnsafe(
+ () -> {
+ final BloomFilter filter = VeloxBloomFilter.empty(10000);
+ for (int i = 0; i < 1000; i++) {
+ Assert.assertFalse(filter.mightContainLong(i));
+ }
+ return null;
+ });
+ }
+
+ @Test
+ public void testMalformed() {
final ByteBuffer buf = ByteBuffer.allocate(5);
buf.put((byte) 1); // kBloomFilterV1
buf.putInt(0); // size
TaskResources$.MODULE$.runUnsafe(
() -> {
final BloomFilter filter = VeloxBloomFilter.readFrom(buf.array());
- for (int i = 0; i < 1000; i++) {
+ Assert.assertThrows(
+ "Bloom-filter is not initialized",
+ RuntimeException.class,
+ new ThrowingRunnable() {
+ @Override
+ public void run() throws Throwable {
+ filter.mightContainLong(0);
+ }
+ });
+ return null;
+ });
+ }
+
+ @Test
+ public void testSanity() {
+ TaskResources$.MODULE$.runUnsafe(
+ () -> {
+ final BloomFilter filter = VeloxBloomFilter.empty(10000);
+ final int numItems = 2000;
+ final int halfNumItems = numItems / 2;
+ for (int i = -halfNumItems; i < halfNumItems; i++) {
Assert.assertFalse(filter.mightContainLong(i));
}
+ for (int i = -halfNumItems; i < halfNumItems; i++) {
+ filter.putLong(i);
+ Assert.assertTrue(filter.mightContainLong(i));
+ }
+ for (int i = -halfNumItems; i < halfNumItems; i++) {
+ Assert.assertTrue(filter.mightContainLong(i));
+ }
+
+ // Check false positives.
+ checkFalsePositives(filter, halfNumItems);
+
return null;
});
}
+
+ @Test
+ public void testMerge() {
+ TaskResources$.MODULE$.runUnsafe(
+ () -> {
+ final BloomFilter filter1 = VeloxBloomFilter.empty(10000);
+ final int start1 = 0;
+ final int end1 = 2000;
+ for (int i = start1; i < end1; i++) {
+ Assert.assertFalse(filter1.mightContainLong(i));
+ }
+ for (int i = start1; i < end1; i++) {
+ filter1.putLong(i);
+ Assert.assertTrue(filter1.mightContainLong(i));
+ }
+ for (int i = start1; i < end1; i++) {
+ Assert.assertTrue(filter1.mightContainLong(i));
+ }
+
+ final BloomFilter filter2 = VeloxBloomFilter.empty(10000);
+ final int start2 = 1000;
+ final int end2 = 3000;
+ for (int i = start2; i < end2; i++) {
+ Assert.assertFalse(filter2.mightContainLong(i));
+ }
+ for (int i = start2; i < end2; i++) {
+ filter2.putLong(i);
+ Assert.assertTrue(filter2.mightContainLong(i));
+ }
+ for (int i = start2; i < end2; i++) {
+ Assert.assertTrue(filter2.mightContainLong(i));
+ }
+
+ try {
+ filter1.mergeInPlace(filter2);
+ } catch (IncompatibleMergeException e) {
+ throw new RuntimeException(e);
+ }
+
+ for (int i = start1; i < end2; i++) {
+ Assert.assertTrue(filter1.mightContainLong(i));
+ }
+
+ // Check false positives.
+ checkFalsePositives(filter1, end2);
+
+ return null;
+ });
+ }
+
+ @Test
+ public void testSerde() {
+ TaskResources$.MODULE$.runUnsafe(
+ () -> {
+ final VeloxBloomFilter filter = VeloxBloomFilter.empty(10000);
+ for (int i = 0; i < 1000; i++) {
+ filter.putLong(i);
+ }
+
+ byte[] data1 = filter.serialize();
+
+ final VeloxBloomFilter filter2 = VeloxBloomFilter.readFrom(data1);
+ byte[] data2 = filter2.serialize();
+
+ Assert.assertArrayEquals(data2, data1);
+ return null;
+ });
+ }
+
+ private static void checkFalsePositives(BloomFilter filter, int start) {
+ final int attemptStart = start;
+ final int attemptCount = 5000000;
+
+ int falsePositives = 0;
+ int negativeFalsePositives = 0;
+
+ for (int i = attemptStart; i < attemptStart + attemptCount; i++) {
+ if (filter.mightContainLong(i)) {
+ falsePositives++;
+ }
+ if (filter.mightContainLong(-i)) {
+ negativeFalsePositives++;
+ }
+ }
+
+ Assert.assertTrue(falsePositives > 0);
+ Assert.assertTrue(falsePositives < attemptCount);
+ Assert.assertTrue(negativeFalsePositives > 0);
+ Assert.assertTrue(negativeFalsePositives < attemptCount);
+ }
}
diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc
index 1d8c09fa4..a3c51f64a 100644
--- a/cpp/velox/jni/VeloxJniWrapper.cc
+++ b/cpp/velox/jni/VeloxJniWrapper.cc
@@ -141,12 +141,24 @@
Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateWithFail
JNI_METHOD_END(nullptr)
}
+JNIEXPORT jlong JNICALL
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_empty( // NOLINT
+ JNIEnv* env,
+ jobject wrapper,
+ jint capacity) {
+ JNI_METHOD_START
+ auto ctx = gluten::getRuntime(env, wrapper);
+ auto filter =
std::make_shared<velox::BloomFilter<std::allocator<uint64_t>>>();
+ filter->reset(capacity);
+ GLUTEN_CHECK(filter->isSet(), "Bloom-filter is not initialized");
+ return ctx->objectStore()->save(filter);
+ JNI_METHOD_END(gluten::kInvalidResourceHandle)
+}
+
JNIEXPORT jlong JNICALL
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_init( // NOLINT
JNIEnv* env,
jobject wrapper,
jbyteArray data) {
JNI_METHOD_START
-
auto len = env->GetArrayLength(data);
auto safeArray = gluten::getByteArrayElementsSafe(env, data);
auto ctx = gluten::getRuntime(env, wrapper);
@@ -157,6 +169,19 @@ JNIEXPORT jlong JNICALL
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWra
JNI_METHOD_END(gluten::kInvalidResourceHandle)
}
+JNIEXPORT void JNICALL
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_insertLong( //
NOLINT
+ JNIEnv* env,
+ jobject wrapper,
+ jlong handle,
+ jlong item) {
+ JNI_METHOD_START
+ auto ctx = gluten::getRuntime(env, wrapper);
+ auto filter =
ctx->objectStore()->retrieve<velox::BloomFilter<std::allocator<uint64_t>>>(handle);
+ GLUTEN_CHECK(filter->isSet(), "Bloom-filter is not initialized");
+ filter->insert(folly::hasher<int64_t>()(item));
+ JNI_METHOD_END()
+}
+
JNIEXPORT jboolean JNICALL
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_mightContainLong(
// NOLINT
JNIEnv* env,
jobject wrapper,
@@ -165,11 +190,55 @@ JNIEXPORT jboolean JNICALL
Java_org_apache_spark_util_sketch_VeloxBloomFilterJni
JNI_METHOD_START
auto ctx = gluten::getRuntime(env, wrapper);
auto filter =
ctx->objectStore()->retrieve<velox::BloomFilter<std::allocator<uint64_t>>>(handle);
- bool out = filter->isSet() &&
filter->mayContain(folly::hasher<int64_t>()(item));
+ GLUTEN_CHECK(filter->isSet(), "Bloom-filter is not initialized");
+ bool out = filter->mayContain(folly::hasher<int64_t>()(item));
return out;
JNI_METHOD_END(false)
}
+namespace {
+static std::vector<char> serialize(BloomFilter<std::allocator<uint64_t>>* bf) {
+ uint32_t size = bf->serializedSize();
+ std::vector<char> buffer;
+ buffer.reserve(size);
+ char* data = buffer.data();
+ bf->serialize(data);
+ return buffer;
+}
+} // namespace
+
+JNIEXPORT void JNICALL
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_mergeFrom( //
NOLINT
+ JNIEnv* env,
+ jobject wrapper,
+ jlong handle,
+ jlong other) {
+ JNI_METHOD_START
+ auto ctx = gluten::getRuntime(env, wrapper);
+ auto to =
ctx->objectStore()->retrieve<velox::BloomFilter<std::allocator<uint64_t>>>(handle);
+ auto from =
ctx->objectStore()->retrieve<velox::BloomFilter<std::allocator<uint64_t>>>(other);
+ GLUTEN_CHECK(to->isSet(), "Bloom-filter is not initialized");
+ GLUTEN_CHECK(from->isSet(), "Bloom-filter is not initialized");
+ std::vector<char> serialized = serialize(from.get());
+ to->merge(serialized.data());
+ JNI_METHOD_END()
+}
+
+JNIEXPORT jbyteArray JNICALL
Java_org_apache_spark_util_sketch_VeloxBloomFilterJniWrapper_serialize( //
NOLINT
+ JNIEnv* env,
+ jobject wrapper,
+ jlong handle) {
+ JNI_METHOD_START
+ auto ctx = gluten::getRuntime(env, wrapper);
+ auto filter =
ctx->objectStore()->retrieve<velox::BloomFilter<std::allocator<uint64_t>>>(handle);
+ GLUTEN_CHECK(filter->isSet(), "Bloom-filter is not initialized");
+ std::vector<char> buffer = serialize(filter.get());
+ auto size = buffer.capacity();
+ jbyteArray out = env->NewByteArray(size);
+ env->SetByteArrayRegion(out, 0, size,
reinterpret_cast<jbyte*>(buffer.data()));
+ return out;
+ JNI_METHOD_END(nullptr)
+}
+
#ifdef __cplusplus
}
#endif
diff --git a/gluten-data/src/main/java/org/apache/gluten/exec/RuntimeAware.java
b/gluten-data/src/main/java/org/apache/gluten/exec/RuntimeAware.java
index 8a3f15b40..ca96ace64 100644
--- a/gluten-data/src/main/java/org/apache/gluten/exec/RuntimeAware.java
+++ b/gluten-data/src/main/java/org/apache/gluten/exec/RuntimeAware.java
@@ -21,5 +21,9 @@ package org.apache.gluten.exec;
* for further native processing.
*/
public interface RuntimeAware {
+ default boolean isCompatibleWith(RuntimeAware other) {
+ return handle() == other.handle();
+ }
+
long handle();
}
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
index eb691e9e7..ddd4cf1d4 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
@@ -115,7 +115,7 @@ class GlutenBloomFilterAggregateQuerySuite
}
}
- testGluten("Test bloom_filter_agg fallback with might_contain offloaded") {
+ testGluten("Test bloom_filter_agg agg fallback") {
val table = "bloom_filter_test"
val numEstimatedItems = 5000000L
val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits
diff --git
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
index 89302fa37..0c75db830 100644
---
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
+++
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
@@ -68,7 +68,7 @@ class GlutenBloomFilterAggregateQuerySuite
Row(null))
}
- testGluten("Test bloom_filter_agg fallback") {
+ testGluten("Test bloom_filter_agg filter fallback") {
val table = "bloom_filter_test"
val numEstimatedItems = 5000000L
val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits
@@ -113,4 +113,37 @@ class GlutenBloomFilterAggregateQuerySuite
}
}
}
+
+ testGluten("Test bloom_filter_agg agg fallback") {
+ val table = "bloom_filter_test"
+ val numEstimatedItems = 5000000L
+ val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits
+ val sqlString = s"""
+ |SELECT col positive_membership_test
+ |FROM $table
+ |WHERE might_contain(
+ | (SELECT bloom_filter_agg(col,
+ | cast($numEstimatedItems as long),
+ | cast($numBits as long))
+ | FROM $table), col)
+ """.stripMargin
+
+ withTempView(table) {
+ (Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
+ .toDF("col")
+ .createOrReplaceTempView(table)
+ withSQLConf(
+ GlutenConfig.COLUMNAR_HASHAGG_ENABLED.key -> "false"
+ ) {
+ val df = spark.sql(sqlString)
+ df.collect
+ assert(
+ collectWithSubqueries(df.queryExecution.executedPlan) {
+ case h if h.isInstanceOf[HashAggregateExecBaseTransformer] => h
+ }.isEmpty,
+ df.queryExecution.executedPlan
+ )
+ }
+ }
+ }
}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
index 730ce6062..ddd4cf1d4 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala
@@ -69,7 +69,7 @@ class GlutenBloomFilterAggregateQuerySuite
Row(null))
}
- testGluten("Test bloom_filter_agg fallback") {
+ testGluten("Test bloom_filter_agg filter fallback") {
val table = "bloom_filter_test"
val numEstimatedItems = 5000000L
val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits
@@ -115,7 +115,7 @@ class GlutenBloomFilterAggregateQuerySuite
}
}
- testGluten("Test bloom_filter_agg fallback with might_contain offloaded") {
+ testGluten("Test bloom_filter_agg agg fallback") {
val table = "bloom_filter_test"
val numEstimatedItems = 5000000L
val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]