You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/06/14 00:54:46 UTC

[spark] branch branch-3.3 updated: [SPARK-38796][SQL] Update to_number and try_to_number functions to allow PR with positive numbers

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new c6778bacd48 [SPARK-38796][SQL] Update to_number and try_to_number functions to allow PR with positive numbers
c6778bacd48 is described below

commit c6778bacd481e794ef013efd241303421f8400e4
Author: Daniel Tenedorio <da...@databricks.com>
AuthorDate: Tue Jun 14 08:54:04 2022 +0800

    [SPARK-38796][SQL] Update to_number and try_to_number functions to allow PR with positive numbers
    
    ### What changes were proposed in this pull request?
    
    Update `to_number` and `try_to_number` functions to allow the `PR` format token with input strings comprising positive numbers.
    
    Before this bug fix, function calls like `to_number(' 123 ', '999PR')` would fail. Now they succeed, which is helpful since `PR` should allow both positive and negative numbers.
    
    This satisfies the following specification:
    
    ```
    to_number(expr, fmt)
    fmt
      { ' [ MI | S ] [ L | $ ]
          [ 0 | 9 | G | , ] [...]
          [ . | D ]
          [ 0 | 9 ] [...]
          [ L | $ ] [ PR | MI | S ] ' }
    ```
    
    ### Why are the changes needed?
    
    After reviewing the specification, this behavior makes the most sense.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, a slight change in the behavior of the format string.
    
    ### How was this patch tested?
    
    Existing and updated unit test coverage.
    
    Closes #36861 from dtenedor/to-number-fix-pr.
    
    Authored-by: Daniel Tenedorio <da...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 4a803ca22a9a98f9bbbbd1a5a33b9ae394fb7c49)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/util/ToNumberParser.scala   | 98 +++++++++++++---------
 .../expressions/StringExpressionsSuite.scala       |  5 +-
 2 files changed, 61 insertions(+), 42 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
index 716224983e0..22e655c4eb4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
@@ -397,6 +397,9 @@ class ToNumberParser(numberFormat: String, errorOnFail: Boolean) extends Seriali
     beforeDecimalPoint.clear()
     afterDecimalPoint.clear()
     var reachedDecimalPoint = false
+    // Record whether we have consumed opening angle bracket characters in the input string.
+    var reachedOpeningAngleBracket = false
+    var reachedClosingAngleBracket = false
     // Record whether the input specified a negative result, such as with a minus sign.
     var negateResult = false
     // This is an index into the characters of the provided input string.
