You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2021/07/06 10:52:12 UTC

[spark] branch branch-3.2 updated: [SPARK-36023][SPARK-35735][SPARK-35768][SQL] Refactor code about parse string to DT/YM

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new b53d285  [SPARK-36023][SPARK-35735][SPARK-35768][SQL] Refactor code about parse string to DT/YM
b53d285 is described below

commit b53d285f72a918abeafaf7517281d08cf57beb64
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Tue Jul 6 13:51:06 2021 +0300

    [SPARK-36023][SPARK-35735][SPARK-35768][SQL] Refactor code about parse string to DT/YM
    
    ### What changes were proposed in this pull request?
     Refactor code about parse string to DT/YM intervals.
    
    ### Why are the changes needed?
    Extracting the common code about parse string to DT/YM should improve code maintenance.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existed UT.
    
    Closes #33217 from AngersZhuuuu/SPARK-35735-35768.
    
    Authored-by: Angerszhuuuu <an...@gmail.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
    (cherry picked from commit 26d1bb16bc565dbcb1a3f536dc78cd87be6c2468)
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../spark/sql/catalyst/util/IntervalUtils.scala    | 201 ++++++++++-----------
 .../sql/catalyst/expressions/CastSuiteBase.scala   |  28 ++-
 2 files changed, 123 insertions(+), 106 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index 30a2fa5..b174165 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToMicros
 import org.apache.spark.sql.catalyst.util.IntervalStringStyles.{ANSI_STYLE, HIVE_STYLE, IntervalStyle}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM}
+import org.apache.spark.sql.types.{DataType, DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM}
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
 // The style of textual representation of intervals
