You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2021/10/09 13:56:41 UTC

[spark] branch master updated: [SPARK-36960][SQL] Pushdown filters with ANSI interval values to ORC

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ebfc6bb  [SPARK-36960][SQL] Pushdown filters with ANSI interval values to ORC
ebfc6bb is described below

commit ebfc6bbe0e9200f87ebb52fb71d009b2d71b956d
Author: Kousuke Saruta <sa...@oss.nttdata.com>
AuthorDate: Sat Oct 9 16:55:59 2021 +0300

    [SPARK-36960][SQL] Pushdown filters with ANSI interval values to ORC
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to push down filters with ANSI intervals to ORC.
    
    ### Why are the changes needed?
    
    After SPARK-36931 (#34184), V1 and V2 ORC datasources support ANSI intervals. So it's great to be able to push down filters with ANSI interval values for the better performance.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New tests.
    
    Closes #34224 from sarutak/orc-ansi-interval-pushdown.
    
    Lead-authored-by: Kousuke Saruta <sa...@oss.nttdata.com>
    Co-authored-by: Kousuke Saruta <sa...@oss.nttdata.co.jp>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../apache/spark/sql/catalyst/dsl/package.scala    |  4 +-
 .../sql/execution/datasources/orc/OrcFilters.scala | 10 ++-
 .../execution/datasources/orc/OrcFilterSuite.scala | 97 ++++++++++++++++++++++
 3 files changed, 108 insertions(+), 3 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 4a97a8d..979c280 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst
 
 import java.sql.{Date, Timestamp}
-import java.time.{Instant, LocalDate}
+import java.time.{Duration, Instant, LocalDate, Period}
 
 import scala.language.implicitConversions
 
@@ -167,6 +167,8 @@ package object dsl {
     implicit def timestampToLiteral(t: Timestamp): Literal = Literal(t)
     implicit def instantToLiteral(i: Instant): Literal = Literal(i)
     implicit def binaryToLiteral(a: Array[Byte]): Literal = Literal(a)
+    implicit def periodToLiteral(p: Period): Literal = Literal(p)
+    implicit def durationToLiteral(d: Duration): Literal = Literal(d)
 
     implicit def symbolToUnresolvedAttribute(s: Symbol): analysis.UnresolvedAttribute =
       analysis.UnresolvedAttribute(s.name)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
index 5abfa4c..8e02fc3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution.datasources.orc
 
-import java.time.{Instant, LocalDate}
+import java.time.{Duration, Instant, LocalDate, Period}
 
 import org.apache.hadoop.hive.common.`type`.HiveDecimal
 import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument}
@@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder
 import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable
 
 import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp}
+import org.apache.spark.sql.catalyst.util.IntervalUtils
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.Filter
@@ -140,7 +141,8 @@ private[sql] object OrcFilters extends OrcFiltersBase {
    */
   def getPredicateLeafType(dataType: DataType): PredicateLeaf.Type = dataType match {
     case BooleanType => PredicateLeaf.Type.BOOLEAN
-    case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG
+    case ByteType | ShortType | IntegerType | LongType |
+         _: AnsiIntervalType => PredicateLeaf.Type.LONG
     case FloatType | DoubleType => PredicateLeaf.Type.FLOAT
     case StringType => PredicateLeaf.Type.STRING
     case DateType => PredicateLeaf.Type.DATE
@@ -166,6 +168,10 @@ private[sql] object OrcFilters extends OrcFiltersBase {
       toJavaDate(localDateToDays(value.asInstanceOf[LocalDate]))
     case _: TimestampType if value.isInstanceOf[Instant] =>
       toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant]))
+    case _: YearMonthIntervalType =>
+      IntervalUtils.periodToMonths(value.asInstanceOf[Period]).longValue()
+    case _: DayTimeIntervalType =>
+      IntervalUtils.durationToMicros(value.asInstanceOf[Duration])
     case _ => value
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
index 681ed91..c53cc10 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc
 import java.math.MathContext
 import java.nio.charset.StandardCharsets
 import java.sql.{Date, Timestamp}
+import java.time.{Duration, Period}
 
 import scala.collection.JavaConverters._
 
@@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
 import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
+import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
@@ -383,6 +385,101 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
     }
   }
 
