RyanBerti commented on code in PR #40615:
URL: https://github.com/apache/spark/pull/40615#discussion_r1162128291
##########
sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala:
##########
@@ -1545,6 +1545,158 @@ class DataFrameAggregateSuite extends QueryTest
)
checkAnswer(res, Row(Array(1), Array(1)))
}
+
+ test("SPARK-16484: hllsketch_estimate positive tests") {
+ val df1 = Seq(
+ (1, "a"), (1, "a"), (1, "a"),
+ (1, "b"),
+ (1, "c"), (1, "c"),
+ (1, "d")
+ ).toDF("id", "value")
+ df1.createOrReplaceTempView("df1")
+
+ val df2 = Seq(
+ (1, "a"),
+ (1, "c"),
+ (1, "d"), (1, "d"), (1, "d"),
+ (1, "e"), (1, "e"),
+ (1, "f")
+ ).toDF("id", "value")
+ df2.createOrReplaceTempView("df2")
+
+ // empty column test
+ val res0 = df1.withColumn("empty_col", lit(null)).groupBy("id")
+ .agg(
+ hllsketch_estimate("empty_col").as("distinct_count")
+ )
+ checkAnswer(res0, Row(1, 0))
+
+ // first test hllsketch_estimate via dataframe + sql, with and without
configs
+ val res1 = df1.groupBy("id")
+ .agg(
+ count("value").as("count"),
+ hllsketch_estimate("value").as("distinct_count_1"),
+ hllsketch_estimate("value", 20, "HLL_8").as("distinct_count_2")
+ )
+ checkAnswer(res1, Row(1, 7, 4, 4))
+
+ val res2 = sql(
+ """select
+ | id,
+ | count(value) as count,
+ | hllsketch_estimate(value) as distinct_count_1,
+ | hllsketch_estimate(value, 20, 'HLL_8') as distinct_count_2
+ |from df1
+ |group by 1
+ |""".stripMargin)
+ checkAnswer(res2, Row(1, 7, 4, 4))
+
+ // next test hllsketch_binary via dataframe + sql, with and without configs
+ val df3 = df1.groupBy("id")
+ .agg(
+ count("value").as("count"),
+ hllsketch_binary("value").as("hllsketch_1"),
+ hllsketch_binary("value", 20, "HLL_8").as("hllsketch_2"),
+ hllsketch_binary("value").as("hllsketch_3")
+ )
+ df3.createOrReplaceTempView("df3")
+
+ // now test hllsketch_union_estimate via dataframe + sql, with and without
configs,
+ // unioning together sketches with default, non-default and different
configurations
+ val df4 = sql(
+ """select
+ | id,
+ | count(value),
+ | hllsketch_binary(value) as hllsketch_1,
+ | hllsketch_binary(value, 20, 'HLL_8') as hllsketch_2,
+ | hllsketch_binary(value, 20, 'HLL_8') as hllsketch_3
+ |from df2
+ |group by 1
+ |""".stripMargin)
+ df4.createOrReplaceTempView("df4")
+
+ val res3 = df3.union(df4).groupBy("id")
+ .agg(
+ sum("count").as("count"),
+ hllsketch_union_estimate("hllsketch_1", 20).as("distinct_count_1"),
+ hllsketch_union_estimate("hllsketch_2").as("distinct_count_2"),
+ hllsketch_union_estimate("hllsketch_3").as("distinct_count_3")
+ )
+ checkAnswer(res3, Row(1, 15, 6, 6, 6))
+
+ val res4 = sql(
+ """select
+ | id,
+ | sum(count) as count,
+ | hllsketch_union_estimate(hllsketch_1, 20) as distinct_count_1,
+ | hllsketch_union_estimate(hllsketch_2) as distinct_count_2,
+ | hllsketch_union_estimate(hllsketch_3) as distinct_count_3
+ |from (select * from df3 union all select * from df4)
+ |group by 1
+ |""".stripMargin)
+ checkAnswer(res4, Row(1, 15, 6, 6, 6))
+ }
+
+ test("SPARK-16484: hllsketch_estimate negative tests") {
+
+ val df1 = Seq(
+ (1, "a"), (1, "a"), (1, "a"),
+ (1, "b"),
+ (1, "c"), (1, "c"),
+ (1, "d")
+ ).toDF("id", "value")
+
+ // validate that the functions error out when lgConfigK < 0
+ val error0 = intercept[AnalysisException] {
+ val res = df1.groupBy("id")
+ .agg(
+ hllsketch_estimate("value", -1, "HLL_4").as("hllsketch")
+ )
+ checkAnswer(res, Nil)
+ }
+ assert(error0.toString contains "DATATYPE_MISMATCH")
+
+ val error1 = intercept[AnalysisException] {
+ val res = df1.groupBy("id")
+ .agg(
+ hllsketch_binary("value", -1, "HLL_4").as("hllsketch")
+ )
+ checkAnswer(res, Nil)
+ }
+ assert(error1.toString contains "DATATYPE_MISMATCH")
+
+ val error2 = intercept[AnalysisException] {
+ val res = df1.groupBy("id")
+ .agg(
+ hllsketch_binary("value").as("hllsketch")
+ )
+ .agg(
+ hllsketch_union_estimate("hllsketch", -1)
+ )
+ checkAnswer(res, Nil)
+ }
+ assert(error2.toString contains "DATATYPE_MISMATCH")
+
+ // validate that the functions error out with unsupported tgtHllType
+ val error3 = intercept[SparkException] {
+ val res = df1.groupBy("id")
+ .agg(
+ hllsketch_estimate("value", 12, "HLL_5").as("hllsketch")
+ )
+ checkAnswer(res, Nil)
+ }
+ assert(error3.toString contains "IllegalArgumentException")
Review Comment:
I've removed the explicit checks in the aggregate function and instead am
relying on the HllSketch/Union arg checks/exceptions.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]