You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2021/07/08 06:42:36 UTC
[spark] branch branch-3.2 updated: [SPARK-36022][SQL] Respect
interval fields in extract
This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new 429d178 [SPARK-36022][SQL] Respect interval fields in extract
429d178 is described below
commit 429d1780b3ab3267a5e22a857cf51458713bc208
Author: Kousuke Saruta <sa...@oss.nttdata.com>
AuthorDate: Thu Jul 8 09:40:57 2021 +0300
[SPARK-36022][SQL] Respect interval fields in extract
### What changes were proposed in this pull request?
This PR fixes an issue about `extract`.
`Extract` should process only existing fields of interval types. For example:
```
spark-sql> SELECT EXTRACT(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH);
11
spark-sql> SELECT EXTRACT(MONTH FROM INTERVAL '2021' YEAR);
0
```
The last command should fail as the month field doesn't present in INTERVAL YEAR.
### Why are the changes needed?
Bug fix.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New tests.
Closes #33247 from sarutak/fix-extract-interval.
Authored-by: Kousuke Saruta <sa...@oss.nttdata.com>
Signed-off-by: Max Gekk <ma...@gmail.com>
(cherry picked from commit 39002cb99514010f6d6cc2e575b9eab1694f04ef)
Signed-off-by: Max Gekk <ma...@gmail.com>
---
.../catalyst/expressions/intervalExpressions.scala | 24 ++++++--
.../apache/spark/sql/IntervalFunctionsSuite.scala | 64 ++++++++++++++++++++++
2 files changed, 82 insertions(+), 6 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index 5d49007..5b111d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND}
+import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
import org.apache.spark.unsafe.types.CalendarInterval
abstract class ExtractIntervalPart[T](
@@ -125,33 +127,43 @@ object ExtractIntervalPart {
source: Expression,
errorHandleFunc: => Nothing): Expression = {
(extractField.toUpperCase(Locale.ROOT), source.dataType) match {
- case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", _: YearMonthIntervalType) =>
+ case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", YearMonthIntervalType(start, end))
+ if isUnitInIntervalRange(YEAR, start, end) =>
ExtractANSIIntervalYears(source)
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", CalendarIntervalType) =>
ExtractIntervalYears(source)
- case ("MONTH" | "MON" | "MONS" | "MONTHS", _: YearMonthIntervalType) =>
+ case ("MONTH" | "MON" | "MONS" | "MONTHS", YearMonthIntervalType(start, end))
+ if isUnitInIntervalRange(MONTH, start, end) =>
ExtractANSIIntervalMonths(source)
case ("MONTH" | "MON" | "MONS" | "MONTHS", CalendarIntervalType) =>
ExtractIntervalMonths(source)
- case ("DAY" | "D" | "DAYS", _: DayTimeIntervalType) =>
+ case ("DAY" | "D" | "DAYS", DayTimeIntervalType(start, end))
+ if isUnitInIntervalRange(DAY, start, end) =>
ExtractANSIIntervalDays(source)
case ("DAY" | "D" | "DAYS", CalendarIntervalType) =>
ExtractIntervalDays(source)
- case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", _: DayTimeIntervalType) =>
+ case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", DayTimeIntervalType(start, end))
+ if isUnitInIntervalRange(HOUR, start, end) =>
ExtractANSIIntervalHours(source)
case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", CalendarIntervalType) =>
ExtractIntervalHours(source)
- case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", _: DayTimeIntervalType) =>
+ case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", DayTimeIntervalType(start, end))
+ if isUnitInIntervalRange(MINUTE, start, end) =>
ExtractANSIIntervalMinutes(source)
case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", CalendarIntervalType) =>
ExtractIntervalMinutes(source)
- case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", _: DayTimeIntervalType) =>
+ case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", DayTimeIntervalType(start, end))
+ if isUnitInIntervalRange(SECOND, start, end) =>
ExtractANSIIntervalSeconds(source)
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", CalendarIntervalType) =>
ExtractIntervalSeconds(source)
case _ => errorHandleFunc
}
}
+
+ private def isUnitInIntervalRange(unit: Byte, start: Byte, end: Byte): Boolean = {
+ start <= unit && unit <= end
+ }
}
abstract class IntervalNumOperation(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntervalFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntervalFunctionsSuite.scala
new file mode 100644
index 0000000..c7e307b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntervalFunctionsSuite.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.time.{Duration, Period}
+
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{DayTimeIntervalType => DT, YearMonthIntervalType => YM}
+import org.apache.spark.sql.types.DataTypeTestUtils._
+
+class IntervalFunctionsSuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ test("SPARK-36022: Respect interval fields in extract") {
+ yearMonthIntervalTypes.foreach { dtype =>
+ val ymDF = Seq(Period.of(1, 2, 0)).toDF.select($"value" cast dtype as "value")
+ .select($"value" cast dtype as "value")
+ val expectedMap = Map("year" -> 1, "month" -> 2)
+ YM.yearMonthFields.foreach { field =>
+ val extractUnit = YM.fieldToString(field)
+ val extractExpr = s"extract($extractUnit FROM value)"
+ if (dtype.startField <= field && field <= dtype.endField) {
+ checkAnswer(ymDF.selectExpr(extractExpr), Row(expectedMap(extractUnit)))
+ } else {
+ intercept[AnalysisException] {
+ ymDF.selectExpr(extractExpr)
+ }
+ }
+ }
+ }
+
+ dayTimeIntervalTypes.foreach { dtype =>
+ val dtDF = Seq(Duration.ofDays(1).plusHours(2).plusMinutes(3).plusSeconds(4)).toDF
+ .select($"value" cast dtype as "value")
+ val expectedMap = Map("day" -> 1, "hour" -> 2, "minute" -> 3, "second" -> 4)
+ DT.dayTimeFields.foreach { field =>
+ val extractUnit = DT.fieldToString(field)
+ val extractExpr = s"extract($extractUnit FROM value)"
+ if (dtype.startField <= field && field <= dtype.endField) {
+ checkAnswer(dtDF.selectExpr(extractExpr), Row(expectedMap(extractUnit)))
+ } else {
+ intercept[AnalysisException] {
+ dtDF.selectExpr(extractExpr)
+ }
+ }
+ }
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org