You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2021/10/30 17:04:35 UTC

[spark] branch master updated: [SPARK-37138][SQL] Support ANSI Interval types in ApproxCountDistinctForIntervals/ApproximatePercentile/Percentile

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 08123a3  [SPARK-37138][SQL] Support ANSI Interval types in ApproxCountDistinctForIntervals/ApproximatePercentile/Percentile
08123a3 is described below

commit 08123a3795683238352e5bf55452de381349fdd9
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Sat Oct 30 20:03:20 2021 +0300

    [SPARK-37138][SQL] Support ANSI Interval types in ApproxCountDistinctForIntervals/ApproximatePercentile/Percentile
    
    ### What changes were proposed in this pull request?
    
    Support Ansi Interval types in the agg expressions:
    - ApproxCountDistinctForIntervals
    - ApproximatePercentile
    - Percentile
    
    ### Why are the changes needed?
    To improve user experience with Spark SQL.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Added new UT.
    
    Closes #34412 from AngersZhuuuu/SPARK-37138.
    
    Authored-by: Angerszhuuuu <an...@gmail.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../ApproxCountDistinctForIntervals.scala          | 13 +++---
 .../aggregate/ApproximatePercentile.scala          | 32 ++++++++------
 .../expressions/aggregate/Percentile.scala         | 26 +++++++++---
 .../ApproxCountDistinctForIntervalsSuite.scala     |  6 ++-
 .../expressions/aggregate/PercentileSuite.scala    |  8 ++--
 ...ApproxCountDistinctForIntervalsQuerySuite.scala | 28 +++++++++++++
 .../sql/ApproximatePercentileQuerySuite.scala      | 22 +++++++++-
 .../apache/spark/sql/PercentileQuerySuite.scala    | 49 ++++++++++++++++++++++
 8 files changed, 153 insertions(+), 31 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
index a7e9a22..f3bf251 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
@@ -61,7 +61,8 @@ case class ApproxCountDistinctForIntervals(
   }
 
   override def inputTypes: Seq[AbstractDataType] = {
-    Seq(TypeCollection(NumericType, TimestampType, DateType, TimestampNTZType), ArrayType)
+    Seq(TypeCollection(NumericType, TimestampType, DateType, TimestampNTZType,
+      YearMonthIntervalType, DayTimeIntervalType), ArrayType)
   }
 
   // Mark as lazy so that endpointsExpression is not evaluated during tree transformation.
@@ -79,14 +80,16 @@ case class ApproxCountDistinctForIntervals(
       TypeCheckFailure("The endpoints provided must be constant literals")
     } else {
       endpointsExpression.dataType match {
-        case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType, _) =>
+        case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType |
+           _: AnsiIntervalType, _) =>
           if (endpoints.length < 2) {
             TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")
           } else {
             TypeCheckSuccess
           }
         case _ =>
-          TypeCheckFailure("Endpoints require (numeric or timestamp or date) type")
+          TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " +
+            "interval year to month or interval day to second) type")
       }
     }
   }
