This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 701d6ee10eb [SPARK-40112][SQL] Improve the TO_BINARY() function
701d6ee10eb is described below

commit 701d6ee10eb03b384b54bf75dfd8aeb3a155569a
Author: Vitalii Li <vitalii...@databricks.com>
AuthorDate: Thu Sep 1 10:18:44 2022 +0800

    [SPARK-40112][SQL] Improve the TO_BINARY() function
    
    ### What changes were proposed in this pull request?
    
    Improvements for `TO_BINARY`:
    - `base64` behaves more strictly, i.e. does not allow symbols not included 
in base64 dictionary (A-Za-z0-9+/) and verifies correct padding and symbol 
groups (see RFC 4648 § 4). Whitespaces are ignored. Current implementation 
allows arbitrary strings and invalid symbols are skipped.
    - `hex` converts only valid hexadecimal strings and throws errors 
otherwise. Whitespaces are not allowed.
    - `utf-8` and `utf8` are interchangeable.
    - Correct errors are thrown and classified for invalid input 
(CONVERSION_INVALID_INPUT) and invalid format (CONVERSION_INVALID_FORMAT)
    
    ### Why are the changes needed?
    
    Better handling for malformed input. Improve parity with implementation 
done by other engines.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, this changes existing function behavior.
    
    ### How was this patch tested?
    
    Unit test, `SQLQueryTestSuite`
    
    Closes #37483 from vitaliili-db/SC-89850.
    
    Authored-by: Vitalii Li <vitalii...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 core/src/main/resources/error/error-classes.json   |   5 +
 .../scala/org/apache/spark/SparkException.scala    |  10 +-
 docs/sql-migration-guide.md                        |   4 +
 .../org/apache/spark/sql/AnalysisException.scala   |  19 +-
 .../sql/catalyst/expressions/mathExpressions.scala |  44 ++-
 .../catalyst/expressions/stringExpressions.scala   | 129 +++++++--
 .../spark/sql/errors/QueryCompilationErrors.scala  |  15 +-
 .../spark/sql/errors/QueryExecutionErrors.scala    |  14 +
 .../expressions/MathExpressionsSuite.scala         |   2 +-
 .../sql-tests/inputs/string-functions.sql          |  43 ++-
 .../sql-tests/inputs/try-string-functions.sql      |  45 ++-
 .../results/ansi/string-functions.sql.out          | 312 +++++++++++++++++++--
 .../sql-tests/results/string-functions.sql.out     | 312 +++++++++++++++++++--
 .../sql-tests/results/try-string-functions.sql.out | Bin 1898 -> 5233 bytes
 .../sql/errors/QueryExecutionErrorsSuite.scala     |  15 +-
 15 files changed, 851 insertions(+), 118 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index df0f887a63c..6a9652b4c67 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -70,6 +70,11 @@
       "Another instance of this query was just started by a concurrent 
session."
     ]
   },