@@ -110,7 +110,7 @@ object IntervalUtils {
   private val yearMonthIndividualLiteralRegex =
     (s"(?i)^INTERVAL\\s+([+|-])?'$yearMonthIndividualPatternString'\\s+(YEAR|MONTH)$$").r
 
-  private def getSign(firstSign: String, secondSign: String): Int = {
+  private def finalSign(firstSign: String, secondSign: String = null): Int = {
     (firstSign, secondSign) match {
       case ("-", "-") => 1
       case ("-", _) => -1
@@ -119,6 +119,39 @@ object IntervalUtils {
     }
   }
 
+  private def throwIllegalIntervalFormatException(
+      input: UTF8String,
+      startFiled: Byte,
+      endField: Byte,
+      intervalStr: String,
+      typeName: String,
+      fallBackNotice: Option[String] = None) = {
+    throw new IllegalArgumentException(
+      s"Interval string does not match $intervalStr format of " +
+        s"${supportedFormat((startFiled, endField)).map(format => s"`$format`").mkString(", ")} " +
+        s"when cast to $typeName: ${input.toString}" +
+        s"${fallBackNotice.map(s => s", $s").getOrElse("")}")
+  }
+
+  private def checkIntervalStringDataType(
+      input: UTF8String,
+      targetStartField: Byte,
+      targetEndField: Byte,
+      inputIntervalType: DataType,
+      fallBackNotice: Option[String] = None): Unit = {
+    val (intervalStr, typeName, inputStartField, inputEndField) = inputIntervalType match {
+      case DT(startField, endField) =>
+        ("day-time", DT(targetStartField, targetEndField).typeName, startField, endField)
+      case YM(startField, endField) =>
+        ("year-month", YM(targetStartField, targetEndField).typeName, startField, endField)
+    }
+    if (targetStartField != inputStartField || targetEndField != inputEndField) {
+      throwIllegalIntervalFormatException(
+        input, targetStartField, targetEndField, intervalStr, typeName, fallBackNotice)
+    }
+  }
+
+
   val supportedFormat = Map(
     (YM.YEAR, YM.MONTH) -> Seq("[+|-]y-m", "INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH"),
     (YM.YEAR, YM.YEAR) -> Seq("[+|-]y", "INTERVAL [+|-]'[+|-]y' YEAR"),
@@ -140,56 +173,41 @@ object IntervalUtils {
       startField: Byte,
       endField: Byte): Int = {
 
-    def checkStringIntervalType(targetStartField: Byte, targetEndField: Byte): Unit = {
-      if (startField != targetStartField || endField != targetEndField) {
-        throw new IllegalArgumentException(s"Interval string does not match year-month format of " +
-          s"${supportedFormat((targetStartField, targetStartField))
-            .map(format => s"`$format`").mkString(", ")} " +
-          s"when cast to ${YM(startField, endField).typeName}: ${input.toString}")
-      }
-    }
+    def checkYMIntervalStringDataType(ym: YM): Unit =
+      checkIntervalStringDataType(input, startField, endField, ym)
 
     input.trimAll().toString match {
-      case yearMonthRegex("-", year, month) =>
-        checkStringIntervalType(YM.YEAR, YM.MONTH)
-        toYMInterval(year, month, -1)
-      case yearMonthRegex(_, year, month) =>
-        checkStringIntervalType(YM.YEAR, YM.MONTH)
-        toYMInterval(year, month, 1)
+      case yearMonthRegex(sign, year, month) =>
+        checkYMIntervalStringDataType(YM(YM.YEAR, YM.MONTH))
+        toYMInterval(year, month, finalSign(sign))
       case yearMonthLiteralRegex(firstSign, secondSign, year, month) =>
-        checkStringIntervalType(YM.YEAR, YM.MONTH)
-        toYMInterval(year, month, getSign(firstSign, secondSign))
-      case yearMonthIndividualRegex(secondSign, value) =>
-        safeToInterval {
-          val sign = getSign("+", secondSign)
+        checkYMIntervalStringDataType(YM(YM.YEAR, YM.MONTH))
+        toYMInterval(year, month, finalSign(firstSign, secondSign))
+      case yearMonthIndividualRegex(firstSign, value) =>
+        safeToInterval("year-month") {
+          val sign = finalSign(firstSign)
           if (endField == YM.YEAR) {
             sign * Math.toIntExact(value.toLong * MONTHS_PER_YEAR)
           } else if (startField == YM.MONTH) {
             Math.toIntExact(sign * value.toLong)
           } else {
-            throw new IllegalArgumentException(
-              s"Interval string does not match year-month format of " +
-                s"${supportedFormat((YM.YEAR, YM.MONTH))
-                  .map(format => s"`$format`").mkString(", ")} " +
-                s"when cast to ${YM(startField, endField).typeName}: ${input.toString}")
+            throwIllegalIntervalFormatException(
+              input, startField, endField, "year-month", YM(startField, endField).typeName)
           }
         }
       case yearMonthIndividualLiteralRegex(firstSign, secondSign, value, suffix) =>
-        safeToInterval {
-          val sign = getSign(firstSign, secondSign)
+        safeToInterval("year-month") {
+          val sign = finalSign(firstSign, secondSign)
           if ("YEAR".equalsIgnoreCase(suffix)) {
-            checkStringIntervalType(YM.YEAR, YM.YEAR)
+            checkYMIntervalStringDataType(YM(YM.YEAR, YM.YEAR))
             sign * Math.toIntExact(value.toLong * MONTHS_PER_YEAR)
           } else {
-            checkStringIntervalType(YM.MONTH, YM.MONTH)
+            checkYMIntervalStringDataType(YM(YM.MONTH, YM.MONTH))
             Math.toIntExact(sign * value.toLong)
           }
         }
-      case _ => throw new IllegalArgumentException(
-        s"Interval string does not match year-month format of " +
-          s"${supportedFormat((YM.YEAR, YM.MONTH))
-            .map(format => s"`$format`").mkString(", ")} " +
-          s"when cast to ${YM(startField, endField).typeName}: ${input.toString}")
+      case _ => throwIllegalIntervalFormatException(input, startField, endField,
+        "year-month", YM(startField, endField).typeName)
     }
   }
 
@@ -201,28 +219,26 @@ object IntervalUtils {
   def fromYearMonthString(input: String): CalendarInterval = {
     require(input != null, "Interval year-month string must be not null")
     input.trim match {
-      case yearMonthRegex("-", yearStr, monthStr) =>
-        new CalendarInterval(toYMInterval(yearStr, monthStr, -1), 0, 0)
-      case yearMonthRegex(_, yearStr, monthStr) =>
-        new CalendarInterval(toYMInterval(yearStr, monthStr, 1), 0, 0)
+      case yearMonthRegex(sign, yearStr, monthStr) =>
+        new CalendarInterval(toYMInterval(yearStr, monthStr, finalSign(sign)), 0, 0)
       case _ =>
         throw new IllegalArgumentException(
           s"Interval string does not match year-month format of 'y-m': $input")
     }
   }
 
-  private def safeToInterval[T](f: => T): T = {
+  private def safeToInterval[T](interval: String)(f: => T): T = {
     try {
       f
     } catch {
       case NonFatal(e) =>
         throw new IllegalArgumentException(
-          s"Error parsing interval year-month string: ${e.getMessage}", e)
+          s"Error parsing interval $interval string: ${e.getMessage}", e)
     }
   }
 
   private def toYMInterval(yearStr: String, monthStr: String, sign: Int): Int = {
-    safeToInterval {
+    safeToInterval("year-month") {
       val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE / MONTHS_PER_YEAR)
       val totalMonths = sign * (years * MONTHS_PER_YEAR + toLongWithRange(MONTH, monthStr, 0, 11))
       Math.toIntExact(totalMonths)
@@ -279,15 +295,6 @@ object IntervalUtils {
       startField: Byte,
       endField: Byte): Long = {
 
-    def checkStringIntervalType(targetStartField: Byte, targetEndField: Byte): Unit = {
-      if (startField != targetStartField || endField != targetEndField) {
-        throw new IllegalArgumentException(s"Interval string does not match day-time format of " +
-          s"${supportedFormat((targetStartField, targetStartField))
-            .map(format => s"`$format`").mkString(", ")} " +
-          s"when cast to ${DT(startField, endField).typeName}: ${input.toString}")
-      }
-    }
-
     def secondAndMicro(second: String, micro: String): String = {
       if (micro != null) {
         s"$second$micro"
@@ -296,50 +303,53 @@ object IntervalUtils {
       }
     }
 
+    def checkDTIntervalStringDataType(dt: DT): Unit =
+      checkIntervalStringDataType(input, startField, endField, dt, Some(fallbackNotice))
+
     input.trimAll().toString match {
       case dayHourRegex(sign, day, hour) =>
-        checkStringIntervalType(DT.DAY, DT.HOUR)
-        toDTInterval(day, hour, "0", "0", getSign(null, sign))
+        checkDTIntervalStringDataType(DT(DT.DAY, DT.HOUR))
+        toDTInterval(day, hour, "0", "0", finalSign(sign))
       case dayHourLiteralRegex(firstSign, secondSign, day, hour) =>
-        checkStringIntervalType(DT.DAY, DT.HOUR)
-        toDTInterval(day, hour, "0", "0", getSign(firstSign, secondSign))
+        checkDTIntervalStringDataType(DT(DT.DAY, DT.HOUR))
+        toDTInterval(day, hour, "0", "0", finalSign(firstSign, secondSign))
       case dayMinuteRegex(sign, day, hour, minute) =>
-        checkStringIntervalType(DT.DAY, DT.MINUTE)
-        toDTInterval(day, hour, minute, "0", getSign(null, sign))
+        checkDTIntervalStringDataType(DT(DT.DAY, DT.MINUTE))
+        toDTInterval(day, hour, minute, "0", finalSign(sign))
       case dayMinuteLiteralRegex(firstSign, secondSign, day, hour, minute) =>
-        checkStringIntervalType(DT.DAY, DT.MINUTE)
-        toDTInterval(day, hour, minute, "0", getSign(firstSign, secondSign))
+        checkDTIntervalStringDataType(DT(DT.DAY, DT.MINUTE))
+        toDTInterval(day, hour, minute, "0", finalSign(firstSign, secondSign))
       case daySecondRegex(sign, day, hour, minute, second, micro) =>
-        checkStringIntervalType(DT.DAY, DT.SECOND)
-        toDTInterval(day, hour, minute, secondAndMicro(second, micro), getSign(null, sign))
+        checkDTIntervalStringDataType(DT(DT.DAY, DT.SECOND))
+        toDTInterval(day, hour, minute, secondAndMicro(second, micro), finalSign(sign))
       case daySecondLiteralRegex(firstSign, secondSign, day, hour, minute, second, micro) =>
-        checkStringIntervalType(DT.DAY, DT.SECOND)
+        checkDTIntervalStringDataType(DT(DT.DAY, DT.SECOND))
         toDTInterval(day, hour, minute, secondAndMicro(second, micro),
-          getSign(firstSign, secondSign))
+          finalSign(firstSign, secondSign))
 
       case hourMinuteRegex(sign, hour, minute) =>
-        checkStringIntervalType(DT.HOUR, DT.MINUTE)
-        toDTInterval(hour, minute, "0", getSign(null, sign))
+        checkDTIntervalStringDataType(DT(DT.HOUR, DT.MINUTE))
+        toDTInterval(hour, minute, "0", finalSign(sign))
       case hourMinuteLiteralRegex(firstSign, secondSign, hour, minute) =>
-        checkStringIntervalType(DT.HOUR, DT.MINUTE)
-        toDTInterval(hour, minute, "0", getSign(firstSign, secondSign))
+        checkDTIntervalStringDataType(DT(DT.HOUR, DT.MINUTE))
+        toDTInterval(hour, minute, "0", finalSign(firstSign, secondSign))
       case hourSecondRegex(sign, hour, minute, second, micro) =>
-        checkStringIntervalType(DT.HOUR, DT.SECOND)
-        toDTInterval(hour, minute, secondAndMicro(second, micro), getSign(null, sign))
+        checkDTIntervalStringDataType(DT(DT.HOUR, DT.SECOND))
+        toDTInterval(hour, minute, secondAndMicro(second, micro), finalSign(sign))
       case hourSecondLiteralRegex(firstSign, secondSign, hour, minute, second, micro) =>
-        checkStringIntervalType(DT.HOUR, DT.SECOND)
-        toDTInterval(hour, minute, secondAndMicro(second, micro), getSign(firstSign, secondSign))
+        checkDTIntervalStringDataType(DT(DT.HOUR, DT.SECOND))
+        toDTInterval(hour, minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign))
 
       case minuteSecondRegex(sign, minute, second, micro) =>
-        checkStringIntervalType(DT.MINUTE, DT.SECOND)
-        toDTInterval(minute, secondAndMicro(second, micro), getSign(null, sign))
+        checkDTIntervalStringDataType(DT(DT.MINUTE, DT.SECOND))
+        toDTInterval(minute, secondAndMicro(second, micro), finalSign(sign))
       case minuteSecondLiteralRegex(firstSign, secondSign, minute, second, micro) =>
-        checkStringIntervalType(DT.MINUTE, DT.SECOND)
-        toDTInterval(minute, secondAndMicro(second, micro), getSign(firstSign, secondSign))
+        checkDTIntervalStringDataType(DT(DT.MINUTE, DT.SECOND))
+        toDTInterval(minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign))
 
-      case dayTimeIndividualRegex(secondSign, value, suffix) =>
-        safeToInterval {
-          val sign = getSign("+", secondSign)
+      case dayTimeIndividualRegex(firstSign, value, suffix) =>
+        safeToInterval("day-time") {
+          val sign = finalSign(firstSign)
           (startField, endField) match {
             case (DT.DAY, DT.DAY) if suffix == null && value.length <= 9 =>
               sign * value.toLong * MICROS_PER_DAY
@@ -352,46 +362,35 @@ object IntervalUtils {
                 case 1 => parseSecondNano(secondAndMicro(value, suffix))
                 case -1 => parseSecondNano(s"-${secondAndMicro(value, suffix)}")
               }
-            case (_, _) => throw new IllegalArgumentException(
-              s"Interval string does not match day-time format of " +
-                s"${supportedFormat((startField, endField))
-                    .map(format => s"`$format`").mkString(", ")} " +
-                s"when cast to ${DT(startField, endField).typeName}: ${input.toString}")
+            case (_, _) => throwIllegalIntervalFormatException(input, startField, endField,
+              "day-time", DT(startField, endField).typeName, Some(fallbackNotice))
           }
         }
       case dayTimeIndividualLiteralRegex(firstSign, secondSign, value, suffix, unit) =>
-        safeToInterval {
-          val sign = getSign(firstSign, secondSign)
+        safeToInterval("day-time") {
+          val sign = finalSign(firstSign, secondSign)
           unit match {
             case "DAY" if suffix == null && value.length <= 9 =>
-              checkStringIntervalType(DT.DAY, DT.DAY)
+              checkDTIntervalStringDataType(DT(DT.DAY, DT.DAY))
               sign * value.toLong * MICROS_PER_DAY
             case "HOUR" if suffix == null && value.length <= 10 =>
-              checkStringIntervalType(DT.HOUR, DT.HOUR)
+              checkDTIntervalStringDataType(DT(DT.HOUR, DT.HOUR))
               sign * value.toLong * MICROS_PER_HOUR
             case "MINUTE" if suffix == null && value.length <= 12 =>
-              checkStringIntervalType(DT.MINUTE, DT.MINUTE)
+              checkDTIntervalStringDataType(DT(DT.MINUTE, DT.MINUTE))
               sign * value.toLong * MICROS_PER_MINUTE
             case "SECOND" if value.length <= 13 =>
-              checkStringIntervalType(DT.SECOND, DT.SECOND)
+              checkDTIntervalStringDataType(DT(DT.SECOND, DT.SECOND))
               sign match {
                 case 1 => parseSecondNano(secondAndMicro(value, suffix))
                 case -1 => parseSecondNano(s"-${secondAndMicro(value, suffix)}")
               }
-            case _ => throw new IllegalArgumentException(
-              s"Interval string does not match day-time format of " +
-                s"${supportedFormat((startField, endField))
-                  .map(format => s"`$format`").mkString(", ")} " +
-                s"when cast to ${DT(startField, endField).typeName}: ${input.toString}")
+            case _ => throwIllegalIntervalFormatException(input, startField, endField,
+              "day-time", DT(startField, endField).typeName, Some(fallbackNotice))
           }
         }
-      case _ =>
-        throw new IllegalArgumentException(
-          s"Interval string does not match day-time format of " +
-            s"${supportedFormat((startField, endField))
-              .map(format => s"`$format`").mkString(", ")} " +
-            s"when cast to ${DT(startField, endField).typeName}: ${input.toString}, " +
-            s"$fallbackNotice")
+      case _ => throwIllegalIntervalFormatException(input, startField, endField,
+        "day-time", DT(startField, endField).typeName, Some(fallbackNotice))
     }
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index 66f5b50..8313242 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
 import org.apache.spark.sql.catalyst.util.DateTimeConstants._
 import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils._
+import org.apache.spark.sql.catalyst.util.IntervalUtils
 import org.apache.spark.sql.catalyst.util.IntervalUtils.microsToDuration
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -1113,10 +1114,14 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
 
     if (!isTryCast) {
       Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval =>
+        val dataType = YearMonthIntervalType()
         val e = intercept[IllegalArgumentException] {
-          cast(Literal.create(interval), YearMonthIntervalType()).eval()
+          cast(Literal.create(interval), dataType).eval()
         }.getMessage
-        assert(e.contains("Interval string does not match year-month format"))
+        assert(e.contains(s"Interval string does not match year-month format of " +
+          s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+            .map(format => s"`$format`").mkString(", ")} " +
+          s"when cast to ${dataType.typeName}: $interval"))
       }
       Seq(("1", YearMonthIntervalType(YEAR, MONTH)),
         ("1", YearMonthIntervalType(YEAR, MONTH)),
@@ -1132,7 +1137,10 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
           val e = intercept[IllegalArgumentException] {
             cast(Literal.create(interval), dataType).eval()
           }.getMessage
-          assert(e.contains("Interval string does not match year-month format"))
+          assert(e.contains(s"Interval string does not match year-month format of " +
+            s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+              .map(format => s"`$format`").mkString(", ")} " +
+            s"when cast to ${dataType.typeName}: $interval"))
         }
     }
   }
@@ -1249,7 +1257,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
           val e = intercept[IllegalArgumentException] {
             cast(Literal.create(interval), dataType).eval()
           }.getMessage
-          assert(e.contains("Interval string does not match day-time format"))
+          assert(e.contains(s"Interval string does not match day-time format of " +
+            s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+              .map(format => s"`$format`").mkString(", ")} " +
+            s"when cast to ${dataType.typeName}: $interval, " +
+            s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
+            "to restore the behavior before Spark 3.0."))
         }
 
       // Check first field outof bound
@@ -1267,7 +1280,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
           val e = intercept[IllegalArgumentException] {
             cast(Literal.create(interval), dataType).eval()
           }.getMessage
-          assert(e.contains("Interval string does not match day-time format"))
+          assert(e.contains(s"Interval string does not match day-time format of " +
+            s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+              .map(format => s"`$format`").mkString(", ")} " +
+            s"when cast to ${dataType.typeName}: $interval, " +
+            s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
+            "to restore the behavior before Spark 3.0."))
         }
     }
   }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org