You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2017/02/24 09:54:14 UTC

spark git commit: [SPARK-19691][SQL][BRANCH-2.1] Fix ClassCastException when calculating percentile of decimal column

Repository: spark
Updated Branches:
  refs/heads/branch-2.1 43084b3cc -> 66a7ca28a


[SPARK-19691][SQL][BRANCH-2.1] Fix ClassCastException when calculating percentile of decimal column

## What changes were proposed in this pull request?
This is a backport of the two following commits: https://github.com/apache/spark/commit/93aa4271596a30752dc5234d869c3ae2f6e8e723

This pr fixed a class-cast exception below;
```
scala> spark.range(10).selectExpr("cast (id as decimal) as x").selectExpr("percentile(x, 0.5)").collect()
 java.lang.ClassCastException: org.apache.spark.sql.types.Decimal cannot be cast to java.lang.Number
	at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:141)
	at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:58)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.update(interfaces.scala:514)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:187)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:181)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:151)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:78)
	at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:109)
	at
```
This fix simply converts catalyst values (i.e., `Decimal`) into scala ones by using `CatalystTypeConverters`.

## How was this patch tested?
Added a test in `DataFrameSuite`.

Author: Takeshi Yamamuro <ya...@apache.org>

Closes #17046 from maropu/SPARK-19691-BACKPORT2.1.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/66a7ca28
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/66a7ca28
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/66a7ca28

Branch: refs/heads/branch-2.1
Commit: 66a7ca28a9de92e67ce24896a851a0c96c92aec6
Parents: 43084b3
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Fri Feb 24 10:54:00 2017 +0100
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Fri Feb 24 10:54:00 2017 +0100

----------------------------------------------------------------------
 .../expressions/aggregate/Percentile.scala      | 42 +++++++++++---------
 .../expressions/aggregate/PercentileSuite.scala |  6 +--
 .../org/apache/spark/sql/DataFrameSuite.scala   |  5 +++
 3 files changed, 31 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/66a7ca28/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 356e088..8dd4f2c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -57,7 +57,7 @@ case class Percentile(
   child: Expression,
   percentageExpression: Expression,
   mutableAggBufferOffset: Int = 0,
-  inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] {
+  inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] {
 
   def this(child: Expression, percentageExpression: Expression) = {
     this(child, percentageExpression, 0, 0)
@@ -123,13 +123,18 @@ case class Percentile(
     }
   }
 
-  override def createAggregationBuffer(): OpenHashMap[Number, Long] = {
+  private def toDoubleValue(d: Any): Double = d match {
+    case d: Decimal => d.toDouble
+    case n: Number => n.doubleValue
+  }
+
+  override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
     // Initialize new counts map instance here.
-    new OpenHashMap[Number, Long]()
+    new OpenHashMap[AnyRef, Long]()
   }
 
-  override def update(buffer: OpenHashMap[Number, Long], input: InternalRow): Unit = {
-    val key = child.eval(input).asInstanceOf[Number]
+  override def update(buffer: OpenHashMap[AnyRef, Long], input: InternalRow): Unit = {
+    val key = child.eval(input).asInstanceOf[AnyRef]
 
     // Null values are ignored in counts map.
     if (key != null) {
@@ -137,30 +142,30 @@ case class Percentile(
     }
   }
 
-  override def merge(buffer: OpenHashMap[Number, Long], other: OpenHashMap[Number, Long]): Unit = {
+  override def merge(buffer: OpenHashMap[AnyRef, Long], other: OpenHashMap[AnyRef, Long]): Unit = {
     other.foreach { case (key, count) =>
       buffer.changeValue(key, count, _ + count)
     }
   }
 
-  override def eval(buffer: OpenHashMap[Number, Long]): Any = {
+  override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
     generateOutput(getPercentiles(buffer))
   }
 
-  private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = {
+  private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = {
     if (buffer.isEmpty) {
       return Seq.empty
     }
 
     val sortedCounts = buffer.toSeq.sortBy(_._1)(
-      child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]])
+      child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
     val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {
       case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
     }.tail
     val maxPosition = accumlatedCounts.last._2 - 1
 
     percentages.map { percentile =>
-      getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue()
+      getPercentile(accumlatedCounts, maxPosition * percentile)
     }
   }
 
@@ -180,7 +185,7 @@ case class Percentile(
    * This function has been based upon similar function from HIVE
    * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
    */
-  private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = {
+  private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = {
     // We may need to do linear interpolation to get the exact percentile
     val lower = position.floor.toLong
     val higher = position.ceil.toLong
@@ -193,18 +198,17 @@ case class Percentile(
     val lowerKey = aggreCounts(lowerIndex)._1
     if (higher == lower) {
       // no interpolation needed because position does not have a fraction
-      return lowerKey
+      return toDoubleValue(lowerKey)
     }
 
     val higherKey = aggreCounts(higherIndex)._1
     if (higherKey == lowerKey) {
       // no interpolation needed because lower position and higher position has the same key
-      return lowerKey
+      return toDoubleValue(lowerKey)
     }
 
     // Linear interpolation to get the exact percentile
-    return (higher - position) * lowerKey.doubleValue() +
-      (position - lower) * higherKey.doubleValue()
+    (higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey)
   }
 
   /**
@@ -218,7 +222,7 @@ case class Percentile(
     }
   }
 
-  override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = {
+  override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = {
     val buffer = new Array[Byte](4 << 10)  // 4K
     val bos = new ByteArrayOutputStream()
     val out = new DataOutputStream(bos)
@@ -241,11 +245,11 @@ case class Percentile(
     }
   }
 
-  override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = {
+  override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = {
     val bis = new ByteArrayInputStream(bytes)
     val ins = new DataInputStream(bis)
     try {
-      val counts = new OpenHashMap[Number, Long]
+      val counts = new OpenHashMap[AnyRef, Long]
       // Read unsafeRow size and content in bytes.
       var sizeOfNextRow = ins.readInt()
       while (sizeOfNextRow >= 0) {
@@ -254,7 +258,7 @@ case class Percentile(
         val row = new UnsafeRow(2)
         row.pointTo(bs, sizeOfNextRow)
         // Insert the pairs into counts map.
-        val key = row.get(0, child.dataType).asInstanceOf[Number]
+        val key = row.get(0, child.dataType)
         val count = row.get(1, LongType).asInstanceOf[Long]
         counts.update(key, count)
         sizeOfNextRow = ins.readInt()

http://git-wip-us.apache.org/repos/asf/spark/blob/66a7ca28/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
index f060ecc..d7c2527 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
@@ -38,12 +38,12 @@ class PercentileSuite extends SparkFunSuite {
     val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5))
 
     // Check empty serialize and deserialize
-    val buffer = new OpenHashMap[Number, Long]()
+    val buffer = new OpenHashMap[AnyRef, Long]()
     assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
 
     // Check non-empty buffer serializa and deserialize.
     data.foreach { key =>
-      buffer.changeValue(key, 1L, _ + 1L)
+      buffer.changeValue(new Integer(key), 1L, _ + 1L)
     }
     assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
   }
@@ -233,7 +233,7 @@ class PercentileSuite extends SparkFunSuite {
   }
 
   private def compareEquals(
-      left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = {
+      left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = {
     left.size == right.size && left.forall { case (key, count) =>
       right.apply(key) == count
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/66a7ca28/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 312cd17..22dfc46 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1734,4 +1734,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema)
     assert(df.filter($"array1" === $"array2").count() == 1)
   }
+
+  test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") {
+    val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)")
+    checkAnswer(df, Row(BigDecimal(0.0)) :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org