This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 12a507464f10 [SPARK-47566][SQL] Support SubstringIndex function to 
work with collated strings
12a507464f10 is described below

commit 12a507464f106d299511d16c2a436cbc0257bc8a
Author: Milan Dankovic <milan.danko...@databricks.com>
AuthorDate: Tue Apr 30 17:19:01 2024 +0800

    [SPARK-47566][SQL] Support SubstringIndex function to work with collated 
strings
    
    ### What changes were proposed in this pull request?
    Extend built-in string functions to support non-binary, non-lowercase 
collation for: substring_index.
    
    ### Why are the changes needed?
    Update collation support for built-in string functions in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users should now be able to use COLLATE within arguments for built-in 
string function SUBSTRING_INDEX in Spark SQL queries, using non-binary 
collations such as UNICODE_CI.
    
    ### How was this patch tested?
    Unit tests for queries using SubstringIndex 
(`CollationStringExpressionsSuite.scala`).
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    ### To consider:
    There is no check for collation match between string and delimiter, it will 
be introduced with Implicit Casting.
    
    We can remove the original `public UTF8String subStringIndex(UTF8String 
delim, int count)` method, and get the existing behavior using 
`subStringIndex(delim, count, 0)`.
    
    Closes #45725 from miland-db/miland-db/substringIndex-stringLocate.
    
    Authored-by: Milan Dankovic <milan.danko...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/util/CollationSupport.java  | 169 +++++++++++++++++++++
 .../org/apache/spark/unsafe/types/UTF8String.java  |  28 +++-
 .../spark/unsafe/types/CollationSupportSuite.java  |  83 ++++++++++
 .../sql/catalyst/analysis/CollationTypeCasts.scala |   5 +
 .../catalyst/expressions/stringExpressions.scala   |  15 +-
 .../sql/CollationStringExpressionsSuite.scala      |  31 ++++
 6 files changed, 323 insertions(+), 8 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index 0c81b99de916..9778ca31209e 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -28,6 +28,9 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.regex.Pattern;
 
