This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 885e98ecbe64 [SPARK-47412][SQL] Add Collation Support for LPad/RPad
885e98ecbe64 is described below

commit 885e98ecbe64ea01dbf542d46aeac706f2761a05
Author: GideonPotok <g.pot...@gmail.com>
AuthorDate: Tue Apr 23 14:22:39 2024 +0800

    [SPARK-47412][SQL] Add Collation Support for LPad/RPad
    
    Add collation support for LPAD and RPAD
    
    ### What changes were proposed in this pull request?
    
    Add collation support for LPAD and RPAD
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes
    
    ### How was this patch tested?
    
    Unit tests and spark-shell
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46041 from GideonPotok/spark_47412_collation_lpad_rpad.
    
    Authored-by: GideonPotok <g.pot...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/analysis/CollationTypeCasts.scala |  7 +-
 .../catalyst/expressions/stringExpressions.scala   |  6 +-
 .../sql/CollationStringExpressionsSuite.scala      | 74 ++++++++++++++++++++++
 3 files changed, 84 insertions(+), 3 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index cffdd2872224..3affd91dd3b8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -22,7 +22,7 @@ import javax.annotation.Nullable
 import scala.annotation.tailrec
 
 import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, 
haveSameType}
-import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, 
CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, 
Expression, Greatest, If, In, InSubquery, Least, Overlay}
+import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, 
CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, 
Expression, Greatest, If, In, InSubquery, Least, Overlay, StringLPad, 
StringRPad}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
@@ -52,6 +52,11 @@ object CollationTypeCasts extends TypeCoercionRule {
       overlay.withNewChildren(collateToSingleType(Seq(overlay.input, 
overlay.replace))
         ++ Seq(overlay.pos, overlay.len))
 
+    case stringPadExpr @ (_: StringRPad | _: StringLPad) =>
+      val Seq(str, len, pad) = stringPadExpr.children
+      val Seq(newStr, newPad) = collateToSingleType(Seq(str, pad))
+      stringPadExpr.withNewChildren(Seq(newStr, len, newPad))
+
     case otherExpr @ (
       _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: 
Greatest | _: Least |
       _: Coalesce | _: BinaryExpression | _: ConcatWs) =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 2b7703ed82b3..cd21a6f5fdc2 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1586,7 +1586,8 @@ case class StringLPad(str: Expression, len: Expression, 
pad: Expression)
   override def third: Expression = pad
 
   override def dataType: DataType = str.dataType
-  override def inputTypes: Seq[AbstractDataType] = Seq(StringType, 
IntegerType, StringType)
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation)
 
   override def nullSafeEval(string: Any, len: Any, pad: Any): Any = {
     string.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], 
pad.asInstanceOf[UTF8String])
@@ -1665,7 +1666,8 @@ case class StringRPad(str: Expression, len: Expression, 
pad: Expression = Litera
   override def third: Expression = pad
 
   override def dataType: DataType = str.dataType
-  override def inputTypes: Seq[AbstractDataType] = Seq(StringType, 
IntegerType, StringType)
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation)
 
   override def nullSafeEval(string: Any, len: Any, pad: Any): Any = {
     string.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], 
pad.asInstanceOf[UTF8String])
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 123d642ed4cd..9c207df95dad 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -534,6 +534,80 @@ class CollationStringExpressionsSuite
     }
   }
 
