This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 8bd7d886e05 [SPARK-38796][SQL] Update to_number and try_to_number 
functions to restrict S and MI sequence to start or end only
8bd7d886e05 is described below

commit 8bd7d886e0570ed6d01ebbadca83c77821aee93f
Author: Daniel Tenedorio <daniel.tenedo...@databricks.com>
AuthorDate: Tue Apr 19 11:18:56 2022 +0800

    [SPARK-38796][SQL] Update to_number and try_to_number functions to restrict 
S and MI sequence to start or end only
    
    ### What changes were proposed in this pull request?
    
    Update `to_number` and `try_to_number` functions to restrict MI sequence to 
start or end only.
    
    This satisfies the following specification:
    
    ```
    to_number(expr, fmt)
    fmt
      { ' [ MI | S ] [ L | $ ]
          [ 0 | 9 | G | , ] [...]
          [ . | D ]
          [ 0 | 9 ] [...]
          [ L | $ ] [ PR | MI | S ] ' }
    ```
    
    ### Why are the changes needed?
    
    After reviewing the specification, this behavior makes the most sense.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, a slight change in the behavior of the format string.
    
    ### How was this patch tested?
    
    Existing and updated unit test coverage.
    
    Closes #36154 from dtenedor/mi-anywhere.
    
    Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 242ee22c00394c29e21bc3de0a93cb6d9746d93c)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../expressions/numberFormatExpressions.scala      |   4 +-
 .../spark/sql/catalyst/util/ToNumberParser.scala   | 163 ++++++++++++---------
 .../expressions/StringExpressionsSuite.scala       |  20 +--
 3 files changed, 106 insertions(+), 81 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
index 88947c5c87a..c866bb9af9e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
@@ -46,8 +46,8 @@ import org.apache.spark.unsafe.types.UTF8String
            grouping separator relevant for the size of the number.
          '$': Specifies the location of the $ currency sign. This character 
may only be specified
            once.
-         'S': Specifies the position of a '-' or '+' sign (optional, only 
allowed once).
-         'MI': Specifies that 'expr' has an optional '-' sign, but no '+' 
(only allowed once).
+         'S' or 'MI': Specifies the position of a '-' or '+' sign (optional, 
only allowed once at
+           the beginning or end of the format string). Note that 'S' allows 
'-' but 'MI' does not.
          'PR': Only allowed at the end of the format string; specifies that 