+  "CONVERSION_INVALID_INPUT" : {
+    "message" : [
+      "The value <str> (<fmt>) cannot be converted to <targetType> because it 
is malformed. Correct the value as per the syntax, or change its format. Use 
<suggestion> to tolerate malformed input and return NULL instead."
+    ]
+  },
   "DATETIME_OVERFLOW" : {
     "message" : [
       "Datetime operation overflow: <operation>."
diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala 
b/core/src/main/scala/org/apache/spark/SparkException.scala
index 55471d7c002..67aa8cdfcac 100644
--- a/core/src/main/scala/org/apache/spark/SparkException.scala
+++ b/core/src/main/scala/org/apache/spark/SparkException.scala
@@ -303,14 +303,18 @@ private[spark] class SparkNoSuchMethodException(
 private[spark] class SparkIllegalArgumentException(
     errorClass: String,
     errorSubClass: Option[String] = None,
-    messageParameters: Array[String])
+    messageParameters: Array[String],
+    context: Array[QueryContext] = Array.empty,
+    summary: String = "")
   extends IllegalArgumentException(
-    SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, 
messageParameters))
+    SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, 
messageParameters, summary))
     with SparkThrowable {
 
   override def getMessageParameters: Array[String] = messageParameters
   override def getErrorClass: String = errorClass
-  override def getErrorSubClass: String = errorSubClass.orNull}
+  override def getErrorSubClass: String = errorSubClass.orNull
+  override def getQueryContext: Array[QueryContext] = context
+}
 
 /**
  * Index out of bounds exception thrown from Spark with an error class.
diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md
index d69f245d8e8..164e330148f 100644
--- a/docs/sql-migration-guide.md
+++ b/docs/sql-migration-guide.md
@@ -29,6 +29,10 @@ license: |
   - Since Spark 3.4, when ANSI SQL mode(configuration 
`spark.sql.ansi.enabled`) is on, Spark SQL always returns NULL result on 
getting a map value with a non-existing key. In Spark 3.3 or earlier, there 
will be an error.
   - Since Spark 3.4, the SQL CLI `spark-sql` does not print the prefix `Error 
in query:` before the error message of `AnalysisException`.
   - Since Spark 3.4, `split` function ignores trailing empty strings when 
`regex` parameter is empty.
+  - Since Spark 3.4, the `to_binary` function throws error for a malformed 
`str` input. Use `try_to_binary` to tolerate malformed input and return NULL 
instead.
+    - Valid Base64 string should include symbols from in base64 alphabet 
(A-Za-z0-9+/), optional padding (`=`), and optional whitespaces. Whitespaces 
are skipped in conversion except when they are preceded by padding symbol(s). 
If padding is present it should conclude the string and follow rules described 
in RFC 4648 § 4.
+    - Valid hexadecimal strings should include only allowed symbols 
(0-9A-Fa-f).
+    - Valid values for `fmt` are case-insensitive `hex`, `base64`, `utf-8`, 
`utf8`.
 
 ## Upgrading from Spark SQL 3.2 to 3.3
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
index 48e1f91990b..6c81cf8566c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.{SparkThrowable, SparkThrowableHelper}
+import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper}
 import org.apache.spark.annotation.Stable
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.trees.Origin
@@ -37,7 +37,8 @@ class AnalysisException protected[sql] (
     val cause: Option[Throwable] = None,
     val errorClass: Option[String] = None,
     val errorSubClass: Option[String] = None,
-    val messageParameters: Array[String] = Array.empty)
+    val messageParameters: Array[String] = Array.empty,
+    val context: Array[QueryContext] = Array.empty)
   extends Exception(message, cause.orNull) with SparkThrowable with 
Serializable {
 
     // Needed for binary compatibility
@@ -65,6 +66,19 @@ class AnalysisException protected[sql] (
       messageParameters = messageParameters,
       cause = cause)
 
+  def this(
+      errorClass: String,
+      messageParameters: Array[String],
+      context: Array[QueryContext],
+      summary: String) =
+    this(
+      SparkThrowableHelper.getMessage(errorClass, null, messageParameters, 
summary),
+      errorClass = Some(errorClass),
+      errorSubClass = None,
+      messageParameters = messageParameters,
+      cause = null,
+      context = context)
+
   def this(errorClass: String, messageParameters: Array[String]) =
     this(errorClass = errorClass, messageParameters = messageParameters, cause 
= None)
 
@@ -138,4 +152,5 @@ class AnalysisException protected[sql] (
   override def getMessageParameters: Array[String] = messageParameters
   override def getErrorClass: String = errorClass.orNull
   override def getErrorSubClass: String = errorSubClass.orNull
+  override def getQueryContext: Array[QueryContext] = context
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index dfbc041b259..5643598b4bd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -26,7 +26,7 @@ import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure,
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils}
-import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -1120,28 +1120,58 @@ case class Hex(child: Expression)
   """,
   since = "1.5.0",
   group = "math_funcs")
-case class Unhex(child: Expression)
+case class Unhex(child: Expression, failOnError: Boolean = false)
   extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
+  def this(expr: Expression) = this(expr, false)
+
   override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
 
   override def nullable: Boolean = true
   override def dataType: DataType = BinaryType
 
-  protected override def nullSafeEval(num: Any): Any =
-    Hex.unhex(num.asInstanceOf[UTF8String].getBytes)
+  protected override def nullSafeEval(num: Any): Any = {
+    val result = Hex.unhex(num.asInstanceOf[UTF8String].getBytes)
+    if (failOnError && result == null) {
+      // The failOnError is set only from `ToBinary` function - hence we might 
safely set `hint`
+      // parameter to `try_to_binary`.
+      throw QueryExecutionErrors.invalidInputInConversionError(
+        BinaryType,
+        num.asInstanceOf[UTF8String],
+        UTF8String.fromString("HEX"),
+        "try_to_binary")
+    }
+    result
+  }
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
-    nullSafeCodeGen(ctx, ev, (c) => {
+    nullSafeCodeGen(ctx, ev, c => {
       val hex = Hex.getClass.getName.stripSuffix("$")
+      val maybeFailOnErrorCode = if (failOnError) {
+        val format = UTF8String.fromString("BASE64");
+        val binaryType = ctx.addReferenceObj("to", BinaryType, 
BinaryType.getClass.getName)
+        s"""
+           |if (${ev.value} == null) {
+           |  throw QueryExecutionErrors.invalidInputInConversionError(
+           |    $binaryType,
+           |    $c,
+           |    $format,
+           |    "try_to_binary");
+           |}
+           |""".stripMargin
+      } else {
+        s"${ev.isNull} = ${ev.value} == null;"
+      }
+
       s"""
         ${ev.value} = $hex.unhex($c.getBytes());
-        ${ev.isNull} = ${ev.value} == null;
+        $maybeFailOnErrorCode
        """
     })
   }
 
-  override protected def withNewChildInternal(newChild: Expression): Unhex = 
copy(child = newChild)
+  override protected def withNewChildInternal(newChild: Expression): Unhex =
+    copy(child = newChild, failOnError)
 }
 
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index dffe0d56f33..1bc79f23846 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -2300,24 +2300,105 @@ case class Base64(child: Expression)
   """,
   since = "1.5.0",
   group = "string_funcs")
-case class UnBase64(child: Expression)
+case class UnBase64(child: Expression, failOnError: Boolean = false)
   extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = BinaryType
   override def inputTypes: Seq[DataType] = Seq(StringType)
 
-  protected override def nullSafeEval(string: Any): Any =
+  def this(expr: Expression) = this(expr, false)
+
+  protected override def nullSafeEval(string: Any): Any = {
+    if (failOnError && 
!UnBase64.isValidBase64(string.asInstanceOf[UTF8String])) {
+      // The failOnError is set only from `ToBinary` function - hence we might 
safely set `hint`
+      // parameter to `try_to_binary`.
+      throw QueryExecutionErrors.invalidInputInConversionError(
+        BinaryType,
+        string.asInstanceOf[UTF8String],
+        UTF8String.fromString("BASE64"),
+        "try_to_binary")
+    }
     JBase64.getMimeDecoder.decode(string.asInstanceOf[UTF8String].toString)
+  }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    nullSafeCodeGen(ctx, ev, (child) => {
+    nullSafeCodeGen(ctx, ev, child => {
+      val maybeValidateInputCode = if (failOnError) {
+        val unbase64 = UnBase64.getClass.getName.stripSuffix("$")
+        val format = UTF8String.fromString("BASE64");
+        val binaryType = ctx.addReferenceObj("to", BinaryType, 
BinaryType.getClass.getName)
+        s"""
+           |if (!$unbase64.isValidBase64($child)) {
+           |  throw QueryExecutionErrors.invalidInputInConversionError(
+           |    $binaryType,
+           |    $child,
+           |    $format,
+           |    "try_to_binary");
+           |}
+       """.stripMargin
+      } else {
+        ""
+      }
       s"""
+         $maybeValidateInputCode
          ${ev.value} = 
${classOf[JBase64].getName}.getMimeDecoder().decode($child.toString());
        """})
   }
 
   override protected def withNewChildInternal(newChild: Expression): UnBase64 =
-    copy(child = newChild)
+    copy(child = newChild, failOnError)
+}
+
+object UnBase64 {
+  def isValidBase64(srcString: UTF8String) : Boolean = {
+    // We use RFC4648. The valid base64 string should contain zero or more 
groups of 4 symbols plus
+    // last group consisting of 2-4 valid symbols and optional padding.
+    // Last group should contain at least 2 valid symbols and up to 2 padding 
characters `=`.
+    // Valid symbols include - (A-Za-z0-9+/). Each group might contain 
arbitrary number of
+    // whitespaces which are ignored.
+    // If padding is present - last group should include exactly 4 symbols.
+    // Examples:
+    //    "abcd"      - Valid, single group of 4 valid symbols
+    //    "abc d"     - Valid, single group of 4 valid symbols, whitespace is 
skipped
+    //    "abc?"      - Invalid, group contains invalid symbol `?`
+    //    "abcdA"     - Invalid, last group should contain at least 2 valid 
symbols
+    //    "abcdAE"    - Valid, a group of 4 valid symbols and a group of 2 
valid symbols
+    //    "abcdAE=="  - Valid, last group includes 2 padding symbols and total 
number of symbols
+    //                  in a group is 4.
+    //    "abcdAE="   - Invalid, last group include padding symbols, therefore 
it should have
+    //                  exactly 4 symbols but contains only 3.
+    //    "ab==tm+1"  - Invalid, nothing should be after padding.
+    var position = 0
+    var padSize = 0
+    for (c: Char <- srcString.toString) {
+      c match {
+        case a
+          if (a >= '0' && a <= '9')
+            || (a >= 'A' && a <= 'Z')
+            || (a >= 'a' && a <= 'z')
+            || a == '/' || a == '+' =>
+          if (padSize != 0) return false // Padding symbols should conclude 
the string.
+          position += 1
+        case '=' =>
+          padSize += 1
+          // Last group preceding padding should have 2 or more symbols. 
Padding size should be 1 or
+          // less.
+          if (padSize > 2 || position % 4 < 2) {
+            return false
+          }
+        case ws if Character.isWhitespace(ws) =>
+          if (padSize != 0) { // Padding symbols should conclude the string.
+            return false
+          }
+        case _ => return false
+      }
+    }
+    if (padSize > 0) { // When padding is present last group should have 
exactly 4 symbols.
+      (position + padSize) % 4 == 0
+    } else { // When padding is absent last group should include 2 or more 
symbols.
+      position % 4 != 1
+    }
+  }
 }
 
 object Decode {
@@ -2473,11 +2554,10 @@ case class Encode(value: Expression, charset: 
Expression)
 /**
  * Converts the input expression to a binary value based on the supplied 
format.
  */
-// scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = """
     _FUNC_(str[, fmt]) - Converts the input `str` to a binary value based on 
the supplied `fmt`.
-      `fmt` can be a case-insensitive string literal of "hex", "utf-8", or 
"base64".
+      `fmt` can be a case-insensitive string literal of "hex", "utf-8", 
"utf8", or "base64".
       By default, the binary format for conversion is "hex" if `fmt` is 
omitted.
       The function returns NULL if at least one of the input parameters is 
NULL.
   """,
@@ -2488,12 +2568,11 @@ case class Encode(value: Expression, charset: 
Expression)
   """,
   since = "3.3.0",
   group = "string_funcs")
-// scalastyle:on line.size.limit
 case class ToBinary(
     expr: Expression,
     format: Option[Expression],
     nullOnInvalidFormat: Boolean = false) extends RuntimeReplaceable
-  with ImplicitCastInputTypes {
+    with ImplicitCastInputTypes {
 
   override lazy val replacement: Expression = format.map { f =>
     assert(f.foldable && (f.dataType == StringType || f.dataType == NullType))
@@ -2502,30 +2581,32 @@ case class ToBinary(
       Literal(null, BinaryType)
     } else {
       value.asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT) match {
-        case "hex" => Unhex(expr)
-        case "utf-8" => Encode(expr, Literal("UTF-8"))
-        case "base64" => UnBase64(expr)
+        case "hex" => Unhex(expr, failOnError = true)
+        case "utf-8" | "utf8" => Encode(expr, Literal("UTF-8"))
+        case "base64" => UnBase64(expr, failOnError = true)
         case _ if nullOnInvalidFormat => Literal(null, BinaryType)
         case other => throw 
QueryCompilationErrors.invalidStringLiteralParameter(
-          "to_binary", "format", other,
-          Some("The value has to be a case-insensitive string literal of " +
-            "'hex', 'utf-8', or 'base64'."))
+              "to_binary",
+              "format",
+              other,
+              Some(
+                "The value has to be a case-insensitive string literal of " +
+                "'hex', 'utf-8', 'utf8', or 'base64'."))
       }
     }
-  }.getOrElse(Unhex(expr))
+  }.getOrElse(Unhex(expr, failOnError = true))
 
   def this(expr: Expression) = this(expr, None, false)
 
