This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new a4bd7583ce6 [SPARK-42664][CONNECT] Support `bloomFilter` function for `DataFrameStatFunctions` a4bd7583ce6 is described below commit a4bd7583ce6d680f0091519007e48894d594b9f6 Author: yangjie01 <yangji...@baidu.com> AuthorDate: Tue Aug 15 19:12:03 2023 +0200 [SPARK-42664][CONNECT] Support `bloomFilter` function for `DataFrameStatFunctions` ### What changes were proposed in this pull request? This is pr using `BloomFilterAggregate` to implement `bloomFilter` function for `DataFrameStatFunctions`. ### Why are the changes needed? Add Spark connect jvm client api coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Add new test - Manually check Scala 2.13 Closes #42414 from LuciferYang/SPARK-42664-backup. Authored-by: yangjie01 <yangji...@baidu.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit b9f11143d058ad05dcda2138133471c9500c8b92) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../org/apache/spark/util/sketch/BloomFilter.java | 4 +- .../apache/spark/sql/DataFrameStatFunctions.scala | 88 +++++++++++++++++++++- .../spark/sql/ClientDataFrameStatSuite.scala | 87 +++++++++++++++++++++ .../CheckConnectJvmClientCompatibility.scala | 3 - .../sql/connect/planner/SparkConnectPlanner.scala | 31 ++++++++ .../aggregate/BloomFilterAggregate.scala | 43 ++++++++++- 6 files changed, 248 insertions(+), 8 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 5c01841e501..f3c2b05e7af 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -199,9 +199,9 @@ public abstract class BloomFilter { * See http://en.wikipedia.org/wiki/Bloom_filter#Probability_of_false_positives for the formula. * * @param n expected insertions (must be positive) - * @param p false positive rate (must be 0 < p < 1) + * @param p false positive rate (must be 0 < p < 1) */ - private static long optimalNumOfBits(long n, double p) { + public static long optimalNumOfBits(long n, double p) { return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 0d4372b8738..4d35b4e8767 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql import java.{lang => jl, util => ju} +import java.io.ByteArrayInputStream import scala.collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.connect.proto.{Relation, StatSampleBy} import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder} import org.apache.spark.sql.functions.lit -import org.apache.spark.util.sketch.CountMinSketch +import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * Statistic functions for `DataFrame`s. @@ -584,6 +586,90 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo } CountMinSketch.readFrom(ds.head()) } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName + * name of the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param fpp + * expected false positive probability of the filter. + * @since 3.5.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col + * the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param fpp + * expected false positive probability of the filter. + * @since 3.5.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(col, expectedNumItems, -1L, fpp) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName + * name of the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param numBits + * expected number of bits of the filter. + * @since 3.5.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col + * the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param numBits + * expected number of bits of the filter. + * @since 3.5.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(col, expectedNumItems, numBits, Double.NaN) + } + + private def buildBloomFilter( + col: Column, + expectedNumItems: Long, + numBits: Long, + fpp: Double): BloomFilter = { + def numBitsValue: Long = if (!fpp.isNaN) { + BloomFilter.optimalNumOfBits(expectedNumItems, fpp) + } else { + numBits + } + + if (fpp <= 0d || fpp >= 1d) { + throw new SparkException("False positive probability must be within range (0.0, 1.0)") + } + val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBitsValue)) + + val ds = sparkSession.newDataset(BinaryEncoder) { builder => + builder.getProjectBuilder + .setInput(root) + .addExpressions(agg.expr) + } + BloomFilter.readFrom(new ByteArrayInputStream(ds.head())) + } } private object DataFrameStatFunctions { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala index 62ff21332f1..7035dc99148 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala @@ -176,4 +176,91 @@ class ClientDataFrameStatSuite extends RemoteSparkSession { assert(sketch.relativeError() === 0.001) assert(sketch.confidence() === 0.99 +- 5e-3) } + + test("Bloom filter -- Long Column") { + val session = spark + import session.implicits._ + val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toLong) + val df = data.toDF("id") + val negativeValues = Seq(-11, 1021, 32767).map(_.toLong) + checkBloomFilter(data, negativeValues, df) + } + + test("Bloom filter -- Int Column") { + val session = spark + import session.implicits._ + val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997) + val df = data.toDF("id") + val negativeValues = Seq(-11, 1021, 32767) + checkBloomFilter(data, negativeValues, df) + } + + test("Bloom filter -- Short Column") { + val session = spark + import session.implicits._ + val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toShort) + val df = data.toDF("id") + val negativeValues = Seq(-11, 1021, 32767).map(_.toShort) + checkBloomFilter(data, negativeValues, df) + } + + test("Bloom filter -- Byte Column") { + val session = spark + import session.implicits._ + val data = Seq(-32, -5, 1, 17, 39, 43, 101, 127).map(_.toByte) + val df = data.toDF("id") + val negativeValues = Seq(-101, 55, 113).map(_.toByte) + checkBloomFilter(data, negativeValues, df) + } + + test("Bloom filter -- String Column") { + val session = spark + import session.implicits._ + val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toString) + val df = data.toDF("id") + val negativeValues = Seq(-11, 1021, 32767).map(_.toString) + checkBloomFilter(data, negativeValues, df) + } + + private def checkBloomFilter( + data: Seq[Any], + notContainValues: Seq[Any], + df: DataFrame): Unit = { + val filter1 = df.stat.bloomFilter("id", 1000, 0.03) + assert(filter1.expectedFpp() - 0.03 < 1e-3) + assert(data.forall(filter1.mightContain)) + assert(notContainValues.forall(n => !filter1.mightContain(n))) + val filter2 = df.stat.bloomFilter("id", 1000, 64 * 5) + assert(filter2.bitSize() == 64 * 5) + assert(data.forall(filter2.mightContain)) + assert(notContainValues.forall(n => !filter2.mightContain(n))) + } + + test("Bloom filter -- Wrong dataType Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toDouble) + val message = intercept[AnalysisException] { + data.toDF("id").stat.bloomFilter("id", 1000, 0.03) + }.getMessage + assert(message.contains("DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE")) + } + + test("Bloom filter test invalid inputs") { + val df = spark.range(1000).toDF("id") + val message1 = intercept[SparkException] { + df.stat.bloomFilter("id", -1000, 100) + }.getMessage + assert(message1.contains("Expected insertions must be positive")) + + val message2 = intercept[SparkException] { + df.stat.bloomFilter("id", 1000, -100) + }.getMessage + assert(message2.contains("Number of bits must be positive")) + + val message3 = intercept[SparkException] { + df.stat.bloomFilter("id", 1000, -1.0) + }.getMessage + assert(message3.contains("False positive probability must be within range (0.0, 1.0)")) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 7356d4daa79..8f226eb2f7e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -163,9 +163,6 @@ object CheckConnectJvmClientCompatibility { // DataFrameNaFunctions ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.fillValue"), - // DataFrameStatFunctions - ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"), - // Dataset ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.Dataset$" // private[sql] diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index dc77c52ef46..f3e87b7067d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -48,6 +48,7 @@ import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, Mu import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical @@ -1731,6 +1732,36 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { val ignoreNulls = extractBoolean(children(3), "ignoreNulls") Some(Lead(children.head, children(1), children(2), ignoreNulls)) + case "bloom_filter_agg" if fun.getArgumentsCount == 3 => + // [col, expectedNumItems: Long, numBits: Long] + val children = fun.getArgumentsList.asScala.map(transformExpression) + + // Check expectedNumItems is LongType and value greater than 0L + val expectedNumItemsExpr = children(1) + val expectedNumItems = expectedNumItemsExpr match { + case Literal(l: Long, LongType) => l + case _ => + throw InvalidPlanInput("Expected insertions must be long literal.") + } + if (expectedNumItems <= 0L) { + throw InvalidPlanInput("Expected insertions must be positive.") + } + + val numBitsExpr = children(2) + // Check numBits is LongType and value greater than 0L + numBitsExpr match { + case Literal(numBits: Long, LongType) => + if (numBits <= 0L) { + throw InvalidPlanInput("Number of bits must be positive.") + } + case _ => + throw InvalidPlanInput("Number of bits must be long literal.") + } + + Some( + new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr) + .toAggregateExpression()) + case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) => val children = fun.getArgumentsList.asScala.map(transformExpression) val timeCol = children.head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index 980785e764c..7cba462ce2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.trees.TernaryLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{RUNTIME_BLOOM_FILTER_MAX_NUM_BITS, RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.sketch.BloomFilter /** @@ -78,7 +79,7 @@ case class BloomFilterAggregate( "exprName" -> "estimatedNumItems or numBits" ) ) - case (LongType, LongType, LongType) => + case (LongType | IntegerType | ShortType | ByteType | StringType, LongType, LongType) => if (!estimatedNumItemsExpression.foldable) { DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", @@ -150,6 +151,15 @@ case class BloomFilterAggregate( Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue, SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) + // Mark as lazy so that `updater` is not evaluated during tree transformation. + private lazy val updater: BloomFilterUpdater = child.dataType match { + case LongType => LongUpdater + case IntegerType => IntUpdater + case ShortType => ShortUpdater + case ByteType => ByteUpdater + case StringType => BinaryUpdater + } + override def first: Expression = child override def second: Expression = estimatedNumItemsExpression @@ -174,7 +184,7 @@ case class BloomFilterAggregate( if (value == null) { return buffer } - buffer.putLong(value.asInstanceOf[Long]) + updater.update(buffer, value) buffer } @@ -224,3 +234,32 @@ object BloomFilterAggregate { bloomFilter } } + +private trait BloomFilterUpdater { + def update(bf: BloomFilter, v: Any): Boolean +} + +private object LongUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putLong(v.asInstanceOf[Long]) +} + +private object IntUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putLong(v.asInstanceOf[Int]) +} + +private object ShortUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putLong(v.asInstanceOf[Short]) +} + +private object ByteUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putLong(v.asInstanceOf[Byte]) +} + +private object BinaryUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putBinary(v.asInstanceOf[UTF8String].getBytes) +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org