+import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
+import static org.apache.spark.unsafe.Platform.copyMemory;
+
 /**
  * Static entry point for collation-aware expressions (StringExpressions, 
RegexpExpressions, and
  * other expressions that require custom collation support), as well as 
private utility methods for
@@ -441,6 +444,45 @@ public final class CollationSupport {
     }
   }
 
+  public static class SubstringIndex {
+    public static UTF8String exec(final UTF8String string, final UTF8String 
delimiter,
+        final int count, final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(string, delimiter, count);
+      } else if (collation.supportsLowercaseEquality) {
+        return execLowercase(string, delimiter, count);
+      } else {
+        return execICU(string, delimiter, count, collationId);
+      }
+    }
+    public static String genCode(final String string, final String delimiter,
+        final int count, final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.SubstringIndex.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s, %s, %d)", string, delimiter, 
count);
+      } else if (collation.supportsLowercaseEquality) {
+        return String.format(expr + "Lowercase(%s, %s, %d)", string, 
delimiter, count);
+      } else {
+        return String.format(expr + "ICU(%s, %s, %d, %d)", string, delimiter, 
count, collationId);
+      }
+    }
+    public static UTF8String execBinary(final UTF8String string, final 
UTF8String delimiter,
+        final int count) {
+      return string.subStringIndex(delimiter, count);
+    }
+    public static UTF8String execLowercase(final UTF8String string, final 
UTF8String delimiter,
+        final int count) {
+      return CollationAwareUTF8String.lowercaseSubStringIndex(string, 
delimiter, count);
+    }
+    public static UTF8String execICU(final UTF8String string, final UTF8String 
delimiter,
+        final int count, final int collationId) {
+      return CollationAwareUTF8String.subStringIndex(string, delimiter, count,
+              collationId);
+    }
+  }
+
   // TODO: Add more collation-aware string expressions.
 
   /**
@@ -639,6 +681,133 @@ public final class CollationSupport {
       return stringSearch.next();
     }
 
+    private static int find(UTF8String target, UTF8String pattern, int start,
+        int collationId) {
+      assert (pattern.numBytes() > 0);
+
+      StringSearch stringSearch = CollationFactory.getStringSearch(target, 
pattern, collationId);
+      // Set search start position (start from character at start position)
+      stringSearch.setIndex(target.bytePosToChar(start));
+
+      // Return either the byte position or -1 if not found
+      return target.charPosToByte(stringSearch.next());
+    }
+
+    private static UTF8String subStringIndex(final UTF8String string, final 
UTF8String delimiter,
+        int count, final int collationId) {
+      if (delimiter.numBytes() == 0 || count == 0 || string.numBytes() == 0) {
+        return UTF8String.EMPTY_UTF8;
+      }
+      if (count > 0) {
+        int idx = -1;
+        while (count > 0) {
+          idx = find(string, delimiter, idx + 1, collationId);
+          if (idx >= 0) {
+            count --;
+          } else {
+            // can not find enough delim
+            return string;
+          }
+        }
+        if (idx == 0) {
+          return UTF8String.EMPTY_UTF8;
+        }
+        byte[] bytes = new byte[idx];
+        copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, 
BYTE_ARRAY_OFFSET, idx);
+        return UTF8String.fromBytes(bytes);
+
+      } else {
+        count = -count;
+
+        StringSearch stringSearch = CollationFactory
+          .getStringSearch(string, delimiter, collationId);
+
+        int start = string.numChars() - 1;
+        int lastMatchLength = 0;
+        int prevStart = -1;
+        while (count > 0) {
+          stringSearch.reset();
+          prevStart = -1;
+          int matchStart = stringSearch.next();
+          lastMatchLength = stringSearch.getMatchLength();
+          while (matchStart <= start) {
+            if (matchStart != StringSearch.DONE) {
+              // Found a match, update the start position
+              prevStart = matchStart;
+              matchStart = stringSearch.next();
+            } else {
+              break;
+            }
+          }
+
+          if (prevStart == -1) {
+            // can not find enough delim
+            return string;
+          } else {
+            start = prevStart - 1;
+            count--;
+          }
+        }
+
+        int resultStart = prevStart + lastMatchLength;
+        if (resultStart == string.numChars()) {
+          return UTF8String.EMPTY_UTF8;
+        }
+
+        return string.substring(resultStart, string.numChars());
+      }
+    }
+
+    private static UTF8String lowercaseSubStringIndex(final UTF8String string,
+      final UTF8String delimiter, int count) {
+      if (delimiter.numBytes() == 0 || count == 0) {
+        return UTF8String.EMPTY_UTF8;
+      }
+
+      UTF8String lowercaseString = string.toLowerCase();
+      UTF8String lowercaseDelimiter = delimiter.toLowerCase();
+
+      if (count > 0) {
+        int idx = -1;
+        while (count > 0) {
+          idx = lowercaseString.find(lowercaseDelimiter, idx + 1);
+          if (idx >= 0) {
+            count --;
+          } else {
+            // can not find enough delim
+            return string;
+          }
+        }
+        if (idx == 0) {
+          return UTF8String.EMPTY_UTF8;
+        }
+        byte[] bytes = new byte[idx];
+        copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, 
BYTE_ARRAY_OFFSET, idx);
+        return UTF8String.fromBytes(bytes);
+
+      } else {
+        int idx = string.numBytes() - delimiter.numBytes() + 1;
+        count = -count;
+        while (count > 0) {
+          idx = lowercaseString.rfind(lowercaseDelimiter, idx - 1);
+          if (idx >= 0) {
+            count --;
+          } else {
+            // can not find enough delim
+            return string;
+          }
+        }
+        if (idx + delimiter.numBytes() == string.numBytes()) {
+          return UTF8String.EMPTY_UTF8;
+        }
+        int size = string.numBytes() - delimiter.numBytes() - idx;
+        byte[] bytes = new byte[size];
+        copyMemory(string.getBaseObject(), string.getBaseOffset() + idx + 
delimiter.numBytes(),
+                bytes, BYTE_ARRAY_OFFSET, size);
+        return UTF8String.fromBytes(bytes);
+      }
+    }
+
   }
 
 }
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index ca6198df2bbf..2a5d14580353 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -926,10 +926,34 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
     return -1;
   }
 
+  public int charPosToByte(int charPos) {
+    if (charPos < 0) {
+      return -1;
+    }
+
+    int i = 0;
+    int c = 0;
+    while (i < numBytes && c < charPos) {
+      i += numBytesForFirstByte(getByte(i));
+      c += 1;
+    }
+    return i;
+  }
+
+  public int bytePosToChar(int bytePos) {
+    int i = 0;
+    int c = 0;
+    while (i < numBytes && i < bytePos) {
+      i += numBytesForFirstByte(getByte(i));
+      c += 1;
+    }
+    return c;
+  }
+
   /**
    * Find the `str` from left to right.
    */
