You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2022/08/12 11:24:03 UTC

[spark] branch master updated: [SPARK-40014][SQL] Support cast of decimals to ANSI intervals

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3f84c7d56a8 [SPARK-40014][SQL] Support cast of decimals to ANSI intervals
3f84c7d56a8 is described below

commit 3f84c7d56a8465f4b7a8ff92f5b4439e29254456
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Fri Aug 12 14:23:20 2022 +0300

    [SPARK-40014][SQL] Support cast of decimals to ANSI intervals
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to support casts of decimals to ANSI intervals, and follow the SQL standard:
    <img width="801" alt="173663908-71945980-5638-4b46-9020-4d2e4badef0c" src="https://user-images.githubusercontent.com/1580697/184117139-d70c972b-3dce-4ee7-9ced-989be956d7cc.png">
    
    ### Why are the changes needed?
    To improve user experience with Spark SQL, and to conform to the SQL standard.
    
    ### Does this PR introduce _any_ user-facing change?
    No, it just extends existing behavior of casts.
    
    Before:
    ```sql
    spark-sql> SELECT CAST(1.001002BD AS INTERVAL SECOND);
    Error in query: cannot resolve 'CAST(1.001002BD AS INTERVAL SECOND)' due to data type mismatch: cannot cast decimal(7,6) to interval second; line 1 pos 7;
    ```
    
    After:
    ```
    spark-sql> SELECT CAST(1.001002BD AS INTERVAL SECOND);
    0 00:00:01.001002000
    ```
    
    ### How was this patch tested?
    By running new tests:
    ```
    $ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z cast.sql"
    $ build/sbt "test:testOnly *CastWithAnsiOffSuite"
    $ build/sbt "test:testOnly *CastWithAnsiOnSuite"
    ```
    
    Closes #37466 from MaxGekk/cast-decimal-to-ansi-intervals.
    
    Authored-by: Max Gekk <ma...@gmail.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../spark/sql/catalyst/expressions/Cast.scala      | 24 ++++++++++++++--
 .../spark/sql/catalyst/util/IntervalUtils.scala    | 31 +++++++++++++++++++++
 .../sql/catalyst/expressions/CastSuiteBase.scala   |  9 +++++-
 .../src/test/resources/sql-tests/inputs/cast.sql   |  6 ++++
 .../resources/sql-tests/results/ansi/cast.sql.out  | 32 ++++++++++++++++++++++
 .../test/resources/sql-tests/results/cast.sql.out  | 32 ++++++++++++++++++++++
 .../sql/errors/QueryExecutionErrorsSuite.scala     | 30 +++++++++++---------
 7 files changed, 148 insertions(+), 16 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 6cbb189c2b4..0b0547aa621 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -112,7 +112,7 @@ object Cast {
     case (StringType, _: AnsiIntervalType) => true
 
     case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true
-    case (_: IntegralType, _: AnsiIntervalType) => true
+    case (_: IntegralType | _: DecimalType, _: AnsiIntervalType) => true
 
     case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
     case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
@@ -197,7 +197,7 @@ object Cast {
     case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
     case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
     case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true
-    case (_: IntegralType, _: AnsiIntervalType) => true
+    case (_: IntegralType | _: DecimalType, _: AnsiIntervalType) => true
 
     case (StringType, _: NumericType) => true
     case (BooleanType, _: NumericType) => true
@@ -795,6 +795,9 @@ case class Cast(
         b => IntervalUtils.intToDayTimeInterval(
           x.integral.asInstanceOf[Integral[Any]].toInt(b), it.startField, it.endField)
       }
+    case DecimalType.Fixed(p, s) =>
+      buildCast[Decimal](_, d =>
+        IntervalUtils.decimalToDayTimeInterval(d, p, s, it.startField, it.endField))
   }
 
   private[this] def castToYearMonthInterval(
@@ -812,6 +815,9 @@ case class Cast(
         b => IntervalUtils.intToYearMonthInterval(
           x.integral.asInstanceOf[Integral[Any]].toInt(b), it.startField, it.endField)
       }
+    case DecimalType.Fixed(p, s) =>
+      buildCast[Decimal](_, d =>
+        IntervalUtils.decimalToYearMonthInterval(d, p, s, it.startField, it.endField))
   }
 
   // LongConverter
@@ -1815,6 +1821,13 @@ case class Cast(
             $evPrim = $iu.intToDayTimeInterval($c, (byte)${it.startField}, (byte)${it.endField});
           """
       }
+    case DecimalType.Fixed(p, s) =>
+      val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
+      (c, evPrim, _) =>
+        code"""
+          $evPrim = $iu.decimalToDayTimeInterval(
+            $c, $p, $s, (byte)${it.startField}, (byte)${it.endField});
+        """
   }
 
   private[this] def castToYearMonthIntervalCode(
@@ -1845,6 +1858,13 @@ case class Cast(
             $evPrim = $iu.intToYearMonthInterval($c, (byte)${it.startField}, (byte)${it.endField});
           """
       }
+    case DecimalType.Fixed(p, s) =>
+      val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
+      (c, evPrim, _) =>
+        code"""
+          $evPrim = $iu.decimalToYearMonthInterval(
+            $c, $p, $s, (byte)${it.startField}, (byte)${it.endField});
+        """
   }
 
   private[this] def decimalToTimestampCode(d: ExprValue): Block = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index 0bcc0ae0e3d..7b574e987d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -1283,6 +1283,20 @@ object IntervalUtils {
     intToYearMonthInterval(vInt, startField, endField)
   }
 
