This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 068be4b10255 [SPARK-48578][SQL] add UTF8 string validation related 
functions
068be4b10255 is described below

commit 068be4b10255b7ebf03c89963f00814d2f5aaa10
Author: Uros Bojanic <[email protected]>
AuthorDate: Tue Jun 25 21:42:34 2024 +0800

    [SPARK-48578][SQL] add UTF8 string validation related functions
    
    ### What changes were proposed in this pull request?
    Introduced 4 new string expressions in Spark SQL: `IsValidUTF8`, 
`MakeValidUTF8`, `ValidateUTF8`, `TryValidateUTF8`.
    
    ### Why are the changes needed?
    These expressions offer a complete set of user-facing expressions that 
allow for UTF8String validation in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, 4 new string expressions are available.
    
    ### How was this patch tested?
    Unit tests in `UTF8StringSuite` and `CollationSupportSuite` and e2e sql 
tests in `string-functions.sql`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    Yes.
    
    Closes #46845 from uros-db/string-validation.
    
    Authored-by: Uros Bojanic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/unsafe/types/CollationSupportSuite.java  |   4 +-
 .../src/main/resources/error/error-conditions.json |   6 +
 .../catalyst/expressions/ExpressionImplUtils.java  |  26 +++
 .../sql/catalyst/analysis/FunctionRegistry.scala   |   4 +
 .../catalyst/expressions/stringExpressions.scala   | 187 ++++++++++++++++++++-
 .../spark/sql/errors/QueryExecutionErrors.scala    |   9 +
 .../expressions/ExpressionImplUtilsSuite.scala     |  63 ++++++-
 .../sql-functions/sql-expression-schema.md         |   4 +
 .../analyzer-results/ansi/string-functions.sql.out |  84 +++++++++
 .../analyzer-results/string-functions.sql.out      |  84 +++++++++
 .../sql-tests/inputs/string-functions.sql          |  13 ++
 .../results/ansi/string-functions.sql.out          | 103 ++++++++++++
 .../sql-tests/results/string-functions.sql.out     | 103 ++++++++++++
 .../sql/CollationStringExpressionsSuite.scala      |  91 +++++++++-
 14 files changed, 775 insertions(+), 6 deletions(-)

diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 99f35ef81dc6..436dff1db0e0 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -171,10 +171,10 @@ public class CollationSupportSuite {
     // Surrogate pairs are treated as invalid UTF8 sequences
     assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[]
       {(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 
0x80}),
-      UTF8String.fromString("\ufffd\ufffd"), false);
+      UTF8String.fromString("\uFFFD\uFFFD"), false);
     assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[]
       {(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 
0x80}),
-      UTF8String.fromString("\ufffd\ufffd"), true);
+      UTF8String.fromString("\uFFFD\uFFFD"), true);
   }
 
   /**
diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 975536c076dd..bf251f057af5 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -2909,6 +2909,12 @@
     ],
     "sqlState" : "42000"
   },
+  "INVALID_UTF8_STRING" : {
+    "message" : [
+      "Invalid UTF8 byte sequence found in string: <str>."
+    ],
+    "sqlState" : "22029"
+  },
   "INVALID_VARIABLE_TYPE_FOR_QUERY_EXECUTE_IMMEDIATE" : {
     "message" : [
       "Variable type must be string type but got <varType>."
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
index 8fe59cb7fae5..07a9409bc57a 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
@@ -112,6 +112,32 @@ public class ExpressionImplUtils {
     return checkSum % 10 == 0;
   }
 
+  /**
+   * Function to validate a given UTF8 string according to Unicode rules.
+   *
+   * @param utf8String
+   *  the input string to validate against possible invalid byte sequences
+   * @return
+   *  the original string if the input string is a valid UTF8String, throw 
exception otherwise.
+   */
+  public static UTF8String validateUTF8String(UTF8String utf8String) {
+    if (utf8String.isValid()) return utf8String;
+    else throw QueryExecutionErrors.invalidUTF8StringError(utf8String);
+  }
+
+  /**
+   * Function to try to validate a given UTF8 string according to Unicode 
rules.
+   *
+   * @param utf8String
+   *  the input string to validate against possible invalid byte sequences
+   * @return
+   *  the original string if the input string is a valid UTF8String, null 
otherwise.
+   */
+  public static UTF8String tryValidateUTF8String(UTF8String utf8String) {
+    if (utf8String.isValid()) return utf8String;
+    else return null;
+  }
+
   public static byte[] aesEncrypt(byte[] input,
                                   byte[] key,
                                   UTF8String mode,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 20da1c030b53..cd113cc1b34c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -601,6 +601,10 @@ object FunctionRegistry {
     expression[RegExpCount]("regexp_count"),
     expression[RegExpSubStr]("regexp_substr"),
     expression[RegExpInStr]("regexp_instr"),
+    expression[IsValidUTF8]("is_valid_utf8"),
+    expression[MakeValidUTF8]("make_valid_utf8"),
+    expression[ValidateUTF8]("validate_utf8"),
+    expression[TryValidateUTF8]("try_validate_utf8"),
 
     // url functions
     expression[UrlEncode]("url_encode"),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 055ef074d621..476b18fac310 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -33,8 +33,8 @@ import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
 import org.apache.spark.sql.catalyst.expressions.Cast._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
-import org.apache.spark.sql.catalyst.trees.BinaryLike
+import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
+import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
 import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, 
UPPER_OR_LOWER}
 import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, 
CollationSupport, GenericArrayData, TypeUtils}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
@@ -696,6 +696,189 @@ case class EndsWith(left: Expression, right: Expression) 
extends StringPredicate
     newLeft: Expression, newRight: Expression): EndsWith = copy(left = 
newLeft, right = newRight)
 }
 