-  private int find(UTF8String str, int start) {
+  public int find(UTF8String str, int start) {
     assert (str.numBytes > 0);
     while (start <= numBytes - str.numBytes) {
       if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, 
str.offset, str.numBytes)) {
@@ -943,7 +967,7 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
   /**
    * Find the `str` from right to left.
    */
-  private int rfind(UTF8String str, int start) {
+  public int rfind(UTF8String str, int start) {
     assert (str.numBytes > 0);
     while (start >= 0) {
       if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, 
str.offset, str.numBytes)) {
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 030c7a7a1e3c..2f05b9ad88c9 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -717,6 +717,89 @@ public class CollationSupportSuite {
     assertLocate("İo", "世界i̇o大千世界大千i̇o", 4, "UNICODE_CI", 12); // 12 instead 
of 11
   }
 
+  private void assertSubstringIndex(String string, String delimiter, Integer 
count,
+        String collationName, String expected) throws SparkException {
+    UTF8String str = UTF8String.fromString(string);
+    UTF8String delim = UTF8String.fromString(delimiter);
+    int collationId = CollationFactory.collationNameToId(collationName);
+    assertEquals(expected,
+      CollationSupport.SubstringIndex.exec(str, delim, count, 
collationId).toString());
+  }
+
+  @Test
+  public void testSubstringIndex() throws SparkException {
+    assertSubstringIndex("wwwgapachegorg", "g", -3, "UTF8_BINARY", 
"apachegorg");
+    assertSubstringIndex("www||apache||org", "||", 2, "UTF8_BINARY", 
"www||apache");
+    assertSubstringIndex("aaaaaaaaaa", "aa", 2, "UTF8_BINARY", "a");
+    assertSubstringIndex("AaAaAaAaAa", "aa", 2, "UTF8_BINARY_LCASE", "A");
+    assertSubstringIndex("www.apache.org", ".", 3, "UTF8_BINARY_LCASE", 
"www.apache.org");
+    assertSubstringIndex("wwwXapacheXorg", "x", 2, "UTF8_BINARY_LCASE", 
"wwwXapache");
+    assertSubstringIndex("wwwxapachexorg", "X", 1, "UTF8_BINARY_LCASE", "www");
+    assertSubstringIndex("www.apache.org", ".", 0, "UTF8_BINARY_LCASE", "");
+    assertSubstringIndex("www.apache.ORG", ".", -3, "UTF8_BINARY_LCASE", 
"www.apache.ORG");
+    assertSubstringIndex("wwwGapacheGorg", "g", 1, "UTF8_BINARY_LCASE", "www");
+    assertSubstringIndex("wwwGapacheGorg", "g", 3, "UTF8_BINARY_LCASE", 
"wwwGapacheGor");
+    assertSubstringIndex("gwwwGapacheGorg", "g", 3, "UTF8_BINARY_LCASE", 
"gwwwGapache");
+    assertSubstringIndex("wwwGapacheGorg", "g", -3, "UTF8_BINARY_LCASE", 
"apacheGorg");
+    assertSubstringIndex("wwwmapacheMorg", "M", -2, "UTF8_BINARY_LCASE", 
"apacheMorg");
+    assertSubstringIndex("www.apache.org", ".", -1, "UTF8_BINARY_LCASE", 
"org");
+    assertSubstringIndex("www.apache.org.", ".", -1, "UTF8_BINARY_LCASE", "");
+    assertSubstringIndex("", ".", -2, "UTF8_BINARY_LCASE", "");
+    assertSubstringIndex("test大千世界X大千世界", "x", -1, "UTF8_BINARY_LCASE", 
"大千世界");
+    assertSubstringIndex("test大千世界X大千世界", "X", 1, "UTF8_BINARY_LCASE", 
"test大千世界");
+    assertSubstringIndex("test大千世界大千世界", "千", 2, "UTF8_BINARY_LCASE", 
"test大千世界大");
+    assertSubstringIndex("www||APACHE||org", "||", 2, "UTF8_BINARY_LCASE", 
"www||APACHE");
+    assertSubstringIndex("www||APACHE||org", "||", -1, "UTF8_BINARY_LCASE", 
"org");
+    assertSubstringIndex("AaAaAaAaAa", "Aa", 2, "UNICODE", "Aa");
+    assertSubstringIndex("wwwYapacheyorg", "y", 3, "UNICODE", 
"wwwYapacheyorg");
+    assertSubstringIndex("www.apache.org", ".", 2, "UNICODE", "www.apache");
+    assertSubstringIndex("wwwYapacheYorg", "Y", 1, "UNICODE", "www");
+    assertSubstringIndex("wwwYapacheYorg", "y", 1, "UNICODE", 
"wwwYapacheYorg");
+    assertSubstringIndex("wwwGapacheGorg", "g", 1, "UNICODE", "wwwGapacheGor");
+    assertSubstringIndex("GwwwGapacheGorG", "G", 3, "UNICODE", "GwwwGapache");
+    assertSubstringIndex("wwwGapacheGorG", "G", -3, "UNICODE", "apacheGorG");
+    assertSubstringIndex("www.apache.org", ".", 0, "UNICODE", "");
+    assertSubstringIndex("www.apache.org", ".", -3, "UNICODE", 
"www.apache.org");
+    assertSubstringIndex("www.apache.org", ".", -2, "UNICODE", "apache.org");
+    assertSubstringIndex("www.apache.org", ".", -1, "UNICODE", "org");
+    assertSubstringIndex("", ".", -2, "UNICODE", "");
+    assertSubstringIndex("test大千世界X大千世界", "X", -1, "UNICODE", "大千世界");
+    assertSubstringIndex("test大千世界X大千世界", "X", 1, "UNICODE", "test大千世界");
+    assertSubstringIndex("大x千世界大千世x界", "x", 1, "UNICODE", "大");
+    assertSubstringIndex("大x千世界大千世x界", "x", -1, "UNICODE", "界");
+    assertSubstringIndex("大x千世界大千世x界", "x", -2, "UNICODE", "千世界大千世x界");
+    assertSubstringIndex("大千世界大千世界", "千", 2, "UNICODE", "大千世界大");
+    assertSubstringIndex("www||apache||org", "||", 2, "UNICODE", 
"www||apache");
+    assertSubstringIndex("AaAaAaAaAa", "aa", 2, "UNICODE_CI", "A");
+    assertSubstringIndex("www.apache.org", ".", 3, "UNICODE_CI", 
"www.apache.org");
+    assertSubstringIndex("wwwXapacheXorg", "x", 2, "UNICODE_CI", "wwwXapache");
+    assertSubstringIndex("wwwxapacheXorg", "X", 1, "UNICODE_CI", "www");
+    assertSubstringIndex("www.apache.org", ".", 0, "UNICODE_CI", "");
+    assertSubstringIndex("wwwGapacheGorg", "G", 3, "UNICODE_CI", 
"wwwGapacheGor");
+    assertSubstringIndex("gwwwGapacheGorg", "g", 3, "UNICODE_CI", 
"gwwwGapache");
+    assertSubstringIndex("gwwwGapacheGorg", "g", -3, "UNICODE_CI", 
"apacheGorg");
+    assertSubstringIndex("www.apache.ORG", ".", -3, "UNICODE_CI", 
"www.apache.ORG");
+    assertSubstringIndex("wwwmapacheMorg", "M", -2, "UNICODE_CI", 
"apacheMorg");
+    assertSubstringIndex("www.apache.org", ".", -1, "UNICODE_CI", "org");
+    assertSubstringIndex("", ".", -2, "UNICODE_CI", "");
+    assertSubstringIndex("test大千世界X大千世界", "X", -1, "UNICODE_CI", "大千世界");
+    assertSubstringIndex("test大千世界X大千世界", "X", 1, "UNICODE_CI", "test大千世界");
+    assertSubstringIndex("test大千世界大千世界", "千", 2, "UNICODE_CI", "test大千世界大");
+    assertSubstringIndex("www||APACHE||org", "||", 2, "UNICODE_CI", 
"www||APACHE");
+    assertSubstringIndex("abİo12", "i̇o", 1, "UNICODE_CI", "ab");
+    assertSubstringIndex("abİo12", "i̇o", -1, "UNICODE_CI", "12");
+    assertSubstringIndex("abi̇o12", "İo", 1, "UNICODE_CI", "ab");
+    assertSubstringIndex("abi̇o12", "İo", -1, "UNICODE_CI", "12");
+    assertSubstringIndex("ai̇bi̇o12", "İo", 1, "UNICODE_CI", "ai̇b");
+    assertSubstringIndex("ai̇bi̇o12i̇o", "İo", 2, "UNICODE_CI", "ai̇bi̇o12");
+    assertSubstringIndex("ai̇bi̇o12i̇o", "İo", -1, "UNICODE_CI", "");
+    assertSubstringIndex("ai̇bi̇o12i̇o", "İo", -2, "UNICODE_CI", "12i̇o");
+    assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "İo", -4, "UNICODE_CI", 
"İo12İoi̇o");
+    assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", -4, "UNICODE_CI", 
"İo12İoi̇o");
+    assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", -4, "UNICODE_CI", 
"i̇o12i̇oİo");
+    assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", -4, "UNICODE_CI", 
"i̇o12i̇oİo");
+  }
+
   // TODO: Test more collation-aware string expressions.
 
   /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index f69218812d36..1130677d5f1b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -49,6 +49,11 @@ object CollationTypeCasts extends TypeCoercionRule {
       stringLocate.withNewChildren(collateToSingleType(
         Seq(stringLocate.first, stringLocate.second)) :+ stringLocate.third)
 
+    case substringIndex: SubstringIndex =>
+      substringIndex.withNewChildren(
+        collateToSingleType(
+          Seq(substringIndex.first, substringIndex.second)) :+ 
substringIndex.third)
+
     case eltExpr: Elt =>
       eltExpr.withNewChildren(eltExpr.children.head +: 
collateToSingleType(eltExpr.children.tail))
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 2d7f9652986a..b0352046b920 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1406,21 +1406,24 @@ case class StringInstr(str: Expression, substr: 
Expression)
 case class SubstringIndex(strExpr: Expression, delimExpr: Expression, 
countExpr: Expression)
  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
-  override def dataType: DataType = StringType
-  override def inputTypes: Seq[DataType] = Seq(StringType, StringType, 
IntegerType)
+  final lazy val collationId: Int = 
first.dataType.asInstanceOf[StringType].collationId
+
+  override def dataType: DataType = strExpr.dataType
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType)
   override def first: Expression = strExpr
   override def second: Expression = delimExpr
   override def third: Expression = countExpr
   override def prettyName: String = "substring_index"
 
   override def nullSafeEval(str: Any, delim: Any, count: Any): Any = {
-    str.asInstanceOf[UTF8String].subStringIndex(
-      delim.asInstanceOf[UTF8String],
-      count.asInstanceOf[Int])
+    CollationSupport.SubstringIndex.exec(str.asInstanceOf[UTF8String],
+      delim.asInstanceOf[UTF8String], count.asInstanceOf[Int], collationId);
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    defineCodeGen(ctx, ev, (str, delim, count) => 
s"$str.subStringIndex($delim, $count)")
+    defineCodeGen(ctx, ev, (str, delim, count) =>
+      CollationSupport.SubstringIndex.genCode(str, delim, 
Integer.parseInt(count, 10), collationId))
   }
 
   override protected def withNewChildrenInternal(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index d88c15fb2325..989e418b7477 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -131,6 +131,37 @@ class CollationStringExpressionsSuite
     assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
   }
 