-  def this(expr: Expression, format: Expression) = this(expr, Some({
+  def this(expr: Expression, format: Expression) =
+    this(expr, Some({
       // We perform this check in the constructor to make it eager and not go 
through type coercion.
       if (format.foldable && (format.dataType == StringType || format.dataType 
== NullType)) {
         format
       } else {
         throw QueryCompilationErrors.requireLiteralParameter("to_binary", 
"format", "string")
       }
-    }),
-    false
-    )
+    }), false)
 
   override def prettyName: String = "to_binary"
 
@@ -2535,11 +2616,11 @@ case class ToBinary(
 
   override protected def withNewChildrenInternal(
       newChildren: IndexedSeq[Expression]): Expression = {
-    if (format.isDefined) {
-      copy(expr = newChildren.head, format = Some(newChildren.last))
-    } else {
-      copy(expr = newChildren.head)
-    }
+      if (format.isDefined) {
+        copy(expr = newChildren.head, format = Some(newChildren.last))
+      } else {
+        copy(expr = newChildren.head)
+      }
   }
 }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index ef4321a4fc7..d142be68b52 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -29,7 +29,7 @@ import 
org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, AttributeSet, CreateMap, Expression, GroupingID, 
NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, 
WindowSpecDefinition}
 import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, 
LogicalPlan, SerdeInfo, Window}
-import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode}
+import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext, TreeNode}
 import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, 