+/**
+ * A function that checks if a UTF8 string is valid.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(str) - Returns true if `str` is a valid UTF-8 string, 
otherwise returns false.",
+  arguments = """
+    Arguments:
+      * str - a string expression
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('Spark');
+       true
+      > SELECT _FUNC_(x'61');
+       true
+      > SELECT _FUNC_(x'80');
+       false
+      > SELECT _FUNC_(x'61C262');
+       false
+  """,
+  since = "4.0.0",
+  group = "string_funcs")
+case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with 
ImplicitCastInputTypes
+  with UnaryLike[Expression] with NullIntolerant {
+
+  override lazy val replacement: Expression = Invoke(input, "isValid", 
BooleanType)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)
+
+  override def nodeName: String = "is_valid_utf8"
+
+  override def nullable: Boolean = true
+
+  override def child: Expression = input
+
+  override protected def withNewChildInternal(newChild: Expression): 
IsValidUTF8 = {
+    copy(input = newChild)
+  }
+
+}
+
+/**
+ * A function that converts an invalid UTF8 string to a valid UTF8 string by 
replacing invalid
+ * UTF-8 byte sequences with the Unicode replacement character (U+FFFD), 
according to the UNICODE
+ * standard rules (Section 3.9, Paragraph D86, Table 3-7). Valid strings 
remain unchanged.
+ */
+// scalastyle:off
+@ExpressionDescription(
+  usage = "_FUNC_(str) - Returns the original string if `str` is a valid UTF-8 
string, " +
+    "otherwise returns a new string whose invalid UTF8 byte sequences are 
replaced using the " +
+    "UNICODE replacement character U+FFFD.",
+  arguments = """
+    Arguments:
+      * str - a string expression
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('Spark');
+       Spark
+      > SELECT _FUNC_(x'61');
+       a
+      > SELECT _FUNC_(x'80');
+       �
+      > SELECT _FUNC_(x'61C262');
+       a�b
+  """,
+  since = "4.0.0",
+  group = "string_funcs")
+// scalastyle:on
+case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with 
ImplicitCastInputTypes
+  with UnaryLike[Expression] with NullIntolerant {
+
+  override lazy val replacement: Expression = Invoke(
+    input, "makeValid", SQLConf.get.defaultStringType)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)
+
+  override def nodeName: String = "make_valid_utf8"
+
+  override def nullable: Boolean = true
+
+  override def child: Expression = input
+
+  override protected def withNewChildInternal(newChild: Expression): 
MakeValidUTF8 = {
+    copy(input = newChild)
+  }
+
+}
+
+/**
+ * A function that validates a UTF8 string, throwing an exception if the 
string is invalid.
+ */
+// scalastyle:off
+@ExpressionDescription(
+  usage = "_FUNC_(str) - Returns the original string if `str` is a valid UTF-8 
string, " +
+    "otherwise throws an exception.",
+  arguments = """
+    Arguments:
+      * str - a string expression
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('Spark');
+       Spark
+      > SELECT _FUNC_(x'61');
+       a
+  """,
+  since = "4.0.0",
+  group = "string_funcs")
+// scalastyle:on
+case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with 
ImplicitCastInputTypes
+  with UnaryLike[Expression] with NullIntolerant {
+
+  override lazy val replacement: Expression = StaticInvoke(
+    classOf[ExpressionImplUtils],
+    input.dataType,
+    "validateUTF8String",
+    Seq(input),
+    inputTypes)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)
+
+  override def nodeName: String = "validate_utf8"
+
+  override def nullable: Boolean = true
+
+  override def child: Expression = input
+
+  override protected def withNewChildInternal(newChild: Expression): 
ValidateUTF8 = {
+    copy(input = newChild)
+  }
+
+}
+
+/**
+ * A function that tries to validate a UTF8 string, returning NULL if the 
string is invalid.
+ */
+// scalastyle:off
+@ExpressionDescription(
+  usage = "_FUNC_(str) - Returns the original string if `str` is a valid UTF-8 
string, " +
+    "otherwise returns NULL.",
+  arguments = """
+    Arguments:
+      * str - a string expression
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('Spark');
+       Spark
+      > SELECT _FUNC_(x'61');
+       a
+      > SELECT _FUNC_(x'80');
+       NULL
+      > SELECT _FUNC_(x'61C262');
+       NULL
+  """,
+  since = "4.0.0",
+  group = "string_funcs")
+// scalastyle:on
+case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with 
ImplicitCastInputTypes
+  with UnaryLike[Expression] with NullIntolerant {
+
+  override lazy val replacement: Expression = StaticInvoke(
+    classOf[ExpressionImplUtils],
+    input.dataType,
+    "tryValidateUTF8String",
+    Seq(input),
+    inputTypes)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)
+
+  override def nodeName: String = "try_validate_utf8"
+
+  override def nullable: Boolean = true
+
+  override def child: Expression = input
+
+  override protected def withNewChildInternal(newChild: Expression): 
TryValidateUTF8 = {
+    copy(input = newChild)
+  }
+
+}
+
 /**
  * Replace all occurrences with string.
  */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 8af931976b2e..6fb09bdeffc5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -211,6 +211,15 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase with ExecutionE
       summary = getSummary(context))
   }
 
