This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 922844fff65 [SPARK-45564][SQL] Simplify
'DataFrameStatFunctions.bloomFilter' with 'BloomFilterAggregate' expression
922844fff65 is described below
commit 922844fff65ac38fd93bd0c914dcc7e5cf879996
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Oct 17 10:11:36 2023 -0500
[SPARK-45564][SQL] Simplify 'DataFrameStatFunctions.bloomFilter' with
'BloomFilterAggregate' expression
### What changes were proposed in this pull request?
Simplify 'DataFrameStatFunctions.bloomFilter' function with
'BloomFilterAggregate' expression
### Why are the changes needed?
existing implementation was based on RDD, and it can be simplified by
dataframe operations
### Does this PR introduce _any_ user-facing change?
when the input parameters or datatypes are invalid, throw
`AnalysisException` instead of `IllegalArgumentException`
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #43391 from zhengruifeng/sql_reimpl_stat_bloomFilter.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
---
.../apache/spark/sql/DataFrameStatFunctions.scala | 68 +++++-----------------
1 file changed, 14 insertions(+), 54 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 9d4f83c53a3..de3b100cd6a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -23,6 +23,8 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
@@ -535,7 +537,7 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* @since 2.0.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double):
BloomFilter = {
- buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp)
+ bloomFilter(Column(colName), expectedNumItems, fpp)
}
/**
@@ -547,7 +549,8 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* @since 2.0.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double):
BloomFilter = {
- buildBloomFilter(col, expectedNumItems, -1L, fpp)
+ val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
+ bloomFilter(col, expectedNumItems, numBits)
}
/**
@@ -559,7 +562,7 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* @since 2.0.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long):
BloomFilter = {
- buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN)
+ bloomFilter(Column(colName), expectedNumItems, numBits)
}
/**
@@ -571,57 +574,14 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* @since 2.0.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long):
BloomFilter = {
- buildBloomFilter(col, expectedNumItems, numBits, Double.NaN)
- }
-
- private def buildBloomFilter(col: Column, expectedNumItems: Long,
- numBits: Long,
- fpp: Double): BloomFilter = {
- val singleCol = df.select(col)
- val colType = singleCol.schema.head.dataType
-
- require(colType == StringType || colType.isInstanceOf[IntegralType],
- s"Bloom filter only supports string type and integral types, but got
$colType.")
-
- val updater: (BloomFilter, InternalRow) => Unit = colType match {
- // For string type, we can get bytes of our `UTF8String` directly, and
call the `putBinary`
- // instead of `putString` to avoid unnecessary conversion.
- case StringType => (filter, row) =>
filter.putBinary(row.getUTF8String(0).getBytes)
- case ByteType => (filter, row) => filter.putLong(row.getByte(0))
- case ShortType => (filter, row) => filter.putLong(row.getShort(0))
- case IntegerType => (filter, row) => filter.putLong(row.getInt(0))
- case LongType => (filter, row) => filter.putLong(row.getLong(0))
- case _ =>
- throw new IllegalArgumentException(
- s"Bloom filter only supports string type and integral types, " +
- s"and does not support type $colType."
- )
- }
-
-
singleCol.queryExecution.toRdd.treeAggregate(null.asInstanceOf[BloomFilter])(
- (filter: BloomFilter, row: InternalRow) => {
- val theFilter =
- if (filter == null) {
- if (fpp.isNaN) {
- BloomFilter.create(expectedNumItems, numBits)
- } else {
- BloomFilter.create(expectedNumItems, fpp)
- }
- } else {
- filter
- }
- updater(theFilter, row)
- theFilter
- },
- (filter1, filter2) => {
- if (filter1 == null) {
- filter2
- } else if (filter2 == null) {
- filter1
- } else {
- filter1.mergeInPlace(filter2)
- }
- }
+ val bloomFilterAgg = new BloomFilterAggregate(
+ col.expr,
+ Literal(expectedNumItems, LongType),
+ Literal(numBits, LongType)
)
+ val bytes = df.select(
+ Column(bloomFilterAgg.toAggregateExpression(false))
+ ).head().getAs[Array[Byte]](0)
+ bloomFilterAgg.deserialize(bytes)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]