You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2022/06/16 09:08:00 UTC

[spark] branch master updated: [SPARK-39470][SQL] Support cast of ANSI intervals to decimals

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d2332d9c47e [SPARK-39470][SQL] Support cast of ANSI intervals to decimals
d2332d9c47e is described below

commit d2332d9c47e8f250a015d6dc5edb028b334aa905
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Thu Jun 16 12:07:43 2022 +0300

    [SPARK-39470][SQL] Support cast of ANSI intervals to decimals
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to support casts of ANSI intervals to decimals, and follow the SQL standard:
    <img width="801" alt="Screenshot 2022-06-12 at 13 04 44" src="https://user-images.githubusercontent.com/1580697/173663908-71945980-5638-4b46-9020-4d2e4badef0c.png">
    
    ### Why are the changes needed?
    To improve user experience with Spark SQL, and to conform to the SQL standard.
    
    ### Does this PR introduce _any_ user-facing change?
    No, it just extends existing behavior of casts.
    
    Before:
    ```sql
    spark-sql> SELECT CAST(INTERVAL '1.001002' SECOND AS DECIMAL(10, 6));
    Error in query: cannot resolve 'CAST(INTERVAL '01.001002' SECOND AS DECIMAL(10,6))' due to data type mismatch: cannot cast interval second to decimal(10,6); line 1 pos 7;
    'Project [unresolvedalias(cast(INTERVAL '01.001002' SECOND as decimal(10,6)), None)]
    +- OneRowRelation
    ```
    
    After:
    ```
    spark-sql> SELECT CAST(INTERVAL '1.001002' SECOND AS DECIMAL(10, 6));
    1.001002
    ```
    
    ### How was this patch tested?
    By running new tests:
    ```
    $ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z cast.sql"
    $ build/sbt "test:testOnly *CastWithAnsiOnSuite"
    $ build/sbt "test:testOnly *CastWithAnsiOffSuite"
    ```
    
    Closes #36857 from MaxGekk/cast-ansi-intervals-to-decimal.
    
    Authored-by: Max Gekk <ma...@gmail.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../spark/sql/catalyst/expressions/Cast.scala      | 59 ++++++++++++++++---
 .../spark/sql/catalyst/util/IntervalUtils.scala    |  9 +++
 .../sql/catalyst/expressions/CastSuiteBase.scala   | 33 +++++++++++
 .../src/test/resources/sql-tests/inputs/cast.sql   | 10 ++++
 .../resources/sql-tests/results/ansi/cast.sql.out  | 68 ++++++++++++++++++++++
 .../test/resources/sql-tests/results/cast.sql.out  | 65 +++++++++++++++++++++
 6 files changed, 237 insertions(+), 7 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 0746bc0fcd0..45950607e0d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.DateTimeConstants._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils._
 import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