@@ -120,9 +123,9 @@ case class ApproxCountDistinctForIntervals(
       val doubleValue = child.dataType match {
         case n: NumericType =>
           n.numeric.toDouble(value.asInstanceOf[n.InternalType])
-        case _: DateType =>
+        case _: DateType | _: YearMonthIntervalType =>
           value.asInstanceOf[Int].toDouble
-        case TimestampType | TimestampNTZType =>
+        case TimestampType | TimestampNTZType | _: DayTimeIntervalType =>
           value.asInstanceOf[Long].toDouble
       }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 8cce79c..0dcb906 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -49,15 +49,16 @@ import org.apache.spark.sql.types._
  *                           yields better accuracy, the default value is
  *                           DEFAULT_PERCENTILE_ACCURACY.
  */
+// scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = """
-    _FUNC_(col, percentage [, accuracy]) - Returns the approximate `percentile` of the numeric
-      column `col` which is the smallest value in the ordered `col` values (sorted from least to
-      greatest) such that no more than `percentage` of `col` values is less than the value
-      or equal to that value. The value of percentage must be between 0.0 and 1.0. The `accuracy`
-      parameter (default: 10000) is a positive numeric literal which controls approximation accuracy
-      at the cost of memory. Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is
-      the relative error of the approximation.
+    _FUNC_(col, percentage [, accuracy]) - Returns the approximate `percentile` of the numeric or
+      ansi interval column `col` which is the smallest value in the ordered `col` values (sorted
+      from least to greatest) such that no more than `percentage` of `col` values is less than
+      the value or equal to that value. The value of percentage must be between 0.0 and 1.0.
+      The `accuracy` parameter (default: 10000) is a positive numeric literal which controls
+      approximation accuracy at the cost of memory. Higher value of `accuracy` yields better
+      accuracy, `1.0/accuracy` is the relative error of the approximation.
       When `percentage` is an array, each value of the percentage array must be between 0.0 and 1.0.
       In this case, returns the approximate percentile array of column `col` at the given
       percentage array.
@@ -68,9 +69,14 @@ import org.apache.spark.sql.types._
        [1,1,0]
       > SELECT _FUNC_(col, 0.5, 100) FROM VALUES (0), (6), (7), (9), (10) AS tab(col);
        7
+      > SELECT _FUNC_(col, 0.5, 100) FROM VALUES (INTERVAL '0' MONTH), (INTERVAL '1' MONTH), (INTERVAL '2' MONTH), (INTERVAL '10' MONTH) AS tab(col);
+       0-1
+      > SELECT _FUNC_(col, array(0.5, 0.7), 100) FROM VALUES (INTERVAL '0' SECOND), (INTERVAL '1' SECOND), (INTERVAL '2' SECOND), (INTERVAL '10' SECOND) AS tab(col);
+       [0 00:00:01.000000000,0 00:00:02.000000000]
   """,
   group = "agg_funcs",
   since = "2.1.0")
+// scalastyle:on line.size.limit
 case class ApproximatePercentile(
     child: Expression,
     percentageExpression: Expression,
@@ -94,7 +100,8 @@ case class ApproximatePercentile(
   override def inputTypes: Seq[AbstractDataType] = {
     // Support NumericType, DateType, TimestampType and TimestampNTZType since their internal types
     // are all numeric, and can be easily cast to double for processing.
-    Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType),
+    Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType,
+      YearMonthIntervalType, DayTimeIntervalType),
       TypeCollection(DoubleType, ArrayType(DoubleType, containsNull = false)), IntegralType)
   }
 
@@ -138,8 +145,9 @@ case class ApproximatePercentile(
     if (value != null) {
       // Convert the value to a double value
       val doubleValue = child.dataType match {
-        case DateType => value.asInstanceOf[Int].toDouble
-        case TimestampType | TimestampNTZType => value.asInstanceOf[Long].toDouble
+        case DateType | _: YearMonthIntervalType => value.asInstanceOf[Int].toDouble
+        case TimestampType | TimestampNTZType | _: DayTimeIntervalType =>
+          value.asInstanceOf[Long].toDouble
         case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType])
         case other: DataType =>
           throw QueryExecutionErrors.dataTypeUnexpectedError(other)
@@ -157,8 +165,8 @@ case class ApproximatePercentile(
   override def eval(buffer: PercentileDigest): Any = {
     val doubleResult = buffer.getPercentiles(percentages)
     val result = child.dataType match {
-      case DateType => doubleResult.map(_.toInt)
-      case TimestampType | TimestampNTZType => doubleResult.map(_.toLong)
+      case DateType | _: YearMonthIntervalType => doubleResult.map(_.toInt)
+      case TimestampType | TimestampNTZType | _: DayTimeIntervalType => doubleResult.map(_.toLong)
       case ByteType => doubleResult.map(_.toByte)
       case ShortType => doubleResult.map(_.toShort)
       case IntegerType => doubleResult.map(_.toInt)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 5bce4d3..7d3dd0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -43,12 +43,13 @@ import org.apache.spark.util.collection.OpenHashMap
  *                             percentage values. Each percentage value must be in the range
  *                             [0.0, 1.0].
  */
+// scalastyle:off line.size.limit
 @ExpressionDescription(
   usage =
     """
-      _FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric column
-       `col` at the given percentage. The value of percentage must be between 0.0 and 1.0. The
-       value of frequency should be positive integral
+      _FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric
+       or ansi interval column `col` at the given percentage. The value of percentage must be
+       between 0.0 and 1.0. The value of frequency should be positive integral
 
       _FUNC_(col, array(percentage1 [, percentage2]...) [, frequency]) - Returns the exact
       percentile value array of numeric column `col` at the given percentage(s). Each value
@@ -62,9 +63,14 @@ import org.apache.spark.util.collection.OpenHashMap
        3.0
       > SELECT _FUNC_(col, array(0.25, 0.75)) FROM VALUES (0), (10) AS tab(col);
        [2.5,7.5]
+      > SELECT _FUNC_(col, 0.5) FROM VALUES (INTERVAL '0' MONTH), (INTERVAL '10' MONTH) AS tab(col);
+       5.0
+      > SELECT _FUNC_(col, array(0.2, 0.5)) FROM VALUES (INTERVAL '0' SECOND), (INTERVAL '10' SECOND) AS tab(col);
+       [2000000.0,5000000.0]
   """,
   group = "agg_funcs",
   since = "2.1.0")
+// scalastyle:on line.size.limit
 case class Percentile(
     child: Expression,
     percentageExpression: Expression,
@@ -118,7 +124,8 @@ case class Percentile(
       case _: ArrayType => ArrayType(DoubleType, false)
       case _ => DoubleType
     }
-    Seq(NumericType, percentageExpType, IntegralType)
+    Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType),
+      percentageExpType, IntegralType)
   }
 
   // Check the inputTypes are valid, and the percentageExpression satisfies:
@@ -191,8 +198,15 @@ case class Percentile(
       return Seq.empty
     }
 
-    val sortedCounts = buffer.toSeq.sortBy(_._1)(
-      child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
+    val ordering =
+      if (child.dataType.isInstanceOf[NumericType]) {
+        child.dataType.asInstanceOf[NumericType].ordering
+      } else if (child.dataType.isInstanceOf[YearMonthIntervalType]) {
+        child.dataType.asInstanceOf[YearMonthIntervalType].ordering
+      } else if (child.dataType.isInstanceOf[DayTimeIntervalType]) {
+        child.dataType.asInstanceOf[DayTimeIntervalType].ordering
+      }
+    val sortedCounts = buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]])
     val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) {
       case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
     }.tail
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
index 9d53673..a017e5b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
@@ -39,7 +39,8 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
       assert(
         wrongColumn.checkInputDataTypes() match {
           case TypeCheckFailure(msg)
-            if msg.contains("requires (numeric or timestamp or date or timestamp_ntz) type") => true
+            if msg.contains("requires (numeric or timestamp or date or timestamp_ntz or " +
+              "interval year to month or interval day to second) type") => true
           case _ => false
         })
     }
@@ -69,7 +70,8 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
       AttributeReference("a", DoubleType)(),
       endpointsExpression = CreateArray(Array("foobar").map(Literal(_))))
     assert(wrongEndpoints.checkInputDataTypes() ==
-        TypeCheckFailure("Endpoints require (numeric or timestamp or date) type"))
+      TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " +
+        "interval year to month or interval day to second) type"))
   }
 
   /** Create an ApproxCountDistinctForIntervals instance and an input and output buffer. */
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
index fa87407..b5882b1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
@@ -170,8 +170,8 @@ class PercentileSuite extends SparkFunSuite {
       val child = AttributeReference("a", dataType)()
       val percentile = new Percentile(child, percentage)
       assertEqual(percentile.checkInputDataTypes(),
-        TypeCheckFailure(s"argument 1 requires numeric type, however, " +
-            s"'a' is of ${dataType.simpleString} type."))
+        TypeCheckFailure(s"argument 1 requires (numeric or interval year to month or " +
+          s"interval day to second) type, however, 'a' is of ${dataType.simpleString} type."))
     }
 
     val invalidFrequencyDataTypes = Seq(FloatType, DoubleType, BooleanType,
@@ -184,8 +184,8 @@ class PercentileSuite extends SparkFunSuite {
       val frq = AttributeReference("frq", frequencyType)()
       val percentile = new Percentile(child, percentage, frq)
       assertEqual(percentile.checkInputDataTypes(),
-        TypeCheckFailure(s"argument 1 requires numeric type, however, " +
-            s"'a' is of ${dataType.simpleString} type."))
+        TypeCheckFailure(s"argument 1 requires (numeric or interval year to month or " +
+          s"interval day to second) type, however, 'a' is of ${dataType.simpleString} type."))
     }
 
     for(dataType <- validDataTypes;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
index 171e93c..53662c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import java.time.{Duration, Period}
+
 import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
 import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
 import org.apache.spark.sql.catalyst.plans.logical.Aggregate
@@ -58,4 +60,30 @@ class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSpa
       }
     }
   }
+
+  test("SPARK-37138: Support Ansi Interval type in ApproxCountDistinctForIntervals") {
+    val table = "approx_count_distinct_for_ansi_intervals_tbl"
+    withTable(table) {
+      Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
+        (Period.ofMonths(200), Duration.ofSeconds(200L)),
+        (Period.ofMonths(300), Duration.ofSeconds(300L)))
+        .toDF("col1", "col2").createOrReplaceTempView(table)
+      val endpoints = (0 to 5).map(_ / 10)
+
+      val relation = spark.table(table).logicalPlan
+      val ymAttr = relation.output.find(_.name == "col1").get
+      val ymAggFunc =
+        ApproxCountDistinctForIntervals(ymAttr, CreateArray(endpoints.map(Literal(_))))
+      val ymAggExpr = ymAggFunc.toAggregateExpression()
+      val ymNamedExpr = Alias(ymAggExpr, ymAggExpr.toString)()
+
+      val dtAttr = relation.output.find(_.name == "col2").get
+      val dtAggFunc =
+        ApproxCountDistinctForIntervals(dtAttr, CreateArray(endpoints.map(Literal(_))))
+      val dtAggExpr = dtAggFunc.toAggregateExpression()
+      val dtNamedExpr = Alias(dtAggExpr, dtAggExpr.toString)()
+      val result = Dataset.ofRows(spark, Aggregate(Nil, Seq(ymNamedExpr, dtNamedExpr), relation))
+      checkAnswer(result, Row(Array(1, 1, 1, 1, 1), Array(1, 1, 1, 1, 1)))
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
index 5ff15c9..9237c9e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql
 
 import java.sql.{Date, Timestamp}
-import java.time.LocalDateTime
+import java.time.{Duration, LocalDateTime, Period}
 
 import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
 import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
@@ -32,7 +32,7 @@ import org.apache.spark.sql.test.SharedSparkSession
 class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession {
   import testImplicits._
 
-  private val table = "percentile_test"
+  private val table = "percentile_approx"
 
   test("percentile_approx, single percentile value") {
     withTempView(table) {
@@ -319,4 +319,22 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession
         Row(18, 17, 17, 17))
     }
   }
+
+  test("SPARK-37138: Support Ansi Interval type in ApproximatePercentile") {
+    withTempView(table) {
+      Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
+        (Period.ofMonths(200), Duration.ofSeconds(200L)),
+        (Period.ofMonths(300), Duration.ofSeconds(300L)))
+        .toDF("col1", "col2").createOrReplaceTempView(table)
+        checkAnswer(
+          spark.sql(
+            s"""SELECT
+               |  percentile_approx(col1, 0.5),
+               |  SUM(null),
+               |  percentile_approx(col2, 0.5)
+               |FROM $table
+           """.stripMargin),
+          Row(Period.ofMonths(200).normalized(), null, Duration.ofSeconds(200L)))
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala
new file mode 100644
index 0000000..f39f0c1
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.time.{Duration, Period}
+
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * End-to-end tests for percentile aggregate function.
+ */
+class PercentileQuerySuite extends QueryTest with SharedSparkSession {
+  import testImplicits._
+
+  private val table = "percentile_test"
+
+  test("SPARK-37138: Support Ansi Interval type in Percentile") {
+    withTempView(table) {
+      Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
+        (Period.ofMonths(200), Duration.ofSeconds(200L)),
+        (Period.ofMonths(300), Duration.ofSeconds(300L)))
+        .toDF("col1", "col2").createOrReplaceTempView(table)
+      checkAnswer(
+        spark.sql(
+          s"""SELECT
+             |  CAST(percentile(col1, 0.5) AS STRING),
+             |  SUM(null),
+             |  CAST(percentile(col2, 0.5) AS STRING)
+             |FROM $table
+           """.stripMargin),
+        Row("200.0", null, "2.0E8"))
+    }
+  }
+}

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org