You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2022/06/16 09:08:00 UTC
[spark] branch master updated: [SPARK-39470][SQL] Support cast of ANSI intervals to decimals
This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d2332d9c47e [SPARK-39470][SQL] Support cast of ANSI intervals to decimals
d2332d9c47e is described below
commit d2332d9c47e8f250a015d6dc5edb028b334aa905
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Thu Jun 16 12:07:43 2022 +0300
[SPARK-39470][SQL] Support cast of ANSI intervals to decimals
### What changes were proposed in this pull request?
In the PR, I propose to support casts of ANSI intervals to decimals, and follow the SQL standard:
<img width="801" alt="Screenshot 2022-06-12 at 13 04 44" src="https://user-images.githubusercontent.com/1580697/173663908-71945980-5638-4b46-9020-4d2e4badef0c.png">
### Why are the changes needed?
To improve user experience with Spark SQL, and to conform to the SQL standard.
### Does this PR introduce _any_ user-facing change?
No, it just extends existing behavior of casts.
Before:
```sql
spark-sql> SELECT CAST(INTERVAL '1.001002' SECOND AS DECIMAL(10, 6));
Error in query: cannot resolve 'CAST(INTERVAL '01.001002' SECOND AS DECIMAL(10,6))' due to data type mismatch: cannot cast interval second to decimal(10,6); line 1 pos 7;
'Project [unresolvedalias(cast(INTERVAL '01.001002' SECOND as decimal(10,6)), None)]
+- OneRowRelation
```
After:
```
spark-sql> SELECT CAST(INTERVAL '1.001002' SECOND AS DECIMAL(10, 6));
1.001002
```
### How was this patch tested?
By running new tests:
```
$ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z cast.sql"
$ build/sbt "test:testOnly *CastWithAnsiOnSuite"
$ build/sbt "test:testOnly *CastWithAnsiOffSuite"
```
Closes #36857 from MaxGekk/cast-ansi-intervals-to-decimal.
Authored-by: Max Gekk <ma...@gmail.com>
Signed-off-by: Max Gekk <ma...@gmail.com>
---
.../spark/sql/catalyst/expressions/Cast.scala | 59 ++++++++++++++++---
.../spark/sql/catalyst/util/IntervalUtils.scala | 9 +++
.../sql/catalyst/expressions/CastSuiteBase.scala | 33 +++++++++++
.../src/test/resources/sql-tests/inputs/cast.sql | 10 ++++
.../resources/sql-tests/results/ansi/cast.sql.out | 68 ++++++++++++++++++++++
.../test/resources/sql-tests/results/cast.sql.out | 65 +++++++++++++++++++++
6 files changed, 237 insertions(+), 7 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 0746bc0fcd0..45950607e0d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
-import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort}
+import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToDecimal, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -110,7 +110,7 @@ object Cast {
case (StringType, _: CalendarIntervalType) => true
case (StringType, _: AnsiIntervalType) => true
- case (_: AnsiIntervalType, _: IntegralType) => true
+ case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true
case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
@@ -194,8 +194,7 @@ object Cast {
case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
- case (_: DayTimeIntervalType, _: IntegralType) => true
- case (_: YearMonthIntervalType, _: IntegralType) => true
+ case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true
case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
@@ -967,10 +966,17 @@ case class Cast(
* NOTE: this modifies `value` in-place, so don't call it on external data.
*/
private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = {
+ changePrecision(value, decimalType, !ansiEnabled)
+ }
+
+ private[this] def changePrecision(
+ value: Decimal,
+ decimalType: DecimalType,
+ nullOnOverflow: Boolean): Decimal = {
if (value.changePrecision(decimalType.precision, decimalType.scale)) {
value
} else {
- if (!ansiEnabled) {
+ if (nullOnOverflow) {
null
} else {
throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
@@ -1015,6 +1021,18 @@ case class Cast(
} catch {
case _: NumberFormatException => null
}
+ case x: DayTimeIntervalType =>
+ buildCast[Long](_, dt =>
+ changePrecision(
+ value = dayTimeIntervalToDecimal(dt, x.endField),
+ decimalType = target,
+ nullOnOverflow = false))
+ case x: YearMonthIntervalType =>
+ buildCast[Int](_, ym =>
+ changePrecision(
+ value = Decimal(yearMonthIntervalToInt(ym, x.startField, x.endField)),
+ decimalType = target,
+ nullOnOverflow = false))
}
// DoubleConverter
@@ -1515,14 +1533,15 @@ case class Cast(
evPrim: ExprValue,
evNull: ExprValue,
canNullSafeCast: Boolean,
- ctx: CodegenContext): Block = {
+ ctx: CodegenContext,
+ nullOnOverflow: Boolean): Block = {
if (canNullSafeCast) {
code"""
|$d.changePrecision(${decimalType.precision}, ${decimalType.scale});
|$evPrim = $d;
""".stripMargin
} else {
- val overflowCode = if (!ansiEnabled) {
+ val overflowCode = if (nullOnOverflow) {
s"$evNull = true;"
} else {
s"""
@@ -1540,6 +1559,16 @@ case class Cast(
}
}
+ private[this] def changePrecision(
+ d: ExprValue,
+ decimalType: DecimalType,
+ evPrim: ExprValue,
+ evNull: ExprValue,
+ canNullSafeCast: Boolean,
+ ctx: CodegenContext): Block = {
+ changePrecision(d, decimalType, evPrim, evNull, canNullSafeCast, ctx, !ansiEnabled)
+ }
+
private[this] def castToDecimalCode(
from: DataType,
target: DecimalType,
@@ -1605,6 +1634,22 @@ case class Cast(
$evNull = true;
}
"""
+ case x: DayTimeIntervalType =>
+ (c, evPrim, evNull) =>
+ val u = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
+ code"""
+ Decimal $tmp = $u.dayTimeIntervalToDecimal($c, (byte)${x.endField});
+ ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx, false)}
+ """
+ case x: YearMonthIntervalType =>
+ (c, evPrim, evNull) =>
+ val u = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
+ val tmpYm = ctx.freshVariable("tmpYm", classOf[Int])
+ code"""
+ int $tmpYm = $u.yearMonthIntervalToInt($c, (byte)${x.startField}, (byte)${x.endField});
+ Decimal $tmp = Decimal.apply($tmpYm);
+ ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx, false)}
+ """
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index dad58b7ae45..721f50208ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -1346,6 +1346,15 @@ object IntervalUtils {
}
}
+ def dayTimeIntervalToDecimal(v: Long, endField: Byte): Decimal = {
+ endField match {
+ case DAY => Decimal(v / MICROS_PER_DAY)
+ case HOUR => Decimal(v / MICROS_PER_HOUR)
+ case MINUTE => Decimal(v / MICROS_PER_MINUTE)
+ case SECOND => Decimal(v, Decimal.MAX_LONG_DIGITS, 6)
+ }
+ }
+
def dayTimeIntervalToInt(v: Long, startField: Byte, endField: Byte): Int = {
val vLong = dayTimeIntervalToLong(v, startField, endField)
val vInt = vLong.toInt
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index ca492e11226..97cbc781829 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -1272,4 +1272,37 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
"to restore the behavior before Spark 3.0."))
}
}
+
+ test("cast ANSI intervals to decimals") {
+ Seq(
+ (Duration.ZERO, DayTimeIntervalType(DAY), DecimalType(10, 3)) -> Decimal(0, 10, 3),
+ (Duration.ofHours(-1), DayTimeIntervalType(HOUR), DecimalType(10, 1)) -> Decimal(-10, 10, 1),
+ (Duration.ofMinutes(1), DayTimeIntervalType(MINUTE), DecimalType(8, 2)) -> Decimal(100, 8, 2),
+ (Duration.ofSeconds(59), DayTimeIntervalType(SECOND), DecimalType(6, 0)) -> Decimal(59, 6, 0),
+ (Duration.ofSeconds(-60).minusMillis(1), DayTimeIntervalType(SECOND),
+ DecimalType(10, 3)) -> Decimal(-60.001, 10, 3),
+ (Duration.ZERO, DayTimeIntervalType(DAY, SECOND), DecimalType(10, 6)) -> Decimal(0, 10, 6),
+ (Duration.ofHours(-23).minusMinutes(59).minusSeconds(59).minusNanos(123456000),
+ DayTimeIntervalType(HOUR, SECOND), DecimalType(18, 6)) -> Decimal(-86399.123456, 18, 6),
+ (Period.ZERO, YearMonthIntervalType(YEAR), DecimalType(5, 2)) -> Decimal(0, 5, 2),
+ (Period.ofMonths(-1), YearMonthIntervalType(MONTH),
+ DecimalType(8, 0)) -> Decimal(-1, 8, 0),
+ (Period.ofYears(-1).minusMonths(1), YearMonthIntervalType(YEAR, MONTH),
+ DecimalType(8, 3)) -> Decimal(-13000, 8, 3)
+ ).foreach { case ((duration, intervalType, targetType), expected) =>
+ checkEvaluation(
+ Cast(Literal.create(duration, intervalType), targetType),
+ expected)
+ }
+
+ dayTimeIntervalTypes.foreach { it =>
+ checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
+ Cast(child, DecimalType.USER_DEFAULT), it)
+ }
+
+ yearMonthIntervalTypes.foreach { it =>
+ checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
+ Cast(child, DecimalType.USER_DEFAULT), it)
+ }
+ }
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
index 5198611a2b3..66a78ec9473 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cast.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
@@ -116,3 +116,13 @@ select cast(interval '10' day as bigint);
select cast(interval '-1000' month as tinyint);
select cast(interval '1000000' second as smallint);
+
+-- cast ANSI intervals to decimals
+select cast(interval '-1' year as decimal(10, 0));
+select cast(interval '1.000001' second as decimal(10, 6));
+select cast(interval '08:11:10.001' hour to second as decimal(10, 4));
+select cast(interval '1 01:02:03.1' day to second as decimal(8, 1));
+select cast(interval '10.123' second as decimal(4, 2));
+select cast(interval '10.005' second as decimal(4, 2));
+select cast(interval '10.123' second as decimal(5, 2));
+select cast(interval '10.123' second as decimal(1, 0));
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
index 11753f2b5ca..470a6081c46 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
@@ -838,3 +838,71 @@ struct<>
-- !query output
org.apache.spark.SparkArithmeticException
[CAST_OVERFLOW] The value INTERVAL '1000000' SECOND of the type "INTERVAL SECOND" cannot be cast to "SMALLINT" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
+
+
+-- !query
+select cast(interval '-1' year as decimal(10, 0))
+-- !query schema
+struct<CAST(INTERVAL '-1' YEAR AS DECIMAL(10,0)):decimal(10,0)>
+-- !query output
+-1
+
+
+-- !query
+select cast(interval '1.000001' second as decimal(10, 6))
+-- !query schema
+struct<CAST(INTERVAL '01.000001' SECOND AS DECIMAL(10,6)):decimal(10,6)>
+-- !query output
+1.000001
+
+
+-- !query
+select cast(interval '08:11:10.001' hour to second as decimal(10, 4))
+-- !query schema
+struct<CAST(INTERVAL '08:11:10.001' HOUR TO SECOND AS DECIMAL(10,4)):decimal(10,4)>
+-- !query output
+29470.0010
+
+
+-- !query
+select cast(interval '1 01:02:03.1' day to second as decimal(8, 1))
+-- !query schema
+struct<CAST(INTERVAL '1 01:02:03.1' DAY TO SECOND AS DECIMAL(8,1)):decimal(8,1)>
+-- !query output
+90123.1
+
+
+-- !query
+select cast(interval '10.123' second as decimal(4, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(4,2)):decimal(4,2)>
+-- !query output
+10.12
+
+
+-- !query
+select cast(interval '10.005' second as decimal(4, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.005' SECOND AS DECIMAL(4,2)):decimal(4,2)>
+-- !query output
+10.01
+
+
+-- !query
+select cast(interval '10.123' second as decimal(5, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(5,2)):decimal(5,2)>
+-- !query output
+10.12
+
+
+-- !query
+select cast(interval '10.123' second as decimal(1, 0))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+[CANNOT_CHANGE_DECIMAL_PRECISION] Decimal(compact, 10, 18, 6) cannot be represented as Decimal(1, 0). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
+== SQL(line 1, position 8) ==
+select cast(interval '10.123' second as decimal(1, 0))
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
index 9c00e1b985e..911eaff30b9 100644
--- a/sql/core/src/test/resources/sql-tests/results/cast.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
@@ -666,3 +666,68 @@ struct<>
-- !query output
org.apache.spark.SparkArithmeticException
[CAST_OVERFLOW] The value INTERVAL '1000000' SECOND of the type "INTERVAL SECOND" cannot be cast to "SMALLINT" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
+
+
+-- !query
+select cast(interval '-1' year as decimal(10, 0))
+-- !query schema
+struct<CAST(INTERVAL '-1' YEAR AS DECIMAL(10,0)):decimal(10,0)>
+-- !query output
+-1
+
+
+-- !query
+select cast(interval '1.000001' second as decimal(10, 6))
+-- !query schema
+struct<CAST(INTERVAL '01.000001' SECOND AS DECIMAL(10,6)):decimal(10,6)>
+-- !query output
+1.000001
+
+
+-- !query
+select cast(interval '08:11:10.001' hour to second as decimal(10, 4))
+-- !query schema
+struct<CAST(INTERVAL '08:11:10.001' HOUR TO SECOND AS DECIMAL(10,4)):decimal(10,4)>
+-- !query output
+29470.0010
+
+
+-- !query
+select cast(interval '1 01:02:03.1' day to second as decimal(8, 1))
+-- !query schema
+struct<CAST(INTERVAL '1 01:02:03.1' DAY TO SECOND AS DECIMAL(8,1)):decimal(8,1)>
+-- !query output
+90123.1
+
+
+-- !query
+select cast(interval '10.123' second as decimal(4, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(4,2)):decimal(4,2)>
+-- !query output
+10.12
+
+
+-- !query
+select cast(interval '10.005' second as decimal(4, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.005' SECOND AS DECIMAL(4,2)):decimal(4,2)>
+-- !query output
+10.01
+
+
+-- !query
+select cast(interval '10.123' second as decimal(5, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(5,2)):decimal(5,2)>
+-- !query output
+10.12
+
+
+-- !query
+select cast(interval '10.123' second as decimal(1, 0))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+[CANNOT_CHANGE_DECIMAL_PRECISION] Decimal(compact, 10, 18, 6) cannot be represented as Decimal(1, 0). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org