@@ -407,66 +410,79 @@ class ToNumberParser(numberFormat: String, errorOnFail: Boolean) extends Seriali
     // Iterate through the tokens representing the provided format string, in order.
     while (formatIndex < formatTokens.size) {
       val token: InputToken = formatTokens(formatIndex)
+      val inputChar: Option[Char] =
+        if (inputIndex < inputLength) {
+          Some(inputString(inputIndex))
+        } else {
+          Option.empty[Char]
+        }
       token match {
         case d: DigitGroups =>
           inputIndex = parseDigitGroups(d, inputString, inputIndex, reachedDecimalPoint).getOrElse(
             return formatMatchFailure(input, numberFormat))
         case DecimalPoint() =>
-          if (inputIndex < inputLength &&
-            inputString(inputIndex) == POINT_SIGN) {
-            reachedDecimalPoint = true
-            inputIndex += 1
-          } else {
-            // There is no decimal point. Consume the token and remain at the same character in the
-            // input string.
+          inputChar.foreach {
+            case POINT_SIGN =>
+              reachedDecimalPoint = true
+              inputIndex += 1
+            case _ =>
+              // There is no decimal point. Consume the token and remain at the same character in
+              // the input string.
           }
         case DollarSign() =>
-          if (inputIndex >= inputLength ||
-            inputString(inputIndex) != DOLLAR_SIGN) {
-            // The input string did not contain an expected dollar sign.
-            return formatMatchFailure(input, numberFormat)
+          inputChar.foreach {
+            case DOLLAR_SIGN =>
+              inputIndex += 1
+            case _ =>
+              // The input string did not contain an expected dollar sign.
+              return formatMatchFailure(input, numberFormat)
           }
-          inputIndex += 1
         case OptionalPlusOrMinusSign() =>
-          if (inputIndex < inputLength &&
-            inputString(inputIndex) == PLUS_SIGN) {
-            inputIndex += 1
-          } else if (inputIndex < inputLength &&
-            inputString(inputIndex) == MINUS_SIGN) {
-            negateResult = !negateResult
-            inputIndex += 1
-          } else {
-            // There is no plus or minus sign. Consume the token and remain at the same character in
-            // the input string.
+          inputChar.foreach {
+            case PLUS_SIGN =>
+              inputIndex += 1
+            case MINUS_SIGN =>
+              negateResult = !negateResult
+              inputIndex += 1
+            case _ =>
+              // There is no plus or minus sign. Consume the token and remain at the same character
+              // in the input string.
           }
         case OptionalMinusSign() =>
-          if (inputIndex < inputLength &&
-            inputString(inputIndex) == MINUS_SIGN) {
-            negateResult = !negateResult
-            inputIndex += 1
-          } else {
-            // There is no minus sign. Consume the token and remain at the same character in the
-            // input string.
+          inputChar.foreach {
+            case MINUS_SIGN =>
+              negateResult = !negateResult
+              inputIndex += 1
+            case _ =>
+              // There is no minus sign. Consume the token and remain at the same character in the
+              // input string.
           }
         case OpeningAngleBracket() =>
-          if (inputIndex >= inputLength ||
-            inputString(inputIndex) != ANGLE_BRACKET_OPEN) {
-            // The input string did not contain an expected opening angle bracket.
-            return formatMatchFailure(input, numberFormat)
+          inputChar.foreach {
+            case ANGLE_BRACKET_OPEN =>
+              if (reachedOpeningAngleBracket) {
+                return formatMatchFailure(input, numberFormat)
+              }
+              reachedOpeningAngleBracket = true
+              inputIndex += 1
+            case _ =>
           }
-          inputIndex += 1
         case ClosingAngleBracket() =>
-          if (inputIndex >= inputLength ||
-            inputString(inputIndex) != ANGLE_BRACKET_CLOSE) {
-            // The input string did not contain an expected closing angle bracket.
-            return formatMatchFailure(input, numberFormat)
+          inputChar.foreach {
+            case ANGLE_BRACKET_CLOSE =>
+              if (!reachedOpeningAngleBracket) {
+                return formatMatchFailure(input, numberFormat)
+              }
+              reachedClosingAngleBracket = true
+              negateResult = !negateResult
+              inputIndex += 1
+            case _ =>
           }
-          negateResult = !negateResult
-          inputIndex += 1
       }
       formatIndex += 1
     }
-    if (inputIndex < inputLength) {
+    if (inputIndex < inputLength ||
+      reachedOpeningAngleBracket != reachedClosingAngleBracket) {
       // If we have consumed all the tokens in the format string, but characters remain unconsumed
       // in the input string, then the input string does not match the format string.
       formatMatchFailure(input, numberFormat)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index db7aae99855..655e9b744bf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -969,7 +969,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       ("+$89,1,2,3,45.123", "S$999,0,0,0,999.00000") -> Decimal(8912345.123),
       ("-454", "S999") -> Decimal(-454),
       ("+454", "S999") -> Decimal(454),
-      ("<454>", "999PR") -> Decimal(-454),
+      ("454", "999PR") -> Decimal(454),
+      (" 454 ", "999PR") -> Decimal(454),
       ("454-", "999MI") -> Decimal(-454),
       ("-$54", "MI$99") -> Decimal(-54),
       // The input string contains more digits than fit in a long integer.
@@ -1089,6 +1090,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       // The trailing PR required exactly one leading < and trailing >.
       ("<454", "999PR"),
       ("454>", "999PR"),
+      ("<454 ", "999PR"),
+      (" 454>", "999PR"),
       ("<<454>>", "999PR"),
       // At least three digits were required.
       ("45", "S$999,099.99"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org