uros-db commented on code in PR #46206:
URL: https://github.com/apache/spark/pull/46206#discussion_r1579045485


##########
sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala:
##########
@@ -536,6 +538,134 @@ class CollationStringExpressionsSuite
     }
   }
 
+  test("StringTrim* functions - unit tests for both paths (codegen and eval)") 
{
+    // Without trimString param.
+    checkEvaluation(StringTrim(Literal.create( "  asd  ", StringType(0))), 
"asd")
+    checkEvaluation(StringTrimLeft(Literal.create("  asd  ", StringType(0))), 
"asd  ")
+    checkEvaluation(StringTrimRight(Literal.create("  asd  ", StringType(0))), 
"  asd")
+
+    // With trimString param.
+    checkEvaluation(
+      StringTrim(Literal.create("  asd  ", StringType(0)), Literal.create(" ", 
StringType(0))),
+      "asd")
+    checkEvaluation(
+      StringTrimLeft(Literal.create("  asd  ", StringType(0)), 
Literal.create(" ", StringType(0))),
+      "asd  ")
+    checkEvaluation(
+      StringTrimRight(Literal.create("  asd  ", StringType(0)), 
Literal.create(" ", StringType(0))),
+      "  asd")
+
+    checkEvaluation(
+      StringTrim(Literal.create("xxasdxx", StringType(0)), Literal.create("x", 
StringType(0))),
+      "asd")
+    checkEvaluation(
+      StringTrimLeft(Literal.create("xxasdxx", StringType(0)), 
Literal.create("x", StringType(0))),
+      "asdxx")
+    checkEvaluation(
+      StringTrimRight(Literal.create("xxasdxx", StringType(0)), 
Literal.create("x", StringType(0))),
+      "xxasd")
+  }
+
+  test("StringTrim* functions - E2E tests") {
+    case class StringTrimTestCase(
+      collation: String,
+      trimFunc: String,
+      sourceString: String,
+      hasTrimString: Boolean,
+      trimString: String,
+      expectedResultString: String)
+
+    val testCases = Seq(
+      // Without trimString param.
+      StringTrimTestCase("UTF8_BINARY", "TRIM", "asd", false, null, "asd"),
+      StringTrimTestCase("UTF8_BINARY", "TRIM", "  asd  ", false, null, "asd"),
+      StringTrimTestCase("UTF8_BINARY", "BTRIM", "asd", false, null, "asd"),
+      StringTrimTestCase("UTF8_BINARY", "BTRIM", "  asd  ", false, null, 
"asd"),
+      StringTrimTestCase("UTF8_BINARY", "LTRIM", "asd", false, null, "asd"),
+      StringTrimTestCase("UTF8_BINARY", "LTRIM", "  asd  ", false, null, "asd  
"),
+      StringTrimTestCase("UTF8_BINARY", "RTRIM", "asd", false, null, "asd"),
+      StringTrimTestCase("UTF8_BINARY", "RTRIM", "  asd  ", false, null, "  
asd"),
+
+      // With null trimString param.
+      StringTrimTestCase("UTF8_BINARY", "TRIM", "asd", true, null, null),
+      StringTrimTestCase("UTF8_BINARY", "TRIM", "  asd  ", true, null, null),
+      StringTrimTestCase("UTF8_BINARY", "BTRIM", "asd", true, null, null),
+      StringTrimTestCase("UTF8_BINARY", "BTRIM", "  asd  ", true, null, null),
+      StringTrimTestCase("UTF8_BINARY", "LTRIM", "asd", true, null, null),
+      StringTrimTestCase("UTF8_BINARY", "LTRIM", "  asd  ", true, null, null),
+      StringTrimTestCase("UTF8_BINARY", "RTRIM", "asd", true, null, null),
+      StringTrimTestCase("UTF8_BINARY", "RTRIM", "  asd  ", true, null, null),
+
+      // With " " trimString param.
+      StringTrimTestCase("UTF8_BINARY", "TRIM", "asd", true, " ", "asd"),
+      StringTrimTestCase("UTF8_BINARY", "TRIM", "  asd  ", true, " ", "asd"),
+      StringTrimTestCase("UTF8_BINARY", "BTRIM", "asd", true, " ", "asd"),
+      StringTrimTestCase("UTF8_BINARY", "BTRIM", "  asd  ", true, " ", "asd"),
+      StringTrimTestCase("UTF8_BINARY", "LTRIM", "asd", true, " ", "asd"),
+      StringTrimTestCase("UTF8_BINARY", "LTRIM", "  asd  ", true, " ", "asd  
"),
+      StringTrimTestCase("UTF8_BINARY", "RTRIM", "asd", true, " ", "asd"),
+      StringTrimTestCase("UTF8_BINARY", "RTRIM", "  asd  ", true, " ", "  
asd"),
+
+      // Try the same with any other collation.
+      // Without trimString param.
+      StringTrimTestCase("UNICODE_CI", "TRIM", "asd", false, null, "asd"),
+      StringTrimTestCase("UNICODE_CI", "TRIM", "  asd  ", false, null, "asd"),
+      StringTrimTestCase("UNICODE_CI", "BTRIM", "asd", false, null, "asd"),
+      StringTrimTestCase("UNICODE_CI", "BTRIM", "  asd  ", false, null, "asd"),
+      StringTrimTestCase("UNICODE_CI", "LTRIM", "asd", false, null, "asd"),
+      StringTrimTestCase("UNICODE_CI", "LTRIM", "  asd  ", false, null, "asd  
"),
+      StringTrimTestCase("UNICODE_CI", "RTRIM", "asd", false, null, "asd"),
+      StringTrimTestCase("UNICODE_CI", "RTRIM", "  asd  ", false, null, "  
asd"),
+
+      // With null trimString param.
+      StringTrimTestCase("UNICODE_CI", "TRIM", "asd", true, null, null),
+      StringTrimTestCase("UNICODE_CI", "TRIM", "  asd  ", true, null, null),
+      StringTrimTestCase("UNICODE_CI", "BTRIM", "asd", true, null, null),
+      StringTrimTestCase("UNICODE_CI", "BTRIM", "  asd  ", true, null, null),
+      StringTrimTestCase("UNICODE_CI", "LTRIM", "asd", true, null, null),
+      StringTrimTestCase("UNICODE_CI", "LTRIM", "  asd  ", true, null, null),
+      StringTrimTestCase("UNICODE_CI", "RTRIM", "asd", true, null, null),
+      StringTrimTestCase("UNICODE_CI", "RTRIM", "  asd  ", true, null, null),
+
+      // With " " trimString param.
+      StringTrimTestCase("UNICODE_CI", "TRIM", "asd", true, " ", "asd"),
+      StringTrimTestCase("UNICODE_CI", "TRIM", "  asd  ", true, " ", "asd"),
+      StringTrimTestCase("UNICODE_CI", "BTRIM", "asd", true, " ", "asd"),
+      StringTrimTestCase("UNICODE_CI", "BTRIM", "  asd  ", true, " ", "asd"),
+      StringTrimTestCase("UNICODE_CI", "LTRIM", "asd", true, " ", "asd"),
+      StringTrimTestCase("UNICODE_CI", "LTRIM", "  asd  ", true, " ", "asd  "),
+      StringTrimTestCase("UNICODE_CI", "RTRIM", "asd", true, " ", "asd"),
+      StringTrimTestCase("UNICODE_CI", "RTRIM", "  asd  ", true, " ", "  asd")
+
+      // Other more complex cases can be found in unit tests in 
CollationSupportSuite.java.
+    )
+
+    // scalastyle:off
+
+    testCases.foreach(testCase => {
+      var df: DataFrame = null
+
+      if (testCase.trimFunc.equalsIgnoreCase("BTRIM")) {
+        // BTRIM has arguments in (srcStr, trimStr) order
+        df = sql(s"SELECT ${testCase.trimFunc}(" +
+          s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" +
+          (if (!testCase.hasTrimString) "" else if (testCase.trimString == 
null) ", null" else s", COLLATE('${testCase.trimString}', 
'${testCase.collation}')") +
+          ")")
+      }
+      else {
+        // While other functions have arguments in (trimStr, srcStr) order
+        df = sql(s"SELECT ${testCase.trimFunc}(" +
+          (if (!testCase.hasTrimString) "" else if (testCase.trimString == 
null) "null, " else s"COLLATE('${testCase.trimString}', 
'${testCase.collation}'), ") +
+          s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" +
+          ")")
+      }
+

Review Comment:
   when you added `_: StringTrim | _: StringTrimLeft | _: StringTrimRight` to 
that last case in CollationTypeCasts, that should ensure that casting is 
enforced
   
   so now just add tests here to confirm that it works



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to