This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 701d6ee10eb [SPARK-40112][SQL] Improve the TO_BINARY() function 701d6ee10eb is described below commit 701d6ee10eb03b384b54bf75dfd8aeb3a155569a Author: Vitalii Li <vitalii...@databricks.com> AuthorDate: Thu Sep 1 10:18:44 2022 +0800 [SPARK-40112][SQL] Improve the TO_BINARY() function ### What changes were proposed in this pull request? Improvements for `TO_BINARY`: - `base64` behaves more strictly, i.e. does not allow symbols not included in base64 dictionary (A-Za-z0-9+/) and verifies correct padding and symbol groups (see RFC 4648 § 4). Whitespaces are ignored. Current implementation allows arbitrary strings and invalid symbols are skipped. - `hex` converts only valid hexadecimal strings and throws errors otherwise. Whitespaces are not allowed. - `utf-8` and `utf8` are interchangeable. - Correct errors are thrown and classified for invalid input (CONVERSION_INVALID_INPUT) and invalid format (CONVERSION_INVALID_FORMAT) ### Why are the changes needed? Better handling for malformed input. Improve parity with implementation done by other engines. ### Does this PR introduce _any_ user-facing change? Yes, this changes existing function behavior. ### How was this patch tested? Unit test, `SQLQueryTestSuite` Closes #37483 from vitaliili-db/SC-89850. Authored-by: Vitalii Li <vitalii...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- core/src/main/resources/error/error-classes.json | 5 + .../scala/org/apache/spark/SparkException.scala | 10 +- docs/sql-migration-guide.md | 4 + .../org/apache/spark/sql/AnalysisException.scala | 19 +- .../sql/catalyst/expressions/mathExpressions.scala | 44 ++- .../catalyst/expressions/stringExpressions.scala | 129 +++++++-- .../spark/sql/errors/QueryCompilationErrors.scala | 15 +- .../spark/sql/errors/QueryExecutionErrors.scala | 14 + .../expressions/MathExpressionsSuite.scala | 2 +- .../sql-tests/inputs/string-functions.sql | 43 ++- .../sql-tests/inputs/try-string-functions.sql | 45 ++- .../results/ansi/string-functions.sql.out | 312 +++++++++++++++++++-- .../sql-tests/results/string-functions.sql.out | 312 +++++++++++++++++++-- .../sql-tests/results/try-string-functions.sql.out | Bin 1898 -> 5233 bytes .../sql/errors/QueryExecutionErrorsSuite.scala | 15 +- 15 files changed, 851 insertions(+), 118 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index df0f887a63c..6a9652b4c67 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -70,6 +70,11 @@ "Another instance of this query was just started by a concurrent session." ] }, + "CONVERSION_INVALID_INPUT" : { + "message" : [ + "The value <str> (<fmt>) cannot be converted to <targetType> because it is malformed. Correct the value as per the syntax, or change its format. Use <suggestion> to tolerate malformed input and return NULL instead." + ] + }, "DATETIME_OVERFLOW" : { "message" : [ "Datetime operation overflow: <operation>." diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index 55471d7c002..67aa8cdfcac 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -303,14 +303,18 @@ private[spark] class SparkNoSuchMethodException( private[spark] class SparkIllegalArgumentException( errorClass: String, errorSubClass: Option[String] = None, - messageParameters: Array[String]) + messageParameters: Array[String], + context: Array[QueryContext] = Array.empty, + summary: String = "") extends IllegalArgumentException( - SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters)) + SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) with SparkThrowable { override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass - override def getErrorSubClass: String = errorSubClass.orNull} + override def getErrorSubClass: String = errorSubClass.orNull + override def getQueryContext: Array[QueryContext] = context +} /** * Index out of bounds exception thrown from Spark with an error class. diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index d69f245d8e8..164e330148f 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -29,6 +29,10 @@ license: | - Since Spark 3.4, when ANSI SQL mode(configuration `spark.sql.ansi.enabled`) is on, Spark SQL always returns NULL result on getting a map value with a non-existing key. In Spark 3.3 or earlier, there will be an error. - Since Spark 3.4, the SQL CLI `spark-sql` does not print the prefix `Error in query:` before the error message of `AnalysisException`. - Since Spark 3.4, `split` function ignores trailing empty strings when `regex` parameter is empty. + - Since Spark 3.4, the `to_binary` function throws error for a malformed `str` input. Use `try_to_binary` to tolerate malformed input and return NULL instead. + - Valid Base64 string should include symbols from in base64 alphabet (A-Za-z0-9+/), optional padding (`=`), and optional whitespaces. Whitespaces are skipped in conversion except when they are preceded by padding symbol(s). If padding is present it should conclude the string and follow rules described in RFC 4648 § 4. + - Valid hexadecimal strings should include only allowed symbols (0-9A-Fa-f). + - Valid values for `fmt` are case-insensitive `hex`, `base64`, `utf-8`, `utf8`. ## Upgrading from Spark SQL 3.2 to 3.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 48e1f91990b..6c81cf8566c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.{SparkThrowable, SparkThrowableHelper} +import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper} import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin @@ -37,7 +37,8 @@ class AnalysisException protected[sql] ( val cause: Option[Throwable] = None, val errorClass: Option[String] = None, val errorSubClass: Option[String] = None, - val messageParameters: Array[String] = Array.empty) + val messageParameters: Array[String] = Array.empty, + val context: Array[QueryContext] = Array.empty) extends Exception(message, cause.orNull) with SparkThrowable with Serializable { // Needed for binary compatibility @@ -65,6 +66,19 @@ class AnalysisException protected[sql] ( messageParameters = messageParameters, cause = cause) + def this( + errorClass: String, + messageParameters: Array[String], + context: Array[QueryContext], + summary: String) = + this( + SparkThrowableHelper.getMessage(errorClass, null, messageParameters, summary), + errorClass = Some(errorClass), + errorSubClass = None, + messageParameters = messageParameters, + cause = null, + context = context) + def this(errorClass: String, messageParameters: Array[String]) = this(errorClass = errorClass, messageParameters = messageParameters, cause = None) @@ -138,4 +152,5 @@ class AnalysisException protected[sql] ( override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass.orNull override def getErrorSubClass: String = errorSubClass.orNull + override def getQueryContext: Array[QueryContext] = context } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index dfbc041b259..5643598b4bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils} -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1120,28 +1120,58 @@ case class Hex(child: Expression) """, since = "1.5.0", group = "math_funcs") -case class Unhex(child: Expression) +case class Unhex(child: Expression, failOnError: Boolean = false) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + def this(expr: Expression) = this(expr, false) + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) override def nullable: Boolean = true override def dataType: DataType = BinaryType - protected override def nullSafeEval(num: Any): Any = - Hex.unhex(num.asInstanceOf[UTF8String].getBytes) + protected override def nullSafeEval(num: Any): Any = { + val result = Hex.unhex(num.asInstanceOf[UTF8String].getBytes) + if (failOnError && result == null) { + // The failOnError is set only from `ToBinary` function - hence we might safely set `hint` + // parameter to `try_to_binary`. + throw QueryExecutionErrors.invalidInputInConversionError( + BinaryType, + num.asInstanceOf[UTF8String], + UTF8String.fromString("HEX"), + "try_to_binary") + } + result + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (c) => { + nullSafeCodeGen(ctx, ev, c => { val hex = Hex.getClass.getName.stripSuffix("$") + val maybeFailOnErrorCode = if (failOnError) { + val format = UTF8String.fromString("BASE64"); + val binaryType = ctx.addReferenceObj("to", BinaryType, BinaryType.getClass.getName) + s""" + |if (${ev.value} == null) { + | throw QueryExecutionErrors.invalidInputInConversionError( + | $binaryType, + | $c, + | $format, + | "try_to_binary"); + |} + |""".stripMargin + } else { + s"${ev.isNull} = ${ev.value} == null;" + } + s""" ${ev.value} = $hex.unhex($c.getBytes()); - ${ev.isNull} = ${ev.value} == null; + $maybeFailOnErrorCode """ }) } - override protected def withNewChildInternal(newChild: Expression): Unhex = copy(child = newChild) + override protected def withNewChildInternal(newChild: Expression): Unhex = + copy(child = newChild, failOnError) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index dffe0d56f33..1bc79f23846 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2300,24 +2300,105 @@ case class Base64(child: Expression) """, since = "1.5.0", group = "string_funcs") -case class UnBase64(child: Expression) +case class UnBase64(child: Expression, failOnError: Boolean = false) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) - protected override def nullSafeEval(string: Any): Any = + def this(expr: Expression) = this(expr, false) + + protected override def nullSafeEval(string: Any): Any = { + if (failOnError && !UnBase64.isValidBase64(string.asInstanceOf[UTF8String])) { + // The failOnError is set only from `ToBinary` function - hence we might safely set `hint` + // parameter to `try_to_binary`. + throw QueryExecutionErrors.invalidInputInConversionError( + BinaryType, + string.asInstanceOf[UTF8String], + UTF8String.fromString("BASE64"), + "try_to_binary") + } JBase64.getMimeDecoder.decode(string.asInstanceOf[UTF8String].toString) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (child) => { + nullSafeCodeGen(ctx, ev, child => { + val maybeValidateInputCode = if (failOnError) { + val unbase64 = UnBase64.getClass.getName.stripSuffix("$") + val format = UTF8String.fromString("BASE64"); + val binaryType = ctx.addReferenceObj("to", BinaryType, BinaryType.getClass.getName) + s""" + |if (!$unbase64.isValidBase64($child)) { + | throw QueryExecutionErrors.invalidInputInConversionError( + | $binaryType, + | $child, + | $format, + | "try_to_binary"); + |} + """.stripMargin + } else { + "" + } s""" + $maybeValidateInputCode ${ev.value} = ${classOf[JBase64].getName}.getMimeDecoder().decode($child.toString()); """}) } override protected def withNewChildInternal(newChild: Expression): UnBase64 = - copy(child = newChild) + copy(child = newChild, failOnError) +} + +object UnBase64 { + def isValidBase64(srcString: UTF8String) : Boolean = { + // We use RFC4648. The valid base64 string should contain zero or more groups of 4 symbols plus + // last group consisting of 2-4 valid symbols and optional padding. + // Last group should contain at least 2 valid symbols and up to 2 padding characters `=`. + // Valid symbols include - (A-Za-z0-9+/). Each group might contain arbitrary number of + // whitespaces which are ignored. + // If padding is present - last group should include exactly 4 symbols. + // Examples: + // "abcd" - Valid, single group of 4 valid symbols + // "abc d" - Valid, single group of 4 valid symbols, whitespace is skipped + // "abc?" - Invalid, group contains invalid symbol `?` + // "abcdA" - Invalid, last group should contain at least 2 valid symbols + // "abcdAE" - Valid, a group of 4 valid symbols and a group of 2 valid symbols + // "abcdAE==" - Valid, last group includes 2 padding symbols and total number of symbols + // in a group is 4. + // "abcdAE=" - Invalid, last group include padding symbols, therefore it should have + // exactly 4 symbols but contains only 3. + // "ab==tm+1" - Invalid, nothing should be after padding. + var position = 0 + var padSize = 0 + for (c: Char <- srcString.toString) { + c match { + case a + if (a >= '0' && a <= '9') + || (a >= 'A' && a <= 'Z') + || (a >= 'a' && a <= 'z') + || a == '/' || a == '+' => + if (padSize != 0) return false // Padding symbols should conclude the string. + position += 1 + case '=' => + padSize += 1 + // Last group preceding padding should have 2 or more symbols. Padding size should be 1 or + // less. + if (padSize > 2 || position % 4 < 2) { + return false + } + case ws if Character.isWhitespace(ws) => + if (padSize != 0) { // Padding symbols should conclude the string. + return false + } + case _ => return false + } + } + if (padSize > 0) { // When padding is present last group should have exactly 4 symbols. + (position + padSize) % 4 == 0 + } else { // When padding is absent last group should include 2 or more symbols. + position % 4 != 1 + } + } } object Decode { @@ -2473,11 +2554,10 @@ case class Encode(value: Expression, charset: Expression) /** * Converts the input expression to a binary value based on the supplied format. */ -// scalastyle:off line.size.limit @ExpressionDescription( usage = """ _FUNC_(str[, fmt]) - Converts the input `str` to a binary value based on the supplied `fmt`. - `fmt` can be a case-insensitive string literal of "hex", "utf-8", or "base64". + `fmt` can be a case-insensitive string literal of "hex", "utf-8", "utf8", or "base64". By default, the binary format for conversion is "hex" if `fmt` is omitted. The function returns NULL if at least one of the input parameters is NULL. """, @@ -2488,12 +2568,11 @@ case class Encode(value: Expression, charset: Expression) """, since = "3.3.0", group = "string_funcs") -// scalastyle:on line.size.limit case class ToBinary( expr: Expression, format: Option[Expression], nullOnInvalidFormat: Boolean = false) extends RuntimeReplaceable - with ImplicitCastInputTypes { + with ImplicitCastInputTypes { override lazy val replacement: Expression = format.map { f => assert(f.foldable && (f.dataType == StringType || f.dataType == NullType)) @@ -2502,30 +2581,32 @@ case class ToBinary( Literal(null, BinaryType) } else { value.asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT) match { - case "hex" => Unhex(expr) - case "utf-8" => Encode(expr, Literal("UTF-8")) - case "base64" => UnBase64(expr) + case "hex" => Unhex(expr, failOnError = true) + case "utf-8" | "utf8" => Encode(expr, Literal("UTF-8")) + case "base64" => UnBase64(expr, failOnError = true) case _ if nullOnInvalidFormat => Literal(null, BinaryType) case other => throw QueryCompilationErrors.invalidStringLiteralParameter( - "to_binary", "format", other, - Some("The value has to be a case-insensitive string literal of " + - "'hex', 'utf-8', or 'base64'.")) + "to_binary", + "format", + other, + Some( + "The value has to be a case-insensitive string literal of " + + "'hex', 'utf-8', 'utf8', or 'base64'.")) } } - }.getOrElse(Unhex(expr)) + }.getOrElse(Unhex(expr, failOnError = true)) def this(expr: Expression) = this(expr, None, false) - def this(expr: Expression, format: Expression) = this(expr, Some({ + def this(expr: Expression, format: Expression) = + this(expr, Some({ // We perform this check in the constructor to make it eager and not go through type coercion. if (format.foldable && (format.dataType == StringType || format.dataType == NullType)) { format } else { throw QueryCompilationErrors.requireLiteralParameter("to_binary", "format", "string") } - }), - false - ) + }), false) override def prettyName: String = "to_binary" @@ -2535,11 +2616,11 @@ case class ToBinary( override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): Expression = { - if (format.isDefined) { - copy(expr = newChildren.head, format = Some(newChildren.last)) - } else { - copy(expr = newChildren.head) - } + if (format.isDefined) { + copy(expr = newChildren.head, format = Some(newChildren.last)) + } else { + copy(expr = newChildren.head) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index ef4321a4fc7..d142be68b52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, LogicalPlan, SerdeInfo, Window} -import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} +import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext, TreeNode} import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -795,6 +795,19 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { s"The '$argName' parameter of function '$funcName' needs to be a $requiredType literal.") } + def invalidFormatInConversion( + argName: String, + funcName: String, + expected: String, + context: SQLQueryContext): Throwable = { + new AnalysisException( + errorClass = "INVALID_PARAMETER_VALUE", + messageParameters = + Array(toSQLId(argName), toSQLId(funcName), expected), + context = getQueryContext(context), + summary = getSummary(context)) + } + def invalidStringLiteralParameter( funcName: String, argName: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 8cb31f45c25..3dcefcc5368 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -168,6 +168,20 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { summary = getSummary(context)) } + def invalidInputInConversionError( + to: DataType, + s: UTF8String, + fmt: UTF8String, + hint: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "CONVERSION_INVALID_INPUT", + messageParameters = Array( + toSQLValue(s, StringType), + toSQLValue(fmt, StringType), + toSQLType(to), + toSQLId(hint))) + } + def cannotCastFromNullTypeError(to: DataType): Throwable = { new SparkException(errorClass = "CANNOT_CAST_DATATYPE", messageParameters = Array(NullType.typeName, to.typeName), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index c8e99112a15..c741b685a34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -595,7 +595,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes(StandardCharsets.UTF_8)) checkEvaluation(Unhex(Literal("三重的")), null) // scalastyle:on - checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType) + checkConsistencyBetweenInterpretedAndCodegen((e: Expression) => Unhex(e), StringType) } test("hypot") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index efbef2ab449..8af82efeab3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -174,12 +174,41 @@ select to_number('00,454.8-', '00,000.9MI'); select to_number('<00,454.8>', '00,000.9PR'); -- to_binary -select to_binary('abc'); -select to_binary('abc', 'utf-8'); -select to_binary('abc', 'base64'); -select to_binary('abc', 'hex'); +-- base64 valid +select to_binary('', 'base64'); +select to_binary(' ', 'base64'); +select to_binary(' ab cd ', 'base64'); +select to_binary(' ab c=', 'base64'); +select to_binary(' ab cdef= = ', 'base64'); +select to_binary( + concat(' b25lIHR3byB0aHJlZSBmb3VyIGZpdmUgc2l4IHNldmVuIGVpZ2h0IG5pbmUgdGVuIGVsZXZlbiB0', + 'd2VsdmUgdGhpcnRlZW4gZm91cnRlZW4gZml2dGVlbiBzaXh0ZWVuIHNldmVudGVlbiBlaWdodGVl'), 'base64'); +-- base64 invalid +select to_binary('a', 'base64'); +select to_binary('a?', 'base64'); +select to_binary('abcde', 'base64'); +select to_binary('abcd=', 'base64'); +select to_binary('a===', 'base64'); +select to_binary('ab==f', 'base64'); +-- utf-8 +select to_binary( + '∮ E⋅da = Q, n → ∞, ∑ f(i) = ∏ g(i), ∀x∈ℝ: ⌈x⌉ = −⌊−x⌋, α ∧ ¬β = ¬(¬α ∨ β)', 'utf-8'); +select to_binary('大千世界', 'utf8'); +select to_binary('', 'utf-8'); +select to_binary(' ', 'utf8'); +-- hex valid +select to_binary('737472696E67'); +select to_binary('737472696E67', 'hex'); +select to_binary(''); +select to_binary('1', 'hex'); +select to_binary('FF'); +-- hex invalid +select to_binary('GG'); +select to_binary('01 AF', 'hex'); -- 'format' parameter can be any foldable string value, not just literal. select to_binary('abc', concat('utf', '-8')); +select to_binary(' ab cdef= = ', substr('base64whynot', 0, 6)); +select to_binary(' ab cdef= = ', replace('HEX0', '0')); -- 'format' parameter is case insensitive. select to_binary('abc', 'Hex'); -- null inputs lead to null result. @@ -187,10 +216,6 @@ select to_binary('abc', null); select to_binary(null, 'utf-8'); select to_binary(null, null); select to_binary(null, cast(null as string)); --- 'format' parameter must be string type or void type. -select to_binary(null, cast(null as int)); -select to_binary('abc', 1); -- invalid format +select to_binary('abc', 1); select to_binary('abc', 'invalidFormat'); --- invalid string input -select to_binary('a!', 'base64'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql index 20f02374e78..d21a80d482a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql @@ -1,10 +1,39 @@ -- try_to_binary -select try_to_binary('abc'); -select try_to_binary('abc', 'utf-8'); -select try_to_binary('abc', 'base64'); -select try_to_binary('abc', 'hex'); +-- base64 valid +select try_to_binary('', 'base64'); +select try_to_binary(' ', 'base64'); +select try_to_binary(' ab cd ', 'base64'); +select try_to_binary(' ab c=', 'base64'); +select try_to_binary(' ab cdef= = ', 'base64'); +select try_to_binary( + concat(' b25lIHR3byB0aHJlZSBmb3VyIGZpdmUgc2l4IHNldmVuIGVpZ2h0IG5pbmUgdGVuIGVsZXZlbiB0', + 'd2VsdmUgdGhpcnRlZW4gZm91cnRlZW4gZml2dGVlbiBzaXh0ZWVuIHNldmVudGVlbiBlaWdodGVl'), 'base64'); +-- base64 invalid +select try_to_binary('a', 'base64'); +select try_to_binary('a?', 'base64'); +select try_to_binary('abcde', 'base64'); +select try_to_binary('abcd=', 'base64'); +select try_to_binary('a===', 'base64'); +select try_to_binary('ab==f', 'base64'); +-- utf-8 +select try_to_binary( + '∮ E⋅da = Q, n → ∞, ∑ f(i) = ∏ g(i), ∀x∈ℝ: ⌈x⌉ = −⌊−x⌋, α ∧ ¬β = ¬(¬α ∨ β)', 'utf-8'); +select try_to_binary('大千世界', 'utf8'); +select try_to_binary('', 'utf-8'); +select try_to_binary(' ', 'utf8'); +-- hex valid +select try_to_binary('737472696E67'); +select try_to_binary('737472696E67', 'hex'); +select try_to_binary(''); +select try_to_binary('1', 'hex'); +select try_to_binary('FF'); +-- hex invalid +select try_to_binary('GG'); +select try_to_binary('01 AF', 'hex'); -- 'format' parameter can be any foldable string value, not just literal. select try_to_binary('abc', concat('utf', '-8')); +select try_to_binary(' ab cdef= = ', substr('base64whynot', 0, 6)); +select try_to_binary(' ab cdef= = ', replace('HEX0', '0')); -- 'format' parameter is case insensitive. select try_to_binary('abc', 'Hex'); -- null inputs lead to null result. @@ -12,10 +41,6 @@ select try_to_binary('abc', null); select try_to_binary(null, 'utf-8'); select try_to_binary(null, null); select try_to_binary(null, cast(null as string)); --- 'format' parameter must be string type or void type. -select try_to_binary(null, cast(null as int)); -select try_to_binary('abc', 1); -- invalid format -select try_to_binary('abc', 'invalidFormat'); --- invalid string input -select try_to_binary('a!', 'base64'); +select try_to_binary('abc', 1); +select try_to_binary('abc', 'invalidFormat'); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index d08084a39c3..810f1942be2 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -1147,35 +1147,271 @@ struct<to_number(<00,454.8>, 00,000.9PR):decimal(6,1)> -- !query -select to_binary('abc') +select to_binary('', 'base64') -- !query schema -struct<to_binary(abc):binary> +struct<to_binary(, base64):binary> -- !query output -�tion -The 'format' parameter of function 'to_binary' needs to be a string literal.; line 1 pos 7 - - -- !query select to_binary('abc', 1) -- !query schema @@ -1250,13 +1511,4 @@ select to_binary('abc', 'invalidFormat') struct<> -- !query output org.apache.spark.sql.AnalysisException -Invalid value for the 'format' parameter of function 'to_binary': invalidformat. The value has to be a case-insensitive string literal of 'hex', 'utf-8', or 'base64'. - - --- !query -select to_binary('a!', 'base64') --- !query schema -struct<> --- !query output -java.lang.IllegalArgumentException -Last unit does not have enough valid bits +Invalid value for the 'format' parameter of function 'to_binary': invalidformat. The value has to be a case-insensitive string literal of 'hex', 'utf-8', 'utf8', or 'base64'. diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index d96000c2dff..a8ad802dd98 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1079,35 +1079,271 @@ struct<to_number(<00,454.8>, 00,000.9PR):decimal(6,1)> -- !query -select to_binary('abc') +select to_binary('', 'base64') -- !query schema -struct<to_binary(abc):binary> +struct<to_binary(, base64):binary> -- !query output -�ULL AS STRING)):binary> NULL --- !query -select to_binary(null, cast(null as int)) --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -The 'format' parameter of function 'to_binary' needs to be a string literal.; line 1 pos 7 - - -- !query select to_binary('abc', 1) -- !query schema @@ -1182,13 +1443,4 @@ select to_binary('abc', 'invalidFormat') struct<> -- !query output org.apache.spark.sql.AnalysisException -Invalid value for the 'format' parameter of function 'to_binary': invalidformat. The value has to be a case-insensitive string literal of 'hex', 'utf-8', or 'base64'. - - --- !query -select to_binary('a!', 'base64') --- !query schema -struct<> --- !query output -java.lang.IllegalArgumentException -Last unit does not have enough valid bits +Invalid value for the 'format' parameter of function 'to_binary': invalidformat. The value has to be a case-insensitive string literal of 'hex', 'utf-8', 'utf8', or 'base64'. diff --git a/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out index b3d3197ee7d..dacbc08a103 100644 Binary files a/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out and b/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index abb64f0f4a7..1b5fa2aa890 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.permission.FsPermission import org.mockito.Mockito.{mock, when} import test.org.apache.spark.sql.connector.JavaSimpleWritableDataSource -import org.apache.spark.{SparkArithmeticException, SparkClassNotFoundException, SparkException, SparkFileNotFoundException, SparkIllegalArgumentException, SparkRuntimeException, SparkSecurityException, SparkSQLException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark._ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.util.BadRecordException import org.apache.spark.sql.connector.SimpleWritableDataSource @@ -53,6 +53,19 @@ class QueryExecutionErrorsSuite import testImplicits._ + test("CONVERSION_INVALID_INPUT: to_binary conversion function") { + checkError( + exception = intercept[SparkIllegalArgumentException] { + sql("select to_binary('???', 'base64')").collect() + }, + errorClass = "CONVERSION_INVALID_INPUT", + parameters = Map( + "str" -> "'???'", + "fmt" -> "'BASE64'", + "targetType" -> "\"BINARY\"", + "suggestion" -> "`try_to_binary`")) + } + private def getAesInputs(): (DataFrame, DataFrame) = { val encryptedText16 = "4Hv0UKCx6nfUeAoPZo1z+w==" val encryptedText24 = "NeTYNgA+PCQBN50DA//O2w==" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org