You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2021/07/06 10:52:12 UTC
[spark] branch branch-3.2 updated:
[SPARK-36023][SPARK-35735][SPARK-35768][SQL] Refactor code about parse
string to DT/YM
This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new b53d285 [SPARK-36023][SPARK-35735][SPARK-35768][SQL] Refactor code about parse string to DT/YM
b53d285 is described below
commit b53d285f72a918abeafaf7517281d08cf57beb64
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Tue Jul 6 13:51:06 2021 +0300
[SPARK-36023][SPARK-35735][SPARK-35768][SQL] Refactor code about parse string to DT/YM
### What changes were proposed in this pull request?
Refactor code about parse string to DT/YM intervals.
### Why are the changes needed?
Extracting the common code about parse string to DT/YM should improve code maintenance.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existed UT.
Closes #33217 from AngersZhuuuu/SPARK-35735-35768.
Authored-by: Angerszhuuuu <an...@gmail.com>
Signed-off-by: Max Gekk <ma...@gmail.com>
(cherry picked from commit 26d1bb16bc565dbcb1a3f536dc78cd87be6c2468)
Signed-off-by: Max Gekk <ma...@gmail.com>
---
.../spark/sql/catalyst/util/IntervalUtils.scala | 201 ++++++++++-----------
.../sql/catalyst/expressions/CastSuiteBase.scala | 28 ++-
2 files changed, 123 insertions(+), 106 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index 30a2fa5..b174165 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToMicros
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.{ANSI_STYLE, HIVE_STYLE, IntervalStyle}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM}
+import org.apache.spark.sql.types.{DataType, DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
// The style of textual representation of intervals
@@ -110,7 +110,7 @@ object IntervalUtils {
private val yearMonthIndividualLiteralRegex =
(s"(?i)^INTERVAL\\s+([+|-])?'$yearMonthIndividualPatternString'\\s+(YEAR|MONTH)$$").r
- private def getSign(firstSign: String, secondSign: String): Int = {
+ private def finalSign(firstSign: String, secondSign: String = null): Int = {
(firstSign, secondSign) match {
case ("-", "-") => 1
case ("-", _) => -1
@@ -119,6 +119,39 @@ object IntervalUtils {
}
}
+ private def throwIllegalIntervalFormatException(
+ input: UTF8String,
+ startFiled: Byte,
+ endField: Byte,
+ intervalStr: String,
+ typeName: String,
+ fallBackNotice: Option[String] = None) = {
+ throw new IllegalArgumentException(
+ s"Interval string does not match $intervalStr format of " +
+ s"${supportedFormat((startFiled, endField)).map(format => s"`$format`").mkString(", ")} " +
+ s"when cast to $typeName: ${input.toString}" +
+ s"${fallBackNotice.map(s => s", $s").getOrElse("")}")
+ }
+
+ private def checkIntervalStringDataType(
+ input: UTF8String,
+ targetStartField: Byte,
+ targetEndField: Byte,
+ inputIntervalType: DataType,
+ fallBackNotice: Option[String] = None): Unit = {
+ val (intervalStr, typeName, inputStartField, inputEndField) = inputIntervalType match {
+ case DT(startField, endField) =>
+ ("day-time", DT(targetStartField, targetEndField).typeName, startField, endField)
+ case YM(startField, endField) =>
+ ("year-month", YM(targetStartField, targetEndField).typeName, startField, endField)
+ }
+ if (targetStartField != inputStartField || targetEndField != inputEndField) {
+ throwIllegalIntervalFormatException(
+ input, targetStartField, targetEndField, intervalStr, typeName, fallBackNotice)
+ }
+ }
+
+
val supportedFormat = Map(
(YM.YEAR, YM.MONTH) -> Seq("[+|-]y-m", "INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH"),
(YM.YEAR, YM.YEAR) -> Seq("[+|-]y", "INTERVAL [+|-]'[+|-]y' YEAR"),
@@ -140,56 +173,41 @@ object IntervalUtils {
startField: Byte,
endField: Byte): Int = {
- def checkStringIntervalType(targetStartField: Byte, targetEndField: Byte): Unit = {
- if (startField != targetStartField || endField != targetEndField) {
- throw new IllegalArgumentException(s"Interval string does not match year-month format of " +
- s"${supportedFormat((targetStartField, targetStartField))
- .map(format => s"`$format`").mkString(", ")} " +
- s"when cast to ${YM(startField, endField).typeName}: ${input.toString}")
- }
- }
+ def checkYMIntervalStringDataType(ym: YM): Unit =
+ checkIntervalStringDataType(input, startField, endField, ym)
input.trimAll().toString match {
- case yearMonthRegex("-", year, month) =>
- checkStringIntervalType(YM.YEAR, YM.MONTH)
- toYMInterval(year, month, -1)
- case yearMonthRegex(_, year, month) =>
- checkStringIntervalType(YM.YEAR, YM.MONTH)
- toYMInterval(year, month, 1)
+ case yearMonthRegex(sign, year, month) =>
+ checkYMIntervalStringDataType(YM(YM.YEAR, YM.MONTH))
+ toYMInterval(year, month, finalSign(sign))
case yearMonthLiteralRegex(firstSign, secondSign, year, month) =>
- checkStringIntervalType(YM.YEAR, YM.MONTH)
- toYMInterval(year, month, getSign(firstSign, secondSign))
- case yearMonthIndividualRegex(secondSign, value) =>
- safeToInterval {
- val sign = getSign("+", secondSign)
+ checkYMIntervalStringDataType(YM(YM.YEAR, YM.MONTH))
+ toYMInterval(year, month, finalSign(firstSign, secondSign))
+ case yearMonthIndividualRegex(firstSign, value) =>
+ safeToInterval("year-month") {
+ val sign = finalSign(firstSign)
if (endField == YM.YEAR) {
sign * Math.toIntExact(value.toLong * MONTHS_PER_YEAR)
} else if (startField == YM.MONTH) {
Math.toIntExact(sign * value.toLong)
} else {
- throw new IllegalArgumentException(
- s"Interval string does not match year-month format of " +
- s"${supportedFormat((YM.YEAR, YM.MONTH))
- .map(format => s"`$format`").mkString(", ")} " +
- s"when cast to ${YM(startField, endField).typeName}: ${input.toString}")
+ throwIllegalIntervalFormatException(
+ input, startField, endField, "year-month", YM(startField, endField).typeName)
}
}
case yearMonthIndividualLiteralRegex(firstSign, secondSign, value, suffix) =>
- safeToInterval {
- val sign = getSign(firstSign, secondSign)
+ safeToInterval("year-month") {
+ val sign = finalSign(firstSign, secondSign)
if ("YEAR".equalsIgnoreCase(suffix)) {
- checkStringIntervalType(YM.YEAR, YM.YEAR)
+ checkYMIntervalStringDataType(YM(YM.YEAR, YM.YEAR))
sign * Math.toIntExact(value.toLong * MONTHS_PER_YEAR)
} else {
- checkStringIntervalType(YM.MONTH, YM.MONTH)
+ checkYMIntervalStringDataType(YM(YM.MONTH, YM.MONTH))
Math.toIntExact(sign * value.toLong)
}
}
- case _ => throw new IllegalArgumentException(
- s"Interval string does not match year-month format of " +
- s"${supportedFormat((YM.YEAR, YM.MONTH))
- .map(format => s"`$format`").mkString(", ")} " +
- s"when cast to ${YM(startField, endField).typeName}: ${input.toString}")
+ case _ => throwIllegalIntervalFormatException(input, startField, endField,
+ "year-month", YM(startField, endField).typeName)
}
}
@@ -201,28 +219,26 @@ object IntervalUtils {
def fromYearMonthString(input: String): CalendarInterval = {
require(input != null, "Interval year-month string must be not null")
input.trim match {
- case yearMonthRegex("-", yearStr, monthStr) =>
- new CalendarInterval(toYMInterval(yearStr, monthStr, -1), 0, 0)
- case yearMonthRegex(_, yearStr, monthStr) =>
- new CalendarInterval(toYMInterval(yearStr, monthStr, 1), 0, 0)
+ case yearMonthRegex(sign, yearStr, monthStr) =>
+ new CalendarInterval(toYMInterval(yearStr, monthStr, finalSign(sign)), 0, 0)
case _ =>
throw new IllegalArgumentException(
s"Interval string does not match year-month format of 'y-m': $input")
}
}
- private def safeToInterval[T](f: => T): T = {
+ private def safeToInterval[T](interval: String)(f: => T): T = {
try {
f
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(
- s"Error parsing interval year-month string: ${e.getMessage}", e)
+ s"Error parsing interval $interval string: ${e.getMessage}", e)
}
}
private def toYMInterval(yearStr: String, monthStr: String, sign: Int): Int = {
- safeToInterval {
+ safeToInterval("year-month") {
val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE / MONTHS_PER_YEAR)
val totalMonths = sign * (years * MONTHS_PER_YEAR + toLongWithRange(MONTH, monthStr, 0, 11))
Math.toIntExact(totalMonths)
@@ -279,15 +295,6 @@ object IntervalUtils {
startField: Byte,
endField: Byte): Long = {
- def checkStringIntervalType(targetStartField: Byte, targetEndField: Byte): Unit = {
- if (startField != targetStartField || endField != targetEndField) {
- throw new IllegalArgumentException(s"Interval string does not match day-time format of " +
- s"${supportedFormat((targetStartField, targetStartField))
- .map(format => s"`$format`").mkString(", ")} " +
- s"when cast to ${DT(startField, endField).typeName}: ${input.toString}")
- }
- }
-
def secondAndMicro(second: String, micro: String): String = {
if (micro != null) {
s"$second$micro"
@@ -296,50 +303,53 @@ object IntervalUtils {
}
}
+ def checkDTIntervalStringDataType(dt: DT): Unit =
+ checkIntervalStringDataType(input, startField, endField, dt, Some(fallbackNotice))
+
input.trimAll().toString match {
case dayHourRegex(sign, day, hour) =>
- checkStringIntervalType(DT.DAY, DT.HOUR)
- toDTInterval(day, hour, "0", "0", getSign(null, sign))
+ checkDTIntervalStringDataType(DT(DT.DAY, DT.HOUR))
+ toDTInterval(day, hour, "0", "0", finalSign(sign))
case dayHourLiteralRegex(firstSign, secondSign, day, hour) =>
- checkStringIntervalType(DT.DAY, DT.HOUR)
- toDTInterval(day, hour, "0", "0", getSign(firstSign, secondSign))
+ checkDTIntervalStringDataType(DT(DT.DAY, DT.HOUR))
+ toDTInterval(day, hour, "0", "0", finalSign(firstSign, secondSign))
case dayMinuteRegex(sign, day, hour, minute) =>
- checkStringIntervalType(DT.DAY, DT.MINUTE)
- toDTInterval(day, hour, minute, "0", getSign(null, sign))
+ checkDTIntervalStringDataType(DT(DT.DAY, DT.MINUTE))
+ toDTInterval(day, hour, minute, "0", finalSign(sign))
case dayMinuteLiteralRegex(firstSign, secondSign, day, hour, minute) =>
- checkStringIntervalType(DT.DAY, DT.MINUTE)
- toDTInterval(day, hour, minute, "0", getSign(firstSign, secondSign))
+ checkDTIntervalStringDataType(DT(DT.DAY, DT.MINUTE))
+ toDTInterval(day, hour, minute, "0", finalSign(firstSign, secondSign))
case daySecondRegex(sign, day, hour, minute, second, micro) =>
- checkStringIntervalType(DT.DAY, DT.SECOND)
- toDTInterval(day, hour, minute, secondAndMicro(second, micro), getSign(null, sign))
+ checkDTIntervalStringDataType(DT(DT.DAY, DT.SECOND))
+ toDTInterval(day, hour, minute, secondAndMicro(second, micro), finalSign(sign))
case daySecondLiteralRegex(firstSign, secondSign, day, hour, minute, second, micro) =>
- checkStringIntervalType(DT.DAY, DT.SECOND)
+ checkDTIntervalStringDataType(DT(DT.DAY, DT.SECOND))
toDTInterval(day, hour, minute, secondAndMicro(second, micro),
- getSign(firstSign, secondSign))
+ finalSign(firstSign, secondSign))
case hourMinuteRegex(sign, hour, minute) =>
- checkStringIntervalType(DT.HOUR, DT.MINUTE)
- toDTInterval(hour, minute, "0", getSign(null, sign))
+ checkDTIntervalStringDataType(DT(DT.HOUR, DT.MINUTE))
+ toDTInterval(hour, minute, "0", finalSign(sign))
case hourMinuteLiteralRegex(firstSign, secondSign, hour, minute) =>
- checkStringIntervalType(DT.HOUR, DT.MINUTE)
- toDTInterval(hour, minute, "0", getSign(firstSign, secondSign))
+ checkDTIntervalStringDataType(DT(DT.HOUR, DT.MINUTE))
+ toDTInterval(hour, minute, "0", finalSign(firstSign, secondSign))
case hourSecondRegex(sign, hour, minute, second, micro) =>
- checkStringIntervalType(DT.HOUR, DT.SECOND)
- toDTInterval(hour, minute, secondAndMicro(second, micro), getSign(null, sign))
+ checkDTIntervalStringDataType(DT(DT.HOUR, DT.SECOND))
+ toDTInterval(hour, minute, secondAndMicro(second, micro), finalSign(sign))
case hourSecondLiteralRegex(firstSign, secondSign, hour, minute, second, micro) =>
- checkStringIntervalType(DT.HOUR, DT.SECOND)
- toDTInterval(hour, minute, secondAndMicro(second, micro), getSign(firstSign, secondSign))
+ checkDTIntervalStringDataType(DT(DT.HOUR, DT.SECOND))
+ toDTInterval(hour, minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign))
case minuteSecondRegex(sign, minute, second, micro) =>
- checkStringIntervalType(DT.MINUTE, DT.SECOND)
- toDTInterval(minute, secondAndMicro(second, micro), getSign(null, sign))
+ checkDTIntervalStringDataType(DT(DT.MINUTE, DT.SECOND))
+ toDTInterval(minute, secondAndMicro(second, micro), finalSign(sign))
case minuteSecondLiteralRegex(firstSign, secondSign, minute, second, micro) =>
- checkStringIntervalType(DT.MINUTE, DT.SECOND)
- toDTInterval(minute, secondAndMicro(second, micro), getSign(firstSign, secondSign))
+ checkDTIntervalStringDataType(DT(DT.MINUTE, DT.SECOND))
+ toDTInterval(minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign))
- case dayTimeIndividualRegex(secondSign, value, suffix) =>
- safeToInterval {
- val sign = getSign("+", secondSign)
+ case dayTimeIndividualRegex(firstSign, value, suffix) =>
+ safeToInterval("day-time") {
+ val sign = finalSign(firstSign)
(startField, endField) match {
case (DT.DAY, DT.DAY) if suffix == null && value.length <= 9 =>
sign * value.toLong * MICROS_PER_DAY
@@ -352,46 +362,35 @@ object IntervalUtils {
case 1 => parseSecondNano(secondAndMicro(value, suffix))
case -1 => parseSecondNano(s"-${secondAndMicro(value, suffix)}")
}
- case (_, _) => throw new IllegalArgumentException(
- s"Interval string does not match day-time format of " +
- s"${supportedFormat((startField, endField))
- .map(format => s"`$format`").mkString(", ")} " +
- s"when cast to ${DT(startField, endField).typeName}: ${input.toString}")
+ case (_, _) => throwIllegalIntervalFormatException(input, startField, endField,
+ "day-time", DT(startField, endField).typeName, Some(fallbackNotice))
}
}
case dayTimeIndividualLiteralRegex(firstSign, secondSign, value, suffix, unit) =>
- safeToInterval {
- val sign = getSign(firstSign, secondSign)
+ safeToInterval("day-time") {
+ val sign = finalSign(firstSign, secondSign)
unit match {
case "DAY" if suffix == null && value.length <= 9 =>
- checkStringIntervalType(DT.DAY, DT.DAY)
+ checkDTIntervalStringDataType(DT(DT.DAY, DT.DAY))
sign * value.toLong * MICROS_PER_DAY
case "HOUR" if suffix == null && value.length <= 10 =>
- checkStringIntervalType(DT.HOUR, DT.HOUR)
+ checkDTIntervalStringDataType(DT(DT.HOUR, DT.HOUR))
sign * value.toLong * MICROS_PER_HOUR
case "MINUTE" if suffix == null && value.length <= 12 =>
- checkStringIntervalType(DT.MINUTE, DT.MINUTE)
+ checkDTIntervalStringDataType(DT(DT.MINUTE, DT.MINUTE))
sign * value.toLong * MICROS_PER_MINUTE
case "SECOND" if value.length <= 13 =>
- checkStringIntervalType(DT.SECOND, DT.SECOND)
+ checkDTIntervalStringDataType(DT(DT.SECOND, DT.SECOND))
sign match {
case 1 => parseSecondNano(secondAndMicro(value, suffix))
case -1 => parseSecondNano(s"-${secondAndMicro(value, suffix)}")
}
- case _ => throw new IllegalArgumentException(
- s"Interval string does not match day-time format of " +
- s"${supportedFormat((startField, endField))
- .map(format => s"`$format`").mkString(", ")} " +
- s"when cast to ${DT(startField, endField).typeName}: ${input.toString}")
+ case _ => throwIllegalIntervalFormatException(input, startField, endField,
+ "day-time", DT(startField, endField).typeName, Some(fallbackNotice))
}
}
- case _ =>
- throw new IllegalArgumentException(
- s"Interval string does not match day-time format of " +
- s"${supportedFormat((startField, endField))
- .map(format => s"`$format`").mkString(", ")} " +
- s"when cast to ${DT(startField, endField).typeName}: ${input.toString}, " +
- s"$fallbackNotice")
+ case _ => throwIllegalIntervalFormatException(input, startField, endField,
+ "day-time", DT(startField, endField).typeName, Some(fallbackNotice))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index 66f5b50..8313242 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
+import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.catalyst.util.IntervalUtils.microsToDuration
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -1113,10 +1114,14 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
if (!isTryCast) {
Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval =>
+ val dataType = YearMonthIntervalType()
val e = intercept[IllegalArgumentException] {
- cast(Literal.create(interval), YearMonthIntervalType()).eval()
+ cast(Literal.create(interval), dataType).eval()
}.getMessage
- assert(e.contains("Interval string does not match year-month format"))
+ assert(e.contains(s"Interval string does not match year-month format of " +
+ s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+ .map(format => s"`$format`").mkString(", ")} " +
+ s"when cast to ${dataType.typeName}: $interval"))
}
Seq(("1", YearMonthIntervalType(YEAR, MONTH)),
("1", YearMonthIntervalType(YEAR, MONTH)),
@@ -1132,7 +1137,10 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
val e = intercept[IllegalArgumentException] {
cast(Literal.create(interval), dataType).eval()
}.getMessage
- assert(e.contains("Interval string does not match year-month format"))
+ assert(e.contains(s"Interval string does not match year-month format of " +
+ s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+ .map(format => s"`$format`").mkString(", ")} " +
+ s"when cast to ${dataType.typeName}: $interval"))
}
}
}
@@ -1249,7 +1257,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
val e = intercept[IllegalArgumentException] {
cast(Literal.create(interval), dataType).eval()
}.getMessage
- assert(e.contains("Interval string does not match day-time format"))
+ assert(e.contains(s"Interval string does not match day-time format of " +
+ s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+ .map(format => s"`$format`").mkString(", ")} " +
+ s"when cast to ${dataType.typeName}: $interval, " +
+ s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
+ "to restore the behavior before Spark 3.0."))
}
// Check first field outof bound
@@ -1267,7 +1280,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
val e = intercept[IllegalArgumentException] {
cast(Literal.create(interval), dataType).eval()
}.getMessage
- assert(e.contains("Interval string does not match day-time format"))
+ assert(e.contains(s"Interval string does not match day-time format of " +
+ s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+ .map(format => s"`$format`").mkString(", ")} " +
+ s"when cast to ${dataType.typeName}: $interval, " +
+ s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
+ "to restore the behavior before Spark 3.0."))
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org