+  def invalidUTF8StringError(str: UTF8String): SparkIllegalArgumentException = 
{
+    new SparkIllegalArgumentException(
+      errorClass = "INVALID_UTF8_STRING",
+      messageParameters = Map(
+        "str" -> str.getBytes.map(byte => f"\\x$byte%02X").mkString
+      )
+    )
+  }
+
   def invalidArrayIndexError(
       index: Int,
       numElements: Int,
diff --git 
a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
 
b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
index 4b33f9bc5278..f521cbcf2e0e 100644
--- 
a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
+++ 
b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.commons.lang3.{JavaVersion, SystemUtils}
 
-import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
+import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException, 
SparkRuntimeException}
 import org.apache.spark.unsafe.types.UTF8String
 
 class ExpressionImplUtilsSuite extends SparkFunSuite {
@@ -353,4 +353,65 @@ class ExpressionImplUtilsSuite extends SparkFunSuite {
       parameters = t.errorParamsMap
     )
   }
+
+  test("Validate UTF8 string") {
+    def validateUTF8(str: UTF8String, expected: UTF8String, except: Boolean): 
Unit = {
+      if (except) {
+        checkError(
+          exception = intercept[SparkIllegalArgumentException] {
+            ExpressionImplUtils.validateUTF8String(str)
+          },
+          errorClass = "INVALID_UTF8_STRING",
+          parameters = Map(
+            "str" -> str.getBytes.map(byte => f"\\x$byte%02X").mkString
+          )
+        )
+      } else {
+        assert(ExpressionImplUtils.validateUTF8String(str)== expected)
+      }
+    }
+    validateUTF8(UTF8String.EMPTY_UTF8,
+      UTF8String.fromString(""), except = false)
+    validateUTF8(UTF8String.fromString(""),
+      UTF8String.fromString(""), except = false)
+    validateUTF8(UTF8String.fromString("aa"),
+      UTF8String.fromString("aa"), except = false)
+    validateUTF8(UTF8String.fromString("\u0061"),
+      UTF8String.fromString("\u0061"), except = false)
+    validateUTF8(UTF8String.fromString(""),
+      UTF8String.fromString(""), except = false)
+    validateUTF8(UTF8String.fromString("abc"),
+      UTF8String.fromString("abc"), except = false)
+    validateUTF8(UTF8String.fromString("hello"),
+      UTF8String.fromString("hello"), except = false)
+    validateUTF8(UTF8String.fromBytes(Array.empty[Byte]),
+      UTF8String.fromString(""), except = false)
+    validateUTF8(UTF8String.fromBytes(Array[Byte](0x41)),
+      UTF8String.fromString("A"), except = false)
+    validateUTF8(UTF8String.fromBytes(Array[Byte](0x61)),
+      UTF8String.fromString("a"), except = false)
+    validateUTF8(UTF8String.fromBytes(Array[Byte](0x80.toByte)),
+      UTF8String.fromString("\uFFFD"), except = true)
+    validateUTF8(UTF8String.fromBytes(Array[Byte](0xFF.toByte)),
+      UTF8String.fromString("\uFFFD"), except = true)
+  }
+
+  test("TryValidate UTF8 string") {
+    def tryValidateUTF8(str: UTF8String, expected: UTF8String): Unit = {
+      assert(ExpressionImplUtils.tryValidateUTF8String(str) == expected)
+    }
+    tryValidateUTF8(UTF8String.fromString(""), UTF8String.fromString(""))
+    tryValidateUTF8(UTF8String.fromString("aa"), UTF8String.fromString("aa"))
+    tryValidateUTF8(UTF8String.fromString("\u0061"), 
UTF8String.fromString("\u0061"))
+    tryValidateUTF8(UTF8String.EMPTY_UTF8, UTF8String.fromString(""))
+    tryValidateUTF8(UTF8String.fromString(""), UTF8String.fromString(""))
+    tryValidateUTF8(UTF8String.fromString("abc"), UTF8String.fromString("abc"))
+    tryValidateUTF8(UTF8String.fromString("hello"), 
UTF8String.fromString("hello"))
+    tryValidateUTF8(UTF8String.fromBytes(Array.empty[Byte]), 
UTF8String.fromString(""))
+    tryValidateUTF8(UTF8String.fromBytes(Array[Byte](0x41)), 
UTF8String.fromString("A"))
+    tryValidateUTF8(UTF8String.fromBytes(Array[Byte](0x61)), 
UTF8String.fromString("a"))
+    tryValidateUTF8(UTF8String.fromBytes(Array[Byte](0x80.toByte)), null)
+    tryValidateUTF8(UTF8String.fromBytes(Array[Byte](0xFF.toByte)), null)
+  }
+
 }
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md 
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 27e56b24625b..cf218becdf1d 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -174,6 +174,7 @@
 | org.apache.spark.sql.catalyst.expressions.IsNaN | isnan | SELECT 