+  test("Support SubstringIndex expression with collation") {
+    case class SubstringIndexTestCase[R](string: String, delimiter: String, 
count: Integer,
+      c: String, result: R)
+    val testCases = Seq(
+      SubstringIndexTestCase("wwwgapachegorg", "g", -3, "UTF8_BINARY", 
"apachegorg"),
+      SubstringIndexTestCase("www||apache||org", "||", 2, "UTF8_BINARY", 
"www||apache"),
+      SubstringIndexTestCase("wwwXapacheXorg", "x", 2, "UTF8_BINARY_LCASE", 
"wwwXapache"),
+      SubstringIndexTestCase("aaaaaaaaaa", "aa", 2, "UNICODE", "a"),
+      SubstringIndexTestCase("wwwmapacheMorg", "M", -2, "UNICODE_CI", 
"apacheMorg")
+    )
+    testCases.foreach(t => {
+      val query = s"SELECT substring_index(collate('${t.string}','${t.c}')," +
+        s"collate('${t.delimiter}','${t.c}'),${t.count})"
+      // Result & data type
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(
+        StringType(CollationFactory.collationNameToId(t.c))))
+      // Implicit casting
+      checkAnswer(sql(s"SELECT 
substring_index(collate('${t.string}','${t.c}')," +
+        s"'${t.delimiter}',${t.count})"), Row(t.result))
+      checkAnswer(sql(s"SELECT 
substring_index('${t.string}',collate('${t.delimiter}','${t.c}')," +
+        s"${t.count})"), Row(t.result))
+    })
+    // Collation mismatch
+    val collationMismatch = intercept[AnalysisException] {
+      sql("SELECT substring_index(collate('abcde','UTF8_BINARY_LCASE')," +
+        "collate('C','UNICODE_CI'),1)")
+    }
+    assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+  }
+
   test("Support StringInStr string expression with collation") {
     case class StringInStrTestCase[R](string: String, substring: String, c: 
String, result: R)
     val testCases = Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to