+  test("SPARK-36960: filter pushdown - year-month interval") {
+    DataTypeTestUtils.yearMonthIntervalTypes.foreach { ymIntervalType =>
+
+      def periods(i: Int): Expression = Literal(Period.of(i, i, 0)).cast(ymIntervalType)
+
+      val baseDF = spark.createDataFrame((1 to 4).map { i =>
+        Tuple1.apply(Period.of(i, i, 0))
+      }).select(col("_1").cast(ymIntervalType))
+
+      withNestedOrcDataFrame(baseDF) {
+        case (inputDF, colName, _) =>
+          implicit val df: DataFrame = inputDF
+
+          val ymIntervalAttr = df(colName).expr
+          assert(df(colName).expr.dataType === ymIntervalType)
+
+         checkFilterPredicate(ymIntervalAttr.isNull, PredicateLeaf.Operator.IS_NULL)
+
+          checkFilterPredicate(ymIntervalAttr === periods(1),
+            PredicateLeaf.Operator.EQUALS)
+          checkFilterPredicate(ymIntervalAttr <=> periods(1),
+            PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+          checkFilterPredicate(ymIntervalAttr < periods(2),
+            PredicateLeaf.Operator.LESS_THAN)
+          checkFilterPredicate(ymIntervalAttr > periods(3),
+            PredicateLeaf.Operator.LESS_THAN_EQUALS)
+          checkFilterPredicate(ymIntervalAttr <= periods(1),
+            PredicateLeaf.Operator.LESS_THAN_EQUALS)
+          checkFilterPredicate(ymIntervalAttr >= periods(4),
+            PredicateLeaf.Operator.LESS_THAN)
+
+          checkFilterPredicate(periods(1) === ymIntervalAttr,
+            PredicateLeaf.Operator.EQUALS)
+          checkFilterPredicate(periods(1) <=> ymIntervalAttr,
+            PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+          checkFilterPredicate(periods(2) > ymIntervalAttr,
+            PredicateLeaf.Operator.LESS_THAN)
+          checkFilterPredicate(periods(3) < ymIntervalAttr,
+            PredicateLeaf.Operator.LESS_THAN_EQUALS)
+          checkFilterPredicate(periods(1) >= ymIntervalAttr,
+            PredicateLeaf.Operator.LESS_THAN_EQUALS)
+          checkFilterPredicate(periods(4) <= ymIntervalAttr,
+            PredicateLeaf.Operator.LESS_THAN)
+      }
+    }
+  }
+
+  test("SPARK-36960: filter pushdown - day-time interval") {
+    DataTypeTestUtils.dayTimeIntervalTypes.foreach { dtIntervalType =>
+
+      def durations(i: Int): Expression =
+        Literal(Duration.ofDays(i).plusHours(i).plusMinutes(i).plusSeconds(i)).cast(dtIntervalType)
+
+      val baseDF = spark.createDataFrame((1 to 4).map { i =>
+        Tuple1.apply(Duration.ofDays(i).plusHours(i).plusMinutes(i).plusSeconds(i))
+      }).select(col("_1").cast(dtIntervalType))
+
+      withNestedOrcDataFrame(baseDF) {
+        case (inputDF, colName, _) =>
+          implicit val df: DataFrame = inputDF
+
+          val ymIntervalAttr = df(colName).expr
+          assert(df(colName).expr.dataType === dtIntervalType)
+
+          checkFilterPredicate(ymIntervalAttr.isNull, PredicateLeaf.Operator.IS_NULL)
+
+          checkFilterPredicate(ymIntervalAttr === durations(1),
+            PredicateLeaf.Operator.EQUALS)
+          checkFilterPredicate(ymIntervalAttr <=> durations(1),
+            PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+          checkFilterPredicate(ymIntervalAttr < durations(2),
+            PredicateLeaf.Operator.LESS_THAN)
+          checkFilterPredicate(ymIntervalAttr > durations(3),
+            PredicateLeaf.Operator.LESS_THAN_EQUALS)
+          checkFilterPredicate(ymIntervalAttr <= durations(1),
+            PredicateLeaf.Operator.LESS_THAN_EQUALS)
+          checkFilterPredicate(ymIntervalAttr >= durations(4),
+            PredicateLeaf.Operator.LESS_THAN)
+
+          checkFilterPredicate(durations(1) === ymIntervalAttr,
+            PredicateLeaf.Operator.EQUALS)
+          checkFilterPredicate(durations(1) <=> ymIntervalAttr,
+            PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+          checkFilterPredicate(durations(2) > ymIntervalAttr,
+            PredicateLeaf.Operator.LESS_THAN)
+          checkFilterPredicate(durations(3) < ymIntervalAttr,
+            PredicateLeaf.Operator.LESS_THAN_EQUALS)
+          checkFilterPredicate(durations(1) >= ymIntervalAttr,
+            PredicateLeaf.Operator.LESS_THAN_EQUALS)
+          checkFilterPredicate(durations(4) <= ymIntervalAttr,
+            PredicateLeaf.Operator.LESS_THAN)
+      }
+    }
+  }
+
   test("no filter pushdown - non-supported types") {
     implicit class IntToBinary(int: Int) {
       def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org