isnan(cast('NaN' as double)) | struct<isnan(CAST(NaN AS DOUBLE)):boolean> |
 | org.apache.spark.sql.catalyst.expressions.IsNotNull | isnotnull | SELECT 
isnotnull(1) | struct<(1 IS NOT NULL):boolean> |
 | org.apache.spark.sql.catalyst.expressions.IsNull | isnull | SELECT isnull(1) 
| struct<(1 IS NULL):boolean> |
+| org.apache.spark.sql.catalyst.expressions.IsValidUTF8 | is_valid_utf8 | 
SELECT is_valid_utf8('Spark') | struct<is_valid_utf8(Spark):boolean> |
 | org.apache.spark.sql.catalyst.expressions.JsonObjectKeys | json_object_keys 
| SELECT json_object_keys('{}') | struct<json_object_keys({}):array<string>> |
 | org.apache.spark.sql.catalyst.expressions.JsonToStructs | from_json | SELECT 
from_json('{"a":1, "b":0.8}', 'a INT, b DOUBLE') | struct<from_json({"a":1, 
"b":0.8}):struct<a:int,b:double>> |
 | org.apache.spark.sql.catalyst.expressions.JsonTuple | json_tuple | SELECT 
json_tuple('{"a":1, "b":2}', 'a', 'b') | struct<c0:string,c1:string> |
@@ -207,6 +208,7 @@
 | org.apache.spark.sql.catalyst.expressions.MakeTimestamp | make_timestamp | 
SELECT make_timestamp(2014, 12, 28, 6, 30, 45.887) | 
struct<make_timestamp(2014, 12, 28, 6, 30, 45.887):timestamp> |
 | org.apache.spark.sql.catalyst.expressions.MakeTimestampLTZExpressionBuilder 
| make_timestamp_ltz | SELECT make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887) | 
struct<make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887):timestamp> |
 | org.apache.spark.sql.catalyst.expressions.MakeTimestampNTZExpressionBuilder 
| make_timestamp_ntz | SELECT make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887) | 
struct<make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887):timestamp_ntz> |
+| org.apache.spark.sql.catalyst.expressions.MakeValidUTF8 | make_valid_utf8 | 
SELECT make_valid_utf8('Spark') | struct<make_valid_utf8(Spark):string> |
 | org.apache.spark.sql.catalyst.expressions.MakeYMInterval | make_ym_interval 
| SELECT make_ym_interval(1, 2) | struct<make_ym_interval(1, 2):interval year 
to month> |
 | org.apache.spark.sql.catalyst.expressions.MapConcat | map_concat | SELECT 