+  def decimalToYearMonthInterval(
+      d: Decimal, p: Int, s: Int, startField: Byte, endField: Byte): Int = {
+    try {
+      val months = if (endField == YEAR) d.toBigDecimal * MONTHS_PER_YEAR else d.toBigDecimal
+      months.setScale(0, BigDecimal.RoundingMode.HALF_UP).toIntExact
+    } catch {
+      case _: ArithmeticException =>
+        throw QueryExecutionErrors.castingCauseOverflowError(
+          d,
+          DecimalType(p, s),
+          YearMonthIntervalType(startField, endField))
+    }
+  }
+
   def yearMonthIntervalToInt(v: Int, startField: Byte, endField: Byte): Int = {
     endField match {
       case YEAR => v / MONTHS_PER_YEAR
@@ -1367,6 +1381,23 @@ object IntervalUtils {
     }
   }
 
+  def decimalToDayTimeInterval(
+      d: Decimal, p: Int, s: Int, startField: Byte, endField: Byte): Long = {
+    try {
+      val micros = endField match {
+        case DAY => d.toBigDecimal * MICROS_PER_DAY
+        case HOUR => d.toBigDecimal * MICROS_PER_HOUR
+        case MINUTE => d.toBigDecimal * MICROS_PER_MINUTE
+        case SECOND => d.toBigDecimal * MICROS_PER_SECOND
+      }
+      micros.setScale(0, BigDecimal.RoundingMode.HALF_UP).toLongExact
+    } catch {
+      case _: ArithmeticException =>
+        throw QueryExecutionErrors.castingCauseOverflowError(
+          d, DecimalType(p, s), DT(startField, endField))
+    }
+  }
+
   def dayTimeIntervalToInt(v: Long, startField: Byte, endField: Byte): Int = {
     val vLong = dayTimeIntervalToLong(v, startField, endField)
     val vInt = vLong.toInt
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index 03467308d14..cfe47ee4322 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -1273,7 +1273,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
       }
   }
 
-  test("cast ANSI intervals to decimals") {
+  test("cast ANSI intervals to/from decimals") {
     Seq(
       (Duration.ZERO, DayTimeIntervalType(DAY), DecimalType(10, 3)) -> Decimal(0, 10, 3),
       (Duration.ofHours(-1), DayTimeIntervalType(HOUR), DecimalType(10, 1)) -> Decimal(-10, 10, 1),
@@ -1293,16 +1293,23 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
       checkEvaluation(
         Cast(Literal.create(duration, intervalType), targetType),
         expected)
+      checkEvaluation(
+        Cast(Literal.create(expected, targetType), intervalType),
+        duration)
     }
 
     dayTimeIntervalTypes.foreach { it =>
       checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
         Cast(child, DecimalType.USER_DEFAULT), it)
+      checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
+        Cast(child, it), DecimalType.USER_DEFAULT)
     }
 
     yearMonthIntervalTypes.foreach { it =>
       checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
         Cast(child, DecimalType.USER_DEFAULT), it)
+      checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
+        Cast(child, it), DecimalType.USER_DEFAULT)
     }
   }
 
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
index fb92825c001..46ce9fb9aac 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cast.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
@@ -140,3 +140,9 @@ select cast(interval '10.123' second as decimal(4, 2));
 select cast(interval '10.005' second as decimal(4, 2));
 select cast(interval '10.123' second as decimal(5, 2));
 select cast(interval '10.123' second as decimal(1, 0));