PermissiveMode}
 import org.apache.spark.sql.connector.catalog._
 import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
@@ -795,6 +795,19 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase {
       s"The '$argName' parameter of function '$funcName' needs to be a 
$requiredType literal.")
   }
 
+  def invalidFormatInConversion(
+      argName: String,
+      funcName: String,
+      expected: String,
+      context: SQLQueryContext): Throwable = {
+    new AnalysisException(
+      errorClass = "INVALID_PARAMETER_VALUE",
+      messageParameters =
+        Array(toSQLId(argName), toSQLId(funcName), expected),
+      context = getQueryContext(context),
+      summary = getSummary(context))
+  }
+
   def invalidStringLiteralParameter(
       funcName: String,
       argName: String,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 8cb31f45c25..3dcefcc5368 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -168,6 +168,20 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase {
       summary = getSummary(context))
   }
 
+  def invalidInputInConversionError(
+      to: DataType,
+      s: UTF8String,
+      fmt: UTF8String,
+      hint: String): SparkIllegalArgumentException = {
+      new SparkIllegalArgumentException(
+        errorClass = "CONVERSION_INVALID_INPUT",
+        messageParameters = Array(
+          toSQLValue(s, StringType),
+          toSQLValue(fmt, StringType),
+          toSQLType(to),
+          toSQLId(hint)))
+  }
+
   def cannotCastFromNullTypeError(to: DataType): Throwable = {
     new SparkException(errorClass = "CANNOT_CAST_DATATYPE",
       messageParameters = Array(NullType.typeName, to.typeName), null)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index c8e99112a15..c741b685a34 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -595,7 +595,7 @@ class MathExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), 