'expr' indicates a
            negative number with wrapping angled brackets.
            ('<1>').
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
index afba683efad..716224983e0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
@@ -49,33 +49,56 @@ object ToNumberParser {
   final val WRAPPING_ANGLE_BRACKETS_TO_NEGATIVE_NUMBER_END = 'R'
 
   // This class represents one or more characters that we expect to be present 
in the input string
-  // based on the format string.
+  // based on the format string. The toString method returns a representation 
of each token suitable
+  // for use in error messages.
   abstract class InputToken()
   // Represents some number of digits (0-9).
   abstract class Digits extends InputToken
   // Represents exactly 'num' digits (0-9).
-  case class ExactlyAsManyDigits(num: Int) extends Digits
+  case class ExactlyAsManyDigits(num: Int) extends Digits {
+    override def toString: String = "digit sequence"
+  }
   // Represents at most 'num' digits (0-9).
-  case class AtMostAsManyDigits(num: Int) extends Digits
+  case class AtMostAsManyDigits(num: Int) extends Digits {
+    override def toString: String = "digit sequence"
+  }
   // Represents one decimal point (.).
-  case class DecimalPoint() extends InputToken
+  case class DecimalPoint() extends InputToken {
+    override def toString: String = ". or D"
+  }
   // Represents one thousands separator (,).
-  case class ThousandsSeparator() extends InputToken
+  case class ThousandsSeparator() extends InputToken {
+    override def toString: String = ", or G"
+  }
   // Represents one or more groups of Digits (0-9) with ThousandsSeparators 
(,) between each group.
   // The 'tokens' are the Digits and ThousandsSeparators in order; the 
'digits' are just the Digits.
-  case class DigitGroups(tokens: Seq[InputToken], digits: Seq[Digits]) extends 
InputToken
+  case class DigitGroups(tokens: Seq[InputToken], digits: Seq[Digits]) extends 
InputToken {
+    override def toString: String = "digit sequence"
+  }
   // Represents one dollar sign ($).
-  case class DollarSign() extends InputToken
+  case class DollarSign() extends InputToken {
+    override def toString: String = "$"
+  }
   // Represents one optional plus sign (+) or minus sign (-).
-  case class OptionalPlusOrMinusSign() extends InputToken
+  case class OptionalPlusOrMinusSign() extends InputToken {
+    override def toString: String = "S"
+  }
   // Represents one optional minus sign (-).
-  case class OptionalMinusSign() extends InputToken
+  case class OptionalMinusSign() extends InputToken {
+    override def toString: String = "MI"
+  }
   // Represents one opening angle bracket (<).
-  case class OpeningAngleBracket() extends InputToken
+  case class OpeningAngleBracket() extends InputToken {
+    override def toString: String = "PR"
+  }
   // Represents one closing angle bracket (>).
-  case class ClosingAngleBracket() extends InputToken
+  case class ClosingAngleBracket() extends InputToken {
+    override def toString: String = "PR"
+  }
   // Represents any unrecognized character other than the above.
-  case class InvalidUnrecognizedCharacter(char: Char) extends InputToken
+  case class InvalidUnrecognizedCharacter(char: Char) extends InputToken {
+    override def toString: String = s"character '$char''"
+  }
 }
 
 /**
@@ -241,16 +264,6 @@ class ToNumberParser(numberFormat: String, errorOnFail: 
Boolean) extends Seriali
    * This implementation of the [[check]] method returns any error, or the 
empty string on success.
    */
   private def validateFormatString: String = {
-    def multipleSignInNumberFormatError(message: String) = {
-      s"At most one $message is allowed in the number format: '$numberFormat'"
-    }
-
-    def notAtEndOfNumberFormatError(message: String) = {
-      s"$message must be at the end of the number format: '$numberFormat'"
-    }
-
-    val inputTokenCounts = formatTokens.groupBy(identity).mapValues(_.size)
-
     val firstDollarSignIndex: Int = formatTokens.indexOf(DollarSign())
     val firstDigitIndex: Int = formatTokens.indexWhere {
       case _: DigitGroups => true
@@ -276,58 +289,25 @@ class ToNumberParser(numberFormat: String, errorOnFail: 
Boolean) extends Seriali
 
     // Make sure the format string contains at least one token.
     if (numberFormat.isEmpty) {
-      "The format string cannot be empty"
-    }
-    // Make sure the format string does not contain any unrecognized 
characters.
-    else if 
(formatTokens.exists(_.isInstanceOf[InvalidUnrecognizedCharacter])) {
-      val unrecognizedChars =
-        formatTokens.filter {
-          _.isInstanceOf[InvalidUnrecognizedCharacter]
-        }.map {
-          case i: InvalidUnrecognizedCharacter => i.char
-        }
-      val char: Char = unrecognizedChars.head
-      s"Encountered invalid character $char in the number format: 
'$numberFormat'"
+      return "The format string cannot be empty"
     }
     // Make sure the format string contains at least one digit.
-    else if (!formatTokens.exists(
+    if (!formatTokens.exists(
       token => token.isInstanceOf[DigitGroups])) {
-      "The format string requires at least one number digit"
-    }
-    // Make sure the format string contains at most one decimal point.
-    else if (inputTokenCounts.getOrElse(DecimalPoint(), 0) > 1) {
-      multipleSignInNumberFormatError(s"'$POINT_LETTER' or '$POINT_SIGN'")
-    }
-    // Make sure the format string contains at most one plus or minus sign.
-    else if (inputTokenCounts.getOrElse(OptionalPlusOrMinusSign(), 0) > 1) {
-      multipleSignInNumberFormatError(s"'$OPTIONAL_PLUS_OR_MINUS_LETTER'")
-    }
-    // Make sure the format string contains at most one dollar sign.
-    else if (inputTokenCounts.getOrElse(DollarSign(), 0) > 1) {
-      multipleSignInNumberFormatError(s"'$DOLLAR_SIGN'")
-    }
-    // Make sure the format string contains at most one "MI" sequence.
-    else if (inputTokenCounts.getOrElse(OptionalMinusSign(), 0) > 1) {
-      multipleSignInNumberFormatError(s"'$OPTIONAL_MINUS_STRING'")
-    }
-    // Make sure the format string contains at most one closing angle bracket 
at the end.
-    else if (inputTokenCounts.getOrElse(ClosingAngleBracket(), 0) > 1 ||
-      (inputTokenCounts.getOrElse(ClosingAngleBracket(), 0) == 1 &&
-        formatTokens.last != ClosingAngleBracket())) {
-      
notAtEndOfNumberFormatError(s"'$WRAPPING_ANGLE_BRACKETS_TO_NEGATIVE_NUMBER'")
+      return "The format string requires at least one number digit"
     }
     // Make sure that any dollar sign in the format string occurs before any 
digits.
-    else if (firstDigitIndex < firstDollarSignIndex) {
-      s"Currency characters must appear before digits in the number format: 
'$numberFormat'"
+    if (firstDigitIndex < firstDollarSignIndex) {
+      return s"Currency characters must appear before digits in the number 
format: '$numberFormat'"
     }
     // Make sure that any dollar sign in the format string occurs before any 
decimal point.
-    else if (firstDecimalPointIndex != -1 &&
+    if (firstDecimalPointIndex != -1 &&
       firstDecimalPointIndex < firstDollarSignIndex) {
-      "Currency characters must appear before any decimal point in the " +
+      return "Currency characters must appear before any decimal point in the 
" +
         s"number format: '$numberFormat'"
     }
     // Make sure that any thousands separators in the format string have 
digits before and after.
-    else if (digitGroupsBeforeDecimalPoint.exists {
+    if (digitGroupsBeforeDecimalPoint.exists {
       case DigitGroups(tokens, _) =>
         tokens.zipWithIndex.exists({
           case (_: ThousandsSeparator, j: Int) if j == 0 || j == tokens.length 
- 1 =>
@@ -340,21 +320,64 @@ class ToNumberParser(numberFormat: String, errorOnFail: 
Boolean) extends Seriali
             false
         })
     }) {
-      "Thousands separators (,) must have digits in between them " +
+      return "Thousands separators (,) must have digits in between them " +
         s"in the number format: '$numberFormat'"
     }
-    // Thousands separators are not allowed after the decimal point, if any.
-    else if (digitGroupsAfterDecimalPoint.exists {
+    // Make sure that thousands separators does not appear after the decimal 
point, if any.
+    if (digitGroupsAfterDecimalPoint.exists {
       case DigitGroups(tokens, digits) =>
         tokens.length > digits.length
     }) {
-      "Thousands separators (,) may not appear after the decimal point " +
+      return "Thousands separators (,) may not appear after the decimal point 
" +
         s"in the number format: '$numberFormat'"
     }
-    // Validation of the format string finished successfully.
-    else {
-      ""
+    // Make sure that the format string does not contain any prohibited 
duplicate tokens.
+    val inputTokenCounts = formatTokens.groupBy(identity).mapValues(_.size)
+    Seq(DecimalPoint(),
+      OptionalPlusOrMinusSign(),
+      OptionalMinusSign(),
+      DollarSign(),
+      ClosingAngleBracket()).foreach {
+      token => if (inputTokenCounts.getOrElse(token, 0) > 1) {
+        return s"At most one ${token.toString} is allowed in the number 
format: '$numberFormat'"
+      }
+    }
+    // Enforce the ordering of tokens in the format string according to this 
specification:
+    // [ MI | S ] [ $ ]
+    // [ 0 | 9 | G | , ] [...]
+    // [ . | D ]
+    // [ 0 | 9 ] [...]
+    // [ $ ] [ PR | MI | S ]
+    val allowedFormatTokens: Seq[Seq[InputToken]] = Seq(
+      Seq(OpeningAngleBracket()),
+      Seq(OptionalMinusSign(), OptionalPlusOrMinusSign()),
+      Seq(DollarSign()),
+      Seq(DigitGroups(Seq(), Seq())),
+      Seq(DecimalPoint()),
+      Seq(DigitGroups(Seq(), Seq())),
+      Seq(DollarSign()),
+      Seq(OptionalMinusSign(), OptionalPlusOrMinusSign(), 
ClosingAngleBracket())
+    )
+    var formatTokenIndex = 0
+    for (allowedTokens: Seq[InputToken] <- allowedFormatTokens) {
+      def tokensMatch(lhs: InputToken, rhs: InputToken): Boolean = {
+        lhs match {
+          case _: DigitGroups => rhs.isInstanceOf[DigitGroups]
+          case _ => lhs == rhs
+        }
+      }
+      if (formatTokenIndex < formatTokens.length &&
+        allowedTokens.exists(tokensMatch(_, formatTokens(formatTokenIndex)))) {
+        formatTokenIndex += 1
+      }
     }
+    if (formatTokenIndex < formatTokens.length) {
+      return s"Unexpected ${formatTokens(formatTokenIndex).toString} found in 
the format string " +
+        s"'$numberFormat'; the structure of the format string must match: " +
+        "[MI|S] [$] [0|9|G|,]* [.|D] [0|9]* [$] [PR|MI|S]"
+    }
+    // Validation of the format string finished successfully.
+    ""
   }
 
   /**
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index afb05dd4d77..91b3d0c69b8 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -972,7 +972,6 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       ("<454>", "999PR") -> Decimal(-454),
       ("454-", "999MI") -> Decimal(-454),
       ("-$54", "MI$99") -> Decimal(-54),
-      ("$4-4", "$9MI9") -> Decimal(-44),
       // The input string contains more digits than fit in a long integer.
       ("123,456,789,123,456,789,123", "999,999,999,999,999,999,999") ->
         Decimal(new JavaBigDecimal("123456789123456789123"))
@@ -1009,7 +1008,8 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
   }
 
   test("ToNumber: negative tests (the format string is invalid)") {
-    val invalidCharacter = "Encountered invalid character"
+    val unexpectedCharacter = "the structure of the format string must match: 
" +
+      "[MI|S] [$] [0|9|G|,]* [.|D] [0|9]* [$] [PR|MI|S]"
     val thousandsSeparatorDigitsBetween =
       "Thousands separators (,) must have digits in between them"
     val mustBeAtEnd = "must be at the end of the number format"
@@ -1018,23 +1018,25 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       // The format string must not be empty.
       ("454", "") -> "The format string cannot be empty",
       // Make sure the format string does not contain any unrecognized 
characters.
-      ("454", "999@") -> invalidCharacter,
-      ("454", "999M") -> invalidCharacter,
-      ("454", "999P") -> invalidCharacter,
+      ("454", "999@") -> unexpectedCharacter,
+      ("454", "999M") -> unexpectedCharacter,
+      ("454", "999P") -> unexpectedCharacter,
       // Make sure the format string contains at least one digit.
       ("454", "$") -> "The format string requires at least one number digit",
       // Make sure the format string contains at most one decimal point.
       ("454", "99.99.99") -> atMostOne,
       // Make sure the format string contains at most one dollar sign.
       ("454", "$$99") -> atMostOne,
-      // Make sure the format string contains at most one minus sign at the 
end.
+      // Make sure the format string contains at most one minus sign at the 
beginning or end.
+      ("$4-4", "$9MI9") -> unexpectedCharacter,
+      ("--4", "SMI9") -> unexpectedCharacter,
       ("--$54", "SS$99") -> atMostOne,
       ("-$54", "MI$99MI") -> atMostOne,
       ("$4-4", "$9MI9MI") -> atMostOne,
       // Make sure the format string contains at most one closing angle 
bracket at the end.
-      ("<$45>", "PR$99") -> mustBeAtEnd,
-      ("$4<4>", "$9PR9") -> mustBeAtEnd,
-      ("<<454>>", "999PRPR") -> mustBeAtEnd,
+      ("<$45>", "PR$99") -> unexpectedCharacter,
+      ("$4<4>", "$9PR9") -> unexpectedCharacter,
+      ("<<454>>", "999PRPR") -> atMostOne,
       // Make sure that any dollar sign in the format string occurs before any 
digits.
       ("4$54", "9$99") -> "Currency characters must appear before digits",
       // Make sure that any dollar sign in the format string occurs before any 
decimal point.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to