dongjoon-hyun commented on code in PR #51505:
URL: https://github.com/apache/spark/pull/51505#discussion_r2462102987
##########
sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala:
##########
@@ -328,4 +331,326 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
parameters = Map("maxItemsTracked" -> "5", "k" -> "10")
)
}
+
+ /////////////////////////////////
+ // approx_top_k_combine
+ /////////////////////////////////
+
+ def setupMixedSizeAccumulations(size1: Int, size2: Int): Unit = {
+ sql(s"SELECT approx_top_k_accumulate(expr, $size1) as acc " +
+ "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);")
+ .createOrReplaceTempView("accumulation1")
+
+ sql(s"SELECT approx_top_k_accumulate(expr, $size2) as acc " +
+ "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);")
+ .createOrReplaceTempView("accumulation2")
+
+ sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2")
+ .createOrReplaceTempView("unioned")
+ }
+
+ def setupMixedTypeAccumulation(seq1: Seq[Any], seq2: Seq[Any]): Unit = {
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ s"FROM VALUES ${seq1.mkString(", ")} AS tab(expr);")
+ .createOrReplaceTempView("accumulation1")
+
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ s"FROM VALUES ${seq2.mkString(", ")} AS tab(expr);")
+ .createOrReplaceTempView("accumulation2")
+
+ sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2")
+ .createOrReplaceTempView("unioned")
+ }
+
+ val mixedNumberTypes: Seq[(DataType, String, Seq[Any])] = Seq(
+ (IntegerType, "INT",
+ Seq(0, 0, 0, 1, 1, 2, 2, 3)),
+ (ByteType, "TINYINT",
+ Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)")),
+ (ShortType, "SMALLINT",
+ Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)")),
+ (LongType, "BIGINT",
+ Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)")),
+ (FloatType, "FLOAT",
+ Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)")),
+ (DoubleType, "DOUBLE",
+ Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)")),
+ (DecimalType(4, 2), "DECIMAL(4,2)",
+ Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", "cast(1 AS
DECIMAL(4, 2))")),
+ (DecimalType(10, 2), "DECIMAL(10,2)",
+ Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", "cast(1 AS
DECIMAL(10, 2))")),
+ (DecimalType(20, 3), "DECIMAL(20,3)",
+ Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", "cast(1 AS
DECIMAL(20, 3))"))
+ )
+
+ val mixedDateTimeTypes: Seq[(DataType, String, Seq[String])] = Seq(
+ (DateType, "DATE",
+ Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")),
+ (TimestampType, "TIMESTAMP",
+ Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'")),
+ (TimestampNTZType, "TIMESTAMP_NTZ",
+ Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-01
00:00:00'")
+ )
+ )
+
+ // positive tests for approx_top_k_combine on every types
+ gridTest("SPARK-52798: same type, same size, specified combine size -
success")(itemsWithTopK) {
+ case (input, expected) =>
+ sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS
tab(expr);")
+ .createOrReplaceTempView("accumulation1")
+ sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS
tab(expr);")
+ .createOrReplaceTempView("accumulation2")
+ sql("SELECT approx_top_k_combine(acc, 30) as com " +
+ "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2);")
+ .createOrReplaceTempView("combined")
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ // expected should be doubled because we combine two identical sketches
+ val expectedDoubled = expected.map {
+ case Row(value: Any, count: Int) => Row(value, count * 2)
+ }
+ checkAnswer(est, Row(expectedDoubled))
+ }
+
+ test("SPARK-52798: same type, same size, specified combine size - success") {
+ setupMixedSizeAccumulations(10, 10)
+
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+ .createOrReplaceTempView("combined")
+
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ }
+
+ test("SPARK-52798: same type, same size, unspecified combine size -
success") {
+ setupMixedSizeAccumulations(10, 10)
+
+ sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
+ .createOrReplaceTempView("combined")
+
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ }
+
+ test("SPARK-52798: same type, different size, specified combine size -
success") {
+ setupMixedSizeAccumulations(10, 20)
+
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+ .createOrReplaceTempView("combination")
+
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combination;")
+ checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ }
+
+ test("SPARK-52798: same type, different size, unspecified combine size -
fail") {
+ setupMixedSizeAccumulations(10, 20)
+
+ val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
+
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ comb.collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
+ parameters = Map("size1" -> "10", "size2" -> "20")
+ )
+ }
+
+ gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10,
20))) {
+ case (size1, size2) =>
+ setupMixedSizeAccumulations(size1, size2)
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 0) as com FROM
unioned").collect()
+ },
+ condition = "APPROX_TOP_K_NON_POSITIVE_ARG",
+ parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0")
+ )
+ }
+
+ test("SPARK-52798: among different number or datetime types - fail at
combine") {
+ def checkMixedTypeError(mixedTypeSeq: Seq[(DataType, String, Seq[Any])]):
Unit = {
+ for (i <- 0 until mixedTypeSeq.size - 1) {
+ for (j <- i + 1 until mixedTypeSeq.size) {
+ val (type1, _, seq1) = mixedTypeSeq(i)
+ val (type2, _, seq2) = mixedTypeSeq(j)
+ setupMixedTypeAccumulation(seq1, seq2)
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(type2))
+ )
+ }
+ }
+ }
+
+ checkMixedTypeError(mixedNumberTypes)
+ checkMixedTypeError(mixedDateTimeTypes)
+ }
+
+ // enumerate all combinations of number and datetime types
+ gridTest("SPARK-52798: number vs datetime - fail on UNION")(
+ for {
+ (type1, typeName1, seq1) <- mixedNumberTypes
+ (type2, typeName2, seq2) <- mixedDateTimeTypes
+ } yield ((type1, typeName1, seq1), (type2, typeName2, seq2))) {
+ case ((_, type1, seq1), (_, type2, seq2)) =>
+ checkError(
+ exception = intercept[ExtendedAnalysisException] {
+ setupMixedTypeAccumulation(seq1, seq2)
+ },
+ condition = "INCOMPATIBLE_COLUMN_TYPE",
+ parameters = Map(
+ "tableOrdinalNumber" -> "second",
+ "columnOrdinalNumber" -> "first",
+ "dataType2" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: " + type1 + ", itemDataTypeDDL: STRING NOT
NULL>\""),
+ "operator" -> "UNION",
+ "hint" -> "",
+ "dataType1" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: " + type2 + ", itemDataTypeDDL: STRING NOT NULL>\"")
+ ),
+ queryContext = Array(
+ ExpectedContext(
+ "SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2", 0, 68))
+ )
+ }
+
+ gridTest("SPARK-52798: number vs string - fail at
combine")(mixedNumberTypes) {
+ case (type1, _, seq1) =>
+ setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'",
"'c'", "'d'", "'d'"))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(StringType))
+ )
+ }
+
+ gridTest("SPARK-52798: number vs boolean - fail at UNION")(mixedNumberTypes)
{
+ case (_, type1, seq1) =>
+ val seq2 = Seq("(true)", "(true)", "(false)", "(false)")
+ checkError(
+ exception = intercept[ExtendedAnalysisException] {
+ setupMixedTypeAccumulation(seq1, seq2)
+ },
+ condition = "INCOMPATIBLE_COLUMN_TYPE",
+ parameters = Map(
+ "tableOrdinalNumber" -> "second",
+ "columnOrdinalNumber" -> "first",
+ "dataType2" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: " + type1 + ", itemDataTypeDDL: STRING NOT
NULL>\""),
+ "operator" -> "UNION",
+ "hint" -> "",
+ "dataType1" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: BOOLEAN, itemDataTypeDDL: STRING NOT NULL>\"")
+ ),
+ queryContext = Array(
+ ExpectedContext(
+ "SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2", 0, 68))
+ )
+ }
+
+ gridTest("SPARK-52798: datetime vs string - fail at
combine")(mixedDateTimeTypes) {
+ case (type1, _, seq1) =>
+ setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'",
"'c'", "'d'", "'d'"))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(StringType))
+ )
+ }
+
+ gridTest("SPARK-52798: datetime vs boolean - fail at
UNION")(mixedDateTimeTypes) {
+ case (_, type1, seq1) =>
+ val seq2 = Seq("(true)", "(true)", "(false)", "(false)")
+ checkError(
+ exception = intercept[ExtendedAnalysisException] {
+ setupMixedTypeAccumulation(seq1, seq2)
+ },
+ condition = "INCOMPATIBLE_COLUMN_TYPE",
+ parameters = Map(
+ "tableOrdinalNumber" -> "second",
+ "columnOrdinalNumber" -> "first",
+ "dataType2" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: " + type1 + ", itemDataTypeDDL: STRING NOT
NULL>\""),
+ "operator" -> "UNION",
+ "hint" -> "",
+ "dataType1" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: BOOLEAN, itemDataTypeDDL: STRING NOT NULL>\"")
+ ),
+ queryContext = Array(
+ ExpectedContext(
+ "SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2", 0, 68))
+ )
+ }
+
+ test("SPARK-52798: string vs boolean - fail at combine") {
Review Comment:
Hi, @yhuang-db and @gengliangwang . This test case seems to break non-ANSI
CI for a week.
- https://github.com/apache/spark/actions/workflows/build_non_ansi.yml
- https://github.com/apache/spark/actions/runs/18623628999/job/53098232441
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]