This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new 8bd7d886e05 [SPARK-38796][SQL] Update to_number and try_to_number
functions to restrict S and MI sequence to start or end only
8bd7d886e05 is described below
commit 8bd7d886e0570ed6d01ebbadca83c77821aee93f
Author: Daniel Tenedorio <[email protected]>
AuthorDate: Tue Apr 19 11:18:56 2022 +0800
[SPARK-38796][SQL] Update to_number and try_to_number functions to restrict
S and MI sequence to start or end only
### What changes were proposed in this pull request?
Update `to_number` and `try_to_number` functions to restrict MI sequence to
start or end only.
This satisfies the following specification:
```
to_number(expr, fmt)
fmt
{ ' [ MI | S ] [ L | $ ]
[ 0 | 9 | G | , ] [...]
[ . | D ]
[ 0 | 9 ] [...]
[ L | $ ] [ PR | MI | S ] ' }
```
### Why are the changes needed?
After reviewing the specification, this behavior makes the most sense.
### Does this PR introduce _any_ user-facing change?
Yes, a slight change in the behavior of the format string.
### How was this patch tested?
Existing and updated unit test coverage.
Closes #36154 from dtenedor/mi-anywhere.
Authored-by: Daniel Tenedorio <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 242ee22c00394c29e21bc3de0a93cb6d9746d93c)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../expressions/numberFormatExpressions.scala | 4 +-
.../spark/sql/catalyst/util/ToNumberParser.scala | 163 ++++++++++++---------
.../expressions/StringExpressionsSuite.scala | 20 +--
3 files changed, 106 insertions(+), 81 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
index 88947c5c87a..c866bb9af9e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
@@ -46,8 +46,8 @@ import org.apache.spark.unsafe.types.UTF8String
grouping separator relevant for the size of the number.
'$': Specifies the location of the $ currency sign. This character
may only be specified
once.
- 'S': Specifies the position of a '-' or '+' sign (optional, only
allowed once).
- 'MI': Specifies that 'expr' has an optional '-' sign, but no '+'
(only allowed once).
+ 'S' or 'MI': Specifies the position of a '-' or '+' sign (optional,
only allowed once at
+ the beginning or end of the format string). Note that 'S' allows
'-' but 'MI' does not.
'PR': Only allowed at the end of the format string; specifies that
'expr' indicates a
negative number with wrapping angled brackets.
('<1>').
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
index afba683efad..716224983e0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
@@ -49,33 +49,56 @@ object ToNumberParser {
final val WRAPPING_ANGLE_BRACKETS_TO_NEGATIVE_NUMBER_END = 'R'
// This class represents one or more characters that we expect to be present
in the input string
- // based on the format string.
+ // based on the format string. The toString method returns a representation
of each token suitable
+ // for use in error messages.
abstract class InputToken()
// Represents some number of digits (0-9).
abstract class Digits extends InputToken
// Represents exactly 'num' digits (0-9).
- case class ExactlyAsManyDigits(num: Int) extends Digits
+ case class ExactlyAsManyDigits(num: Int) extends Digits {
+ override def toString: String = "digit sequence"
+ }
// Represents at most 'num' digits (0-9).
- case class AtMostAsManyDigits(num: Int) extends Digits
+ case class AtMostAsManyDigits(num: Int) extends Digits {
+ override def toString: String = "digit sequence"
+ }
// Represents one decimal point (.).
- case class DecimalPoint() extends InputToken
+ case class DecimalPoint() extends InputToken {
+ override def toString: String = ". or D"
+ }
// Represents one thousands separator (,).
- case class ThousandsSeparator() extends InputToken
+ case class ThousandsSeparator() extends InputToken {
+ override def toString: String = ", or G"
+ }
// Represents one or more groups of Digits (0-9) with ThousandsSeparators
(,) between each group.
// The 'tokens' are the Digits and ThousandsSeparators in order; the
'digits' are just the Digits.
- case class DigitGroups(tokens: Seq[InputToken], digits: Seq[Digits]) extends
InputToken
+ case class DigitGroups(tokens: Seq[InputToken], digits: Seq[Digits]) extends
InputToken {
+ override def toString: String = "digit sequence"
+ }
// Represents one dollar sign ($).
- case class DollarSign() extends InputToken
+ case class DollarSign() extends InputToken {
+ override def toString: String = "$"
+ }
// Represents one optional plus sign (+) or minus sign (-).
- case class OptionalPlusOrMinusSign() extends InputToken
+ case class OptionalPlusOrMinusSign() extends InputToken {
+ override def toString: String = "S"
+ }
// Represents one optional minus sign (-).
- case class OptionalMinusSign() extends InputToken
+ case class OptionalMinusSign() extends InputToken {
+ override def toString: String = "MI"
+ }
// Represents one opening angle bracket (<).
- case class OpeningAngleBracket() extends InputToken
+ case class OpeningAngleBracket() extends InputToken {
+ override def toString: String = "PR"
+ }
// Represents one closing angle bracket (>).
- case class ClosingAngleBracket() extends InputToken
+ case class ClosingAngleBracket() extends InputToken {
+ override def toString: String = "PR"
+ }
// Represents any unrecognized character other than the above.
- case class InvalidUnrecognizedCharacter(char: Char) extends InputToken
+ case class InvalidUnrecognizedCharacter(char: Char) extends InputToken {
+ override def toString: String = s"character '$char''"
+ }
}
/**
@@ -241,16 +264,6 @@ class ToNumberParser(numberFormat: String, errorOnFail:
Boolean) extends Seriali
* This implementation of the [[check]] method returns any error, or the
empty string on success.
*/
private def validateFormatString: String = {
- def multipleSignInNumberFormatError(message: String) = {
- s"At most one $message is allowed in the number format: '$numberFormat'"
- }
-
- def notAtEndOfNumberFormatError(message: String) = {
- s"$message must be at the end of the number format: '$numberFormat'"
- }
-
- val inputTokenCounts = formatTokens.groupBy(identity).mapValues(_.size)
-
val firstDollarSignIndex: Int = formatTokens.indexOf(DollarSign())
val firstDigitIndex: Int = formatTokens.indexWhere {
case _: DigitGroups => true
@@ -276,58 +289,25 @@ class ToNumberParser(numberFormat: String, errorOnFail:
Boolean) extends Seriali
// Make sure the format string contains at least one token.
if (numberFormat.isEmpty) {
- "The format string cannot be empty"
- }
- // Make sure the format string does not contain any unrecognized
characters.
- else if
(formatTokens.exists(_.isInstanceOf[InvalidUnrecognizedCharacter])) {
- val unrecognizedChars =
- formatTokens.filter {
- _.isInstanceOf[InvalidUnrecognizedCharacter]
- }.map {
- case i: InvalidUnrecognizedCharacter => i.char
- }
- val char: Char = unrecognizedChars.head
- s"Encountered invalid character $char in the number format:
'$numberFormat'"
+ return "The format string cannot be empty"
}
// Make sure the format string contains at least one digit.
- else if (!formatTokens.exists(
+ if (!formatTokens.exists(
token => token.isInstanceOf[DigitGroups])) {
- "The format string requires at least one number digit"
- }
- // Make sure the format string contains at most one decimal point.
- else if (inputTokenCounts.getOrElse(DecimalPoint(), 0) > 1) {
- multipleSignInNumberFormatError(s"'$POINT_LETTER' or '$POINT_SIGN'")
- }
- // Make sure the format string contains at most one plus or minus sign.
- else if (inputTokenCounts.getOrElse(OptionalPlusOrMinusSign(), 0) > 1) {
- multipleSignInNumberFormatError(s"'$OPTIONAL_PLUS_OR_MINUS_LETTER'")
- }
- // Make sure the format string contains at most one dollar sign.
- else if (inputTokenCounts.getOrElse(DollarSign(), 0) > 1) {
- multipleSignInNumberFormatError(s"'$DOLLAR_SIGN'")
- }
- // Make sure the format string contains at most one "MI" sequence.
- else if (inputTokenCounts.getOrElse(OptionalMinusSign(), 0) > 1) {
- multipleSignInNumberFormatError(s"'$OPTIONAL_MINUS_STRING'")
- }
- // Make sure the format string contains at most one closing angle bracket
at the end.
- else if (inputTokenCounts.getOrElse(ClosingAngleBracket(), 0) > 1 ||
- (inputTokenCounts.getOrElse(ClosingAngleBracket(), 0) == 1 &&
- formatTokens.last != ClosingAngleBracket())) {
-
notAtEndOfNumberFormatError(s"'$WRAPPING_ANGLE_BRACKETS_TO_NEGATIVE_NUMBER'")
+ return "The format string requires at least one number digit"
}
// Make sure that any dollar sign in the format string occurs before any
digits.
- else if (firstDigitIndex < firstDollarSignIndex) {
- s"Currency characters must appear before digits in the number format:
'$numberFormat'"
+ if (firstDigitIndex < firstDollarSignIndex) {
+ return s"Currency characters must appear before digits in the number
format: '$numberFormat'"
}
// Make sure that any dollar sign in the format string occurs before any
decimal point.
- else if (firstDecimalPointIndex != -1 &&
+ if (firstDecimalPointIndex != -1 &&
firstDecimalPointIndex < firstDollarSignIndex) {
- "Currency characters must appear before any decimal point in the " +
+ return "Currency characters must appear before any decimal point in the
" +
s"number format: '$numberFormat'"
}
// Make sure that any thousands separators in the format string have
digits before and after.
- else if (digitGroupsBeforeDecimalPoint.exists {
+ if (digitGroupsBeforeDecimalPoint.exists {
case DigitGroups(tokens, _) =>
tokens.zipWithIndex.exists({
case (_: ThousandsSeparator, j: Int) if j == 0 || j == tokens.length
- 1 =>
@@ -340,21 +320,64 @@ class ToNumberParser(numberFormat: String, errorOnFail:
Boolean) extends Seriali
false
})
}) {
- "Thousands separators (,) must have digits in between them " +
+ return "Thousands separators (,) must have digits in between them " +
s"in the number format: '$numberFormat'"
}
- // Thousands separators are not allowed after the decimal point, if any.
- else if (digitGroupsAfterDecimalPoint.exists {
+ // Make sure that thousands separators does not appear after the decimal
point, if any.
+ if (digitGroupsAfterDecimalPoint.exists {
case DigitGroups(tokens, digits) =>
tokens.length > digits.length
}) {
- "Thousands separators (,) may not appear after the decimal point " +
+ return "Thousands separators (,) may not appear after the decimal point
" +
s"in the number format: '$numberFormat'"
}
- // Validation of the format string finished successfully.
- else {
- ""
+ // Make sure that the format string does not contain any prohibited
duplicate tokens.
+ val inputTokenCounts = formatTokens.groupBy(identity).mapValues(_.size)
+ Seq(DecimalPoint(),
+ OptionalPlusOrMinusSign(),
+ OptionalMinusSign(),
+ DollarSign(),
+ ClosingAngleBracket()).foreach {
+ token => if (inputTokenCounts.getOrElse(token, 0) > 1) {
+ return s"At most one ${token.toString} is allowed in the number
format: '$numberFormat'"
+ }
+ }
+ // Enforce the ordering of tokens in the format string according to this
specification:
+ // [ MI | S ] [ $ ]
+ // [ 0 | 9 | G | , ] [...]
+ // [ . | D ]
+ // [ 0 | 9 ] [...]
+ // [ $ ] [ PR | MI | S ]
+ val allowedFormatTokens: Seq[Seq[InputToken]] = Seq(
+ Seq(OpeningAngleBracket()),
+ Seq(OptionalMinusSign(), OptionalPlusOrMinusSign()),
+ Seq(DollarSign()),
+ Seq(DigitGroups(Seq(), Seq())),
+ Seq(DecimalPoint()),
+ Seq(DigitGroups(Seq(), Seq())),
+ Seq(DollarSign()),
+ Seq(OptionalMinusSign(), OptionalPlusOrMinusSign(),
ClosingAngleBracket())
+ )
+ var formatTokenIndex = 0
+ for (allowedTokens: Seq[InputToken] <- allowedFormatTokens) {
+ def tokensMatch(lhs: InputToken, rhs: InputToken): Boolean = {
+ lhs match {
+ case _: DigitGroups => rhs.isInstanceOf[DigitGroups]
+ case _ => lhs == rhs
+ }
+ }
+ if (formatTokenIndex < formatTokens.length &&
+ allowedTokens.exists(tokensMatch(_, formatTokens(formatTokenIndex)))) {
+ formatTokenIndex += 1
+ }
}
+ if (formatTokenIndex < formatTokens.length) {
+ return s"Unexpected ${formatTokens(formatTokenIndex).toString} found in
the format string " +
+ s"'$numberFormat'; the structure of the format string must match: " +
+ "[MI|S] [$] [0|9|G|,]* [.|D] [0|9]* [$] [PR|MI|S]"
+ }
+ // Validation of the format string finished successfully.
+ ""
}
/**
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index afb05dd4d77..91b3d0c69b8 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -972,7 +972,6 @@ class StringExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
("<454>", "999PR") -> Decimal(-454),
("454-", "999MI") -> Decimal(-454),
("-$54", "MI$99") -> Decimal(-54),
- ("$4-4", "$9MI9") -> Decimal(-44),
// The input string contains more digits than fit in a long integer.
("123,456,789,123,456,789,123", "999,999,999,999,999,999,999") ->
Decimal(new JavaBigDecimal("123456789123456789123"))
@@ -1009,7 +1008,8 @@ class StringExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
}
test("ToNumber: negative tests (the format string is invalid)") {
- val invalidCharacter = "Encountered invalid character"
+ val unexpectedCharacter = "the structure of the format string must match:
" +
+ "[MI|S] [$] [0|9|G|,]* [.|D] [0|9]* [$] [PR|MI|S]"
val thousandsSeparatorDigitsBetween =
"Thousands separators (,) must have digits in between them"
val mustBeAtEnd = "must be at the end of the number format"
@@ -1018,23 +1018,25 @@ class StringExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
// The format string must not be empty.
("454", "") -> "The format string cannot be empty",
// Make sure the format string does not contain any unrecognized
characters.
- ("454", "999@") -> invalidCharacter,
- ("454", "999M") -> invalidCharacter,
- ("454", "999P") -> invalidCharacter,
+ ("454", "999@") -> unexpectedCharacter,
+ ("454", "999M") -> unexpectedCharacter,
+ ("454", "999P") -> unexpectedCharacter,
// Make sure the format string contains at least one digit.
("454", "$") -> "The format string requires at least one number digit",
// Make sure the format string contains at most one decimal point.
("454", "99.99.99") -> atMostOne,
// Make sure the format string contains at most one dollar sign.
("454", "$$99") -> atMostOne,
- // Make sure the format string contains at most one minus sign at the
end.
+ // Make sure the format string contains at most one minus sign at the
beginning or end.
+ ("$4-4", "$9MI9") -> unexpectedCharacter,
+ ("--4", "SMI9") -> unexpectedCharacter,
("--$54", "SS$99") -> atMostOne,
("-$54", "MI$99MI") -> atMostOne,
("$4-4", "$9MI9MI") -> atMostOne,
// Make sure the format string contains at most one closing angle
bracket at the end.
- ("<$45>", "PR$99") -> mustBeAtEnd,
- ("$4<4>", "$9PR9") -> mustBeAtEnd,
- ("<<454>>", "999PRPR") -> mustBeAtEnd,
+ ("<$45>", "PR$99") -> unexpectedCharacter,
+ ("$4<4>", "$9PR9") -> unexpectedCharacter,
+ ("<<454>>", "999PRPR") -> atMostOne,
// Make sure that any dollar sign in the format string occurs before any
digits.
("4$54", "9$99") -> "Currency characters must appear before digits",
// Make sure that any dollar sign in the format string occurs before any
decimal point.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]