dbatomic commented on code in PR #46640:
URL: https://github.com/apache/spark/pull/46640#discussion_r1622375244
##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala:
##########
@@ -999,6 +1000,113 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
}
}
+ test("RewriteGroupByCollation rule works for string") {
+ val dataType = StringType(CollationFactory.collationNameToId("UNICODE_CI"))
+ val schema = StructType(Seq(StructField("name", dataType)))
+ val data = Seq(Row("AA"), Row("aa"), Row("BB"))
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(data),
schema)
+ df.createOrReplaceTempView("tempTable")
+ // test RewriteGroupByCollation idempotence
+ val dfGroupBy1 = spark.sql("SELECT COUNT(*) FROM tempTable GROUP BY name")
+ val logicalPlan1 = dfGroupBy1.queryExecution.analyzed
+ val newPlan1 = RewriteGroupByCollation(logicalPlan1)
+ val newNewPlan1 = RewriteGroupByCollation(newPlan1)
+ assert(newPlan1 == newNewPlan1)
+ // get the query execution result
+ checkAnswer(dfGroupBy1, Seq(Row(2), Row(1)))
+ }
+
+ test("hash aggregation should be used for collated strings") {
+ val t1 = "T_1"
+
+ case class HashAggregationTestCase[R](collation: String, result1: R,
result2: R)
+ val testCases = Seq(
+ HashAggregationTestCase("UTF8_BINARY",
+ Seq(Row(1), Row(1), Row(1)),
+ Seq(Row("aa", 1), Row("AA", 1), Row("bb", 1))
+ ),
+ HashAggregationTestCase("UTF8_BINARY_LCASE",
+ Seq(Row(2), Row(1)),
+ Seq(Row("aa", 2), Row("bb", 1))
+ ),
+ HashAggregationTestCase("UNICODE",
+ Seq(Row(1), Row(1), Row(1)),
+ Seq(Row("aa", 1), Row("AA", 1), Row("bb", 1))
+ ),
+ HashAggregationTestCase("UNICODE_CI",
+ Seq(Row(2), Row(1)),
+ Seq(Row("aa", 2), Row("bb", 1))
+ )
+ )
+
+ testCases.foreach(t => {
+ withTable(t1) {
+ sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}) USING
PARQUET")
+ sql(s"INSERT INTO $t1 VALUES ('aa'), ('AA'), ('bb')")
+
+ val df1 = sql(s"SELECT COUNT(*) FROM $t1 GROUP BY x")
+ checkAnswer(df1, t.result1)
+
+ val df2 = sql(s"SELECT x, COUNT(*) FROM $t1 GROUP BY x")
+ checkAnswer(df2, t.result2)
+
+ val queryPlan1 = df1.queryExecution.executedPlan
+ val queryPlan2 = df2.queryExecution.executedPlan
+
+ if
(CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) {
+ // hash agg can always be used for binary collations
+ assert(collectFirst(queryPlan1) { case _: HashAggregateExec => ()
}.nonEmpty)
+ assert(collectFirst(queryPlan1) { case _: SortAggregateExec => ()
}.isEmpty)
+ assert(collectFirst(queryPlan2) { case _: HashAggregateExec => ()
}.nonEmpty)
+ assert(collectFirst(queryPlan2) { case _: SortAggregateExec => ()
}.isEmpty)
+ } else {
+ // hash agg can also be used if a non-binary collation is only used
for grouping
+ assert(collectFirst(queryPlan1) { case _: HashAggregateExec => ()
}.nonEmpty)
+ assert(collectFirst(queryPlan1) { case _: SortAggregateExec => ()
}.isEmpty)
+ // however, sort agg will be used if a non-binary collation is
present in the aggregate
+ assert(collectFirst(queryPlan2) { case _: HashAggregateExec => ()
}.isEmpty)
+ assert(collectFirst(queryPlan2) { case _: SortAggregateExec => ()
}.nonEmpty)
+ // check that CollationKey is injected into the Aggregate logical
plan in any case
+ assert(collectFirst(queryPlan1) { case s: HashAggregateExec =>
+ s.groupingExpressions.head.dataType.isInstanceOf[BinaryType]
}.nonEmpty)
+ assert(collectFirst(queryPlan2) { case s: SortAggregateExec =>
+ s.groupingExpressions.head.dataType.isInstanceOf[BinaryType]
}.nonEmpty)
+ }
+ }
+ })
+ }
+
+ test("hash aggregation should be used for collated struct field") {
+ val t1 = "T_1"
+ withTable(t1) {
+ val schema = "x: STRING, y: STRING COLLATE UTF8_BINARY_LCASE"
+ sql(s"CREATE TABLE $t1 (c struct<$schema>) USING PARQUET")
+ sql(s"INSERT INTO $t1 VALUES (named_struct('x', 'aa', 'y', 'aa'))")
+ sql(s"INSERT INTO $t1 VALUES (named_struct('x', 'AA', 'y', 'AA'))")
+ sql(s"INSERT INTO $t1 VALUES (named_struct('x', 'bb', 'y', 'bb'))")
+
+ val df1 = sql(s"SELECT COUNT(*) FROM $t1 GROUP BY c.x")
+ val result1 = Seq(Row(1), Row(1), Row(1))
+ checkAnswer(df1, result1)
+
+ val df2 = sql(s"SELECT COUNT(*) FROM $t1 GROUP BY c.y")
+ val result2 = Seq(Row(2), Row(1))
+ checkAnswer(df2, result2)
+
+ val queryPlan1 = df1.queryExecution.executedPlan
+ val queryPlan2 = df2.queryExecution.executedPlan
+
+ // hash aggregation can be used for all collations in these examples
+ assert(collectFirst(queryPlan1) { case _: HashAggregateExec => ()
}.nonEmpty)
+ assert(collectFirst(queryPlan1) { case _: SortAggregateExec => ()
}.isEmpty)
+ assert(collectFirst(queryPlan2) { case _: HashAggregateExec => ()
}.nonEmpty)
+ assert(collectFirst(queryPlan2) { case _: SortAggregateExec => ()
}.isEmpty)
+ // check that CollationKey is injected into the Aggregate logical plan
+ assert(collectFirst(queryPlan1) { case s: HashAggregateExec =>
Review Comment:
It would be cleaner if you would explicitly check whether `head` is
instanceof CollationKey, instead of relying on return type.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]