+  test("Support StringRPad string expressions with collation") {
+    // Supported collations
+    case class StringRPadTestCase[R](s: String, len: Int, pad: String, c: 
String, result: R)
+    val testCases = Seq(
+      StringRPadTestCase("", 5, " ", "UTF8_BINARY", "     "),
+      StringRPadTestCase("abc", 5, " ", "UNICODE", "abc  "),
+      StringRPadTestCase("Hello", 7, "Wörld", "UTF8_BINARY_LCASE", "HelloWö"),
+      StringRPadTestCase("1234567890", 5, "aaaAAa", "UNICODE_CI", "12345"),
+      StringRPadTestCase("aaAA", 2, " ", "UTF8_BINARY", "aa"),
+      StringRPadTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ℀℃", 2, "1", 
"UTF8_BINARY_LCASE", "ÀÃ"),
+      StringRPadTestCase("ĂȦÄäåäá", 20, "ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", "UNICODE", 
"ĂȦÄäåäáÀÃÂĀĂȦÄäåäáâã"),
+      StringRPadTestCase("aȦÄä", 8, "a1", "UNICODE_CI", "aȦÄäa1a1")
+    )
+    testCases.foreach(t => {
+      val query = s"SELECT rpad(collate('${t.s}', '${t.c}')," +
+        s" ${t.len}, collate('${t.pad}', '${t.c}'))"
+      // Result & data type
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+      // Implicit casting
+      checkAnswer(
+        sql(s"SELECT rpad(collate('${t.s}', '${t.c}'), ${t.len}, '${t.pad}')"),
+        Row(t.result))
+      checkAnswer(
+        sql(s"SELECT rpad('${t.s}', ${t.len}, collate('${t.pad}', '${t.c}'))"),
+        Row(t.result))
+    })
+    // Collation mismatch
+    val collationMismatch = intercept[AnalysisException] {
+      sql("SELECT rpad(collate('abcde', 'UNICODE_CI'),1,collate('C', 
'UTF8_BINARY_LCASE'))")
+    }
+    assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+  }
+
+  test("Support StringLPad string expressions with collation") {
+    // Supported collations
+    case class StringLPadTestCase[R](s: String, len: Int, pad: String, c: 
String, result: R)
+    val testCases = Seq(
+      StringLPadTestCase("", 5, " ", "UTF8_BINARY", "     "),
+      StringLPadTestCase("abc", 5, " ", "UNICODE", "  abc"),
+      StringLPadTestCase("Hello", 7, "Wörld", "UTF8_BINARY_LCASE", "WöHello"),
+      StringLPadTestCase("1234567890", 5, "aaaAAa", "UNICODE_CI", "12345"),
+      StringLPadTestCase("aaAA", 2, " ", "UTF8_BINARY", "aa"),
+      StringLPadTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ℀℃", 2, "1", 
"UTF8_BINARY_LCASE", "ÀÃ"),
+      StringLPadTestCase("ĂȦÄäåäá", 20, "ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", "UNICODE", 
"ÀÃÂĀĂȦÄäåäáâãĂȦÄäåäá"),
+      StringLPadTestCase("aȦÄä", 8, "a1", "UNICODE_CI", "a1a1aȦÄä")
+    )
+    testCases.foreach(t => {
+      val query = s"SELECT lpad(collate('${t.s}', '${t.c}')," +
+        s" ${t.len}, collate('${t.pad}', '${t.c}'))"
+      // Result & data type
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+      // Implicit casting
+      checkAnswer(
+        sql(s"SELECT lpad(collate('${t.s}', '${t.c}'), ${t.len}, '${t.pad}')"),
+        Row(t.result))
+      checkAnswer(
+        sql(s"SELECT lpad('${t.s}', ${t.len}, collate('${t.pad}', '${t.c}'))"),
+        Row(t.result))
+    })
+    // Collation mismatch
+    val collationMismatch = intercept[AnalysisException] {
+      sql("SELECT lpad(collate('abcde', 'UNICODE_CI'),1,collate('C', 
'UTF8_BINARY_LCASE'))")
+    }
+    assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+  }
+
+  test("Support StringLPad string expressions with explicit collation on 
second parameter") {
+    val query = "SELECT lpad('abc', collate('5', 'unicode_ci'), ' ')"
+    checkAnswer(sql(query), Row("  abc"))
+    assert(sql(query).schema.fields.head.dataType.sameType(StringType(0)))
+  }
+
   // TODO: Add more tests for other string expressions
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to