This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4339e0c0e4d [SPARK-45707][SQL] Simplify
`DataFrameStatFunctions.countMinSketch` with `CountMinSketchAgg`
4339e0c0e4d is described below
commit 4339e0c0e4d7e502ae6cafa90444cd153017cb1a
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun Oct 29 22:42:01 2023 -0700
[SPARK-45707][SQL] Simplify `DataFrameStatFunctions.countMinSketch` with
`CountMinSketchAgg`
### What changes were proposed in this pull request?
Simplify `DataFrameStatFunctions.countMinSketch` with `CountMinSketchAgg`
### Why are the changes needed?
to make it consistent with sql functions
### Does this PR introduce _any_ user-facing change?
better error messages: `IllegalArgumentException` -> `AnalysisException`
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #43560 from zhengruifeng/sql_reimpl_stat_countMinSketch.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../apache/spark/sql/DataFrameStatFunctions.scala | 44 +++++++---------------
.../org/apache/spark/sql/DataFrameStatSuite.scala | 2 +-
2 files changed, 14 insertions(+), 32 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index de3b100cd6a..f3690773f6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -22,9 +22,8 @@ import java.{lang => jl, util => ju}
import scala.jdk.CollectionConverters._
import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate,
CountMinSketchAgg}
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
@@ -483,7 +482,9 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* @since 2.0.0
*/
def countMinSketch(col: Column, depth: Int, width: Int, seed: Int):
CountMinSketch = {
- countMinSketch(col, CountMinSketch.create(depth, width, seed))
+ val eps = 2.0 / width
+ val confidence = 1 - 1 / Math.pow(2, depth)
+ countMinSketch(col, eps, confidence, seed)
}
/**
@@ -497,35 +498,16 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* @since 2.0.0
*/
def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int):
CountMinSketch = {
- countMinSketch(col, CountMinSketch.create(eps, confidence, seed))
- }
-
- private def countMinSketch(col: Column, zero: CountMinSketch):
CountMinSketch = {
- val singleCol = df.select(col)
- val colType = singleCol.schema.head.dataType
-
- val updater: (CountMinSketch, InternalRow) => Unit = colType match {
- // For string type, we can get bytes of our `UTF8String` directly, and
call the `addBinary`
- // instead of `addString` to avoid unnecessary conversion.
- case StringType => (sketch, row) =>
sketch.addBinary(row.getUTF8String(0).getBytes)
- case ByteType => (sketch, row) => sketch.addLong(row.getByte(0))
- case ShortType => (sketch, row) => sketch.addLong(row.getShort(0))
- case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0))
- case LongType => (sketch, row) => sketch.addLong(row.getLong(0))
- case _ =>
- throw new IllegalArgumentException(
- s"Count-min Sketch only supports string type and integral types, " +
- s"and does not support type $colType."
- )
- }
-
- singleCol.queryExecution.toRdd.aggregate(zero)(
- (sketch: CountMinSketch, row: InternalRow) => {
- updater(sketch, row)
- sketch
- },
- (sketch1, sketch2) => sketch1.mergeInPlace(sketch2)
+ val countMinSketchAgg = new CountMinSketchAgg(
+ col.expr,
+ Literal(eps, DoubleType),
+ Literal(confidence, DoubleType),
+ Literal(seed, IntegerType)
)
+ val bytes = df.select(
+ Column(countMinSketchAgg.toAggregateExpression(false))
+ ).head().getAs[Array[Byte]](0)
+ countMinSketchAgg.deserialize(bytes)
}
/**
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 1dece5c8285..430e3622102 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -508,7 +508,7 @@ class DataFrameStatSuite extends QueryTest with
SharedSparkSession {
assert(sketch4.relativeError() === 0.001 +- 1e04)
assert(sketch4.confidence() === 0.99 +- 5e-3)
- intercept[IllegalArgumentException] {
+ intercept[AnalysisException] {
df.select($"id" cast DoubleType as "id")
.stat
.countMinSketch($"id", depth = 10, width = 20, seed = 42)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]