-import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort}
+import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToDecimal, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -110,7 +110,7 @@ object Cast {
     case (StringType, _: CalendarIntervalType) => true
     case (StringType, _: AnsiIntervalType) => true
 
-    case (_: AnsiIntervalType, _: IntegralType) => true
+    case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true
 
     case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
     case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
@@ -194,8 +194,7 @@ object Cast {
 
     case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
     case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
-    case (_: DayTimeIntervalType, _: IntegralType) => true
-    case (_: YearMonthIntervalType, _: IntegralType) => true
+    case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true
 
     case (StringType, _: NumericType) => true
     case (BooleanType, _: NumericType) => true
@@ -967,10 +966,17 @@ case class Cast(
    * NOTE: this modifies `value` in-place, so don't call it on external data.
    */
   private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = {
+    changePrecision(value, decimalType, !ansiEnabled)
+  }
+
+  private[this] def changePrecision(
+      value: Decimal,
+      decimalType: DecimalType,
+      nullOnOverflow: Boolean): Decimal = {
     if (value.changePrecision(decimalType.precision, decimalType.scale)) {
       value
     } else {
-      if (!ansiEnabled) {
+      if (nullOnOverflow) {
         null
       } else {
         throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
@@ -1015,6 +1021,18 @@ case class Cast(
       } catch {
         case _: NumberFormatException => null
       }
+    case x: DayTimeIntervalType =>
+      buildCast[Long](_, dt =>
+        changePrecision(
+          value = dayTimeIntervalToDecimal(dt, x.endField),
+          decimalType = target,
+          nullOnOverflow = false))
+    case x: YearMonthIntervalType =>
+      buildCast[Int](_, ym =>
+        changePrecision(
+          value = Decimal(yearMonthIntervalToInt(ym, x.startField, x.endField)),
+          decimalType = target,
+          nullOnOverflow = false))
   }
 
   // DoubleConverter
@@ -1515,14 +1533,15 @@ case class Cast(
       evPrim: ExprValue,
       evNull: ExprValue,
       canNullSafeCast: Boolean,
-      ctx: CodegenContext): Block = {
+      ctx: CodegenContext,
+      nullOnOverflow: Boolean): Block = {
     if (canNullSafeCast) {
       code"""
          |$d.changePrecision(${decimalType.precision}, ${decimalType.scale});
          |$evPrim = $d;
        """.stripMargin
     } else {
-      val overflowCode = if (!ansiEnabled) {
+      val overflowCode = if (nullOnOverflow) {
         s"$evNull = true;"
       } else {
         s"""
@@ -1540,6 +1559,16 @@ case class Cast(
     }
   }
 
+  private[this] def changePrecision(
+      d: ExprValue,
+      decimalType: DecimalType,
+      evPrim: ExprValue,
+      evNull: ExprValue,
+      canNullSafeCast: Boolean,
+      ctx: CodegenContext): Block = {
+    changePrecision(d, decimalType, evPrim, evNull, canNullSafeCast, ctx, !ansiEnabled)
+  }
+
   private[this] def castToDecimalCode(
       from: DataType,
       target: DecimalType,
@@ -1605,6 +1634,22 @@ case class Cast(
               $evNull = true;
             }
           """
+      case x: DayTimeIntervalType =>
+        (c, evPrim, evNull) =>
+          val u = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
+          code"""
+            Decimal $tmp = $u.dayTimeIntervalToDecimal($c, (byte)${x.endField});
+            ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx, false)}
+          """
+      case x: YearMonthIntervalType =>
+        (c, evPrim, evNull) =>
+          val u = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
+          val tmpYm = ctx.freshVariable("tmpYm", classOf[Int])
+          code"""
+            int $tmpYm = $u.yearMonthIntervalToInt($c, (byte)${x.startField}, (byte)${x.endField});
+            Decimal $tmp = Decimal.apply($tmpYm);
+            ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx, false)}
+          """
     }
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index dad58b7ae45..721f50208ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -1346,6 +1346,15 @@ object IntervalUtils {
     }
   }
 
+  def dayTimeIntervalToDecimal(v: Long, endField: Byte): Decimal = {
+    endField match {
+      case DAY => Decimal(v / MICROS_PER_DAY)
+      case HOUR => Decimal(v / MICROS_PER_HOUR)
+      case MINUTE => Decimal(v / MICROS_PER_MINUTE)
+      case SECOND => Decimal(v, Decimal.MAX_LONG_DIGITS, 6)
+    }
+  }
+
   def dayTimeIntervalToInt(v: Long, startField: Byte, endField: Byte): Int = {
     val vLong = dayTimeIntervalToLong(v, startField, endField)
     val vInt = vLong.toInt
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index ca492e11226..97cbc781829 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -1272,4 +1272,37 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
           "to restore the behavior before Spark 3.0."))
       }
   }
+
+  test("cast ANSI intervals to decimals") {
+    Seq(
+      (Duration.ZERO, DayTimeIntervalType(DAY), DecimalType(10, 3)) -> Decimal(0, 10, 3),
+      (Duration.ofHours(-1), DayTimeIntervalType(HOUR), DecimalType(10, 1)) -> Decimal(-10, 10, 1),
+      (Duration.ofMinutes(1), DayTimeIntervalType(MINUTE), DecimalType(8, 2)) -> Decimal(100, 8, 2),
+      (Duration.ofSeconds(59), DayTimeIntervalType(SECOND), DecimalType(6, 0)) -> Decimal(59, 6, 0),
+      (Duration.ofSeconds(-60).minusMillis(1), DayTimeIntervalType(SECOND),
+        DecimalType(10, 3)) -> Decimal(-60.001, 10, 3),
+      (Duration.ZERO, DayTimeIntervalType(DAY, SECOND), DecimalType(10, 6)) -> Decimal(0, 10, 6),
+      (Duration.ofHours(-23).minusMinutes(59).minusSeconds(59).minusNanos(123456000),
+        DayTimeIntervalType(HOUR, SECOND), DecimalType(18, 6)) -> Decimal(-86399.123456, 18, 6),
+      (Period.ZERO, YearMonthIntervalType(YEAR), DecimalType(5, 2)) -> Decimal(0, 5, 2),
+      (Period.ofMonths(-1), YearMonthIntervalType(MONTH),
+        DecimalType(8, 0)) -> Decimal(-1, 8, 0),
+      (Period.ofYears(-1).minusMonths(1), YearMonthIntervalType(YEAR, MONTH),
+        DecimalType(8, 3)) -> Decimal(-13000, 8, 3)
+    ).foreach { case ((duration, intervalType, targetType), expected) =>
+      checkEvaluation(
+        Cast(Literal.create(duration, intervalType), targetType),
+        expected)
+    }
+
+    dayTimeIntervalTypes.foreach { it =>
+      checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
+        Cast(child, DecimalType.USER_DEFAULT), it)
+    }
+
+    yearMonthIntervalTypes.foreach { it =>
+      checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
+        Cast(child, DecimalType.USER_DEFAULT), it)
+    }
+  }
 }
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
index 5198611a2b3..66a78ec9473 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cast.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
@@ -116,3 +116,13 @@ select cast(interval '10' day as bigint);
 
 select cast(interval '-1000' month as tinyint);
 select cast(interval '1000000' second as smallint);
+
+-- cast ANSI intervals to decimals
+select cast(interval '-1' year as decimal(10, 0));
+select cast(interval '1.000001' second as decimal(10, 6));
+select cast(interval '08:11:10.001' hour to second as decimal(10, 4));
+select cast(interval '1 01:02:03.1' day to second as decimal(8, 1));
+select cast(interval '10.123' second as decimal(4, 2));
+select cast(interval '10.005' second as decimal(4, 2));
+select cast(interval '10.123' second as decimal(5, 2));
+select cast(interval '10.123' second as decimal(1, 0));
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
index 11753f2b5ca..470a6081c46 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
@@ -838,3 +838,71 @@ struct<>
 -- !query output
 org.apache.spark.SparkArithmeticException
 [CAST_OVERFLOW] The value INTERVAL '1000000' SECOND of the type "INTERVAL SECOND" cannot be cast to "SMALLINT" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
+
+
+-- !query
+select cast(interval '-1' year as decimal(10, 0))
+-- !query schema
+struct<CAST(INTERVAL '-1' YEAR AS DECIMAL(10,0)):decimal(10,0)>
+-- !query output
+-1
+
+
+-- !query
+select cast(interval '1.000001' second as decimal(10, 6))
+-- !query schema
+struct<CAST(INTERVAL '01.000001' SECOND AS DECIMAL(10,6)):decimal(10,6)>
+-- !query output
+1.000001
+
+
+-- !query
+select cast(interval '08:11:10.001' hour to second as decimal(10, 4))
+-- !query schema
+struct<CAST(INTERVAL '08:11:10.001' HOUR TO SECOND AS DECIMAL(10,4)):decimal(10,4)>
+-- !query output
+29470.0010
+
+
+-- !query
+select cast(interval '1 01:02:03.1' day to second as decimal(8, 1))
+-- !query schema
+struct<CAST(INTERVAL '1 01:02:03.1' DAY TO SECOND AS DECIMAL(8,1)):decimal(8,1)>
+-- !query output
+90123.1
+
+
+-- !query
+select cast(interval '10.123' second as decimal(4, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(4,2)):decimal(4,2)>
+-- !query output
+10.12
+
+
+-- !query
+select cast(interval '10.005' second as decimal(4, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.005' SECOND AS DECIMAL(4,2)):decimal(4,2)>
+-- !query output
+10.01
+
+
+-- !query
+select cast(interval '10.123' second as decimal(5, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(5,2)):decimal(5,2)>
+-- !query output
+10.12
+
+
+-- !query
+select cast(interval '10.123' second as decimal(1, 0))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+[CANNOT_CHANGE_DECIMAL_PRECISION] Decimal(compact, 10, 18, 6) cannot be represented as Decimal(1, 0). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
+== SQL(line 1, position 8) ==
+select cast(interval '10.123' second as decimal(1, 0))
+       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
index 9c00e1b985e..911eaff30b9 100644
--- a/sql/core/src/test/resources/sql-tests/results/cast.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
@@ -666,3 +666,68 @@ struct<>
 -- !query output
 org.apache.spark.SparkArithmeticException
 [CAST_OVERFLOW] The value INTERVAL '1000000' SECOND of the type "INTERVAL SECOND" cannot be cast to "SMALLINT" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
+
+
+-- !query
+select cast(interval '-1' year as decimal(10, 0))
+-- !query schema
+struct<CAST(INTERVAL '-1' YEAR AS DECIMAL(10,0)):decimal(10,0)>
+-- !query output
+-1
+
+
+-- !query
+select cast(interval '1.000001' second as decimal(10, 6))
+-- !query schema
+struct<CAST(INTERVAL '01.000001' SECOND AS DECIMAL(10,6)):decimal(10,6)>
+-- !query output
+1.000001
+
+
+-- !query
+select cast(interval '08:11:10.001' hour to second as decimal(10, 4))
+-- !query schema
+struct<CAST(INTERVAL '08:11:10.001' HOUR TO SECOND AS DECIMAL(10,4)):decimal(10,4)>
+-- !query output
+29470.0010
+
+
+-- !query
+select cast(interval '1 01:02:03.1' day to second as decimal(8, 1))
+-- !query schema
+struct<CAST(INTERVAL '1 01:02:03.1' DAY TO SECOND AS DECIMAL(8,1)):decimal(8,1)>
+-- !query output
+90123.1
+
+
+-- !query
+select cast(interval '10.123' second as decimal(4, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(4,2)):decimal(4,2)>
+-- !query output
+10.12
+
+
+-- !query
+select cast(interval '10.005' second as decimal(4, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.005' SECOND AS DECIMAL(4,2)):decimal(4,2)>
+-- !query output
+10.01
+
+
+-- !query
+select cast(interval '10.123' second as decimal(5, 2))
+-- !query schema
+struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(5,2)):decimal(5,2)>
+-- !query output
+10.12
+
+
+-- !query
+select cast(interval '10.123' second as decimal(1, 0))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+[CANNOT_CHANGE_DECIMAL_PRECISION] Decimal(compact, 10, 18, 6) cannot be represented as Decimal(1, 0). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org