This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 8bd7d886e05 [SPARK-38796][SQL] Update to_number and try_to_number functions to restrict S and MI sequence to start or end only 8bd7d886e05 is described below commit 8bd7d886e0570ed6d01ebbadca83c77821aee93f Author: Daniel Tenedorio <daniel.tenedo...@databricks.com> AuthorDate: Tue Apr 19 11:18:56 2022 +0800 [SPARK-38796][SQL] Update to_number and try_to_number functions to restrict S and MI sequence to start or end only ### What changes were proposed in this pull request? Update `to_number` and `try_to_number` functions to restrict MI sequence to start or end only. This satisfies the following specification: ``` to_number(expr, fmt) fmt { ' [ MI | S ] [ L | $ ] [ 0 | 9 | G | , ] [...] [ . | D ] [ 0 | 9 ] [...] [ L | $ ] [ PR | MI | S ] ' } ``` ### Why are the changes needed? After reviewing the specification, this behavior makes the most sense. ### Does this PR introduce _any_ user-facing change? Yes, a slight change in the behavior of the format string. ### How was this patch tested? Existing and updated unit test coverage. Closes #36154 from dtenedor/mi-anywhere. Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 242ee22c00394c29e21bc3de0a93cb6d9746d93c) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/numberFormatExpressions.scala | 4 +- .../spark/sql/catalyst/util/ToNumberParser.scala | 163 ++++++++++++--------- .../expressions/StringExpressionsSuite.scala | 20 +-- 3 files changed, 106 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index 88947c5c87a..c866bb9af9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -46,8 +46,8 @@ import org.apache.spark.unsafe.types.UTF8String grouping separator relevant for the size of the number. '$': Specifies the location of the $ currency sign. This character may only be specified once. - 'S': Specifies the position of a '-' or '+' sign (optional, only allowed once). - 'MI': Specifies that 'expr' has an optional '-' sign, but no '+' (only allowed once). + 'S' or 'MI': Specifies the position of a '-' or '+' sign (optional, only allowed once at + the beginning or end of the format string). Note that 'S' allows '-' but 'MI' does not. 'PR': Only allowed at the end of the format string; specifies that 'expr' indicates a negative number with wrapping angled brackets. ('<1>'). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala index afba683efad..716224983e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala @@ -49,33 +49,56 @@ object ToNumberParser { final val WRAPPING_ANGLE_BRACKETS_TO_NEGATIVE_NUMBER_END = 'R' // This class represents one or more characters that we expect to be present in the input string - // based on the format string. + // based on the format string. The toString method returns a representation of each token suitable + // for use in error messages. abstract class InputToken() // Represents some number of digits (0-9). abstract class Digits extends InputToken // Represents exactly 'num' digits (0-9). - case class ExactlyAsManyDigits(num: Int) extends Digits + case class ExactlyAsManyDigits(num: Int) extends Digits { + override def toString: String = "digit sequence" + } // Represents at most 'num' digits (0-9). - case class AtMostAsManyDigits(num: Int) extends Digits + case class AtMostAsManyDigits(num: Int) extends Digits { + override def toString: String = "digit sequence" + } // Represents one decimal point (.). - case class DecimalPoint() extends InputToken + case class DecimalPoint() extends InputToken { + override def toString: String = ". or D" + } // Represents one thousands separator (,). - case class ThousandsSeparator() extends InputToken + case class ThousandsSeparator() extends InputToken { + override def toString: String = ", or G" + } // Represents one or more groups of Digits (0-9) with ThousandsSeparators (,) between each group. // The 'tokens' are the Digits and ThousandsSeparators in order; the 'digits' are just the Digits. - case class DigitGroups(tokens: Seq[InputToken], digits: Seq[Digits]) extends InputToken + case class DigitGroups(tokens: Seq[InputToken], digits: Seq[Digits]) extends InputToken { + override def toString: String = "digit sequence" + } // Represents one dollar sign ($). - case class DollarSign() extends InputToken + case class DollarSign() extends InputToken { + override def toString: String = "$" + } // Represents one optional plus sign (+) or minus sign (-). - case class OptionalPlusOrMinusSign() extends InputToken + case class OptionalPlusOrMinusSign() extends InputToken { + override def toString: String = "S" + } // Represents one optional minus sign (-). - case class OptionalMinusSign() extends InputToken + case class OptionalMinusSign() extends InputToken { + override def toString: String = "MI" + } // Represents one opening angle bracket (<). - case class OpeningAngleBracket() extends InputToken + case class OpeningAngleBracket() extends InputToken { + override def toString: String = "PR" + } // Represents one closing angle bracket (>). - case class ClosingAngleBracket() extends InputToken + case class ClosingAngleBracket() extends InputToken { + override def toString: String = "PR" + } // Represents any unrecognized character other than the above. - case class InvalidUnrecognizedCharacter(char: Char) extends InputToken + case class InvalidUnrecognizedCharacter(char: Char) extends InputToken { + override def toString: String = s"character '$char''" + } } /** @@ -241,16 +264,6 @@ class ToNumberParser(numberFormat: String, errorOnFail: Boolean) extends Seriali * This implementation of the [[check]] method returns any error, or the empty string on success. */ private def validateFormatString: String = { - def multipleSignInNumberFormatError(message: String) = { - s"At most one $message is allowed in the number format: '$numberFormat'" - } - - def notAtEndOfNumberFormatError(message: String) = { - s"$message must be at the end of the number format: '$numberFormat'" - } - - val inputTokenCounts = formatTokens.groupBy(identity).mapValues(_.size) - val firstDollarSignIndex: Int = formatTokens.indexOf(DollarSign()) val firstDigitIndex: Int = formatTokens.indexWhere { case _: DigitGroups => true @@ -276,58 +289,25 @@ class ToNumberParser(numberFormat: String, errorOnFail: Boolean) extends Seriali // Make sure the format string contains at least one token. if (numberFormat.isEmpty) { - "The format string cannot be empty" - } - // Make sure the format string does not contain any unrecognized characters. - else if (formatTokens.exists(_.isInstanceOf[InvalidUnrecognizedCharacter])) { - val unrecognizedChars = - formatTokens.filter { - _.isInstanceOf[InvalidUnrecognizedCharacter] - }.map { - case i: InvalidUnrecognizedCharacter => i.char - } - val char: Char = unrecognizedChars.head - s"Encountered invalid character $char in the number format: '$numberFormat'" + return "The format string cannot be empty" } // Make sure the format string contains at least one digit. - else if (!formatTokens.exists( + if (!formatTokens.exists( token => token.isInstanceOf[DigitGroups])) { - "The format string requires at least one number digit" - } - // Make sure the format string contains at most one decimal point. - else if (inputTokenCounts.getOrElse(DecimalPoint(), 0) > 1) { - multipleSignInNumberFormatError(s"'$POINT_LETTER' or '$POINT_SIGN'") - } - // Make sure the format string contains at most one plus or minus sign. - else if (inputTokenCounts.getOrElse(OptionalPlusOrMinusSign(), 0) > 1) { - multipleSignInNumberFormatError(s"'$OPTIONAL_PLUS_OR_MINUS_LETTER'") - } - // Make sure the format string contains at most one dollar sign. - else if (inputTokenCounts.getOrElse(DollarSign(), 0) > 1) { - multipleSignInNumberFormatError(s"'$DOLLAR_SIGN'") - } - // Make sure the format string contains at most one "MI" sequence. - else if (inputTokenCounts.getOrElse(OptionalMinusSign(), 0) > 1) { - multipleSignInNumberFormatError(s"'$OPTIONAL_MINUS_STRING'") - } - // Make sure the format string contains at most one closing angle bracket at the end. - else if (inputTokenCounts.getOrElse(ClosingAngleBracket(), 0) > 1 || - (inputTokenCounts.getOrElse(ClosingAngleBracket(), 0) == 1 && - formatTokens.last != ClosingAngleBracket())) { - notAtEndOfNumberFormatError(s"'$WRAPPING_ANGLE_BRACKETS_TO_NEGATIVE_NUMBER'") + return "The format string requires at least one number digit" } // Make sure that any dollar sign in the format string occurs before any digits. - else if (firstDigitIndex < firstDollarSignIndex) { - s"Currency characters must appear before digits in the number format: '$numberFormat'" + if (firstDigitIndex < firstDollarSignIndex) { + return s"Currency characters must appear before digits in the number format: '$numberFormat'" } // Make sure that any dollar sign in the format string occurs before any decimal point. - else if (firstDecimalPointIndex != -1 && + if (firstDecimalPointIndex != -1 && firstDecimalPointIndex < firstDollarSignIndex) { - "Currency characters must appear before any decimal point in the " + + return "Currency characters must appear before any decimal point in the " + s"number format: '$numberFormat'" } // Make sure that any thousands separators in the format string have digits before and after. - else if (digitGroupsBeforeDecimalPoint.exists { + if (digitGroupsBeforeDecimalPoint.exists { case DigitGroups(tokens, _) => tokens.zipWithIndex.exists({ case (_: ThousandsSeparator, j: Int) if j == 0 || j == tokens.length - 1 => @@ -340,21 +320,64 @@ class ToNumberParser(numberFormat: String, errorOnFail: Boolean) extends Seriali false }) }) { - "Thousands separators (,) must have digits in between them " + + return "Thousands separators (,) must have digits in between them " + s"in the number format: '$numberFormat'" } - // Thousands separators are not allowed after the decimal point, if any. - else if (digitGroupsAfterDecimalPoint.exists { + // Make sure that thousands separators does not appear after the decimal point, if any. + if (digitGroupsAfterDecimalPoint.exists { case DigitGroups(tokens, digits) => tokens.length > digits.length }) { - "Thousands separators (,) may not appear after the decimal point " + + return "Thousands separators (,) may not appear after the decimal point " + s"in the number format: '$numberFormat'" } - // Validation of the format string finished successfully. - else { - "" + // Make sure that the format string does not contain any prohibited duplicate tokens. + val inputTokenCounts = formatTokens.groupBy(identity).mapValues(_.size) + Seq(DecimalPoint(), + OptionalPlusOrMinusSign(), + OptionalMinusSign(), + DollarSign(), + ClosingAngleBracket()).foreach { + token => if (inputTokenCounts.getOrElse(token, 0) > 1) { + return s"At most one ${token.toString} is allowed in the number format: '$numberFormat'" + } + } + // Enforce the ordering of tokens in the format string according to this specification: + // [ MI | S ] [ $ ] + // [ 0 | 9 | G | , ] [...] + // [ . | D ] + // [ 0 | 9 ] [...] + // [ $ ] [ PR | MI | S ] + val allowedFormatTokens: Seq[Seq[InputToken]] = Seq( + Seq(OpeningAngleBracket()), + Seq(OptionalMinusSign(), OptionalPlusOrMinusSign()), + Seq(DollarSign()), + Seq(DigitGroups(Seq(), Seq())), + Seq(DecimalPoint()), + Seq(DigitGroups(Seq(), Seq())), + Seq(DollarSign()), + Seq(OptionalMinusSign(), OptionalPlusOrMinusSign(), ClosingAngleBracket()) + ) + var formatTokenIndex = 0 + for (allowedTokens: Seq[InputToken] <- allowedFormatTokens) { + def tokensMatch(lhs: InputToken, rhs: InputToken): Boolean = { + lhs match { + case _: DigitGroups => rhs.isInstanceOf[DigitGroups] + case _ => lhs == rhs + } + } + if (formatTokenIndex < formatTokens.length && + allowedTokens.exists(tokensMatch(_, formatTokens(formatTokenIndex)))) { + formatTokenIndex += 1 + } } + if (formatTokenIndex < formatTokens.length) { + return s"Unexpected ${formatTokens(formatTokenIndex).toString} found in the format string " + + s"'$numberFormat'; the structure of the format string must match: " + + "[MI|S] [$] [0|9|G|,]* [.|D] [0|9]* [$] [PR|MI|S]" + } + // Validation of the format string finished successfully. + "" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index afb05dd4d77..91b3d0c69b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -972,7 +972,6 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ("<454>", "999PR") -> Decimal(-454), ("454-", "999MI") -> Decimal(-454), ("-$54", "MI$99") -> Decimal(-54), - ("$4-4", "$9MI9") -> Decimal(-44), // The input string contains more digits than fit in a long integer. ("123,456,789,123,456,789,123", "999,999,999,999,999,999,999") -> Decimal(new JavaBigDecimal("123456789123456789123")) @@ -1009,7 +1008,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("ToNumber: negative tests (the format string is invalid)") { - val invalidCharacter = "Encountered invalid character" + val unexpectedCharacter = "the structure of the format string must match: " + + "[MI|S] [$] [0|9|G|,]* [.|D] [0|9]* [$] [PR|MI|S]" val thousandsSeparatorDigitsBetween = "Thousands separators (,) must have digits in between them" val mustBeAtEnd = "must be at the end of the number format" @@ -1018,23 +1018,25 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // The format string must not be empty. ("454", "") -> "The format string cannot be empty", // Make sure the format string does not contain any unrecognized characters. - ("454", "999@") -> invalidCharacter, - ("454", "999M") -> invalidCharacter, - ("454", "999P") -> invalidCharacter, + ("454", "999@") -> unexpectedCharacter, + ("454", "999M") -> unexpectedCharacter, + ("454", "999P") -> unexpectedCharacter, // Make sure the format string contains at least one digit. ("454", "$") -> "The format string requires at least one number digit", // Make sure the format string contains at most one decimal point. ("454", "99.99.99") -> atMostOne, // Make sure the format string contains at most one dollar sign. ("454", "$$99") -> atMostOne, - // Make sure the format string contains at most one minus sign at the end. + // Make sure the format string contains at most one minus sign at the beginning or end. + ("$4-4", "$9MI9") -> unexpectedCharacter, + ("--4", "SMI9") -> unexpectedCharacter, ("--$54", "SS$99") -> atMostOne, ("-$54", "MI$99MI") -> atMostOne, ("$4-4", "$9MI9MI") -> atMostOne, // Make sure the format string contains at most one closing angle bracket at the end. - ("<$45>", "PR$99") -> mustBeAtEnd, - ("$4<4>", "$9PR9") -> mustBeAtEnd, - ("<<454>>", "999PRPR") -> mustBeAtEnd, + ("<$45>", "PR$99") -> unexpectedCharacter, + ("$4<4>", "$9PR9") -> unexpectedCharacter, + ("<<454>>", "999PRPR") -> atMostOne, // Make sure that any dollar sign in the format string occurs before any digits. ("4$54", "9$99") -> "Currency characters must appear before digits", // Make sure that any dollar sign in the format string occurs before any decimal point. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org