+
+-- cast decimals to ANSI intervals
+select cast(10.123456BD as interval day to second);
+select cast(80.654321BD as interval hour to minute);
+select cast(-10.123456BD as interval year to month);
+select cast(10.654321BD as interval month);
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
index 373e8c7b362..1e3ed4c4f8b 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
@@ -996,3 +996,35 @@ org.apache.spark.SparkArithmeticException
 == SQL(line 1, position 8) ==
 select cast(interval '10.123' second as decimal(1, 0))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+
+-- !query
+select cast(10.123456BD as interval day to second)
+-- !query schema
+struct<CAST(10.123456 AS INTERVAL DAY TO SECOND):interval day to second>
+-- !query output
+0 00:00:10.123456000
+
+
+-- !query
+select cast(80.654321BD as interval hour to minute)
+-- !query schema
+struct<CAST(80.654321 AS INTERVAL HOUR TO MINUTE):interval hour to minute>
+-- !query output
+0 01:20:00.000000000
+
+
+-- !query
+select cast(-10.123456BD as interval year to month)
+-- !query schema
+struct<CAST(-10.123456 AS INTERVAL YEAR TO MONTH):interval year to month>
+-- !query output
+-0-10
+
+
+-- !query
+select cast(10.654321BD as interval month)
+-- !query schema
+struct<CAST(10.654321 AS INTERVAL MONTH):interval month>
+-- !query output
+0-11
diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
index 5e216045112..ef1f28ed746 100644
--- a/sql/core/src/test/resources/sql-tests/results/cast.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
@@ -821,3 +821,35 @@ struct<>
 -- !query output
 org.apache.spark.SparkArithmeticException
 [CANNOT_CHANGE_DECIMAL_PRECISION] Decimal(compact, 10, 18, 6) cannot be represented as Decimal(1, 0). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
+
+
+-- !query
+select cast(10.123456BD as interval day to second)
+-- !query schema
+struct<CAST(10.123456 AS INTERVAL DAY TO SECOND):interval day to second>
+-- !query output
+0 00:00:10.123456000
+
+
+-- !query
+select cast(80.654321BD as interval hour to minute)
+-- !query schema
+struct<CAST(80.654321 AS INTERVAL HOUR TO MINUTE):interval hour to minute>
+-- !query output
+0 01:20:00.000000000
+
+
+-- !query
+select cast(-10.123456BD as interval year to month)
+-- !query schema
+struct<CAST(-10.123456 AS INTERVAL YEAR TO MONTH):interval year to month>
+-- !query output
+-0-10
+
+
+-- !query
+select cast(10.654321BD as interval month)
+-- !query schema
+struct<CAST(10.654321 AS INTERVAL MONTH):interval month>
+-- !query output
+0-11
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index fb5a05c26de..c2723ba4c1a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.functions.{lit, lower, struct, sum, udf}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy.EXCEPTION
 import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
-import org.apache.spark.sql.types.{DataType, DecimalType, MetadataBuilder, StructType}
+import org.apache.spark.sql.types.{DataType, DecimalType, LongType, MetadataBuilder, StructType}
 import org.apache.spark.util.Utils
 
 class QueryExecutionErrorsSuite
@@ -655,18 +655,22 @@ class QueryExecutionErrorsSuite
   }
 
   test("CAST_OVERFLOW: from long to ANSI intervals") {
-    Seq("INTERVAL YEAR TO MONTH", "INTERVAL HOUR TO MINUTE").foreach { it =>
-      checkError(
-        exception = intercept[SparkArithmeticException] {
-          sql(s"select CAST(9223372036854775807L AS $it)").collect()
-        },
-        errorClass = "CAST_OVERFLOW",
-        parameters = Map(
-          "value" -> "9223372036854775807L",
-          "sourceType" -> "\"BIGINT\"",
-          "targetType" -> s""""$it"""",
-          "ansiConfig" -> s""""${SQLConf.ANSI_ENABLED.key}""""),
-        sqlState = "22005")
+    Seq(
+      LongType -> "9223372036854775807L",
+      DecimalType(19, 0) -> "9223372036854775807BD").foreach { case (sourceType, sourceValue) =>
+      Seq("INTERVAL YEAR TO MONTH", "INTERVAL HOUR TO MINUTE").foreach { it =>
+        checkError(
+          exception = intercept[SparkArithmeticException] {
+            sql(s"select CAST($sourceValue AS $it)").collect()
+          },
+          errorClass = "CAST_OVERFLOW",
+          parameters = Map(
+            "value" -> sourceValue,
+            "sourceType" -> s""""${sourceType.sql}"""",
+            "targetType" -> s""""$it"""",
+            "ansiConfig" -> s""""${SQLConf.ANSI_ENABLED.key}""""),
+          sqlState = "22005")
+      }
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org