You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2021/07/08 06:42:36 UTC

[spark] branch branch-3.2 updated: [SPARK-36022][SQL] Respect interval fields in extract

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 429d178  [SPARK-36022][SQL] Respect interval fields in extract
429d178 is described below

commit 429d1780b3ab3267a5e22a857cf51458713bc208
Author: Kousuke Saruta <sa...@oss.nttdata.com>
AuthorDate: Thu Jul 8 09:40:57 2021 +0300

    [SPARK-36022][SQL] Respect interval fields in extract
    
    ### What changes were proposed in this pull request?
    
    This PR fixes an issue about `extract`.
    `Extract` should process only existing fields of interval types. For example:
    
    ```
    spark-sql> SELECT EXTRACT(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH);
    11
    spark-sql> SELECT EXTRACT(MONTH FROM INTERVAL '2021' YEAR);
    0
    ```
    The last command should fail as the month field doesn't present in INTERVAL YEAR.
    
    ### Why are the changes needed?
    
    Bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New tests.
    
    Closes #33247 from sarutak/fix-extract-interval.
    
    Authored-by: Kousuke Saruta <sa...@oss.nttdata.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
    (cherry picked from commit 39002cb99514010f6d6cc2e575b9eab1694f04ef)
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../catalyst/expressions/intervalExpressions.scala | 24 ++++++--
 .../apache/spark/sql/IntervalFunctionsSuite.scala  | 64 ++++++++++++++++++++++
 2 files changed, 82 insertions(+), 6 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index 5d49007..5b111d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils._
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND}
+import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
 import org.apache.spark.unsafe.types.CalendarInterval
 
 abstract class ExtractIntervalPart[T](
@@ -125,33 +127,43 @@ object ExtractIntervalPart {
       source: Expression,
       errorHandleFunc: => Nothing): Expression = {
     (extractField.toUpperCase(Locale.ROOT), source.dataType) match {
-      case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", _: YearMonthIntervalType) =>
+      case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", YearMonthIntervalType(start, end))
+        if isUnitInIntervalRange(YEAR, start, end) =>
         ExtractANSIIntervalYears(source)
       case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", CalendarIntervalType) =>
         ExtractIntervalYears(source)
-      case ("MONTH" | "MON" | "MONS" | "MONTHS", _: YearMonthIntervalType) =>
+      case ("MONTH" | "MON" | "MONS" | "MONTHS", YearMonthIntervalType(start, end))
+        if isUnitInIntervalRange(MONTH, start, end) =>
         ExtractANSIIntervalMonths(source)
       case ("MONTH" | "MON" | "MONS" | "MONTHS", CalendarIntervalType) =>
         ExtractIntervalMonths(source)
-      case ("DAY" | "D" | "DAYS", _: DayTimeIntervalType) =>
+      case ("DAY" | "D" | "DAYS", DayTimeIntervalType(start, end))
+        if isUnitInIntervalRange(DAY, start, end) =>
         ExtractANSIIntervalDays(source)
       case ("DAY" | "D" | "DAYS", CalendarIntervalType) =>
         ExtractIntervalDays(source)
-      case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", _: DayTimeIntervalType) =>
+      case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", DayTimeIntervalType(start, end))
+        if isUnitInIntervalRange(HOUR, start, end) =>
         ExtractANSIIntervalHours(source)
       case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", CalendarIntervalType) =>
         ExtractIntervalHours(source)
-      case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", _: DayTimeIntervalType) =>
+      case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", DayTimeIntervalType(start, end))
+        if isUnitInIntervalRange(MINUTE, start, end) =>
         ExtractANSIIntervalMinutes(source)
       case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", CalendarIntervalType) =>
         ExtractIntervalMinutes(source)
-      case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", _: DayTimeIntervalType) =>
+      case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", DayTimeIntervalType(start, end))
+        if isUnitInIntervalRange(SECOND, start, end) =>
         ExtractANSIIntervalSeconds(source)
       case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", CalendarIntervalType) =>
         ExtractIntervalSeconds(source)
       case _ => errorHandleFunc
     }
   }
+
+  private def isUnitInIntervalRange(unit: Byte, start: Byte, end: Byte): Boolean = {
+    start <= unit && unit <= end
+  }
 }
 
 abstract class IntervalNumOperation(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntervalFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntervalFunctionsSuite.scala
new file mode 100644
index 0000000..c7e307b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntervalFunctionsSuite.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.time.{Duration, Period}
+
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{DayTimeIntervalType => DT, YearMonthIntervalType => YM}
+import org.apache.spark.sql.types.DataTypeTestUtils._
+
+class IntervalFunctionsSuite extends QueryTest with SharedSparkSession {
+  import testImplicits._
+
+  test("SPARK-36022: Respect interval fields in extract") {
+    yearMonthIntervalTypes.foreach { dtype =>
+      val ymDF = Seq(Period.of(1, 2, 0)).toDF.select($"value" cast dtype as "value")
+        .select($"value" cast dtype as "value")
+      val expectedMap = Map("year" -> 1, "month" -> 2)
+      YM.yearMonthFields.foreach { field =>
+        val extractUnit = YM.fieldToString(field)
+        val extractExpr = s"extract($extractUnit FROM value)"
+        if (dtype.startField <= field && field <= dtype.endField) {
+          checkAnswer(ymDF.selectExpr(extractExpr), Row(expectedMap(extractUnit)))
+        } else {
+          intercept[AnalysisException] {
+            ymDF.selectExpr(extractExpr)
+          }
+        }
+      }
+    }
+
+    dayTimeIntervalTypes.foreach { dtype =>
+      val dtDF = Seq(Duration.ofDays(1).plusHours(2).plusMinutes(3).plusSeconds(4)).toDF
+        .select($"value" cast dtype as "value")
+      val expectedMap = Map("day" -> 1, "hour" -> 2, "minute" -> 3, "second" -> 4)
+      DT.dayTimeFields.foreach { field =>
+        val extractUnit = DT.fieldToString(field)
+        val extractExpr = s"extract($extractUnit FROM value)"
+        if (dtype.startField <= field && field <= dtype.endField) {
+          checkAnswer(dtDF.selectExpr(extractExpr), Row(expectedMap(extractUnit)))
+        } else {
+          intercept[AnalysisException] {
+            dtDF.selectExpr(extractExpr)
+          }
+        }
+      }
+    }
+  }
+}

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org