map_concat(map(1, 'a', 2, 'b'), map(3, 'c')) | struct<map_concat(map(1, a, 2, 
b), map(3, c)):map<int,string>> |
 | org.apache.spark.sql.catalyst.expressions.MapContainsKey | map_contains_key 
| SELECT map_contains_key(map(1, 'a', 2, 'b'), 1) | 
struct<map_contains_key(map(1, a, 2, b), 1):boolean> |
@@ -357,6 +359,7 @@
 | org.apache.spark.sql.catalyst.expressions.TryToBinary | try_to_binary | 
SELECT try_to_binary('abc', 'utf-8') | struct<try_to_binary(abc, utf-8):binary> 
|
 | org.apache.spark.sql.catalyst.expressions.TryToNumber | try_to_number | 
SELECT try_to_number('454', '999') | struct<try_to_number(454, 
999):decimal(3,0)> |
 | org.apache.spark.sql.catalyst.expressions.TryToTimestampExpressionBuilder | 
try_to_timestamp | SELECT try_to_timestamp('2016-12-31 00:12:00') | 
struct<try_to_timestamp(2016-12-31 00:12:00):timestamp> |
+| org.apache.spark.sql.catalyst.expressions.TryValidateUTF8 | 
try_validate_utf8 | SELECT try_validate_utf8('Spark') | 
struct<try_validate_utf8(Spark):string> |
 | org.apache.spark.sql.catalyst.expressions.TypeOf | typeof | SELECT typeof(1) 
| struct<typeof(1):string> |
 | org.apache.spark.sql.catalyst.expressions.UnBase64 | unbase64 | SELECT 
unbase64('U3BhcmsgU1FM') | struct<unbase64(U3BhcmsgU1FM):binary> |
 | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT 
negative(1) | struct<negative(1):int> |
@@ -372,6 +375,7 @@
 | org.apache.spark.sql.catalyst.expressions.UrlDecode | url_decode | SELECT 
url_decode('https%3A%2F%2Fspark.apache.org') | 
struct<url_decode(https%3A%2F%2Fspark.apache.org):string> |
 | org.apache.spark.sql.catalyst.expressions.UrlEncode | url_encode | SELECT 