"三重的".getBytes(StandardCharsets.UTF_8))
     checkEvaluation(Unhex(Literal("三重的")), null)
     // scalastyle:on
-    checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType)
+    checkConsistencyBetweenInterpretedAndCodegen((e: Expression) => Unhex(e), 
StringType)
   }
 
   test("hypot") {
diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
index efbef2ab449..8af82efeab3 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
@@ -174,12 +174,41 @@ select to_number('00,454.8-', '00,000.9MI');
 select to_number('<00,454.8>', '00,000.9PR');
 
 -- to_binary
-select to_binary('abc');
-select to_binary('abc', 'utf-8');
-select to_binary('abc', 'base64');
-select to_binary('abc', 'hex');
+-- base64 valid
+select to_binary('', 'base64');
+select to_binary('  ', 'base64');
+select to_binary(' ab cd ', 'base64');
+select to_binary(' ab c=', 'base64');
+select to_binary(' ab cdef= = ', 'base64');
+select to_binary(
+  concat(' 
b25lIHR3byB0aHJlZSBmb3VyIGZpdmUgc2l4IHNldmVuIGVpZ2h0IG5pbmUgdGVuIGVsZXZlbiB0',
+         
'd2VsdmUgdGhpcnRlZW4gZm91cnRlZW4gZml2dGVlbiBzaXh0ZWVuIHNldmVudGVlbiBlaWdodGVl'),
 'base64');
+-- base64 invalid
+select to_binary('a', 'base64');
+select to_binary('a?', 'base64');
+select to_binary('abcde', 'base64');
+select to_binary('abcd=', 'base64');
+select to_binary('a===', 'base64');
+select to_binary('ab==f', 'base64');
+-- utf-8
+select to_binary(
+  '∮ E⋅da = Q,  n → ∞, ∑ f(i) = ∏ g(i), ∀x∈ℝ: ⌈x⌉ = −⌊−x⌋, α ∧ ¬β = ¬(¬α ∨ 
β)', 'utf-8');
+select to_binary('大千世界', 'utf8');
+select to_binary('', 'utf-8');
+select to_binary('  ', 'utf8');
+-- hex valid
+select to_binary('737472696E67');
+select to_binary('737472696E67', 'hex');
+select to_binary('');
+select to_binary('1', 'hex');
+select to_binary('FF');
+-- hex invalid
+select to_binary('GG');
+select to_binary('01 AF', 'hex');
 -- 'format' parameter can be any foldable string value, not just literal.
 select to_binary('abc', concat('utf', '-8'));
+select to_binary(' ab cdef= = ', substr('base64whynot', 0, 6));
+select to_binary(' ab cdef= = ', replace('HEX0', '0'));
 -- 'format' parameter is case insensitive.
 select to_binary('abc', 'Hex');
 -- null inputs lead to null result.
@@ -187,10 +216,6 @@ select to_binary('abc', null);
 select to_binary(null, 'utf-8');
 select to_binary(null, null);
 select to_binary(null, cast(null as string));
--- 'format' parameter must be string type or void type.
-select to_binary(null, cast(null as int));
-select to_binary('abc', 1);
 -- invalid format
+select to_binary('abc', 1);
 select to_binary('abc', 'invalidFormat');
--- invalid string input
-select to_binary('a!', 'base64');
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql
index 20f02374e78..d21a80d482a 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql
@@ -1,10 +1,39 @@
 -- try_to_binary
-select try_to_binary('abc');
-select try_to_binary('abc', 'utf-8');
-select try_to_binary('abc', 'base64');
-select try_to_binary('abc', 'hex');
+-- base64 valid
+select try_to_binary('', 'base64');
+select try_to_binary('  ', 'base64');
+select try_to_binary(' ab cd ', 'base64');
+select try_to_binary(' ab c=', 'base64');
+select try_to_binary(' ab cdef= = ', 'base64');
+select try_to_binary(
+  concat(' 
b25lIHR3byB0aHJlZSBmb3VyIGZpdmUgc2l4IHNldmVuIGVpZ2h0IG5pbmUgdGVuIGVsZXZlbiB0',
+         
'd2VsdmUgdGhpcnRlZW4gZm91cnRlZW4gZml2dGVlbiBzaXh0ZWVuIHNldmVudGVlbiBlaWdodGVl'),
 'base64');
+-- base64 invalid
+select try_to_binary('a', 'base64');
+select try_to_binary('a?', 'base64');
+select try_to_binary('abcde', 'base64');
+select try_to_binary('abcd=', 'base64');
+select try_to_binary('a===', 'base64');
+select try_to_binary('ab==f', 'base64');
+-- utf-8
+select try_to_binary(
+  '∮ E⋅da = Q,  n → ∞, ∑ f(i) = ∏ g(i), ∀x∈ℝ: ⌈x⌉ = −⌊−x⌋, α ∧ ¬β = ¬(¬α ∨ 
β)', 'utf-8');
+select try_to_binary('大千世界', 'utf8');
+select try_to_binary('', 'utf-8');
+select try_to_binary('  ', 'utf8');
+-- hex valid
+select try_to_binary('737472696E67');
+select try_to_binary('737472696E67', 'hex');
+select try_to_binary('');
+select try_to_binary('1', 'hex');
+select try_to_binary('FF');
+-- hex invalid
+select try_to_binary('GG');
+select try_to_binary('01 AF', 'hex');
 -- 'format' parameter can be any foldable string value, not just literal.
 select try_to_binary('abc', concat('utf', '-8'));
+select try_to_binary(' ab cdef= = ', substr('base64whynot', 0, 6));
+select try_to_binary(' ab cdef= = ', replace('HEX0', '0'));
 -- 'format' parameter is case insensitive.
 select try_to_binary('abc', 'Hex');
 -- null inputs lead to null result.
@@ -12,10 +41,6 @@ select try_to_binary('abc', null);
 select try_to_binary(null, 'utf-8');
 select try_to_binary(null, null);
 select try_to_binary(null, cast(null as string));
--- 'format' parameter must be string type or void type.
-select try_to_binary(null, cast(null as int));
-select try_to_binary('abc', 1);
 -- invalid format
-select try_to_binary('abc', 'invalidFormat');
--- invalid string input
-select try_to_binary('a!', 'base64');
+select try_to_binary('abc', 1);
+select try_to_binary('abc', 'invalidFormat');
\ No newline at end of file
diff --git 
a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
index d08084a39c3..810f1942be2 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
@@ -1147,35 +1147,271 @@ struct<to_number(<00,454.8>, 00,000.9PR):decimal(6,1)>
 
 
 -- !query
-select to_binary('abc')
+select to_binary('', 'base64')
 -- !query schema
-struct<to_binary(abc):binary>
+struct<to_binary(, base64):binary>
 -- !query output
-�tion
-The 'format' parameter of function 'to_binary' needs to be a string literal.; 
line 1 pos 7
-
-
 -- !query
 select to_binary('abc', 1)
 -- !query schema
@@ -1250,13 +1511,4 @@ select to_binary('abc', 'invalidFormat')
 struct<>
 -- !query output
 org.apache.spark.sql.AnalysisException
-Invalid value for the 'format' parameter of function 'to_binary': 
invalidformat. The value has to be a case-insensitive string literal of 'hex', 
'utf-8', or 'base64'.
-
-
--- !query
-select to_binary('a!', 'base64')
--- !query schema
-struct<>
--- !query output
-java.lang.IllegalArgumentException
-Last unit does not have enough valid bits
+Invalid value for the 'format' parameter of function 'to_binary': 
invalidformat. The value has to be a case-insensitive string literal of 'hex', 
'utf-8', 'utf8', or 'base64'.
diff --git 
a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index d96000c2dff..a8ad802dd98 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -1079,35 +1079,271 @@ struct<to_number(<00,454.8>, 00,000.9PR):decimal(6,1)>
 
 
 -- !query
-select to_binary('abc')
+select to_binary('', 'base64')
 -- !query schema
-struct<to_binary(abc):binary>
+struct<to_binary(, base64):binary>
 -- !query output
-�ULL AS STRING)):binary>
 NULL
 
 
