gengliangwang commented on code in PR #51505:
URL: https://github.com/apache/spark/pull/51505#discussion_r2430785760


##########
sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala:
##########
@@ -328,4 +331,326 @@ class ApproxTopKSuite extends QueryTest with 
SharedSparkSession {
       parameters = Map("maxItemsTracked" -> "5", "k" -> "10")
     )
   }
+
+  /////////////////////////////////
+  // approx_top_k_combine
+  /////////////////////////////////
+
+  def setupMixedSizeAccumulations(size1: Int, size2: Int): Unit = {
+    sql(s"SELECT approx_top_k_accumulate(expr, $size1) as acc " +
+      "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);")
+      .createOrReplaceTempView("accumulation1")
+
+    sql(s"SELECT approx_top_k_accumulate(expr, $size2) as acc " +
+      "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);")
+      .createOrReplaceTempView("accumulation2")
+
+    sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM 
accumulation2")
+      .createOrReplaceTempView("unioned")
+  }
+
+  def setupMixedTypeAccumulation(seq1: Seq[Any], seq2: Seq[Any]): Unit = {
+    sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+      s"FROM VALUES ${seq1.mkString(", ")} AS tab(expr);")
+      .createOrReplaceTempView("accumulation1")
+
+    sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+      s"FROM VALUES ${seq2.mkString(", ")} AS tab(expr);")
+      .createOrReplaceTempView("accumulation2")
+
+    sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM 
accumulation2")
+      .createOrReplaceTempView("unioned")
+  }
+
+  val mixedNumberTypes: Seq[(DataType, String, Seq[Any])] = Seq(
+    (IntegerType, "INT",
+      Seq(0, 0, 0, 1, 1, 2, 2, 3)),
+    (ByteType, "TINYINT",
+      Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)")),
+    (ShortType, "SMALLINT",
+      Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)")),
+    (LongType, "BIGINT",
+      Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)")),
+    (FloatType, "FLOAT",
+      Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)")),
+    (DoubleType, "DOUBLE",
+      Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)")),
+    (DecimalType(4, 2), "DECIMAL(4,2)",
+      Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", "cast(1 AS 
DECIMAL(4, 2))")),
+    (DecimalType(10, 2), "DECIMAL(10,2)",
+      Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", "cast(1 AS 
DECIMAL(10, 2))")),
+    (DecimalType(20, 3), "DECIMAL(20,3)",
+      Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", "cast(1 AS 
DECIMAL(20, 3))"))
+  )
+
+  val mixedDateTimeTypes: Seq[(DataType, String, Seq[String])] = Seq(
+    (DateType, "DATE",
+      Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")),
+    (TimestampType, "TIMESTAMP",
+      Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'")),
+    (TimestampNTZType, "TIMESTAMP_NTZ",
+      Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-01 
00:00:00'")
+    )
+  )
+
+  // positive tests for approx_top_k_combine on every types
+  gridTest("SPARK-52798: same type, same size, specified combine size - 
success")(itemsWithTopK) {
+    case (input, expected) =>
+      sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS 
tab(expr);")
+        .createOrReplaceTempView("accumulation1")
+      sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS 
tab(expr);")
+        .createOrReplaceTempView("accumulation2")
+      sql("SELECT approx_top_k_combine(acc, 30) as com " +
+        "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM 
accumulation2);")
+        .createOrReplaceTempView("combined")
+      val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+      // expected should be doubled because we combine two identical sketches
+      val expectedDoubled = expected.map {
+        case Row(value: Any, count: Int) => Row(value, count * 2)
+      }
+      checkAnswer(est, Row(expectedDoubled))
+  }
+
+  test("SPARK-52798: same type, same size, specified combine size - success") {
+    setupMixedSizeAccumulations(10, 10)
+
+    sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+      .createOrReplaceTempView("combined")
+
+    val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+    checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), 
Row(4, 2))))
+  }
+
+  test("SPARK-52798: same type, same size, unspecified combine size - 
success") {
+    setupMixedSizeAccumulations(10, 10)
+
+    sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
+      .createOrReplaceTempView("combined")
+
+    val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+    checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), 
Row(4, 2))))
+  }
+
+  test("SPARK-52798: same type, different size, specified combine size - 
success") {
+    setupMixedSizeAccumulations(10, 20)
+
+    sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+      .createOrReplaceTempView("combination")
+
+    val est = sql("SELECT approx_top_k_estimate(com) FROM combination;")
+    checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), 
Row(4, 2))))
+  }
+
+  test("SPARK-52798: same type, different size, unspecified combine size - 
fail") {
+    setupMixedSizeAccumulations(10, 20)
+
+    val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
+
+    checkError(
+      exception = intercept[SparkRuntimeException] {
+        comb.collect()
+      },
+      condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
+      parameters = Map("size1" -> "10", "size2" -> "20")
+    )
+  }
+
+  gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10, 
20))) {
+    case (size1, size2) =>
+      setupMixedSizeAccumulations(size1, size2)
+      checkError(
+        exception = intercept[SparkRuntimeException] {
+          sql("SELECT approx_top_k_combine(acc, 0) as com FROM 
unioned").collect()
+        },
+        condition = "APPROX_TOP_K_NON_POSITIVE_ARG",
+        parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0")
+      )
+  }
+
+  test("SPARK-52798: among different number or datetime types - fail at 
combine") {

Review Comment:
   QQ: are there test cases for **valid** combine with different input types. 
For example, int and long, date and string



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to