url_encode('https://spark.apache.org') | 
struct<url_encode(https://spark.apache.org):string> |
 | org.apache.spark.sql.catalyst.expressions.Uuid | uuid | SELECT uuid() | 
struct<uuid():string> |
+| org.apache.spark.sql.catalyst.expressions.ValidateUTF8 | validate_utf8 | 
SELECT validate_utf8('Spark') | struct<validate_utf8(Spark):string> |
 | org.apache.spark.sql.catalyst.expressions.WeekDay | weekday | SELECT 
weekday('2009-07-30') | struct<weekday(2009-07-30):int> |
 | org.apache.spark.sql.catalyst.expressions.WeekOfYear | weekofyear | SELECT 
weekofyear('2008-02-20') | struct<weekofyear(2008-02-20):int> |
 | org.apache.spark.sql.catalyst.expressions.WidthBucket | width_bucket | 
SELECT width_bucket(5.3, 0.2, 10.6, 5) | struct<width_bucket(5.3, 0.2, 10.6, 
5):bigint> |
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/string-functions.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/string-functions.sql.out
index c7675b16384f..c4f002d84ea6 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/string-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/string-functions.sql.out
@@ -1677,3 +1677,87 @@ select luhn_check(123.456)
 -- !query analysis
 Project [luhn_check(cast(123.456 as string)) AS luhn_check(123.456)#x]
 +- OneRowRelation
+
+
+-- !query
+select is_valid_utf8('')
+-- !query analysis
+Project [is_valid_utf8() AS is_valid_utf8()#x]
++- OneRowRelation
+
+
+-- !query
+select is_valid_utf8('abc')
+-- !query analysis
+Project [is_valid_utf8(abc) AS is_valid_utf8(abc)#x]
++- OneRowRelation
+
+
+-- !query
+select is_valid_utf8(x'80')
+-- !query analysis
+Project [is_valid_utf8(cast(0x80 as string)) AS is_valid_utf8(X'80')#x]
++- OneRowRelation
+
+
+-- !query
+select make_valid_utf8('')
+-- !query analysis
+Project [make_valid_utf8() AS make_valid_utf8()#x]
++- OneRowRelation
+
+
+-- !query
+select make_valid_utf8('abc')
+-- !query analysis
+Project [make_valid_utf8(abc) AS make_valid_utf8(abc)#x]
++- OneRowRelation
+
+
+-- !query
+select make_valid_utf8(x'80')
+-- !query analysis
+Project [make_valid_utf8(cast(0x80 as string)) AS make_valid_utf8(X'80')#x]
++- OneRowRelation
+
+
+-- !query
+select validate_utf8('')
+-- !query analysis
+Project [validate_utf8() AS validate_utf8()#x]
++- OneRowRelation
+
+
+-- !query
+select validate_utf8('abc')
+-- !query analysis
+Project [validate_utf8(abc) AS validate_utf8(abc)#x]
++- OneRowRelation
+
+
+-- !query
+select validate_utf8(x'80')
+-- !query analysis
+Project [validate_utf8(cast(0x80 as string)) AS validate_utf8(X'80')#x]
++- OneRowRelation
+
+
+-- !query
+select try_validate_utf8('')
+-- !query analysis
+Project [try_validate_utf8() AS try_validate_utf8()#x]
++- OneRowRelation
+
+
+-- !query
+select try_validate_utf8('abc')
+-- !query analysis
+Project [try_validate_utf8(abc) AS try_validate_utf8(abc)#x]
++- OneRowRelation
+
+
+-- !query
+select try_validate_utf8(x'80')
+-- !query analysis
+Project [try_validate_utf8(cast(0x80 as string)) AS try_validate_utf8(X'80')#x]
++- OneRowRelation
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out
index c7675b16384f..c4f002d84ea6 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out
@@ -1677,3 +1677,87 @@ select luhn_check(123.456)
 -- !query analysis
 Project [luhn_check(cast(123.456 as string)) AS luhn_check(123.456)#x]
 +- OneRowRelation
+
+
+-- !query
+select is_valid_utf8('')
+-- !query analysis
+Project [is_valid_utf8() AS is_valid_utf8()#x]
++- OneRowRelation
+
+
+-- !query
+select is_valid_utf8('abc')
+-- !query analysis
+Project [is_valid_utf8(abc) AS is_valid_utf8(abc)#x]
++- OneRowRelation
+
+
+-- !query
+select is_valid_utf8(x'80')
+-- !query analysis
+Project [is_valid_utf8(cast(0x80 as string)) AS is_valid_utf8(X'80')#x]
++- OneRowRelation
+
+
+-- !query
+select make_valid_utf8('')
+-- !query analysis
+Project [make_valid_utf8() AS make_valid_utf8()#x]
++- OneRowRelation
+
+
+-- !query
+select make_valid_utf8('abc')
+-- !query analysis
+Project [make_valid_utf8(abc) AS make_valid_utf8(abc)#x]
++- OneRowRelation
+
+
+-- !query
+select make_valid_utf8(x'80')
+-- !query analysis
+Project [make_valid_utf8(cast(0x80 as string)) AS make_valid_utf8(X'80')#x]
++- OneRowRelation
+
+
+-- !query
+select validate_utf8('')
+-- !query analysis
+Project [validate_utf8() AS validate_utf8()#x]
++- OneRowRelation
+
+
+-- !query
+select validate_utf8('abc')
+-- !query analysis
+Project [validate_utf8(abc) AS validate_utf8(abc)#x]
++- OneRowRelation
+
+
+-- !query
+select validate_utf8(x'80')
+-- !query analysis
+Project [validate_utf8(cast(0x80 as string)) AS validate_utf8(X'80')#x]
++- OneRowRelation
+
+
+-- !query
+select try_validate_utf8('')
+-- !query analysis
+Project [try_validate_utf8() AS try_validate_utf8()#x]
++- OneRowRelation
+
+
+-- !query
+select try_validate_utf8('abc')
+-- !query analysis
+Project [try_validate_utf8(abc) AS try_validate_utf8(abc)#x]
++- OneRowRelation
+
+
+-- !query
+select try_validate_utf8(x'80')
+-- !query analysis
+Project [try_validate_utf8(cast(0x80 as string)) AS try_validate_utf8(X'80')#x]
++- OneRowRelation
diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
index 0d9c0f3a6a14..256b8e0d49fa 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
@@ -288,3 +288,16 @@ select luhn_check(6011111111111117);
 select luhn_check(6011111111111118);
 select luhn_check(123.456);
 
+--utf8 string validation
+select is_valid_utf8('');
+select is_valid_utf8('abc');
+select is_valid_utf8(x'80');
+select make_valid_utf8('');
+select make_valid_utf8('abc');
+select make_valid_utf8(x'80');
+select validate_utf8('');
+select validate_utf8('abc');
+select validate_utf8(x'80');
+select try_validate_utf8('');
+select try_validate_utf8('abc');
+select try_validate_utf8(x'80');
diff --git 
a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
index 9f72e215ea54..24d4cfa74b5a 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
@@ -2194,3 +2194,106 @@ select luhn_check(123.456)
 struct<luhn_check(123.456):boolean>
 -- !query output
 false
+
+
+-- !query
+select is_valid_utf8('')
+-- !query schema
+struct<is_valid_utf8():boolean>
+-- !query output
+true
+
+
+-- !query
+select is_valid_utf8('abc')
+-- !query schema
+struct<is_valid_utf8(abc):boolean>
+-- !query output
+true
+
+
+-- !query
+select is_valid_utf8(x'80')
+-- !query schema
+struct<is_valid_utf8(X'80'):boolean>
+-- !query output
+false
+
+
+-- !query
+select make_valid_utf8('')
+-- !query schema
+struct<make_valid_utf8():string>
+-- !query output
+
+
+
+-- !query
+select make_valid_utf8('abc')
+-- !query schema
+struct<make_valid_utf8(abc):string>
+-- !query output
+abc
+
+
+-- !query
+select make_valid_utf8(x'80')
+-- !query schema
+struct<make_valid_utf8(X'80'):string>
+-- !query output
+�
+
+
+-- !query
+select validate_utf8('')
+-- !query schema
+struct<validate_utf8():string>
+-- !query output
+
+
+
+-- !query
+select validate_utf8('abc')
+-- !query schema
+struct<validate_utf8(abc):string>
+-- !query output
+abc
+
+
+-- !query
+select validate_utf8(x'80')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkIllegalArgumentException
+{
+  "errorClass" : "INVALID_UTF8_STRING",
+  "sqlState" : "22029",
+  "messageParameters" : {
+    "str" : "\\x80"
+  }
+}
+
+
+-- !query
+select try_validate_utf8('')
+-- !query schema
+struct<try_validate_utf8():string>
+-- !query output
+
+
+
+-- !query
+select try_validate_utf8('abc')
+-- !query schema
+struct<try_validate_utf8(abc):string>
+-- !query output
+abc
+
+
+-- !query
+select try_validate_utf8(x'80')
+-- !query schema
+struct<try_validate_utf8(X'80'):string>
+-- !query output
+NULL
diff --git 
a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index e6778cb539bd..53f516dac03c 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -2126,3 +2126,106 @@ select luhn_check(123.456)
 struct<luhn_check(123.456):boolean>
 -- !query output
 false
+
+
+-- !query
+select is_valid_utf8('')
+-- !query schema
+struct<is_valid_utf8():boolean>
+-- !query output
+true
+
+
+-- !query
+select is_valid_utf8('abc')
+-- !query schema
+struct<is_valid_utf8(abc):boolean>
+-- !query output
+true
+
+
+-- !query
+select is_valid_utf8(x'80')
+-- !query schema
+struct<is_valid_utf8(X'80'):boolean>
+-- !query output
+false
+
+
+-- !query
+select make_valid_utf8('')
+-- !query schema
+struct<make_valid_utf8():string>
+-- !query output
+
+
+
+-- !query
+select make_valid_utf8('abc')
+-- !query schema
+struct<make_valid_utf8(abc):string>
+-- !query output
+abc
+
+
+-- !query
+select make_valid_utf8(x'80')
+-- !query schema
+struct<make_valid_utf8(X'80'):string>
+-- !query output
+�
+
+
+-- !query
+select validate_utf8('')
+-- !query schema
+struct<validate_utf8():string>
+-- !query output
+
+
+
+-- !query
+select validate_utf8('abc')
+-- !query schema
+struct<validate_utf8(abc):string>
+-- !query output
+abc
+
+
+-- !query
+select validate_utf8(x'80')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkIllegalArgumentException
+{
+  "errorClass" : "INVALID_UTF8_STRING",
+  "sqlState" : "22029",
+  "messageParameters" : {
+    "str" : "\\x80"
+  }
+}
+
+
+-- !query
+select try_validate_utf8('')
+-- !query schema
+struct<try_validate_utf8():string>
+-- !query output
+
+
+
+-- !query
+select try_validate_utf8('abc')
+-- !query schema
+struct<try_validate_utf8(abc):string>
+-- !query output
+abc
+
+
+-- !query
+select try_validate_utf8(x'80')
+-- !query schema
+struct<try_validate_utf8(X'80'):string>
+-- !query output
+NULL
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index a5e5b08cd9ff..78aee5b80e54 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.SparkConf
+import scala.jdk.CollectionConverters.MapHasAsScala
+
+import org.apache.spark.{SparkConf, SparkIllegalArgumentException}
 import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, 
Literal, StringTrim, StringTrimLeft, StringTrimRight}
 import org.apache.spark.sql.catalyst.util.CollationFactory
 import org.apache.spark.sql.internal.SQLConf
@@ -667,6 +669,93 @@ class CollationStringExpressionsSuite
     })
   }
 
+  test("Support IsValidUTF8 string expression with collation") {
+    // Supported collations
+    case class IsValidUTF8TestCase(input: String, collationName: String, 
result: Any)
+    val testCases = Seq(
+      IsValidUTF8TestCase("null", "UTF8_BINARY", result = null),
+      IsValidUTF8TestCase("''", "UTF8_LCASE", result = true),
+      IsValidUTF8TestCase("'abc'", "UNICODE", result = true),
+      IsValidUTF8TestCase("x'FF'", "UNICODE_CI", result = false)
+    )
+    testCases.foreach { testCase =>
+      withSQLConf(SQLConf.DEFAULT_COLLATION.key -> testCase.collationName) {
+        val query = s"SELECT is_valid_utf8(${testCase.input})"
+        // Result & data type
+        checkAnswer(sql(query), Row(testCase.result))
+        assert(sql(query).schema.fields.head.dataType.sameType(BooleanType))
+      }
+    }
+  }
+
+  test("Support MakeValidUTF8 string expression with collation") {
+    // Supported collations
+    case class MakeValidUTF8TestCase(input: String, collationName: String, 
result: Any)
+    val testCases = Seq(
+      MakeValidUTF8TestCase("null", "UTF8_BINARY", result = null),
+      MakeValidUTF8TestCase("''", "UTF8_LCASE", result = ""),
+      MakeValidUTF8TestCase("'abc'", "UNICODE", result = "abc"),
+      MakeValidUTF8TestCase("x'FF'", "UNICODE_CI", result = "\uFFFD")
+    )
+    testCases.foreach { testCase =>
+      withSQLConf(SQLConf.DEFAULT_COLLATION.key -> testCase.collationName) {
+        val query = s"SELECT make_valid_utf8(${testCase.input})"
+        // Result & data type
+        checkAnswer(sql(query), Row(testCase.result))
+        val dataType = StringType(testCase.collationName)
+        assert(sql(query).schema.fields.head.dataType.sameType(dataType))
+      }
+    }
+  }
+
+  test("Support ValidateUTF8 string expression with collation") {
+    // Supported collations
+    case class ValidateUTF8TestCase(input: String, collationName: String, 
result: Any)
+    val testCases = Seq(
+      ValidateUTF8TestCase("null", "UTF8_BINARY", result = null),
+      ValidateUTF8TestCase("''", "UTF8_LCASE", result = ""),
+      ValidateUTF8TestCase("'abc'", "UNICODE", result = "abc"),
+      ValidateUTF8TestCase("x'FF'", "UNICODE_CI", result = None)
+    )
+    testCases.foreach { testCase =>
+      withSQLConf(SQLConf.DEFAULT_COLLATION.key -> testCase.collationName) {
+        val query = s"SELECT validate_utf8(${testCase.input})"
+        if (testCase.result == None) {
+          // Exception thrown
+          val e = intercept[SparkIllegalArgumentException] {
+            sql(query).collect()
+          }
+          assert(e.getErrorClass == "INVALID_UTF8_STRING")
+          assert(e.getMessageParameters.asScala == Map("str" -> "\\xFF"))
+        } else {
+          // Result & data type
+          checkAnswer(sql(query), Row(testCase.result))
+          val dataType = StringType(testCase.collationName)
+          assert(sql(query).schema.fields.head.dataType.sameType(dataType))
+        }
+      }
+    }
+  }
+
+  test("Support TryValidateUTF8 string expression with collation") {
+    // Supported collations
+    case class ValidateUTF8TestCase(input: String, collationName: String, 
result: Any)
+    val testCases = Seq(
+      ValidateUTF8TestCase("null", "UTF8_BINARY", result = null),
+      ValidateUTF8TestCase("''", "UTF8_LCASE", result = ""),
+      ValidateUTF8TestCase("'abc'", "UNICODE", result = "abc"),
+      ValidateUTF8TestCase("x'FF'", "UNICODE_CI", result = null)
+    )
+    testCases.foreach { testCase =>
+      withSQLConf(SQLConf.DEFAULT_COLLATION.key -> testCase.collationName) {
+        val query = s"SELECT try_validate_utf8(${testCase.input})"
+        // Result & data type
+        checkAnswer(sql(query), Row(testCase.result))
+        
assert(sql(query).schema.fields.head.dataType.sameType(StringType(testCase.collationName)))
+      }
+    }
+  }
+
   test("Support Left/Right/Substr with collation") {
     case class SubstringTestCase(
         method: String,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to