--- !query
-select to_binary(null, cast(null as int))
--- !query schema
-struct<>
--- !query output
-org.apache.spark.sql.AnalysisException
-The 'format' parameter of function 'to_binary' needs to be a string literal.; 
line 1 pos 7
-
-
 -- !query
 select to_binary('abc', 1)
 -- !query schema
@@ -1182,13 +1443,4 @@ select to_binary('abc', 'invalidFormat')
 struct<>
 -- !query output
 org.apache.spark.sql.AnalysisException
-Invalid value for the 'format' parameter of function 'to_binary': 
invalidformat. The value has to be a case-insensitive string literal of 'hex', 
'utf-8', or 'base64'.
-
-
--- !query
-select to_binary('a!', 'base64')
--- !query schema
-struct<>
--- !query output
-java.lang.IllegalArgumentException
-Last unit does not have enough valid bits
+Invalid value for the 'format' parameter of function 'to_binary': 
invalidformat. The value has to be a case-insensitive string literal of 'hex', 
'utf-8', 'utf8', or 'base64'.
diff --git 
a/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out
index b3d3197ee7d..dacbc08a103 100644
Binary files 
a/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out 
and 
b/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out 
differ
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index abb64f0f4a7..1b5fa2aa890 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.fs.permission.FsPermission
 import org.mockito.Mockito.{mock, when}
 import test.org.apache.spark.sql.connector.JavaSimpleWritableDataSource
 
