You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2022/08/12 11:24:03 UTC
[spark] branch master updated: [SPARK-40014][SQL] Support cast of decimals to ANSI intervals
This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3f84c7d56a8 [SPARK-40014][SQL] Support cast of decimals to ANSI intervals
3f84c7d56a8 is described below
commit 3f84c7d56a8465f4b7a8ff92f5b4439e29254456
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Fri Aug 12 14:23:20 2022 +0300
[SPARK-40014][SQL] Support cast of decimals to ANSI intervals
### What changes were proposed in this pull request?
In the PR, I propose to support casts of decimals to ANSI intervals, and follow the SQL standard:
<img width="801" alt="173663908-71945980-5638-4b46-9020-4d2e4badef0c" src="https://user-images.githubusercontent.com/1580697/184117139-d70c972b-3dce-4ee7-9ced-989be956d7cc.png">
### Why are the changes needed?
To improve user experience with Spark SQL, and to conform to the SQL standard.
### Does this PR introduce _any_ user-facing change?
No, it just extends existing behavior of casts.
Before:
```sql
spark-sql> SELECT CAST(1.001002BD AS INTERVAL SECOND);
Error in query: cannot resolve 'CAST(1.001002BD AS INTERVAL SECOND)' due to data type mismatch: cannot cast decimal(7,6) to interval second; line 1 pos 7;
```
After:
```
spark-sql> SELECT CAST(1.001002BD AS INTERVAL SECOND);
0 00:00:01.001002000
```
### How was this patch tested?
By running new tests:
```
$ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z cast.sql"
$ build/sbt "test:testOnly *CastWithAnsiOffSuite"
$ build/sbt "test:testOnly *CastWithAnsiOnSuite"
```
Closes #37466 from MaxGekk/cast-decimal-to-ansi-intervals.
Authored-by: Max Gekk <ma...@gmail.com>
Signed-off-by: Max Gekk <ma...@gmail.com>
---
.../spark/sql/catalyst/expressions/Cast.scala | 24 ++++++++++++++--
.../spark/sql/catalyst/util/IntervalUtils.scala | 31 +++++++++++++++++++++
.../sql/catalyst/expressions/CastSuiteBase.scala | 9 +++++-
.../src/test/resources/sql-tests/inputs/cast.sql | 6 ++++
.../resources/sql-tests/results/ansi/cast.sql.out | 32 ++++++++++++++++++++++
.../test/resources/sql-tests/results/cast.sql.out | 32 ++++++++++++++++++++++
.../sql/errors/QueryExecutionErrorsSuite.scala | 30 +++++++++++---------
7 files changed, 148 insertions(+), 16 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 6cbb189c2b4..0b0547aa621 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -112,7 +112,7 @@ object Cast {
case (StringType, _: AnsiIntervalType) => true
case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true
- case (_: IntegralType, _: AnsiIntervalType) => true
+ case (_: IntegralType | _: DecimalType, _: AnsiIntervalType) => true
case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
@@ -197,7 +197,7 @@ object Cast {
case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true
- case (_: IntegralType, _: AnsiIntervalType) => true
+ case (_: IntegralType | _: DecimalType, _: AnsiIntervalType) => true
case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
@@ -795,6 +795,9 @@ case class Cast(
b => IntervalUtils.intToDayTimeInterval(
x.integral.asInstanceOf[Integral[Any]].toInt(b), it.startField, it.endField)
}
+ case DecimalType.Fixed(p, s) =>
+ buildCast[Decimal](_, d =>
+ IntervalUtils.decimalToDayTimeInterval(d, p, s, it.startField, it.endField))
}
private[this] def castToYearMonthInterval(
@@ -812,6 +815,9 @@ case class Cast(
b => IntervalUtils.intToYearMonthInterval(
x.integral.asInstanceOf[Integral[Any]].toInt(b), it.startField, it.endField)
}
+ case DecimalType.Fixed(p, s) =>
+ buildCast[Decimal](_, d =>
+ IntervalUtils.decimalToYearMonthInterval(d, p, s, it.startField, it.endField))
}
// LongConverter
@@ -1815,6 +1821,13 @@ case class Cast(
$evPrim = $iu.intToDayTimeInterval($c, (byte)${it.startField}, (byte)${it.endField});
"""
}
+ case DecimalType.Fixed(p, s) =>
+ val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
+ (c, evPrim, _) =>
+ code"""
+ $evPrim = $iu.decimalToDayTimeInterval(
+ $c, $p, $s, (byte)${it.startField}, (byte)${it.endField});
+ """
}
private[this] def castToYearMonthIntervalCode(
@@ -1845,6 +1858,13 @@ case class Cast(
$evPrim = $iu.intToYearMonthInterval($c, (byte)${it.startField}, (byte)${it.endField});
"""
}
+ case DecimalType.Fixed(p, s) =>
+ val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
+ (c, evPrim, _) =>
+ code"""
+ $evPrim = $iu.decimalToYearMonthInterval(
+ $c, $p, $s, (byte)${it.startField}, (byte)${it.endField});
+ """
}
private[this] def decimalToTimestampCode(d: ExprValue): Block = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index 0bcc0ae0e3d..7b574e987d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -1283,6 +1283,20 @@ object IntervalUtils {
intToYearMonthInterval(vInt, startField, endField)
}
+ def decimalToYearMonthInterval(
+ d: Decimal, p: Int, s: Int, startField: Byte, endField: Byte): Int = {
+ try {
+ val months = if (endField == YEAR) d.toBigDecimal * MONTHS_PER_YEAR else d.toBigDecimal
+ months.setScale(0, BigDecimal.RoundingMode.HALF_UP).toIntExact
+ } catch {
+ case _: ArithmeticException =>
+ throw QueryExecutionErrors.castingCauseOverflowError(
+ d,
+ DecimalType(p, s),
+ YearMonthIntervalType(startField, endField))
+ }
+ }
+
def yearMonthIntervalToInt(v: Int, startField: Byte, endField: Byte): Int = {
endField match {
case YEAR => v / MONTHS_PER_YEAR
@@ -1367,6 +1381,23 @@ object IntervalUtils {
}
}
+ def decimalToDayTimeInterval(
+ d: Decimal, p: Int, s: Int, startField: Byte, endField: Byte): Long = {
+ try {
+ val micros = endField match {
+ case DAY => d.toBigDecimal * MICROS_PER_DAY
+ case HOUR => d.toBigDecimal * MICROS_PER_HOUR
+ case MINUTE => d.toBigDecimal * MICROS_PER_MINUTE
+ case SECOND => d.toBigDecimal * MICROS_PER_SECOND
+ }
+ micros.setScale(0, BigDecimal.RoundingMode.HALF_UP).toLongExact
+ } catch {
+ case _: ArithmeticException =>
+ throw QueryExecutionErrors.castingCauseOverflowError(
+ d, DecimalType(p, s), DT(startField, endField))
+ }
+ }
+
def dayTimeIntervalToInt(v: Long, startField: Byte, endField: Byte): Int = {
val vLong = dayTimeIntervalToLong(v, startField, endField)
val vInt = vLong.toInt
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index 03467308d14..cfe47ee4322 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -1273,7 +1273,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
}
- test("cast ANSI intervals to decimals") {
+ test("cast ANSI intervals to/from decimals") {
Seq(
(Duration.ZERO, DayTimeIntervalType(DAY), DecimalType(10, 3)) -> Decimal(0, 10, 3),
(Duration.ofHours(-1), DayTimeIntervalType(HOUR), DecimalType(10, 1)) -> Decimal(-10, 10, 1),
@@ -1293,16 +1293,23 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
Cast(Literal.create(duration, intervalType), targetType),
expected)
+ checkEvaluation(
+ Cast(Literal.create(expected, targetType), intervalType),
+ duration)
}
dayTimeIntervalTypes.foreach { it =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
Cast(child, DecimalType.USER_DEFAULT), it)
+ checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
+ Cast(child, it), DecimalType.USER_DEFAULT)
}
yearMonthIntervalTypes.foreach { it =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
Cast(child, DecimalType.USER_DEFAULT), it)
+ checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
+ Cast(child, it), DecimalType.USER_DEFAULT)
}
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
index fb92825c001..46ce9fb9aac 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cast.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
@@ -140,3 +140,9 @@ select cast(interval '10.123' second as decimal(4, 2));
select cast(interval '10.005' second as decimal(4, 2));
select cast(interval '10.123' second as decimal(5, 2));
select cast(interval '10.123' second as decimal(1, 0));
+
+-- cast decimals to ANSI intervals
+select cast(10.123456BD as interval day to second);
+select cast(80.654321BD as interval hour to minute);
+select cast(-10.123456BD as interval year to month);
+select cast(10.654321BD as interval month);
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
index 373e8c7b362..1e3ed4c4f8b 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
@@ -996,3 +996,35 @@ org.apache.spark.SparkArithmeticException
== SQL(line 1, position 8) ==
select cast(interval '10.123' second as decimal(1, 0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+
+-- !query
+select cast(10.123456BD as interval day to second)
+-- !query schema
+struct<CAST(10.123456 AS INTERVAL DAY TO SECOND):interval day to second>
+-- !query output
+0 00:00:10.123456000
+
+
+-- !query
+select cast(80.654321BD as interval hour to minute)
+-- !query schema
+struct<CAST(80.654321 AS INTERVAL HOUR TO MINUTE):interval hour to minute>
+-- !query output
+0 01:20:00.000000000
+
+
+-- !query
+select cast(-10.123456BD as interval year to month)
+-- !query schema
+struct<CAST(-10.123456 AS INTERVAL YEAR TO MONTH):interval year to month>
+-- !query output
+-0-10
+
+
+-- !query
+select cast(10.654321BD as interval month)
+-- !query schema
+struct<CAST(10.654321 AS INTERVAL MONTH):interval month>
+-- !query output
+0-11
diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
index 5e216045112..ef1f28ed746 100644
--- a/sql/core/src/test/resources/sql-tests/results/cast.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
@@ -821,3 +821,35 @@ struct<>
-- !query output
org.apache.spark.SparkArithmeticException
[CANNOT_CHANGE_DECIMAL_PRECISION] Decimal(compact, 10, 18, 6) cannot be represented as Decimal(1, 0). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
+
+
+-- !query
+select cast(10.123456BD as interval day to second)
+-- !query schema
+struct<CAST(10.123456 AS INTERVAL DAY TO SECOND):interval day to second>
+-- !query output
+0 00:00:10.123456000
+
+
+-- !query
+select cast(80.654321BD as interval hour to minute)
+-- !query schema
+struct<CAST(80.654321 AS INTERVAL HOUR TO MINUTE):interval hour to minute>
+-- !query output
+0 01:20:00.000000000
+
+
+-- !query
+select cast(-10.123456BD as interval year to month)
+-- !query schema
+struct<CAST(-10.123456 AS INTERVAL YEAR TO MONTH):interval year to month>
+-- !query output
+-0-10
+
+
+-- !query
+select cast(10.654321BD as interval month)
+-- !query schema
+struct<CAST(10.654321 AS INTERVAL MONTH):interval month>
+-- !query output
+0-11
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index fb5a05c26de..c2723ba4c1a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.functions.{lit, lower, struct, sum, udf}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy.EXCEPTION
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
-import org.apache.spark.sql.types.{DataType, DecimalType, MetadataBuilder, StructType}
+import org.apache.spark.sql.types.{DataType, DecimalType, LongType, MetadataBuilder, StructType}
import org.apache.spark.util.Utils
class QueryExecutionErrorsSuite
@@ -655,18 +655,22 @@ class QueryExecutionErrorsSuite
}
test("CAST_OVERFLOW: from long to ANSI intervals") {
- Seq("INTERVAL YEAR TO MONTH", "INTERVAL HOUR TO MINUTE").foreach { it =>
- checkError(
- exception = intercept[SparkArithmeticException] {
- sql(s"select CAST(9223372036854775807L AS $it)").collect()
- },
- errorClass = "CAST_OVERFLOW",
- parameters = Map(
- "value" -> "9223372036854775807L",
- "sourceType" -> "\"BIGINT\"",
- "targetType" -> s""""$it"""",
- "ansiConfig" -> s""""${SQLConf.ANSI_ENABLED.key}""""),
- sqlState = "22005")
+ Seq(
+ LongType -> "9223372036854775807L",
+ DecimalType(19, 0) -> "9223372036854775807BD").foreach { case (sourceType, sourceValue) =>
+ Seq("INTERVAL YEAR TO MONTH", "INTERVAL HOUR TO MINUTE").foreach { it =>
+ checkError(
+ exception = intercept[SparkArithmeticException] {
+ sql(s"select CAST($sourceValue AS $it)").collect()
+ },
+ errorClass = "CAST_OVERFLOW",
+ parameters = Map(
+ "value" -> sourceValue,
+ "sourceType" -> s""""${sourceType.sql}"""",
+ "targetType" -> s""""$it"""",
+ "ansiConfig" -> s""""${SQLConf.ANSI_ENABLED.key}""""),
+ sqlState = "22005")
+ }
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org