This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4339e0c0e4d [SPARK-45707][SQL] Simplify `DataFrameStatFunctions.countMinSketch` with `CountMinSketchAgg` 4339e0c0e4d is described below commit 4339e0c0e4d7e502ae6cafa90444cd153017cb1a Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sun Oct 29 22:42:01 2023 -0700 [SPARK-45707][SQL] Simplify `DataFrameStatFunctions.countMinSketch` with `CountMinSketchAgg` ### What changes were proposed in this pull request? Simplify `DataFrameStatFunctions.countMinSketch` with `CountMinSketchAgg` ### Why are the changes needed? to make it consistent with sql functions ### Does this PR introduce _any_ user-facing change? better error messages: `IllegalArgumentException` -> `AnalysisException` ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #43560 from zhengruifeng/sql_reimpl_stat_countMinSketch. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../apache/spark/sql/DataFrameStatFunctions.scala | 44 +++++++--------------- .../org/apache/spark/sql/DataFrameStatSuite.scala | 2 +- 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index de3b100cd6a..f3690773f6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -22,9 +22,8 @@ import java.{lang => jl, util => ju} import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, CountMinSketchAgg} import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ @@ -483,7 +482,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @since 2.0.0 */ def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { - countMinSketch(col, CountMinSketch.create(depth, width, seed)) + val eps = 2.0 / width + val confidence = 1 - 1 / Math.pow(2, depth) + countMinSketch(col, eps, confidence, seed) } /** @@ -497,35 +498,16 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @since 2.0.0 */ def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { - countMinSketch(col, CountMinSketch.create(eps, confidence, seed)) - } - - private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = { - val singleCol = df.select(col) - val colType = singleCol.schema.head.dataType - - val updater: (CountMinSketch, InternalRow) => Unit = colType match { - // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary` - // instead of `addString` to avoid unnecessary conversion. - case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes) - case ByteType => (sketch, row) => sketch.addLong(row.getByte(0)) - case ShortType => (sketch, row) => sketch.addLong(row.getShort(0)) - case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0)) - case LongType => (sketch, row) => sketch.addLong(row.getLong(0)) - case _ => - throw new IllegalArgumentException( - s"Count-min Sketch only supports string type and integral types, " + - s"and does not support type $colType." - ) - } - - singleCol.queryExecution.toRdd.aggregate(zero)( - (sketch: CountMinSketch, row: InternalRow) => { - updater(sketch, row) - sketch - }, - (sketch1, sketch2) => sketch1.mergeInPlace(sketch2) + val countMinSketchAgg = new CountMinSketchAgg( + col.expr, + Literal(eps, DoubleType), + Literal(confidence, DoubleType), + Literal(seed, IntegerType) ) + val bytes = df.select( + Column(countMinSketchAgg.toAggregateExpression(false)) + ).head().getAs[Array[Byte]](0) + countMinSketchAgg.deserialize(bytes) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 1dece5c8285..430e3622102 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -508,7 +508,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { assert(sketch4.relativeError() === 0.001 +- 1e04) assert(sketch4.confidence() === 0.99 +- 5e-3) - intercept[IllegalArgumentException] { + intercept[AnalysisException] { df.select($"id" cast DoubleType as "id") .stat .countMinSketch($"id", depth = 10, width = 20, seed = 42) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org