You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2017/02/24 09:54:14 UTC
spark git commit: [SPARK-19691][SQL][BRANCH-2.1] Fix
ClassCastException when calculating percentile of decimal column
Repository: spark
Updated Branches:
refs/heads/branch-2.1 43084b3cc -> 66a7ca28a
[SPARK-19691][SQL][BRANCH-2.1] Fix ClassCastException when calculating percentile of decimal column
## What changes were proposed in this pull request?
This is a backport of the two following commits: https://github.com/apache/spark/commit/93aa4271596a30752dc5234d869c3ae2f6e8e723
This pr fixed a class-cast exception below;
```
scala> spark.range(10).selectExpr("cast (id as decimal) as x").selectExpr("percentile(x, 0.5)").collect()
java.lang.ClassCastException: org.apache.spark.sql.types.Decimal cannot be cast to java.lang.Number
at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:141)
at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:58)
at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.update(interfaces.scala:514)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:187)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:181)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:151)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:78)
at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:109)
at
```
This fix simply converts catalyst values (i.e., `Decimal`) into scala ones by using `CatalystTypeConverters`.
## How was this patch tested?
Added a test in `DataFrameSuite`.
Author: Takeshi Yamamuro <ya...@apache.org>
Closes #17046 from maropu/SPARK-19691-BACKPORT2.1.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/66a7ca28
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/66a7ca28
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/66a7ca28
Branch: refs/heads/branch-2.1
Commit: 66a7ca28a9de92e67ce24896a851a0c96c92aec6
Parents: 43084b3
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Fri Feb 24 10:54:00 2017 +0100
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Fri Feb 24 10:54:00 2017 +0100
----------------------------------------------------------------------
.../expressions/aggregate/Percentile.scala | 42 +++++++++++---------
.../expressions/aggregate/PercentileSuite.scala | 6 +--
.../org/apache/spark/sql/DataFrameSuite.scala | 5 +++
3 files changed, 31 insertions(+), 22 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/66a7ca28/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 356e088..8dd4f2c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -57,7 +57,7 @@ case class Percentile(
child: Expression,
percentageExpression: Expression,
mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] {
+ inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] {
def this(child: Expression, percentageExpression: Expression) = {
this(child, percentageExpression, 0, 0)
@@ -123,13 +123,18 @@ case class Percentile(
}
}
- override def createAggregationBuffer(): OpenHashMap[Number, Long] = {
+ private def toDoubleValue(d: Any): Double = d match {
+ case d: Decimal => d.toDouble
+ case n: Number => n.doubleValue
+ }
+
+ override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
// Initialize new counts map instance here.
- new OpenHashMap[Number, Long]()
+ new OpenHashMap[AnyRef, Long]()
}
- override def update(buffer: OpenHashMap[Number, Long], input: InternalRow): Unit = {
- val key = child.eval(input).asInstanceOf[Number]
+ override def update(buffer: OpenHashMap[AnyRef, Long], input: InternalRow): Unit = {
+ val key = child.eval(input).asInstanceOf[AnyRef]
// Null values are ignored in counts map.
if (key != null) {
@@ -137,30 +142,30 @@ case class Percentile(
}
}
- override def merge(buffer: OpenHashMap[Number, Long], other: OpenHashMap[Number, Long]): Unit = {
+ override def merge(buffer: OpenHashMap[AnyRef, Long], other: OpenHashMap[AnyRef, Long]): Unit = {
other.foreach { case (key, count) =>
buffer.changeValue(key, count, _ + count)
}
}
- override def eval(buffer: OpenHashMap[Number, Long]): Any = {
+ override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
generateOutput(getPercentiles(buffer))
}
- private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = {
+ private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = {
if (buffer.isEmpty) {
return Seq.empty
}
val sortedCounts = buffer.toSeq.sortBy(_._1)(
- child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]])
+ child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
}.tail
val maxPosition = accumlatedCounts.last._2 - 1
percentages.map { percentile =>
- getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue()
+ getPercentile(accumlatedCounts, maxPosition * percentile)
}
}
@@ -180,7 +185,7 @@ case class Percentile(
* This function has been based upon similar function from HIVE
* `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
*/
- private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = {
+ private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = {
// We may need to do linear interpolation to get the exact percentile
val lower = position.floor.toLong
val higher = position.ceil.toLong
@@ -193,18 +198,17 @@ case class Percentile(
val lowerKey = aggreCounts(lowerIndex)._1
if (higher == lower) {
// no interpolation needed because position does not have a fraction
- return lowerKey
+ return toDoubleValue(lowerKey)
}
val higherKey = aggreCounts(higherIndex)._1
if (higherKey == lowerKey) {
// no interpolation needed because lower position and higher position has the same key
- return lowerKey
+ return toDoubleValue(lowerKey)
}
// Linear interpolation to get the exact percentile
- return (higher - position) * lowerKey.doubleValue() +
- (position - lower) * higherKey.doubleValue()
+ (higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey)
}
/**
@@ -218,7 +222,7 @@ case class Percentile(
}
}
- override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = {
+ override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = {
val buffer = new Array[Byte](4 << 10) // 4K
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(bos)
@@ -241,11 +245,11 @@ case class Percentile(
}
}
- override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = {
+ override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = {
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(bis)
try {
- val counts = new OpenHashMap[Number, Long]
+ val counts = new OpenHashMap[AnyRef, Long]
// Read unsafeRow size and content in bytes.
var sizeOfNextRow = ins.readInt()
while (sizeOfNextRow >= 0) {
@@ -254,7 +258,7 @@ case class Percentile(
val row = new UnsafeRow(2)
row.pointTo(bs, sizeOfNextRow)
// Insert the pairs into counts map.
- val key = row.get(0, child.dataType).asInstanceOf[Number]
+ val key = row.get(0, child.dataType)
val count = row.get(1, LongType).asInstanceOf[Long]
counts.update(key, count)
sizeOfNextRow = ins.readInt()
http://git-wip-us.apache.org/repos/asf/spark/blob/66a7ca28/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
index f060ecc..d7c2527 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
@@ -38,12 +38,12 @@ class PercentileSuite extends SparkFunSuite {
val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5))
// Check empty serialize and deserialize
- val buffer = new OpenHashMap[Number, Long]()
+ val buffer = new OpenHashMap[AnyRef, Long]()
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
// Check non-empty buffer serializa and deserialize.
data.foreach { key =>
- buffer.changeValue(key, 1L, _ + 1L)
+ buffer.changeValue(new Integer(key), 1L, _ + 1L)
}
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
}
@@ -233,7 +233,7 @@ class PercentileSuite extends SparkFunSuite {
}
private def compareEquals(
- left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = {
+ left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = {
left.size == right.size && left.forall { case (key, count) =>
right.apply(key) == count
}
http://git-wip-us.apache.org/repos/asf/spark/blob/66a7ca28/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 312cd17..22dfc46 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1734,4 +1734,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema)
assert(df.filter($"array1" === $"array2").count() == 1)
}
+
+ test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") {
+ val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)")
+ checkAnswer(df, Row(BigDecimal(0.0)) :: Nil)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org