This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 31036274fc1c [SPARK-47410][SQL] Refactor UTF8String and CollationFactory 31036274fc1c is described below commit 31036274fc1c8013c6428735659959f46afea5d8 Author: Uros Bojanic <157381213+uros...@users.noreply.github.com> AuthorDate: Thu Apr 11 22:21:21 2024 +0800 [SPARK-47410][SQL] Refactor UTF8String and CollationFactory ### What changes were proposed in this pull request? This PR introduces comprehensive support for collation-aware expressions in Spark, focusing on improving code structure, clarity, and testing coverage for various expressions (including: Contains, StartsWith, EndsWith). ### Why are the changes needed? The changes are essential to improve the maintainability and readability of collation-related code in Spark expressions. By restructuring and centralizing collation support through, we simplify the addition of new collation-aware operations and ensure consistent testing across different collation types. ### Does this PR introduce _any_ user-facing change? No, this PR is focused on internal refactoring and testing enhancements for collation-aware expression support. ### How was this patch tested? Unit tests in CollationSupportSuite.java E2E tests in CollationStringExpressionsSuite.scala ### Was this patch authored or co-authored using generative AI tooling? Yes. Closes #45978 from uros-db/SPARK-47410. Authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/util/CollationFactory.java | 54 +- .../spark/sql/catalyst/util/CollationSupport.java | 174 ++++++ .../org/apache/spark/unsafe/types/UTF8String.java | 54 -- .../spark/unsafe/types/CollationSupportSuite.java | 266 +++++++++ .../unsafe/types/UTF8StringWithCollationSuite.java | 103 ---- .../expressions/codegen/CodeGenerator.scala | 3 +- .../catalyst/expressions/stringExpressions.scala | 41 +- .../sql/CollationRegexpExpressionsSuite.scala | 616 +++++++++------------ .../sql/CollationStringExpressionsSuite.scala | 179 ++++-- .../org/apache/spark/sql/CollationSuite.scala | 84 --- 10 files changed, 874 insertions(+), 700 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 72a6e574707f..ff7bc450f851 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -78,6 +78,14 @@ public final class CollationFactory { */ public final boolean supportsBinaryOrdering; + /** + * Support for Lowercase Equality implies that it is possible to check equality on + * byte by byte level, but only after calling "UTF8String.toLowerCase" on both arguments. + * This allows custom collation support for UTF8_BINARY_LCASE collation in various Spark + * expressions, as this particular collation is not supported by the external ICU library. + */ + public final boolean supportsLowercaseEquality; + public Collation( String collationName, Collator collator, @@ -85,7 +93,8 @@ public final class CollationFactory { String version, ToLongFunction<UTF8String> hashFunction, boolean supportsBinaryEquality, - boolean supportsBinaryOrdering) { + boolean supportsBinaryOrdering, + boolean supportsLowercaseEquality) { this.collationName = collationName; this.collator = collator; this.comparator = comparator; @@ -93,9 +102,12 @@ public final class CollationFactory { this.hashFunction = hashFunction; this.supportsBinaryEquality = supportsBinaryEquality; this.supportsBinaryOrdering = supportsBinaryOrdering; + this.supportsLowercaseEquality = supportsLowercaseEquality; // De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality assert(!supportsBinaryOrdering || supportsBinaryEquality); + // No Collation can simultaneously support binary equality and lowercase equality + assert(!supportsBinaryEquality || !supportsLowercaseEquality); if (supportsBinaryEquality) { this.equalsFunction = UTF8String::equals; @@ -112,7 +124,8 @@ public final class CollationFactory { Collator collator, String version, boolean supportsBinaryEquality, - boolean supportsBinaryOrdering) { + boolean supportsBinaryOrdering, + boolean supportsLowercaseEquality) { this( collationName, collator, @@ -120,7 +133,8 @@ public final class CollationFactory { version, s -> (long)collator.getCollationKey(s.toString()).hashCode(), supportsBinaryEquality, - supportsBinaryOrdering); + supportsBinaryOrdering, + supportsLowercaseEquality); } } @@ -141,7 +155,8 @@ public final class CollationFactory { "1.0", s -> (long)s.hashCode(), true, - true); + true, + false); // Case-insensitive UTF8 binary collation. // TODO: Do in place comparisons instead of creating new strings. @@ -152,17 +167,18 @@ public final class CollationFactory { "1.0", (s) -> (long)s.toLowerCase().hashCode(), false, - false); + false, + true); // UNICODE case sensitive comparison (ROOT locale, in ICU). collationTable[2] = new Collation( - "UNICODE", Collator.getInstance(ULocale.ROOT), "153.120.0.0", true, false); + "UNICODE", Collator.getInstance(ULocale.ROOT), "153.120.0.0", true, false, false); collationTable[2].collator.setStrength(Collator.TERTIARY); collationTable[2].collator.freeze(); // UNICODE case-insensitive comparison (ROOT locale, in ICU + Secondary strength). collationTable[3] = new Collation( - "UNICODE_CI", Collator.getInstance(ULocale.ROOT), "153.120.0.0", false, false); + "UNICODE_CI", Collator.getInstance(ULocale.ROOT), "153.120.0.0", false, false, false); collationTable[3].collator.setStrength(Collator.SECONDARY); collationTable[3].collator.freeze(); @@ -172,19 +188,31 @@ public final class CollationFactory { } /** - * Auxiliary methods for collation aware string operations. + * Returns a StringSearch object for the given pattern and target strings, under collation + * rules corresponding to the given collationId. The external ICU library StringSearch object can + * be used to find occurrences of the pattern in the target string, while respecting collation. */ - public static StringSearch getStringSearch( - final UTF8String left, - final UTF8String right, + final UTF8String targetUTF8String, + final UTF8String patternUTF8String, final int collationId) { - String pattern = right.toString(); - CharacterIterator target = new StringCharacterIterator(left.toString()); + String pattern = patternUTF8String.toString(); + CharacterIterator target = new StringCharacterIterator(targetUTF8String.toString()); Collator collator = CollationFactory.fetchCollation(collationId).collator; return new StringSearch(pattern, target, (RuleBasedCollator) collator); } + /** + * Returns a collation-unaware StringSearch object for the given pattern and target strings. + * While this object does not respect collation, it can be used to find occurrences of the pattern + * in the target string for UTF8_BINARY or UTF8_BINARY_LCASE (if arguments are lowercased). + */ + public static StringSearch getStringSearch( + final UTF8String targetUTF8String, + final UTF8String patternUTF8String) { + return new StringSearch(patternUTF8String.toString(), targetUTF8String.toString()); + } + /** * Returns the collation id for the given collation name. */ diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java new file mode 100644 index 000000000000..fe1952921b7f --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.util; + +import com.ibm.icu.text.StringSearch; + +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Static entry point for collation-aware expressions (StringExpressions, RegexpExpressions, and + * other expressions that require custom collation support), as well as private utility methods for + * collation-aware UTF8String operations needed to implement . + */ +public final class CollationSupport { + + /** + * Collation-aware string expressions. + */ + + public static class Contains { + public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(l, r); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(l, r); + } else { + return execICU(l, r, collationId); + } + } + public static String genCode(final String l, final String r, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.Contains.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", l, r); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s)", l, r); + } else { + return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); + } + } + public static boolean execBinary(final UTF8String l, final UTF8String r) { + return l.contains(r); + } + public static boolean execLowercase(final UTF8String l, final UTF8String r) { + return l.toLowerCase().contains(r.toLowerCase()); + } + public static boolean execICU(final UTF8String l, final UTF8String r, + final int collationId) { + if (r.numBytes() == 0) return true; + if (l.numBytes() == 0) return false; + StringSearch stringSearch = CollationFactory.getStringSearch(l, r, collationId); + return stringSearch.first() != StringSearch.DONE; + } + } + + public static class StartsWith { + public static boolean exec(final UTF8String l, final UTF8String r, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(l, r); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(l, r); + } else { + return execICU(l, r, collationId); + } + } + public static String genCode(final String l, final String r, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StartsWith.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", l, r); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s)", l, r); + } else { + return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); + } + } + public static boolean execBinary(final UTF8String l, final UTF8String r) { + return l.startsWith(r); + } + public static boolean execLowercase(final UTF8String l, final UTF8String r) { + return l.toLowerCase().startsWith(r.toLowerCase()); + } + public static boolean execICU(final UTF8String l, final UTF8String r, + final int collationId) { + return CollationAwareUTF8String.matchAt(l, r, 0, collationId); + } + } + + public static class EndsWith { + public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(l, r); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(l, r); + } else { + return execICU(l, r, collationId); + } + } + public static String genCode(final String l, final String r, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.EndsWith.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", l, r); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s)", l, r); + } else { + return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); + } + } + public static boolean execBinary(final UTF8String l, final UTF8String r) { + return l.endsWith(r); + } + public static boolean execLowercase(final UTF8String l, final UTF8String r) { + return l.toLowerCase().endsWith(r.toLowerCase()); + } + public static boolean execICU(final UTF8String l, final UTF8String r, + final int collationId) { + return CollationAwareUTF8String.matchAt(l, r, l.numBytes() - r.numBytes(), collationId); + } + } + + // TODO: Add more collation-aware string expressions. + + /** + * Collation-aware regexp expressions. + */ + + // TODO: Add more collation-aware regexp expressions. + + /** + * Other collation-aware expressions. + */ + + // TODO: Add other collation-aware expressions. + + /** + * Utility class for collation-aware UTF8String operations. + */ + + private static class CollationAwareUTF8String { + + private static boolean matchAt(final UTF8String target, final UTF8String pattern, + final int pos, final int collationId) { + if (pattern.numChars() + pos > target.numChars() || pos < 0) { + return false; + } + if (pattern.numBytes() == 0 || target.numBytes() == 0) { + return pattern.numBytes() == 0; + } + return CollationFactory.getStringSearch(target.substring( + pos, pos + pattern.numChars()), pattern, collationId).last() == 0; + } + + } + +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 2006efb07a04..2009f1d20442 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -30,7 +30,6 @@ import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; -import com.ibm.icu.text.StringSearch; import org.apache.spark.sql.catalyst.util.CollationFactory; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UTF8StringBuilder; @@ -342,28 +341,6 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, return false; } - public boolean contains(final UTF8String substring, int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - return this.contains(substring); - } - if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { - return this.toLowerCase().contains(substring.toLowerCase()); - } - return collatedContains(substring, collationId); - } - - private boolean collatedContains(final UTF8String substring, int collationId) { - if (substring.numBytes == 0) return true; - if (this.numBytes == 0) return false; - StringSearch stringSearch = CollationFactory.getStringSearch(this, substring, collationId); - while (stringSearch.next() != StringSearch.DONE) { - if (stringSearch.getMatchLength() == stringSearch.getPattern().length()) { - return true; - } - } - return false; - } - /** * Returns the byte at position `i`. */ @@ -378,45 +355,14 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); } - private boolean matchAt(final UTF8String s, int pos, int collationId) { - if (s.numChars() + pos > this.numChars() || pos < 0) { - return false; - } - if (s.numBytes == 0 || this.numBytes == 0) { - return s.numBytes == 0; - } - return CollationFactory.getStringSearch(this.substring(pos, pos + s.numChars()), - s, collationId).last() == 0; - } - public boolean startsWith(final UTF8String prefix) { return matchAt(prefix, 0); } - public boolean startsWith(final UTF8String prefix, int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - return this.startsWith(prefix); - } - if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { - return this.toLowerCase().startsWith(prefix.toLowerCase()); - } - return matchAt(prefix, 0, collationId); - } - public boolean endsWith(final UTF8String suffix) { return matchAt(suffix, numBytes - suffix.numBytes); } - public boolean endsWith(final UTF8String suffix, int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - return this.endsWith(suffix); - } - if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { - return this.toLowerCase().endsWith(suffix.toLowerCase()); - } - return matchAt(suffix, numBytes - suffix.numBytes, collationId); - } - /** * Returns the upper case of this string */ diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java new file mode 100644 index 000000000000..bfb696c35fff --- /dev/null +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.unsafe.types; + +import org.apache.spark.SparkException; +import org.apache.spark.sql.catalyst.util.CollationFactory; +import org.apache.spark.sql.catalyst.util.CollationSupport; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + + +public class CollationSupportSuite { + + /** + * Collation-aware string expressions. + */ + + private void assertContains(String pattern, String target, String collationName, boolean value) + throws SparkException { + UTF8String l = UTF8String.fromString(pattern); + UTF8String r = UTF8String.fromString(target); + int collationId = CollationFactory.collationNameToId(collationName); + assertEquals(CollationSupport.Contains.exec(l, r, collationId), value); + } + + @Test + public void testContains() throws SparkException { + // Edge cases + assertContains("", "", "UTF8_BINARY", true); + assertContains("c", "", "UTF8_BINARY", true); + assertContains("", "c", "UTF8_BINARY", false); + assertContains("", "", "UNICODE", true); + assertContains("c", "", "UNICODE", true); + assertContains("", "c", "UNICODE", false); + assertContains("", "", "UTF8_BINARY_LCASE", true); + assertContains("c", "", "UTF8_BINARY_LCASE", true); + assertContains("", "c", "UTF8_BINARY_LCASE", false); + assertContains("", "", "UNICODE_CI", true); + assertContains("c", "", "UNICODE_CI", true); + assertContains("", "c", "UNICODE_CI", false); + // Basic tests + assertContains("abcde", "bcd", "UTF8_BINARY", true); + assertContains("abcde", "bde", "UTF8_BINARY", false); + assertContains("abcde", "fgh", "UTF8_BINARY", false); + assertContains("abcde", "abcde", "UNICODE", true); + assertContains("abcde", "aBcDe", "UNICODE", false); + assertContains("abcde", "fghij", "UNICODE", false); + assertContains("abcde", "C", "UTF8_BINARY_LCASE", true); + assertContains("abcde", "AbCdE", "UTF8_BINARY_LCASE", true); + assertContains("abcde", "X", "UTF8_BINARY_LCASE", false); + assertContains("abcde", "c", "UNICODE_CI", true); + assertContains("abcde", "bCD", "UNICODE_CI", true); + assertContains("abcde", "123", "UNICODE_CI", false); + // Case variation + assertContains("aBcDe", "bcd", "UTF8_BINARY", false); + assertContains("aBcDe", "BcD", "UTF8_BINARY", true); + assertContains("aBcDe", "abcde", "UNICODE", false); + assertContains("aBcDe", "aBcDe", "UNICODE", true); + assertContains("aBcDe", "bcd", "UTF8_BINARY_LCASE", true); + assertContains("aBcDe", "BCD", "UTF8_BINARY_LCASE", true); + assertContains("aBcDe", "abcde", "UNICODE_CI", true); + assertContains("aBcDe", "AbCdE", "UNICODE_CI", true); + // Accent variation + assertContains("aBcDe", "bćd", "UTF8_BINARY", false); + assertContains("aBcDe", "BćD", "UTF8_BINARY", false); + assertContains("aBcDe", "abćde", "UNICODE", false); + assertContains("aBcDe", "aBćDe", "UNICODE", false); + assertContains("aBcDe", "bćd", "UTF8_BINARY_LCASE", false); + assertContains("aBcDe", "BĆD", "UTF8_BINARY_LCASE", false); + assertContains("aBcDe", "abćde", "UNICODE_CI", false); + assertContains("aBcDe", "AbĆdE", "UNICODE_CI", false); + // Variable byte length characters + assertContains("ab世De", "b世D", "UTF8_BINARY", true); + assertContains("ab世De", "B世d", "UTF8_BINARY", false); + assertContains("äbćδe", "bćδ", "UTF8_BINARY", true); + assertContains("äbćδe", "BcΔ", "UTF8_BINARY", false); + assertContains("ab世De", "ab世De", "UNICODE", true); + assertContains("ab世De", "AB世dE", "UNICODE", false); + assertContains("äbćδe", "äbćδe", "UNICODE", true); + assertContains("äbćδe", "ÄBcΔÉ", "UNICODE", false); + assertContains("ab世De", "b世D", "UTF8_BINARY_LCASE", true); + assertContains("ab世De", "B世d", "UTF8_BINARY_LCASE", true); + assertContains("äbćδe", "bćδ", "UTF8_BINARY_LCASE", true); + assertContains("äbćδe", "BcΔ", "UTF8_BINARY_LCASE", false); + assertContains("ab世De", "ab世De", "UNICODE_CI", true); + assertContains("ab世De", "AB世dE", "UNICODE_CI", true); + assertContains("äbćδe", "ÄbćδE", "UNICODE_CI", true); + assertContains("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false); + } + + private void assertStartsWith(String pattern, String prefix, String collationName, boolean value) + throws SparkException { + UTF8String l = UTF8String.fromString(pattern); + UTF8String r = UTF8String.fromString(prefix); + int collationId = CollationFactory.collationNameToId(collationName); + assertEquals(CollationSupport.StartsWith.exec(l, r, collationId), value); + } + + @Test + public void testStartsWith() throws SparkException { + // Edge cases + assertStartsWith("", "", "UTF8_BINARY", true); + assertStartsWith("c", "", "UTF8_BINARY", true); + assertStartsWith("", "c", "UTF8_BINARY", false); + assertStartsWith("", "", "UNICODE", true); + assertStartsWith("c", "", "UNICODE", true); + assertStartsWith("", "c", "UNICODE", false); + assertStartsWith("", "", "UTF8_BINARY_LCASE", true); + assertStartsWith("c", "", "UTF8_BINARY_LCASE", true); + assertStartsWith("", "c", "UTF8_BINARY_LCASE", false); + assertStartsWith("", "", "UNICODE_CI", true); + assertStartsWith("c", "", "UNICODE_CI", true); + assertStartsWith("", "c", "UNICODE_CI", false); + // Basic tests + assertStartsWith("abcde", "abc", "UTF8_BINARY", true); + assertStartsWith("abcde", "abd", "UTF8_BINARY", false); + assertStartsWith("abcde", "fgh", "UTF8_BINARY", false); + assertStartsWith("abcde", "abcde", "UNICODE", true); + assertStartsWith("abcde", "aBcDe", "UNICODE", false); + assertStartsWith("abcde", "fghij", "UNICODE", false); + assertStartsWith("abcde", "A", "UTF8_BINARY_LCASE", true); + assertStartsWith("abcde", "AbCdE", "UTF8_BINARY_LCASE", true); + assertStartsWith("abcde", "X", "UTF8_BINARY_LCASE", false); + assertStartsWith("abcde", "a", "UNICODE_CI", true); + assertStartsWith("abcde", "aBC", "UNICODE_CI", true); + assertStartsWith("abcde", "123", "UNICODE_CI", false); + // Case variation + assertStartsWith("aBcDe", "abc", "UTF8_BINARY", false); + assertStartsWith("aBcDe", "aBc", "UTF8_BINARY", true); + assertStartsWith("aBcDe", "abcde", "UNICODE", false); + assertStartsWith("aBcDe", "aBcDe", "UNICODE", true); + assertStartsWith("aBcDe", "abc", "UTF8_BINARY_LCASE", true); + assertStartsWith("aBcDe", "ABC", "UTF8_BINARY_LCASE", true); + assertStartsWith("aBcDe", "abcde", "UNICODE_CI", true); + assertStartsWith("aBcDe", "AbCdE", "UNICODE_CI", true); + // Accent variation + assertStartsWith("aBcDe", "abć", "UTF8_BINARY", false); + assertStartsWith("aBcDe", "aBć", "UTF8_BINARY", false); + assertStartsWith("aBcDe", "abćde", "UNICODE", false); + assertStartsWith("aBcDe", "aBćDe", "UNICODE", false); + assertStartsWith("aBcDe", "abć", "UTF8_BINARY_LCASE", false); + assertStartsWith("aBcDe", "ABĆ", "UTF8_BINARY_LCASE", false); + assertStartsWith("aBcDe", "abćde", "UNICODE_CI", false); + assertStartsWith("aBcDe", "AbĆdE", "UNICODE_CI", false); + // Variable byte length characters + assertStartsWith("ab世De", "ab世", "UTF8_BINARY", true); + assertStartsWith("ab世De", "aB世", "UTF8_BINARY", false); + assertStartsWith("äbćδe", "äbć", "UTF8_BINARY", true); + assertStartsWith("äbćδe", "äBc", "UTF8_BINARY", false); + assertStartsWith("ab世De", "ab世De", "UNICODE", true); + assertStartsWith("ab世De", "AB世dE", "UNICODE", false); + assertStartsWith("äbćδe", "äbćδe", "UNICODE", true); + assertStartsWith("äbćδe", "ÄBcΔÉ", "UNICODE", false); + assertStartsWith("ab世De", "ab世", "UTF8_BINARY_LCASE", true); + assertStartsWith("ab世De", "aB世", "UTF8_BINARY_LCASE", true); + assertStartsWith("äbćδe", "äbć", "UTF8_BINARY_LCASE", true); + assertStartsWith("äbćδe", "äBc", "UTF8_BINARY_LCASE", false); + assertStartsWith("ab世De", "ab世De", "UNICODE_CI", true); + assertStartsWith("ab世De", "AB世dE", "UNICODE_CI", true); + assertStartsWith("äbćδe", "ÄbćδE", "UNICODE_CI", true); + assertStartsWith("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false); + } + + private void assertEndsWith(String pattern, String suffix, String collationName, boolean value) + throws SparkException { + UTF8String l = UTF8String.fromString(pattern); + UTF8String r = UTF8String.fromString(suffix); + int collationId = CollationFactory.collationNameToId(collationName); + assertEquals(CollationSupport.EndsWith.exec(l, r, collationId), value); + } + + @Test + public void testEndsWith() throws SparkException { + // Edge cases + assertEndsWith("", "", "UTF8_BINARY", true); + assertEndsWith("c", "", "UTF8_BINARY", true); + assertEndsWith("", "c", "UTF8_BINARY", false); + assertEndsWith("", "", "UNICODE", true); + assertEndsWith("c", "", "UNICODE", true); + assertEndsWith("", "c", "UNICODE", false); + assertEndsWith("", "", "UTF8_BINARY_LCASE", true); + assertEndsWith("c", "", "UTF8_BINARY_LCASE", true); + assertEndsWith("", "c", "UTF8_BINARY_LCASE", false); + assertEndsWith("", "", "UNICODE_CI", true); + assertEndsWith("c", "", "UNICODE_CI", true); + assertEndsWith("", "c", "UNICODE_CI", false); + // Basic tests + assertEndsWith("abcde", "cde", "UTF8_BINARY", true); + assertEndsWith("abcde", "bde", "UTF8_BINARY", false); + assertEndsWith("abcde", "fgh", "UTF8_BINARY", false); + assertEndsWith("abcde", "abcde", "UNICODE", true); + assertEndsWith("abcde", "aBcDe", "UNICODE", false); + assertEndsWith("abcde", "fghij", "UNICODE", false); + assertEndsWith("abcde", "E", "UTF8_BINARY_LCASE", true); + assertEndsWith("abcde", "AbCdE", "UTF8_BINARY_LCASE", true); + assertEndsWith("abcde", "X", "UTF8_BINARY_LCASE", false); + assertEndsWith("abcde", "e", "UNICODE_CI", true); + assertEndsWith("abcde", "CDe", "UNICODE_CI", true); + assertEndsWith("abcde", "123", "UNICODE_CI", false); + // Case variation + assertEndsWith("aBcDe", "cde", "UTF8_BINARY", false); + assertEndsWith("aBcDe", "cDe", "UTF8_BINARY", true); + assertEndsWith("aBcDe", "abcde", "UNICODE", false); + assertEndsWith("aBcDe", "aBcDe", "UNICODE", true); + assertEndsWith("aBcDe", "cde", "UTF8_BINARY_LCASE", true); + assertEndsWith("aBcDe", "CDE", "UTF8_BINARY_LCASE", true); + assertEndsWith("aBcDe", "abcde", "UNICODE_CI", true); + assertEndsWith("aBcDe", "AbCdE", "UNICODE_CI", true); + // Accent variation + assertEndsWith("aBcDe", "ćde", "UTF8_BINARY", false); + assertEndsWith("aBcDe", "ćDe", "UTF8_BINARY", false); + assertEndsWith("aBcDe", "abćde", "UNICODE", false); + assertEndsWith("aBcDe", "aBćDe", "UNICODE", false); + assertEndsWith("aBcDe", "ćde", "UTF8_BINARY_LCASE", false); + assertEndsWith("aBcDe", "ĆDE", "UTF8_BINARY_LCASE", false); + assertEndsWith("aBcDe", "abćde", "UNICODE_CI", false); + assertEndsWith("aBcDe", "AbĆdE", "UNICODE_CI", false); + // Variable byte length characters + assertEndsWith("ab世De", "世De", "UTF8_BINARY", true); + assertEndsWith("ab世De", "世dE", "UTF8_BINARY", false); + assertEndsWith("äbćδe", "ćδe", "UTF8_BINARY", true); + assertEndsWith("äbćδe", "cΔé", "UTF8_BINARY", false); + assertEndsWith("ab世De", "ab世De", "UNICODE", true); + assertEndsWith("ab世De", "AB世dE", "UNICODE", false); + assertEndsWith("äbćδe", "äbćδe", "UNICODE", true); + assertEndsWith("äbćδe", "ÄBcΔÉ", "UNICODE", false); + assertEndsWith("ab世De", "世De", "UTF8_BINARY_LCASE", true); + assertEndsWith("ab世De", "世dE", "UTF8_BINARY_LCASE", true); + assertEndsWith("äbćδe", "ćδe", "UTF8_BINARY_LCASE", true); + assertEndsWith("äbćδe", "cδE", "UTF8_BINARY_LCASE", false); + assertEndsWith("ab世De", "ab世De", "UNICODE_CI", true); + assertEndsWith("ab世De", "AB世dE", "UNICODE_CI", true); + assertEndsWith("äbćδe", "ÄbćδE", "UNICODE_CI", true); + assertEndsWith("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false); + } + + // TODO: Test more collation-aware string expressions. + + /** + * Collation-aware regexp expressions. + */ + + // TODO: Test more collation-aware regexp expressions. + + /** + * Other collation-aware expressions. + */ + + // TODO: Test other collation-aware expressions. + +} diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringWithCollationSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringWithCollationSuite.java deleted file mode 100644 index b60da7b945a4..000000000000 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringWithCollationSuite.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.unsafe.types; - -import org.apache.spark.SparkException; -import org.apache.spark.sql.catalyst.util.CollationFactory; -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.*; - - -public class UTF8StringWithCollationSuite { - - private void assertStartsWith(String pattern, String prefix, String collationName, boolean value) - throws SparkException { - assertEquals(UTF8String.fromString(pattern).startsWith(UTF8String.fromString(prefix), - CollationFactory.collationNameToId(collationName)), value); - } - - private void assertEndsWith(String pattern, String suffix, String collationName, boolean value) - throws SparkException { - assertEquals(UTF8String.fromString(pattern).endsWith(UTF8String.fromString(suffix), - CollationFactory.collationNameToId(collationName)), value); - } - - @Test - public void startsWithTest() throws SparkException { - assertStartsWith("", "", "UTF8_BINARY", true); - assertStartsWith("c", "", "UTF8_BINARY", true); - assertStartsWith("", "c", "UTF8_BINARY", false); - assertStartsWith("abcde", "a", "UTF8_BINARY", true); - assertStartsWith("abcde", "A", "UTF8_BINARY", false); - assertStartsWith("abcde", "bcd", "UTF8_BINARY", false); - assertStartsWith("abcde", "BCD", "UTF8_BINARY", false); - assertStartsWith("", "", "UNICODE", true); - assertStartsWith("c", "", "UNICODE", true); - assertStartsWith("", "c", "UNICODE", false); - assertStartsWith("abcde", "a", "UNICODE", true); - assertStartsWith("abcde", "A", "UNICODE", false); - assertStartsWith("abcde", "bcd", "UNICODE", false); - assertStartsWith("abcde", "BCD", "UNICODE", false); - assertStartsWith("", "", "UTF8_BINARY_LCASE", true); - assertStartsWith("c", "", "UTF8_BINARY_LCASE", true); - assertStartsWith("", "c", "UTF8_BINARY_LCASE", false); - assertStartsWith("abcde", "a", "UTF8_BINARY_LCASE", true); - assertStartsWith("abcde", "A", "UTF8_BINARY_LCASE", true); - assertStartsWith("abcde", "abc", "UTF8_BINARY_LCASE", true); - assertStartsWith("abcde", "BCD", "UTF8_BINARY_LCASE", false); - assertStartsWith("", "", "UNICODE_CI", true); - assertStartsWith("c", "", "UNICODE_CI", true); - assertStartsWith("", "c", "UNICODE_CI", false); - assertStartsWith("abcde", "a", "UNICODE_CI", true); - assertStartsWith("abcde", "A", "UNICODE_CI", true); - assertStartsWith("abcde", "abc", "UNICODE_CI", true); - assertStartsWith("abcde", "BCD", "UNICODE_CI", false); - } - - @Test - public void endsWithTest() throws SparkException { - assertEndsWith("", "", "UTF8_BINARY", true); - assertEndsWith("c", "", "UTF8_BINARY", true); - assertEndsWith("", "c", "UTF8_BINARY", false); - assertEndsWith("abcde", "e", "UTF8_BINARY", true); - assertEndsWith("abcde", "E", "UTF8_BINARY", false); - assertEndsWith("abcde", "bcd", "UTF8_BINARY", false); - assertEndsWith("abcde", "BCD", "UTF8_BINARY", false); - assertEndsWith("", "", "UNICODE", true); - assertEndsWith("c", "", "UNICODE", true); - assertEndsWith("", "c", "UNICODE", false); - assertEndsWith("abcde", "e", "UNICODE", true); - assertEndsWith("abcde", "E", "UNICODE", false); - assertEndsWith("abcde", "bcd", "UNICODE", false); - assertEndsWith("abcde", "BCD", "UNICODE", false); - assertEndsWith("", "", "UTF8_BINARY_LCASE", true); - assertEndsWith("c", "", "UTF8_BINARY_LCASE", true); - assertEndsWith("", "c", "UTF8_BINARY_LCASE", false); - assertEndsWith("abcde", "e", "UTF8_BINARY_LCASE", true); - assertEndsWith("abcde", "E", "UTF8_BINARY_LCASE", true); - assertEndsWith("abcde", "cde", "UTF8_BINARY_LCASE", true); - assertEndsWith("abcde", "BCD", "UTF8_BINARY_LCASE", false); - assertEndsWith("", "", "UNICODE_CI", true); - assertEndsWith("c", "", "UNICODE_CI", true); - assertEndsWith("", "c", "UNICODE_CI", false); - assertEndsWith("abcde", "e", "UNICODE_CI", true); - assertEndsWith("abcde", "E", "UNICODE_CI", true); - assertEndsWith("abcde", "cde", "UNICODE_CI", true); - assertEndsWith("abcde", "BCD", "UNICODE_CI", false); - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 01f22720dd12..5aa766a60c10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.encoders.HashableWeakReference import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, MapData, SQLOrderingUtil, UnsafeRowUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, MapData, SQLOrderingUtil, UnsafeRowUtils} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -1531,6 +1531,7 @@ object CodeGenerator extends Logging { classOf[TaskKilledException].getName, classOf[InputMetrics].getName, classOf[CollationFactory].getName, + classOf[CollationSupport].getName, QueryExecutionErrors.getClass.getName.stripSuffix("$") ) evaluator.setExtendedClass(classOf[GeneratedClass]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index cf6c9d4f1d94..9c862581bfe4 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} -import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeAnyCollation @@ -591,18 +591,11 @@ object ContainsExpressionBuilder extends StringBinaryPredicateExpressionBuilderB case class Contains(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - l.contains(r) - } else { - l.contains(r, collationId) - } + CollationSupport.Contains.exec(l, r, collationId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - defineCodeGen(ctx, ev, (c1, c2) => s"$c1.contains($c2)") - } else { - defineCodeGen(ctx, ev, (c1, c2) => s"$c1.contains($c2, $collationId)") - } + defineCodeGen(ctx, ev, (c1, c2) => + CollationSupport.Contains.genCode(c1, c2, collationId)) } override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight) @@ -638,19 +631,12 @@ object StartsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilde case class StartsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - l.startsWith(r) - } else { - l.startsWith(r, collationId) - } + CollationSupport.StartsWith.exec(l, r, collationId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - defineCodeGen(ctx, ev, (c1, c2) => s"$c1.startsWith($c2)") - } else { - defineCodeGen(ctx, ev, (c1, c2) => s"$c1.startsWith($c2, $collationId)") - } + defineCodeGen(ctx, ev, (c1, c2) => + CollationSupport.StartsWith.genCode(c1, c2, collationId)) } override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight) @@ -686,19 +672,12 @@ object EndsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilderB case class EndsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - l.endsWith(r) - } else { - l.endsWith(r, collationId) - } + CollationSupport.EndsWith.exec(l, r, collationId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - defineCodeGen(ctx, ev, (c1, c2) => s"$c1.endsWith($c2)") - } else { - defineCodeGen(ctx, ev, (c1, c2) => s"$c1.endsWith($c2, $collationId)") - } + defineCodeGen(ctx, ev, (c1, c2) => + CollationSupport.EndsWith.genCode(c1, c2, collationId)) } override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index c547068a03c3..0876425847bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -20,420 +20,310 @@ package org.apache.spark.sql import scala.collection.immutable.Seq import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{ArrayType, BooleanType, IntegerType, StringType} class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper { - case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) - case class CollationTestFail[R](s1: String, s2: String, collation: String) - - test("Support Like string expression with Collation") { - def prepareLike( - input: String, - regExp: String, - collation: String): Expression = { - val collationId = CollationFactory.collationNameToId(collation) - val inputExpr = Literal.create(input, StringType(collationId)) - val regExpExpr = Literal.create(regExp, StringType(collationId)) - Like(inputExpr, regExpExpr, '\\') - } + test("Support Like string expression with collation") { // Supported collations - val checks = Seq( - CollationTestCase("ABC", "%B%", "UTF8_BINARY", true) - ) - checks.foreach(ct => - checkEvaluation(prepareLike(ct.s1, ct.s2, ct.collation), ct.expectedResult)) + case class LikeTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + LikeTestCase("ABC", "%B%", "UTF8_BINARY", true) + ) + testCases.foreach(t => { + val query = s"SELECT like(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) + // TODO: Implicit casting (not currently supported) + }) // Unsupported collations - val fails = Seq( - CollationTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"), - CollationTestFail("ABC", "%B%", "UNICODE"), - CollationTestFail("ABC", "%b%", "UNICODE_CI") - ) - fails.foreach(ct => - assert(prepareLike(ct.s1, ct.s2, ct.collation) - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "first", - "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", - "inputType" -> s""""STRING COLLATE ${ct.collation}"""" - ) - ) - ) - ) + case class LikeTestFail(l: String, r: String, c: String) + val failCases = Seq( + LikeTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"), + LikeTestFail("ABC", "%B%", "UNICODE"), + LikeTestFail("ABC", "%b%", "UNICODE_CI") + ) + failCases.foreach(t => { + val query = s"SELECT like(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val unsupportedCollation = intercept[AnalysisException] { sql(query) } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + // TODO: Collation mismatch (not currently supported) } - test("Support ILike string expression with Collation") { - def prepareILike( - input: String, - regExp: String, - collation: String): Expression = { - val collationId = CollationFactory.collationNameToId(collation) - val inputExpr = Literal.create(input, StringType(collationId)) - val regExpExpr = Literal.create(regExp, StringType(collationId)) - ILike(inputExpr, regExpExpr, '\\').replacement - } - + test("Support ILike string expression with collation") { // Supported collations - val checks = Seq( - CollationTestCase("ABC", "%b%", "UTF8_BINARY", true) - ) - checks.foreach(ct => - checkEvaluation(prepareILike(ct.s1, ct.s2, ct.collation), ct.expectedResult) - ) + case class ILikeTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + ILikeTestCase("ABC", "%b%", "UTF8_BINARY", true) + ) + testCases.foreach(t => { + val query = s"SELECT ilike(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) + // TODO: Implicit casting (not currently supported) + }) // Unsupported collations - val fails = Seq( - CollationTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"), - CollationTestFail("ABC", "%b%", "UNICODE"), - CollationTestFail("ABC", "%b%", "UNICODE_CI") - ) - fails.foreach(ct => - assert(prepareILike(ct.s1, ct.s2, ct.collation) - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "first", - "requiredType" -> """"STRING"""", - "inputSql" -> s""""lower('${ct.s1}' collate ${ct.collation})"""", - "inputType" -> s""""STRING COLLATE ${ct.collation}"""" - ) - ) - ) - ) + case class ILikeTestFail(l: String, r: String, c: String) + val failCases = Seq( + ILikeTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"), + ILikeTestFail("ABC", "%b%", "UNICODE"), + ILikeTestFail("ABC", "%b%", "UNICODE_CI") + ) + failCases.foreach(t => { + val query = s"SELECT ilike(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val unsupportedCollation = intercept[AnalysisException] { sql(query) } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + // TODO: Collation mismatch (not currently supported) } - test("Support RLike string expression with Collation") { - def prepareRLike( - input: String, - regExp: String, - collation: String): Expression = { - val collationId = CollationFactory.collationNameToId(collation) - val inputExpr = Literal.create(input, StringType(collationId)) - val regExpExpr = Literal.create(regExp, StringType(collationId)) - RLike(inputExpr, regExpExpr) - } + test("Support RLike string expression with collation") { // Supported collations - val checks = Seq( - CollationTestCase("ABC", ".B.", "UTF8_BINARY", true) - ) - checks.foreach(ct => - checkEvaluation(prepareRLike(ct.s1, ct.s2, ct.collation), ct.expectedResult) - ) + case class RLikeTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + RLikeTestCase("ABC", ".B.", "UTF8_BINARY", true) + ) + testCases.foreach(t => { + val query = s"SELECT rlike(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) + // TODO: Implicit casting (not currently supported) + }) // Unsupported collations - val fails = Seq( - CollationTestFail("ABC", ".b.", "UTF8_BINARY_LCASE"), - CollationTestFail("ABC", ".B.", "UNICODE"), - CollationTestFail("ABC", ".b.", "UNICODE_CI") - ) - fails.foreach(ct => - assert(prepareRLike(ct.s1, ct.s2, ct.collation) - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "first", - "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", - "inputType" -> s""""STRING COLLATE ${ct.collation}"""" - ) - ) - ) - ) + case class RLikeTestFail(l: String, r: String, c: String) + val failCases = Seq( + RLikeTestFail("ABC", ".b.", "UTF8_BINARY_LCASE"), + RLikeTestFail("ABC", ".B.", "UNICODE"), + RLikeTestFail("ABC", ".b.", "UNICODE_CI") + ) + failCases.foreach(t => { + val query = s"SELECT rlike(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val unsupportedCollation = intercept[AnalysisException] { sql(query) } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + // TODO: Collation mismatch (not currently supported) } - test("Support StringSplit string expression with Collation") { - def prepareStringSplit( - input: String, - splitBy: String, - collation: String): Expression = { - val collationId = CollationFactory.collationNameToId(collation) - val inputExpr = Literal.create(input, StringType(collationId)) - val splitByExpr = Literal.create(splitBy, StringType(collationId)) - StringSplit(inputExpr, splitByExpr, Literal(-1)) - } - + test("Support StringSplit string expression with collation") { // Supported collations - val checks = Seq( - CollationTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")) - ) - checks.foreach(ct => - checkEvaluation(prepareStringSplit(ct.s1, ct.s2, ct.collation), ct.expectedResult) - ) + case class StringSplitTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")) + ) + testCases.foreach(t => { + val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c)))) + // TODO: Implicit casting (not currently supported) + }) // Unsupported collations - val fails = Seq( - CollationTestFail("ABC", "[b]", "UTF8_BINARY_LCASE"), - CollationTestFail("ABC", "[B]", "UNICODE"), - CollationTestFail("ABC", "[b]", "UNICODE_CI") - ) - fails.foreach(ct => - assert(prepareStringSplit(ct.s1, ct.s2, ct.collation) - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "first", - "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", - "inputType" -> s""""STRING COLLATE ${ct.collation}"""" - ) - ) - ) - ) + case class StringSplitTestFail(l: String, r: String, c: String) + val failCases = Seq( + StringSplitTestFail("ABC", "[b]", "UTF8_BINARY_LCASE"), + StringSplitTestFail("ABC", "[B]", "UNICODE"), + StringSplitTestFail("ABC", "[b]", "UNICODE_CI") + ) + failCases.foreach(t => { + val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val unsupportedCollation = intercept[AnalysisException] { sql(query) } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + // TODO: Collation mismatch (not currently supported) } - test("Support RegExpReplace string expression with Collation") { - def prepareRegExpReplace( - input: String, - regExp: String, - collation: String): RegExpReplace = { - val collationId = CollationFactory.collationNameToId(collation) - val inputExpr = Literal.create(input, StringType(collationId)) - val regExpExpr = Literal.create(regExp, StringType(collationId)) - RegExpReplace(inputExpr, regExpExpr, Literal.create("FFF", StringType(collationId))) - } - + test("Support RegExpReplace string expression with collation") { // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE") - ) - checks.foreach(ct => - checkEvaluation(prepareRegExpReplace(ct.s1, ct.s2, ct.collation), ct.expectedResult) - ) + case class RegExpReplaceTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + RegExpReplaceTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE") + ) + testCases.foreach(t => { + val query = + s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 'FFF')" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + // TODO: Implicit casting (not currently supported) + }) // Unsupported collations - val fails = Seq( - CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - CollationTestFail("ABCDE", ".C.", "UNICODE"), - CollationTestFail("ABCDE", ".c.", "UNICODE_CI") - ) - fails.foreach(ct => - assert(prepareRegExpReplace(ct.s1, ct.s2, ct.collation) - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "first", - "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", - "inputType" -> s""""STRING COLLATE ${ct.collation}"""" - ) - ) - ) - ) + case class RegExpReplaceTestFail(l: String, r: String, c: String) + val failCases = Seq( + RegExpReplaceTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + RegExpReplaceTestFail("ABCDE", ".C.", "UNICODE"), + RegExpReplaceTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + failCases.foreach(t => { + val query = + s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 'FFF')" + val unsupportedCollation = intercept[AnalysisException] { sql(query) } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + // TODO: Collation mismatch (not currently supported) } - test("Support RegExpExtract string expression with Collation") { - def prepareRegExpExtract( - input: String, - regExp: String, - collation: String): RegExpExtract = { - val collationId = CollationFactory.collationNameToId(collation) - val inputExpr = Literal.create(input, StringType(collationId)) - val regExpExpr = Literal.create(regExp, StringType(collationId)) - RegExpExtract(inputExpr, regExpExpr, Literal(0)) - } - + test("Support RegExpExtract string expression with collation") { // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") - ) - checks.foreach(ct => - checkEvaluation(prepareRegExpExtract(ct.s1, ct.s2, ct.collation), ct.expectedResult) - ) + case class RegExpExtractTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + RegExpExtractTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") + ) + testCases.foreach(t => { + val query = + s"SELECT regexp_extract(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 0)" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + // TODO: Implicit casting (not currently supported) + }) // Unsupported collations - val fails = Seq( - CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - CollationTestFail("ABCDE", ".C.", "UNICODE"), - CollationTestFail("ABCDE", ".c.", "UNICODE_CI") - ) - fails.foreach(ct => - assert(prepareRegExpExtract(ct.s1, ct.s2, ct.collation) - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "first", - "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", - "inputType" -> s""""STRING COLLATE ${ct.collation}"""" - ) - ) - ) - ) + case class RegExpExtractTestFail(l: String, r: String, c: String) + val failCases = Seq( + RegExpExtractTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + RegExpExtractTestFail("ABCDE", ".C.", "UNICODE"), + RegExpExtractTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + failCases.foreach(t => { + val query = + s"SELECT regexp_extract(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 0)" + val unsupportedCollation = intercept[AnalysisException] { sql(query) } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + // TODO: Collation mismatch (not currently supported) } - test("Support RegExpExtractAll string expression with Collation") { - def prepareRegExpExtractAll( - input: String, - regExp: String, - collation: String): RegExpExtractAll = { - val collationId = CollationFactory.collationNameToId(collation) - val inputExpr = Literal.create(input, StringType(collationId)) - val regExpExpr = Literal.create(regExp, StringType(collationId)) - RegExpExtractAll(inputExpr, regExpExpr, Literal(0)) - } - + test("Support RegExpExtractAll string expression with collation") { // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", Seq("BCD")) - ) - checks.foreach(ct => - checkEvaluation(prepareRegExpExtractAll(ct.s1, ct.s2, ct.collation), ct.expectedResult) - ) + case class RegExpExtractAllTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + RegExpExtractAllTestCase("ABCDE", ".C.", "UTF8_BINARY", Seq("BCD")) + ) + testCases.foreach(t => { + val query = + s"SELECT regexp_extract_all(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 0)" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c)))) + // TODO: Implicit casting (not currently supported) + }) // Unsupported collations - val fails = Seq( - CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - CollationTestFail("ABCDE", ".C.", "UNICODE"), - CollationTestFail("ABCDE", ".c.", "UNICODE_CI") - ) - fails.foreach(ct => - assert(prepareRegExpExtractAll(ct.s1, ct.s2, ct.collation) - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "first", - "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", - "inputType" -> s""""STRING COLLATE ${ct.collation}"""" - ) - ) - ) - ) + case class RegExpExtractAllTestFail(l: String, r: String, c: String) + val failCases = Seq( + RegExpExtractAllTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + RegExpExtractAllTestFail("ABCDE", ".C.", "UNICODE"), + RegExpExtractAllTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + failCases.foreach(t => { + val query = + s"SELECT regexp_extract_all(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 0)" + val unsupportedCollation = intercept[AnalysisException] { sql(query) } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + // TODO: Collation mismatch (not currently supported) } - test("Support RegExpCount string expression with Collation") { - def prepareRegExpCount( - input: String, - regExp: String, - collation: String): Expression = { - val collationId = CollationFactory.collationNameToId(collation) - val inputExpr = Literal.create(input, StringType(collationId)) - val regExpExpr = Literal.create(regExp, StringType(collationId)) - RegExpCount(inputExpr, regExpExpr).replacement - } - + test("Support RegExpCount string expression with collation") { // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) - ) - checks.foreach(ct => - checkEvaluation(prepareRegExpCount(ct.s1, ct.s2, ct.collation), ct.expectedResult) - ) + case class RegExpCountTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + RegExpCountTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) + ) + testCases.foreach(t => { + val query = s"SELECT regexp_count(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) + // TODO: Implicit casting (not currently supported) + }) // Unsupported collations - val fails = Seq( - CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - CollationTestFail("ABCDE", ".C.", "UNICODE"), - CollationTestFail("ABCDE", ".c.", "UNICODE_CI") - ) - fails.foreach(ct => - assert(prepareRegExpCount(ct.s1, ct.s2, ct.collation).asInstanceOf[Size].child - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "first", - "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", - "inputType" -> s""""STRING COLLATE ${ct.collation}"""" - ) - ) - ) - ) + case class RegExpCountTestFail(l: String, r: String, c: String) + val failCases = Seq( + RegExpCountTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + RegExpCountTestFail("ABCDE", ".C.", "UNICODE"), + RegExpCountTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + failCases.foreach(t => { + val query = s"SELECT regexp_count(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val unsupportedCollation = intercept[AnalysisException] { sql(query) } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + // TODO: Collation mismatch (not currently supported) } - test("Support RegExpSubStr string expression with Collation") { - def prepareRegExpSubStr( - input: String, - regExp: String, - collation: String): Expression = { - val collationId = CollationFactory.collationNameToId(collation) - val inputExpr = Literal.create(input, StringType(collationId)) - val regExpExpr = Literal.create(regExp, StringType(collationId)) - RegExpSubStr(inputExpr, regExpExpr).replacement.asInstanceOf[NullIf].left - } - + test("Support RegExpSubStr string expression with collation") { // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") - ) - checks.foreach(ct => - checkEvaluation(prepareRegExpSubStr(ct.s1, ct.s2, ct.collation), ct.expectedResult) - ) + case class RegExpSubStrTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + RegExpSubStrTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") + ) + testCases.foreach(t => { + val query = s"SELECT regexp_substr(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + // TODO: Implicit casting (not currently supported) + }) // Unsupported collations - val fails = Seq( - CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - CollationTestFail("ABCDE", ".C.", "UNICODE"), - CollationTestFail("ABCDE", ".c.", "UNICODE_CI") - ) - fails.foreach(ct => - assert(prepareRegExpSubStr(ct.s1, ct.s2, ct.collation) - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "first", - "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", - "inputType" -> s""""STRING COLLATE ${ct.collation}"""" - ) - ) - ) - ) + case class RegExpSubStrTestFail(l: String, r: String, c: String) + val failCases = Seq( + RegExpSubStrTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + RegExpSubStrTestFail("ABCDE", ".C.", "UNICODE"), + RegExpSubStrTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + failCases.foreach(t => { + val query = s"SELECT regexp_substr(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val unsupportedCollation = intercept[AnalysisException] { sql(query) } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + // TODO: Collation mismatch (not currently supported) } - test("Support RegExpInStr string expression with Collation") { - def prepareRegExpInStr( - input: String, - regExp: String, - collation: String): RegExpInStr = { - val collationId = CollationFactory.collationNameToId(collation) - val inputExpr = Literal.create(input, StringType(collationId)) - val regExpExpr = Literal.create(regExp, StringType(collationId)) - RegExpInStr(inputExpr, regExpExpr, Literal(0)) - } - + test("Support RegExpInStr string expression with collation") { // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 2) - ) - checks.foreach(ct => - checkEvaluation(prepareRegExpInStr(ct.s1, ct.s2, ct.collation), ct.expectedResult) - ) + case class RegExpInStrTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + RegExpInStrTestCase("ABCDE", ".C.", "UTF8_BINARY", 2) + ) + testCases.foreach(t => { + val query = s"SELECT regexp_instr(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) + // TODO: Implicit casting (not currently supported) + }) // Unsupported collations - val fails = Seq( - CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - CollationTestFail("ABCDE", ".C.", "UNICODE"), - CollationTestFail("ABCDE", ".c.", "UNICODE_CI") - ) - fails.foreach(ct => - assert(prepareRegExpInStr(ct.s1, ct.s2, ct.collation) - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "first", - "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", - "inputType" -> s""""STRING COLLATE ${ct.collation}"""" - ) - ) - ) - ) + case class RegExpInStrTestFail(l: String, r: String, c: String) + val failCases = Seq( + RegExpInStrTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + RegExpInStrTestFail("ABCDE", ".C.", "UNICODE"), + RegExpInStrTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + failCases.foreach(t => { + val query = s"SELECT regexp_instr(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + // TODO: Collation mismatch (not currently supported) } + } class CollationRegexpExpressionsANSISuite extends CollationRegexpExpressionsSuite { override protected def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_ENABLED, true) + + // TODO: If needed, add more tests for other regexp expressions (with ANSI mode enabled) + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index c26f3ae02255..97dea6697541 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -20,80 +20,157 @@ package org.apache.spark.sql import scala.collection.immutable.Seq import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs, ExpressionEvalHelper, Literal, StringRepeat} -import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{BooleanType, StringType} class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper { - case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) - case class CollationTestFail[R](s1: String, s2: String, collation: String) - - - test("Support ConcatWs string expression with Collation") { - def prepareConcatWs( - sep: String, - collation: String, - inputs: Any*): ConcatWs = { - val collationId = CollationFactory.collationNameToId(collation) - val inputExprs = inputs.map(s => Literal.create(s, StringType(collationId))) - val sepExpr = Literal.create(sep, StringType(collationId)) - ConcatWs(sepExpr +: inputExprs) - } - // Supported Collations - val checks = Seq( - CollationTestCase("Spark", "SQL", "UTF8_BINARY", "Spark SQL") + test("Support ConcatWs string expression with collation") { + // Supported collations + case class ConcatWsTestCase[R](s: String, a: Array[String], c: String, result: R) + val testCases = Seq( + ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_BINARY", "Spark SQL") ) - checks.foreach(ct => - checkEvaluation(prepareConcatWs(" ", ct.collation, ct.s1, ct.s2), ct.expectedResult) + testCases.foreach(t => { + val arrCollated = t.a.map(s => s"collate('$s', '${t.c}')").mkString(", ") + var query = s"SELECT concat_ws(collate('${t.s}', '${t.c}'), $arrCollated)" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + // Implicit casting + val arr = t.a.map(s => s"'$s'").mkString(", ") + query = s"SELECT concat_ws(collate('${t.s}', '${t.c}'), $arr)" + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + query = s"SELECT concat_ws('${t.s}', $arrCollated)" + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + }) + // Unsupported collations + case class ConcatWsTestFail(s: String, a: Array[String], c: String) + val failCases = Seq( + ConcatWsTestFail(" ", Array("ABC", "%b%"), "UTF8_BINARY_LCASE"), + ConcatWsTestFail(" ", Array("ABC", "%B%"), "UNICODE"), + ConcatWsTestFail(" ", Array("ABC", "%b%"), "UNICODE_CI") ) + failCases.foreach(t => { + val arrCollated = t.a.map(s => s"collate('$s', '${t.c}')").mkString(", ") + val query = s"SELECT concat_ws(collate('${t.s}', '${t.c}'), $arrCollated)" + val unsupportedCollation = intercept[AnalysisException] { sql(query) } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT concat_ws(' ',collate('Spark', 'UTF8_BINARY_LCASE'),collate('SQL', 'UNICODE'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } - // Unsupported Collations - val fails = Seq( - CollationTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"), - CollationTestFail("ABC", "%B%", "UNICODE"), - CollationTestFail("ABC", "%b%", "UNICODE_CI") - ) - fails.foreach(ct => - assert(prepareConcatWs(" ", ct.collation, ct.s1, ct.s2) - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "first", - "requiredType" -> """"STRING"""", - "inputSql" -> s""""' ' collate ${ct.collation}"""", - "inputType" -> s""""STRING COLLATE ${ct.collation}"""" - ) - ) - ) + test("Support Contains string expression with collation") { + // Supported collations + case class ContainsTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + ContainsTestCase("", "", "UTF8_BINARY", true), + ContainsTestCase("abcde", "C", "UNICODE", false), + ContainsTestCase("abcde", "FGH", "UTF8_BINARY_LCASE", false), + ContainsTestCase("abcde", "BCD", "UNICODE_CI", true) ) + testCases.foreach(t => { + val query = s"SELECT contains(collate('${t.l}','${t.c}'),collate('${t.r}','${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) + // Implicit casting + checkAnswer(sql(s"SELECT contains(collate('${t.l}','${t.c}'),'${t.r}')"), Row(t.result)) + checkAnswer(sql(s"SELECT contains('${t.l}',collate('${t.r}','${t.c}'))"), Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT contains(collate('abcde','UTF8_BINARY_LCASE'),collate('C','UNICODE_CI'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } - test("REPEAT check output type on explicitly collated string") { - def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = { - val s = Literal.create(input, StringType(collationId)) + test("Support StartsWith string expression with collation") { + // Supported collations + case class StartsWithTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + StartsWithTestCase("", "", "UTF8_BINARY", true), + StartsWithTestCase("abcde", "A", "UNICODE", false), + StartsWithTestCase("abcde", "FGH", "UTF8_BINARY_LCASE", false), + StartsWithTestCase("abcde", "ABC", "UNICODE_CI", true) + ) + testCases.foreach(t => { + val query = s"SELECT startswith(collate('${t.l}','${t.c}'),collate('${t.r}','${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) + // Implicit casting + checkAnswer(sql(s"SELECT startswith(collate('${t.l}', '${t.c}'),'${t.r}')"), Row(t.result)) + checkAnswer(sql(s"SELECT startswith('${t.l}', collate('${t.r}', '${t.c}'))"), Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT startswith(collate('abcde', 'UTF8_BINARY_LCASE'),collate('C', 'UNICODE_CI'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } - checkEvaluation(Collation(StringRepeat(s, Literal.create(n))).replacement, expected) + test("Support EndsWith string expression with collation") { + // Supported collations + case class EndsWithTestCase[R](l: String, r: String, c: String, result: R) + val testCases = Seq( + EndsWithTestCase("", "", "UTF8_BINARY", true), + EndsWithTestCase("abcde", "E", "UNICODE", false), + EndsWithTestCase("abcde", "FGH", "UTF8_BINARY_LCASE", false), + EndsWithTestCase("abcde", "CDE", "UNICODE_CI", true) + ) + testCases.foreach(t => { + val query = s"SELECT endswith(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) + // Implicit casting + checkAnswer(sql(s"SELECT endswith(collate('${t.l}', '${t.c}'),'${t.r}')"), Row(t.result)) + checkAnswer(sql(s"SELECT endswith('${t.l}', collate('${t.r}', '${t.c}'))"), Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT endswith(collate('abcde', 'UTF8_BINARY_LCASE'),collate('C', 'UNICODE_CI'))") } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } - testRepeat("UTF8_BINARY", 0, "abc", 2) - testRepeat("UTF8_BINARY_LCASE", 1, "abc", 2) - testRepeat("UNICODE", 2, "abc", 2) - testRepeat("UNICODE_CI", 3, "abc", 2) + test("Support StringRepeat string expression with collation") { + // Supported collations + case class StringRepeatTestCase[R](s: String, n: Int, c: String, result: R) + val testCases = Seq( + StringRepeatTestCase("", 1, "UTF8_BINARY", ""), + StringRepeatTestCase("a", 0, "UNICODE", ""), + StringRepeatTestCase("XY", 3, "UTF8_BINARY_LCASE", "XYXYXY"), + StringRepeatTestCase("123", 2, "UNICODE_CI", "123123") + ) + testCases.foreach(t => { + val query = s"SELECT repeat(collate('${t.s}', '${t.c}'), ${t.n})" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + }) } // TODO: Add more tests for other string expressions } -class CollationStringExpressionsANSISuite extends CollationRegexpExpressionsSuite { +class CollationStringExpressionsANSISuite extends CollationStringExpressionsSuite { override protected def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_ENABLED, true) + + // TODO: If needed, add more tests for other string expressions (with ANSI mode enabled) + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index c0322387c804..c4ddd25c99b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -271,90 +271,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) } - case class CollationTestCase[R](left: String, right: String, collation: String, expectedResult: R) - - test("Support contains string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("", "", "UTF8_BINARY", true), - CollationTestCase("c", "", "UTF8_BINARY", true), - CollationTestCase("", "c", "UTF8_BINARY", false), - CollationTestCase("abcde", "c", "UTF8_BINARY", true), - CollationTestCase("abcde", "C", "UTF8_BINARY", false), - CollationTestCase("abcde", "bcd", "UTF8_BINARY", true), - CollationTestCase("abcde", "BCD", "UTF8_BINARY", false), - CollationTestCase("abcde", "fgh", "UTF8_BINARY", false), - CollationTestCase("abcde", "FGH", "UTF8_BINARY", false), - CollationTestCase("", "", "UNICODE", true), - CollationTestCase("c", "", "UNICODE", true), - CollationTestCase("", "c", "UNICODE", false), - CollationTestCase("abcde", "c", "UNICODE", true), - CollationTestCase("abcde", "C", "UNICODE", false), - CollationTestCase("abcde", "bcd", "UNICODE", true), - CollationTestCase("abcde", "BCD", "UNICODE", false), - CollationTestCase("abcde", "fgh", "UNICODE", false), - CollationTestCase("abcde", "FGH", "UNICODE", false), - CollationTestCase("", "", "UTF8_BINARY_LCASE", true), - CollationTestCase("c", "", "UTF8_BINARY_LCASE", true), - CollationTestCase("", "c", "UTF8_BINARY_LCASE", false), - CollationTestCase("abcde", "c", "UTF8_BINARY_LCASE", true), - CollationTestCase("abcde", "C", "UTF8_BINARY_LCASE", true), - CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", true), - CollationTestCase("abcde", "BCD", "UTF8_BINARY_LCASE", true), - CollationTestCase("abcde", "fgh", "UTF8_BINARY_LCASE", false), - CollationTestCase("abcde", "FGH", "UTF8_BINARY_LCASE", false), - CollationTestCase("", "", "UNICODE_CI", true), - CollationTestCase("c", "", "UNICODE_CI", true), - CollationTestCase("", "c", "UNICODE_CI", false), - CollationTestCase("abcde", "c", "UNICODE_CI", true), - CollationTestCase("abcde", "C", "UNICODE_CI", true), - CollationTestCase("abcde", "bcd", "UNICODE_CI", true), - CollationTestCase("abcde", "BCD", "UNICODE_CI", true), - CollationTestCase("abcde", "fgh", "UNICODE_CI", false), - CollationTestCase("abcde", "FGH", "UNICODE_CI", false) - ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT contains(collate('${testCase.left}', '${testCase.collation}')," + - s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) - }) - } - - test("Support startsWith string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("abcde", "abc", "UTF8_BINARY", true), - CollationTestCase("abcde", "ABC", "UTF8_BINARY", false), - CollationTestCase("abcde", "abc", "UNICODE", true), - CollationTestCase("abcde", "ABC", "UNICODE", false), - CollationTestCase("abcde", "ABC", "UTF8_BINARY_LCASE", true), - CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", false), - CollationTestCase("abcde", "ABC", "UNICODE_CI", true), - CollationTestCase("abcde", "bcd", "UNICODE_CI", false) - ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT startswith(collate('${testCase.left}', '${testCase.collation}')," + - s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) - }) - } - - test("Support endsWith string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("abcde", "cde", "UTF8_BINARY", true), - CollationTestCase("abcde", "CDE", "UTF8_BINARY", false), - CollationTestCase("abcde", "cde", "UNICODE", true), - CollationTestCase("abcde", "CDE", "UNICODE", false), - CollationTestCase("abcde", "CDE", "UTF8_BINARY_LCASE", true), - CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", false), - CollationTestCase("abcde", "CDE", "UNICODE_CI", true), - CollationTestCase("abcde", "bcd", "UNICODE_CI", false) - ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT endswith(collate('${testCase.left}', '${testCase.collation}')," + - s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) - }) - } - test("aggregates count respects collation") { Seq( ("utf8_binary", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org