You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2021/10/30 17:04:35 UTC
[spark] branch master updated: [SPARK-37138][SQL] Support ANSI
Interval types in
ApproxCountDistinctForIntervals/ApproximatePercentile/Percentile
This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 08123a3 [SPARK-37138][SQL] Support ANSI Interval types in ApproxCountDistinctForIntervals/ApproximatePercentile/Percentile
08123a3 is described below
commit 08123a3795683238352e5bf55452de381349fdd9
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Sat Oct 30 20:03:20 2021 +0300
[SPARK-37138][SQL] Support ANSI Interval types in ApproxCountDistinctForIntervals/ApproximatePercentile/Percentile
### What changes were proposed in this pull request?
Support Ansi Interval types in the agg expressions:
- ApproxCountDistinctForIntervals
- ApproximatePercentile
- Percentile
### Why are the changes needed?
To improve user experience with Spark SQL.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added new UT.
Closes #34412 from AngersZhuuuu/SPARK-37138.
Authored-by: Angerszhuuuu <an...@gmail.com>
Signed-off-by: Max Gekk <ma...@gmail.com>
---
.../ApproxCountDistinctForIntervals.scala | 13 +++---
.../aggregate/ApproximatePercentile.scala | 32 ++++++++------
.../expressions/aggregate/Percentile.scala | 26 +++++++++---
.../ApproxCountDistinctForIntervalsSuite.scala | 6 ++-
.../expressions/aggregate/PercentileSuite.scala | 8 ++--
...ApproxCountDistinctForIntervalsQuerySuite.scala | 28 +++++++++++++
.../sql/ApproximatePercentileQuerySuite.scala | 22 +++++++++-
.../apache/spark/sql/PercentileQuerySuite.scala | 49 ++++++++++++++++++++++
8 files changed, 153 insertions(+), 31 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
index a7e9a22..f3bf251 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
@@ -61,7 +61,8 @@ case class ApproxCountDistinctForIntervals(
}
override def inputTypes: Seq[AbstractDataType] = {
- Seq(TypeCollection(NumericType, TimestampType, DateType, TimestampNTZType), ArrayType)
+ Seq(TypeCollection(NumericType, TimestampType, DateType, TimestampNTZType,
+ YearMonthIntervalType, DayTimeIntervalType), ArrayType)
}
// Mark as lazy so that endpointsExpression is not evaluated during tree transformation.
@@ -79,14 +80,16 @@ case class ApproxCountDistinctForIntervals(
TypeCheckFailure("The endpoints provided must be constant literals")
} else {
endpointsExpression.dataType match {
- case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType, _) =>
+ case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType |
+ _: AnsiIntervalType, _) =>
if (endpoints.length < 2) {
TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")
} else {
TypeCheckSuccess
}
case _ =>
- TypeCheckFailure("Endpoints require (numeric or timestamp or date) type")
+ TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " +
+ "interval year to month or interval day to second) type")
}
}
}
@@ -120,9 +123,9 @@ case class ApproxCountDistinctForIntervals(
val doubleValue = child.dataType match {
case n: NumericType =>
n.numeric.toDouble(value.asInstanceOf[n.InternalType])
- case _: DateType =>
+ case _: DateType | _: YearMonthIntervalType =>
value.asInstanceOf[Int].toDouble
- case TimestampType | TimestampNTZType =>
+ case TimestampType | TimestampNTZType | _: DayTimeIntervalType =>
value.asInstanceOf[Long].toDouble
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 8cce79c..0dcb906 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -49,15 +49,16 @@ import org.apache.spark.sql.types._
* yields better accuracy, the default value is
* DEFAULT_PERCENTILE_ACCURACY.
*/
+// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
- _FUNC_(col, percentage [, accuracy]) - Returns the approximate `percentile` of the numeric
- column `col` which is the smallest value in the ordered `col` values (sorted from least to
- greatest) such that no more than `percentage` of `col` values is less than the value
- or equal to that value. The value of percentage must be between 0.0 and 1.0. The `accuracy`
- parameter (default: 10000) is a positive numeric literal which controls approximation accuracy
- at the cost of memory. Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is
- the relative error of the approximation.
+ _FUNC_(col, percentage [, accuracy]) - Returns the approximate `percentile` of the numeric or
+ ansi interval column `col` which is the smallest value in the ordered `col` values (sorted
+ from least to greatest) such that no more than `percentage` of `col` values is less than
+ the value or equal to that value. The value of percentage must be between 0.0 and 1.0.
+ The `accuracy` parameter (default: 10000) is a positive numeric literal which controls
+ approximation accuracy at the cost of memory. Higher value of `accuracy` yields better
+ accuracy, `1.0/accuracy` is the relative error of the approximation.
When `percentage` is an array, each value of the percentage array must be between 0.0 and 1.0.
In this case, returns the approximate percentile array of column `col` at the given
percentage array.
@@ -68,9 +69,14 @@ import org.apache.spark.sql.types._
[1,1,0]
> SELECT _FUNC_(col, 0.5, 100) FROM VALUES (0), (6), (7), (9), (10) AS tab(col);
7
+ > SELECT _FUNC_(col, 0.5, 100) FROM VALUES (INTERVAL '0' MONTH), (INTERVAL '1' MONTH), (INTERVAL '2' MONTH), (INTERVAL '10' MONTH) AS tab(col);
+ 0-1
+ > SELECT _FUNC_(col, array(0.5, 0.7), 100) FROM VALUES (INTERVAL '0' SECOND), (INTERVAL '1' SECOND), (INTERVAL '2' SECOND), (INTERVAL '10' SECOND) AS tab(col);
+ [0 00:00:01.000000000,0 00:00:02.000000000]
""",
group = "agg_funcs",
since = "2.1.0")
+// scalastyle:on line.size.limit
case class ApproximatePercentile(
child: Expression,
percentageExpression: Expression,
@@ -94,7 +100,8 @@ case class ApproximatePercentile(
override def inputTypes: Seq[AbstractDataType] = {
// Support NumericType, DateType, TimestampType and TimestampNTZType since their internal types
// are all numeric, and can be easily cast to double for processing.
- Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType),
+ Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType,
+ YearMonthIntervalType, DayTimeIntervalType),
TypeCollection(DoubleType, ArrayType(DoubleType, containsNull = false)), IntegralType)
}
@@ -138,8 +145,9 @@ case class ApproximatePercentile(
if (value != null) {
// Convert the value to a double value
val doubleValue = child.dataType match {
- case DateType => value.asInstanceOf[Int].toDouble
- case TimestampType | TimestampNTZType => value.asInstanceOf[Long].toDouble
+ case DateType | _: YearMonthIntervalType => value.asInstanceOf[Int].toDouble
+ case TimestampType | TimestampNTZType | _: DayTimeIntervalType =>
+ value.asInstanceOf[Long].toDouble
case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType])
case other: DataType =>
throw QueryExecutionErrors.dataTypeUnexpectedError(other)
@@ -157,8 +165,8 @@ case class ApproximatePercentile(
override def eval(buffer: PercentileDigest): Any = {
val doubleResult = buffer.getPercentiles(percentages)
val result = child.dataType match {
- case DateType => doubleResult.map(_.toInt)
- case TimestampType | TimestampNTZType => doubleResult.map(_.toLong)
+ case DateType | _: YearMonthIntervalType => doubleResult.map(_.toInt)
+ case TimestampType | TimestampNTZType | _: DayTimeIntervalType => doubleResult.map(_.toLong)
case ByteType => doubleResult.map(_.toByte)
case ShortType => doubleResult.map(_.toShort)
case IntegerType => doubleResult.map(_.toInt)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 5bce4d3..7d3dd0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -43,12 +43,13 @@ import org.apache.spark.util.collection.OpenHashMap
* percentage values. Each percentage value must be in the range
* [0.0, 1.0].
*/
+// scalastyle:off line.size.limit
@ExpressionDescription(
usage =
"""
- _FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric column
- `col` at the given percentage. The value of percentage must be between 0.0 and 1.0. The
- value of frequency should be positive integral
+ _FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric
+ or ansi interval column `col` at the given percentage. The value of percentage must be
+ between 0.0 and 1.0. The value of frequency should be positive integral
_FUNC_(col, array(percentage1 [, percentage2]...) [, frequency]) - Returns the exact
percentile value array of numeric column `col` at the given percentage(s). Each value
@@ -62,9 +63,14 @@ import org.apache.spark.util.collection.OpenHashMap
3.0
> SELECT _FUNC_(col, array(0.25, 0.75)) FROM VALUES (0), (10) AS tab(col);
[2.5,7.5]
+ > SELECT _FUNC_(col, 0.5) FROM VALUES (INTERVAL '0' MONTH), (INTERVAL '10' MONTH) AS tab(col);
+ 5.0
+ > SELECT _FUNC_(col, array(0.2, 0.5)) FROM VALUES (INTERVAL '0' SECOND), (INTERVAL '10' SECOND) AS tab(col);
+ [2000000.0,5000000.0]
""",
group = "agg_funcs",
since = "2.1.0")
+// scalastyle:on line.size.limit
case class Percentile(
child: Expression,
percentageExpression: Expression,
@@ -118,7 +124,8 @@ case class Percentile(
case _: ArrayType => ArrayType(DoubleType, false)
case _ => DoubleType
}
- Seq(NumericType, percentageExpType, IntegralType)
+ Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType),
+ percentageExpType, IntegralType)
}
// Check the inputTypes are valid, and the percentageExpression satisfies:
@@ -191,8 +198,15 @@ case class Percentile(
return Seq.empty
}
- val sortedCounts = buffer.toSeq.sortBy(_._1)(
- child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
+ val ordering =
+ if (child.dataType.isInstanceOf[NumericType]) {
+ child.dataType.asInstanceOf[NumericType].ordering
+ } else if (child.dataType.isInstanceOf[YearMonthIntervalType]) {
+ child.dataType.asInstanceOf[YearMonthIntervalType].ordering
+ } else if (child.dataType.isInstanceOf[DayTimeIntervalType]) {
+ child.dataType.asInstanceOf[DayTimeIntervalType].ordering
+ }
+ val sortedCounts = buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]])
val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
}.tail
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
index 9d53673..a017e5b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
@@ -39,7 +39,8 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
assert(
wrongColumn.checkInputDataTypes() match {
case TypeCheckFailure(msg)
- if msg.contains("requires (numeric or timestamp or date or timestamp_ntz) type") => true
+ if msg.contains("requires (numeric or timestamp or date or timestamp_ntz or " +
+ "interval year to month or interval day to second) type") => true
case _ => false
})
}
@@ -69,7 +70,8 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Array("foobar").map(Literal(_))))
assert(wrongEndpoints.checkInputDataTypes() ==
- TypeCheckFailure("Endpoints require (numeric or timestamp or date) type"))
+ TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " +
+ "interval year to month or interval day to second) type"))
}
/** Create an ApproxCountDistinctForIntervals instance and an input and output buffer. */
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
index fa87407..b5882b1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
@@ -170,8 +170,8 @@ class PercentileSuite extends SparkFunSuite {
val child = AttributeReference("a", dataType)()
val percentile = new Percentile(child, percentage)
assertEqual(percentile.checkInputDataTypes(),
- TypeCheckFailure(s"argument 1 requires numeric type, however, " +
- s"'a' is of ${dataType.simpleString} type."))
+ TypeCheckFailure(s"argument 1 requires (numeric or interval year to month or " +
+ s"interval day to second) type, however, 'a' is of ${dataType.simpleString} type."))
}
val invalidFrequencyDataTypes = Seq(FloatType, DoubleType, BooleanType,
@@ -184,8 +184,8 @@ class PercentileSuite extends SparkFunSuite {
val frq = AttributeReference("frq", frequencyType)()
val percentile = new Percentile(child, percentage, frq)
assertEqual(percentile.checkInputDataTypes(),
- TypeCheckFailure(s"argument 1 requires numeric type, however, " +
- s"'a' is of ${dataType.simpleString} type."))
+ TypeCheckFailure(s"argument 1 requires (numeric or interval year to month or " +
+ s"interval day to second) type, however, 'a' is of ${dataType.simpleString} type."))
}
for(dataType <- validDataTypes;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
index 171e93c..53662c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.time.{Duration, Period}
+
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
@@ -58,4 +60,30 @@ class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSpa
}
}
}
+
+ test("SPARK-37138: Support Ansi Interval type in ApproxCountDistinctForIntervals") {
+ val table = "approx_count_distinct_for_ansi_intervals_tbl"
+ withTable(table) {
+ Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
+ (Period.ofMonths(200), Duration.ofSeconds(200L)),
+ (Period.ofMonths(300), Duration.ofSeconds(300L)))
+ .toDF("col1", "col2").createOrReplaceTempView(table)
+ val endpoints = (0 to 5).map(_ / 10)
+
+ val relation = spark.table(table).logicalPlan
+ val ymAttr = relation.output.find(_.name == "col1").get
+ val ymAggFunc =
+ ApproxCountDistinctForIntervals(ymAttr, CreateArray(endpoints.map(Literal(_))))
+ val ymAggExpr = ymAggFunc.toAggregateExpression()
+ val ymNamedExpr = Alias(ymAggExpr, ymAggExpr.toString)()
+
+ val dtAttr = relation.output.find(_.name == "col2").get
+ val dtAggFunc =
+ ApproxCountDistinctForIntervals(dtAttr, CreateArray(endpoints.map(Literal(_))))
+ val dtAggExpr = dtAggFunc.toAggregateExpression()
+ val dtNamedExpr = Alias(dtAggExpr, dtAggExpr.toString)()
+ val result = Dataset.ofRows(spark, Aggregate(Nil, Seq(ymNamedExpr, dtNamedExpr), relation))
+ checkAnswer(result, Row(Array(1, 1, 1, 1, 1), Array(1, 1, 1, 1, 1)))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
index 5ff15c9..9237c9e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql
import java.sql.{Date, Timestamp}
-import java.time.LocalDateTime
+import java.time.{Duration, LocalDateTime, Period}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
@@ -32,7 +32,7 @@ import org.apache.spark.sql.test.SharedSparkSession
class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession {
import testImplicits._
- private val table = "percentile_test"
+ private val table = "percentile_approx"
test("percentile_approx, single percentile value") {
withTempView(table) {
@@ -319,4 +319,22 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession
Row(18, 17, 17, 17))
}
}
+
+ test("SPARK-37138: Support Ansi Interval type in ApproximatePercentile") {
+ withTempView(table) {
+ Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
+ (Period.ofMonths(200), Duration.ofSeconds(200L)),
+ (Period.ofMonths(300), Duration.ofSeconds(300L)))
+ .toDF("col1", "col2").createOrReplaceTempView(table)
+ checkAnswer(
+ spark.sql(
+ s"""SELECT
+ | percentile_approx(col1, 0.5),
+ | SUM(null),
+ | percentile_approx(col2, 0.5)
+ |FROM $table
+ """.stripMargin),
+ Row(Period.ofMonths(200).normalized(), null, Duration.ofSeconds(200L)))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala
new file mode 100644
index 0000000..f39f0c1
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.time.{Duration, Period}
+
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * End-to-end tests for percentile aggregate function.
+ */
+class PercentileQuerySuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ private val table = "percentile_test"
+
+ test("SPARK-37138: Support Ansi Interval type in Percentile") {
+ withTempView(table) {
+ Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
+ (Period.ofMonths(200), Duration.ofSeconds(200L)),
+ (Period.ofMonths(300), Duration.ofSeconds(300L)))
+ .toDF("col1", "col2").createOrReplaceTempView(table)
+ checkAnswer(
+ spark.sql(
+ s"""SELECT
+ | CAST(percentile(col1, 0.5) AS STRING),
+ | SUM(null),
+ | CAST(percentile(col2, 0.5) AS STRING)
+ |FROM $table
+ """.stripMargin),
+ Row("200.0", null, "2.0E8"))
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org