-import org.apache.spark.{SparkArithmeticException, 
SparkClassNotFoundException, SparkException, SparkFileNotFoundException, 
SparkIllegalArgumentException, SparkRuntimeException, SparkSecurityException, 
SparkSQLException, SparkUnsupportedOperationException, SparkUpgradeException}
+import org.apache.spark._
 import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, 
SaveMode}
 import org.apache.spark.sql.catalyst.util.BadRecordException
 import org.apache.spark.sql.connector.SimpleWritableDataSource
@@ -53,6 +53,19 @@ class QueryExecutionErrorsSuite
 
   import testImplicits._
 
+  test("CONVERSION_INVALID_INPUT: to_binary conversion function") {
+    checkError(
+      exception = intercept[SparkIllegalArgumentException] {
+        sql("select to_binary('???', 'base64')").collect()
+      },
+      errorClass = "CONVERSION_INVALID_INPUT",
+      parameters = Map(
+        "str" -> "'???'",
+        "fmt" -> "'BASE64'",
+        "targetType" -> "\"BINARY\"",
+        "suggestion" -> "`try_to_binary`"))
+  }
+
   private def getAesInputs(): (DataFrame, DataFrame) = {
     val encryptedText16 = "4Hv0UKCx6nfUeAoPZo1z+w=="
     val encryptedText24 = "NeTYNgA+PCQBN50DA//O2w=="


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to