This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 58f0f3be2204 [SPARK-49668][SQL] Implement collation key support for
trim collation
58f0f3be2204 is described below
commit 58f0f3be22045d29143a25f721fadf9361abd68d
Author: Jovan Pavlovic <[email protected]>
AuthorDate: Fri Oct 4 09:34:09 2024 +0200
[SPARK-49668][SQL] Implement collation key support for trim collation
### What changes were proposed in this pull request?
Implementing support for collation key for trim collations.
### Why are the changes needed?
Needed for correct calcuclation of collation key for trim collation.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added test in CollationExpressionSuite.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48313 from jovanpavl-db/implement_collation_key_support.
Authored-by: Jovan Pavlovic <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../spark/sql/catalyst/util/CollationFactory.java | 28 ++++-
.../expressions/CollationExpressionSuite.scala | 12 +++
.../org/apache/spark/sql/CollationSuite.scala | 113 ++++++++++++---------
3 files changed, 105 insertions(+), 48 deletions(-)
diff --git
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index e368e2479a3a..ac8ed11c7dca 100644
---
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -341,6 +341,22 @@ public final class CollationFactory {
SPACE_TRIMMING_OFFSET, SPACE_TRIMMING_MASK)];
}
+ protected static UTF8String applyTrimmingPolicy(UTF8String s, int
collationId) {
+ return applyTrimmingPolicy(s, getSpaceTrimming(collationId));
+ }
+
+ /**
+ * Utility function to trim spaces when collation uses space trimming.
+ */
+ protected static UTF8String applyTrimmingPolicy(UTF8String s,
SpaceTrimming spaceTrimming) {
+ return switch (spaceTrimming) {
+ case LTRIM -> s.trimLeft();
+ case RTRIM -> s.trimRight();
+ case TRIM -> s.trim();
+ default -> s; // NOTRIM
+ };
+ }
+
/**
* Main entry point for retrieving `Collation` instance from collation
ID.
*/
@@ -1130,24 +1146,32 @@ public final class CollationFactory {
public static UTF8String getCollationKey(UTF8String input, int collationId) {
Collation collation = fetchCollation(collationId);
+ if (usesTrimCollation(collationId)) {
+ input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId);
+ }
if (collation.supportsBinaryEquality) {
return input;
} else if (collation.supportsLowercaseEquality) {
return CollationAwareUTF8String.lowerCaseCodePoints(input);
} else {
- CollationKey collationKey =
collation.collator.getCollationKey(input.toValidString());
+ CollationKey collationKey = collation.collator.getCollationKey(
+ input.toValidString());
return UTF8String.fromBytes(collationKey.toByteArray());
}
}
public static byte[] getCollationKeyBytes(UTF8String input, int collationId)
{
Collation collation = fetchCollation(collationId);
+ if (usesTrimCollation(collationId)) {
+ input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId);
+ }
if (collation.supportsBinaryEquality) {
return input.getBytes();
} else if (collation.supportsLowercaseEquality) {
return CollationAwareUTF8String.lowerCaseCodePoints(input).getBytes();
} else {
- return
collation.collator.getCollationKey(input.toValidString()).toByteArray();
+ return collation.collator.getCollationKey(
+ input.toValidString()).toByteArray();
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
index e34b54c7086c..c5c79021c421 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
@@ -169,17 +169,29 @@ class CollationExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
("", "UTF8_BINARY", UTF8String.fromString("").getBytes),
("aa", "UTF8_BINARY", UTF8String.fromString("aa").getBytes),
("AA", "UTF8_BINARY", UTF8String.fromString("AA").getBytes),
+ (" AA ", "UTF8_BINARY_TRIM", UTF8String.fromString("AA").getBytes),
+ (" AA ", "UTF8_BINARY_LTRIM", UTF8String.fromString("AA ").getBytes),
+ (" AA ", "UTF8_BINARY_RTRIM", UTF8String.fromString(" AA").getBytes),
("aA", "UTF8_BINARY", UTF8String.fromString("aA").getBytes),
("", "UTF8_LCASE", UTF8String.fromString("").getBytes),
("aa", "UTF8_LCASE", UTF8String.fromString("aa").getBytes),
("AA", "UTF8_LCASE", UTF8String.fromString("aa").getBytes),
+ (" AA ", "UTF8_LCASE_TRIM", UTF8String.fromString("aa").getBytes),
+ (" AA ", "UTF8_LCASE_LTRIM", UTF8String.fromString("aa ").getBytes),
+ (" AA ", "UTF8_LCASE_RTRIM", UTF8String.fromString(" aa").getBytes),
("aA", "UTF8_LCASE", UTF8String.fromString("aa").getBytes),
("", "UNICODE", Array[Byte](1, 1, 0)),
("aa", "UNICODE", Array[Byte](42, 42, 1, 6, 1, 6, 0)),
("AA", "UNICODE", Array[Byte](42, 42, 1, 6, 1, -36, -36, 0)),
("aA", "UNICODE", Array[Byte](42, 42, 1, 6, 1, -59, -36, 0)),
+ (" aa ", "UNICODE_TRIM", Array[Byte](42, 42, 1, 6, 1, 6, 0)),
+ (" aa", "UNICODE_LTRIM", Array[Byte](42, 42, 1, 6, 1, 6, 0)),
+ ("aa ", "UNICODE_RTRIM", Array[Byte](42, 42, 1, 6, 1, 6, 0)),
("", "UNICODE_CI", Array[Byte](1, 0)),
("aa", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0)),
+ (" aa ", "UNICODE_CI_TRIM", Array[Byte](42, 42, 1, 6, 0)),
+ (" aa", "UNICODE_CI_LTRIM", Array[Byte](42, 42, 1, 6, 0)),
+ ("aa ", "UNICODE_CI_RTRIM", Array[Byte](42, 42, 1, 6, 0)),
("AA", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0)),
("aA", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0))
)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index 03d3ed6ac7cb..b8f33bcb1977 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -1101,7 +1101,8 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
}
}
- for (collation <- Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI",
"")) {
+ for (collation <- Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI",
+ "UNICODE_CI_TRIM", "")) {
for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) {
val collationSetup = if (collation.isEmpty) "" else " COLLATE " +
collation
val supportsBinaryEquality = collation.isEmpty || collation == "UNICODE"
||
@@ -1294,21 +1295,23 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
val t1 = "T_1"
val t2 = "T_2"
- case class HashJoinTestCase[R](collation: String, result: R)
+ case class HashJoinTestCase[R](collation: String, data1: String, data2:
String, result: R)
val testCases = Seq(
- HashJoinTestCase("UTF8_BINARY", Seq(Row("aa", 1, "aa", 2))),
- HashJoinTestCase("UTF8_LCASE", Seq(Row("aa", 1, "AA", 2), Row("aa", 1,
"aa", 2))),
- HashJoinTestCase("UNICODE", Seq(Row("aa", 1, "aa", 2))),
- HashJoinTestCase("UNICODE_CI", Seq(Row("aa", 1, "AA", 2), Row("aa", 1,
"aa", 2)))
+ HashJoinTestCase("UTF8_BINARY", "aa", "AA", Seq(Row("aa", 1, "aa", 2))),
+ HashJoinTestCase("UTF8_LCASE", "aa", "AA", Seq(Row("aa", 1, "AA", 2),
Row("aa", 1, "aa", 2))),
+ HashJoinTestCase("UNICODE", "aa", "AA", Seq(Row("aa", 1, "aa", 2))),
+ HashJoinTestCase("UNICODE_CI", "aa", "AA", Seq(Row("aa", 1, "AA", 2),
Row("aa", 1, "aa", 2))),
+ HashJoinTestCase("UNICODE_CI_TRIM", "aa", " AA ", Seq(Row("aa", 1, " AA
", 2),
+ Row("aa", 1, "aa", 2)))
)
testCases.foreach(t => {
withTable(t1, t2) {
sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING
PARQUET")
- sql(s"INSERT INTO $t1 VALUES ('aa', 1)")
+ sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)")
sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING
PARQUET")
- sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)")
+ sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)")
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, t.result)
@@ -1345,25 +1348,27 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
val t1 = "T_1"
val t2 = "T_2"
- case class HashJoinTestCase[R](collation: String, result: R)
+ case class HashJoinTestCase[R](collation: String, data1: String, data2:
String, result: R)
val testCases = Seq(
- HashJoinTestCase("UTF8_BINARY",
+ HashJoinTestCase("UTF8_BINARY", "aa", "AA",
Seq(Row(Seq("aa"), 1, Seq("aa"), 2))),
- HashJoinTestCase("UTF8_LCASE",
+ HashJoinTestCase("UTF8_LCASE", "aa", "AA",
Seq(Row(Seq("aa"), 1, Seq("AA"), 2), Row(Seq("aa"), 1, Seq("aa"), 2))),
- HashJoinTestCase("UNICODE",
+ HashJoinTestCase("UNICODE", "aa", "AA",
Seq(Row(Seq("aa"), 1, Seq("aa"), 2))),
- HashJoinTestCase("UNICODE_CI",
- Seq(Row(Seq("aa"), 1, Seq("AA"), 2), Row(Seq("aa"), 1, Seq("aa"), 2)))
+ HashJoinTestCase("UNICODE_CI", "aa", "AA",
+ Seq(Row(Seq("aa"), 1, Seq("AA"), 2), Row(Seq("aa"), 1, Seq("aa"), 2))),
+ HashJoinTestCase("UNICODE_CI_TRIM", "aa", " AA ",
+ Seq(Row(Seq("aa"), 1, Seq(" AA "), 2), Row(Seq("aa"), 1, Seq("aa"),
2)))
)
testCases.foreach(t => {
withTable(t1, t2) {
sql(s"CREATE TABLE $t1 (x ARRAY<STRING COLLATE ${t.collation}>, i int)
USING PARQUET")
- sql(s"INSERT INTO $t1 VALUES (array('aa'), 1)")
+ sql(s"INSERT INTO $t1 VALUES (array('${t.data1}'), 1)")
sql(s"CREATE TABLE $t2 (y ARRAY<STRING COLLATE ${t.collation}>, j int)
USING PARQUET")
- sql(s"INSERT INTO $t2 VALUES (array('AA'), 2), (array('aa'), 2)")
+ sql(s"INSERT INTO $t2 VALUES (array('${t.data2}'), 2),
(array('${t.data1}'), 2)")
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, t.result)
@@ -1401,27 +1406,30 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
val t1 = "T_1"
val t2 = "T_2"
- case class HashJoinTestCase[R](collation: String, result: R)
+ case class HashJoinTestCase[R](collation: String, data1: String, data2:
String, result: R)
val testCases = Seq(
- HashJoinTestCase("UTF8_BINARY",
+ HashJoinTestCase("UTF8_BINARY", "aa", "AA",
Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("aa")), 2))),
- HashJoinTestCase("UTF8_LCASE",
+ HashJoinTestCase("UTF8_LCASE", "aa", "AA",
Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("AA")), 2), Row(Seq(Seq("aa")), 1,
Seq(Seq("aa")), 2))),
- HashJoinTestCase("UNICODE",
+ HashJoinTestCase("UNICODE", "aa", "AA",
Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("aa")), 2))),
- HashJoinTestCase("UNICODE_CI",
- Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("AA")), 2), Row(Seq(Seq("aa")), 1,
Seq(Seq("aa")), 2)))
+ HashJoinTestCase("UNICODE_CI", "aa", "AA",
+ Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("AA")), 2), Row(Seq(Seq("aa")), 1,
Seq(Seq("aa")), 2))),
+ HashJoinTestCase("UNICODE_CI_TRIM", "aa", " AA ",
+ Seq(Row(Seq(Seq("aa")), 1, Seq(Seq(" AA ")), 2), Row(Seq(Seq("aa")),
1, Seq(Seq("aa")), 2)))
)
testCases.foreach(t => {
withTable(t1, t2) {
sql(s"CREATE TABLE $t1 (x ARRAY<ARRAY<STRING COLLATE ${t.collation}>>,
i int) USING " +
s"PARQUET")
- sql(s"INSERT INTO $t1 VALUES (array(array('aa')), 1)")
+ sql(s"INSERT INTO $t1 VALUES (array(array('${t.data1}')), 1)")
sql(s"CREATE TABLE $t2 (y ARRAY<ARRAY<STRING COLLATE ${t.collation}>>,
j int) USING " +
s"PARQUET")
- sql(s"INSERT INTO $t2 VALUES (array(array('AA')), 2),
(array(array('aa')), 2)")
+ sql(s"INSERT INTO $t2 VALUES (array(array('${t.data2}')), 2)," +
+ s" (array(array('${t.data1}')), 2)")
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, t.result)
@@ -1460,24 +1468,27 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
val t1 = "T_1"
val t2 = "T_2"
- case class HashJoinTestCase[R](collation: String, result: R)
+ case class HashJoinTestCase[R](collation: String, data1 : String, data2:
String, result: R)
val testCases = Seq(
- HashJoinTestCase("UTF8_BINARY",
+ HashJoinTestCase("UTF8_BINARY", "aa", "AA",
Seq(Row(Row("aa"), 1, Row("aa"), 2))),
- HashJoinTestCase("UTF8_LCASE",
+ HashJoinTestCase("UTF8_LCASE", "aa", "AA",
Seq(Row(Row("aa"), 1, Row("AA"), 2), Row(Row("aa"), 1, Row("aa"), 2))),
- HashJoinTestCase("UNICODE",
+ HashJoinTestCase("UNICODE", "aa", "AA",
Seq(Row(Row("aa"), 1, Row("aa"), 2))),
- HashJoinTestCase("UNICODE_CI",
- Seq(Row(Row("aa"), 1, Row("AA"), 2), Row(Row("aa"), 1, Row("aa"), 2)))
+ HashJoinTestCase("UNICODE_CI", "aa", "AA",
+ Seq(Row(Row("aa"), 1, Row("AA"), 2), Row(Row("aa"), 1, Row("aa"), 2))),
+ HashJoinTestCase("UNICODE_CI_TRIM", "aa", " AA ",
+ Seq(Row(Row("aa"), 1, Row(" AA "), 2), Row(Row("aa"), 1, Row("aa"),
2)))
)
testCases.foreach(t => {
withTable(t1, t2) {
sql(s"CREATE TABLE $t1 (x STRUCT<f:STRING COLLATE ${t.collation}>, i
int) USING PARQUET")
- sql(s"INSERT INTO $t1 VALUES (named_struct('f', 'aa'), 1)")
+ sql(s"INSERT INTO $t1 VALUES (named_struct('f', '${t.data1}'), 1)")
sql(s"CREATE TABLE $t2 (y STRUCT<f:STRING COLLATE ${t.collation}>, j
int) USING PARQUET")
- sql(s"INSERT INTO $t2 VALUES (named_struct('f', 'AA'), 2),
(named_struct('f', 'aa'), 2)")
+ sql(s"INSERT INTO $t2 VALUES (named_struct('f', '${t.data2}'), 2)," +
+ s" (named_struct('f', '${t.data1}'), 2)")
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, t.result)
@@ -1510,29 +1521,33 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
val t1 = "T_1"
val t2 = "T_2"
- case class HashJoinTestCase[R](collation: String, result: R)
+ case class HashJoinTestCase[R](collation: String, data1: String, data2:
String, result: R)
val testCases = Seq(
- HashJoinTestCase("UTF8_BINARY",
+ HashJoinTestCase("UTF8_BINARY", "aa", "AA",
Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))),
- HashJoinTestCase("UTF8_LCASE",
+ HashJoinTestCase("UTF8_LCASE", "aa", "AA",
Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("AA"))), 2),
Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))),
- HashJoinTestCase("UNICODE",
+ HashJoinTestCase("UNICODE", "aa", "AA",
Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))),
- HashJoinTestCase("UNICODE_CI",
+ HashJoinTestCase("UNICODE_CI", "aa", "AA",
Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("AA"))), 2),
+ Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))),
+ HashJoinTestCase("UNICODE_CI_TRIM", "aa", " AA ",
+ Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row(" AA "))), 2),
Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2)))
)
testCases.foreach(t => {
withTable(t1, t2) {
sql(s"CREATE TABLE $t1 (x STRUCT<f:ARRAY<STRUCT<f:STRING COLLATE
${t.collation}>>>, " +
s"i int) USING PARQUET")
- sql(s"INSERT INTO $t1 VALUES (named_struct('f',
array(named_struct('f', 'aa'))), 1)")
+ sql(s"INSERT INTO $t1 VALUES (named_struct('f',
array(named_struct('f', '${t.data1}'))), 1)"
+ )
sql(s"CREATE TABLE $t2 (y STRUCT<f:ARRAY<STRUCT<f:STRING COLLATE
${t.collation}>>>, " +
s"j int) USING PARQUET")
- sql(s"INSERT INTO $t2 VALUES (named_struct('f',
array(named_struct('f', 'AA'))), 2), " +
- s"(named_struct('f', array(named_struct('f', 'aa'))), 2)")
+ sql(s"INSERT INTO $t2 VALUES (named_struct('f',
array(named_struct('f', '${t.data2}'))), 2)"
+ + s", (named_struct('f', array(named_struct('f', '${t.data1}'))),
2)")
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, t.result)
@@ -1613,7 +1628,9 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "STRING COLLATE
UTF8_LCASE",
"'a', 'a', 1", "'a', 'A', 1", Row("a", "a", 1, "a", "A", 1)),
HashMultiJoinTestCase("STRING COLLATE UTF8_LCASE", "STRING COLLATE
UNICODE_CI",
- "'a', 'a', 1", "'A', 'A', 1", Row("a", "a", 1, "A", "A", 1))
+ "'a', 'a', 1", "'A', 'A', 1", Row("a", "a", 1, "A", "A", 1)),
+ HashMultiJoinTestCase("STRING COLLATE UTF8_LCASE", "STRING COLLATE
UNICODE_CI_TRIM",
+ "'a', 'a', 1", "'A', ' A ', 1", Row("a", "a", 1, "A", " A ", 1))
)
testCases.foreach(t => {
@@ -1646,15 +1663,19 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
test("hll sketch aggregate should respect collation") {
case class HllSketchAggTestCase[R](c: String, result: R)
val testCases = Seq(
- HllSketchAggTestCase("UTF8_BINARY", 4),
- HllSketchAggTestCase("UTF8_LCASE", 3),
- HllSketchAggTestCase("UNICODE", 4),
- HllSketchAggTestCase("UNICODE_CI", 3)
+ HllSketchAggTestCase("UTF8_BINARY", 5),
+ HllSketchAggTestCase("UTF8_BINARY_TRIM", 4),
+ HllSketchAggTestCase("UTF8_LCASE", 4),
+ HllSketchAggTestCase("UTF8_LCASE_TRIM", 3),
+ HllSketchAggTestCase("UNICODE", 5),
+ HllSketchAggTestCase("UNICODE_TRIM", 4),
+ HllSketchAggTestCase("UNICODE_CI", 4),
+ HllSketchAggTestCase("UNICODE_CI_TRIM", 3)
)
testCases.foreach(t => {
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.c) {
val q = "SELECT hll_sketch_estimate(hll_sketch_agg(col)) FROM " +
- "VALUES ('a'), ('A'), ('b'), ('b'), ('c') tab(col)"
+ "VALUES ('a'), ('A'), ('b'), ('b'), ('c'), (' c ') tab(col)"
val df = sql(q)
checkAnswer(df, Seq(Row(t.result)))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]