You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2021/10/15 14:17:19 UTC

[spark] branch master updated: [SPARK-35926][SQL] Add support YearMonthIntervalType for width_bucket

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d061e3  [SPARK-35926][SQL] Add support YearMonthIntervalType for width_bucket
9d061e3 is described below

commit 9d061e3939a021c602c070fc13cef951a8f94c82
Author: PengLei <pe...@gmail.com>
AuthorDate: Fri Oct 15 17:15:50 2021 +0300

    [SPARK-35926][SQL] Add support YearMonthIntervalType for width_bucket
    
    ### What changes were proposed in this pull request?
    Support width_bucket(YearMonthIntervalType, YearMonthIntervalType, YearMonthIntervalType, Long), it return long result
     eg:
    ```
    width_bucket(input_value, min_value, max_value, bucket_nums)
    width_bucket(INTERVAL '1' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10)
    It will divides the range between the max_value and min_value into 10 buckets.
    [ INTERVAL '0' YEAR,  INTERVAL '1' YEAR),  [ INTERVAL '1' YEAR,  INTERVAL '2' YEAR)......  [INTERVAL '9' YEAR,  INTERVAL '10' YEAR)
    Then, calculates which bucket the given input_value locate.
    ```
    
    The function `width_bucket` is introduced from [SPARK-21117](https://issues.apache.org/jira/browse/SPARK-21117)
    ### Why are the changes needed?
    [35926](https://issues.apache.org/jira/browse/SPARK-35926)
    1. The `WIDTH_BUCKET` function assigns values to buckets (individual segments) in an equiwidth histogram. The ANSI SQL Standard Syntax is like follow: `WIDTH_BUCKET( expression, min, max, buckets)`. [Reference](https://www.oreilly.com/library/view/sql-in-a/9780596155322/re91.html).
    2. `WIDTH_BUCKET` just support `Double` at now, Of course, we can cast `Int` to `Double` to use it. But we cloud not cast `YearMonthIntervayType` to `Double`.
    3. I think it has a use scenario. eg:  Histogram of employee years of service, the `years of service` is a column of `YearMonthIntervalType` dataType.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. The user can use `width_bucket` with YearMonthIntervalType.
    
    ### How was this patch tested?
    Add ut test
    
    Closes #33132 from Peng-Lei/SPARK-35926.
    
    Authored-by: PengLei <pe...@gmail.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../sql/catalyst/expressions/mathExpressions.scala | 33 ++++++++++++++++++----
 .../expressions/MathExpressionsSuite.scala         | 15 ++++++++++
 .../test/resources/sql-tests/inputs/interval.sql   |  2 ++
 .../sql-tests/results/ansi/interval.sql.out        | 18 +++++++++++-
 .../resources/sql-tests/results/interval.sql.out   | 18 +++++++++++-
 .../org/apache/spark/sql/MathFunctionsSuite.scala  | 17 +++++++++++
 6 files changed, 96 insertions(+), 7 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index c14fa72..6c34ed6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.util.NumberConverter
+import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -1613,6 +1613,10 @@ object WidthBucket {
        5
       > SELECT _FUNC_(-0.9, 5.2, 0.5, 2);
        3
+      > SELECT _FUNC_(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10);
+       1
+      > SELECT _FUNC_(INTERVAL '1' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10);
+       2
   """,
   since = "3.1.0",
   group = "math_funcs")
@@ -1623,16 +1627,35 @@ case class WidthBucket(
     numBucket: Expression)
   extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(
+    TypeCollection(DoubleType, YearMonthIntervalType),
+    TypeCollection(DoubleType, YearMonthIntervalType),
+    TypeCollection(DoubleType, YearMonthIntervalType),
+    LongType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    super.checkInputDataTypes() match {
+      case TypeCheckSuccess =>
+        (value.dataType, minValue.dataType, maxValue.dataType) match {
+          case (_: YearMonthIntervalType, _: YearMonthIntervalType, _: YearMonthIntervalType) =>
+            TypeCheckSuccess
+          case _ =>
+            val types = Seq(value.dataType, minValue.dataType, maxValue.dataType)
+            TypeUtils.checkForSameTypeInputExpr(types, s"function $prettyName")
+        }
+      case f => f
+    }
+  }
+
   override def dataType: DataType = LongType
   override def nullable: Boolean = true
   override def prettyName: String = "width_bucket"
 
   override protected def nullSafeEval(input: Any, min: Any, max: Any, numBucket: Any): Any = {
     WidthBucket.computeBucketNumber(
-      input.asInstanceOf[Double],
-      min.asInstanceOf[Double],
-      max.asInstanceOf[Double],
+      input.asInstanceOf[Number].doubleValue(),
+      min.asInstanceOf[Number].doubleValue(),
+      max.asInstanceOf[Number].doubleValue(),
       numBucket.asInstanceOf[Long])
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index aced787..bfb9614 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -725,4 +725,19 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(Signum(Literal(Duration.of(Long.MaxValue, ChronoUnit.MICROS))), 1.0)
     checkEvaluation(Signum(Literal(Duration.of(Long.MinValue, ChronoUnit.MICROS))), -1.0)
   }
+
+  test("SPARK-35926: Support YearMonthIntervalType in width-bucket function") {
+    Seq(
+      (Period.ofMonths(-1), Period.ofYears(0), Period.ofYears(10), 10L) -> 0L,
+      (Period.ofMonths(0), Period.ofYears(0), Period.ofYears(10), 10L) -> 1L,
+      (Period.ofMonths(13), Period.ofYears(0), Period.ofYears(10), 10L) -> 2L,
+      (Period.ofYears(1), Period.ofYears(0), Period.ofYears(10), 10L) -> 2L,
+      (Period.ofYears(1), Period.ofYears(0), Period.ofYears(1), 10L) -> 11L,
+      (Period.ofMonths(Int.MaxValue), Period.ofYears(0), Period.ofYears(1), 10L) -> 11L,
+      (Period.ofMonths(0), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10L) -> 6L,
+      (Period.ofMonths(-1), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10L) -> 5L
+    ).foreach { case ((v, s, e, n), expected) =>
+      checkEvaluation(WidthBucket(Literal(v), Literal(s), Literal(e), Literal(n)), expected)
+    }
+  }
 }
diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/interval.sql
index 7dd7e4e..2d1d8c4 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/interval.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/interval.sql
@@ -382,3 +382,5 @@ SELECT signum(INTERVAL '0-0' YEAR TO MONTH);
 SELECT signum(INTERVAL '-10' DAY);
 SELECT signum(INTERVAL '10' HOUR);
 SELECT signum(INTERVAL '0 0:0:0' DAY TO SECOND);
+SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10);
+SELECT width_bucket(INTERVAL '-1' YEAR, INTERVAL -'1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10);
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
index cff294e..294e5c9 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 282
+-- Number of queries: 284
 
 
 -- !query
@@ -2657,3 +2657,19 @@ SELECT signum(INTERVAL '0 0:0:0' DAY TO SECOND)
 struct<SIGNUM(INTERVAL '0 00:00:00' DAY TO SECOND):double>
 -- !query output
 0.0
+
+
+-- !query
+SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10)
+-- !query schema
+struct<width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10):bigint>
+-- !query output
+1
+
+
+-- !query
+SELECT width_bucket(INTERVAL '-1' YEAR, INTERVAL -'1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10)
+-- !query schema
+struct<width_bucket(INTERVAL '-1' YEAR, INTERVAL '-1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10):bigint>
+-- !query output
+1
diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out
index 688cde5..5d2edba 100644
--- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 282
+-- Number of queries: 284
 
 
 -- !query
@@ -2646,3 +2646,19 @@ SELECT signum(INTERVAL '0 0:0:0' DAY TO SECOND)
 struct<SIGNUM(INTERVAL '0 00:00:00' DAY TO SECOND):double>
 -- !query output
 0.0
+
+
+-- !query
+SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10)
+-- !query schema
+struct<width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10):bigint>
+-- !query output
+1
+
+
+-- !query
+SELECT width_bucket(INTERVAL '-1' YEAR, INTERVAL -'1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10)
+-- !query schema
+struct<width_bucket(INTERVAL '-1' YEAR, INTERVAL '-1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10):bigint>
+-- !query output
+1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
index 3512e5c..ce25a88 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import java.nio.charset.StandardCharsets
+import java.time.Period
 
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.functions.{log => logarithm}
@@ -520,4 +521,20 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
     checkAnswer(df.selectExpr("positive(a)"), Row(1))
     checkAnswer(df.selectExpr("positive(b)"), Row(-1))
   }
+
+  test("SPARK-35926: Support YearMonthIntervalType in width-bucket function") {
+    Seq(
+      (Period.ofMonths(-1), Period.ofYears(0), Period.ofYears(10), 10) -> 0,
+      (Period.ofMonths(0), Period.ofYears(0), Period.ofYears(10), 10) -> 1,
+      (Period.ofMonths(13), Period.ofYears(0), Period.ofYears(10), 10) -> 2,
+      (Period.ofYears(1), Period.ofYears(0), Period.ofYears(10), 10) -> 2,
+      (Period.ofYears(1), Period.ofYears(0), Period.ofYears(1), 10) -> 11,
+      (Period.ofMonths(Int.MaxValue), Period.ofYears(0), Period.ofYears(1), 10) -> 11,
+      (Period.ofMonths(0), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10) -> 6,
+      (Period.ofMonths(-1), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10) -> 5
+    ).foreach { case ((value, start, end, num), expected) =>
+      val df = Seq((value, start, end, num)).toDF("v", "s", "e", "n")
+      checkAnswer(df.selectExpr("width_bucket(v, s, e, n)"), Row(expected))
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org