You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2017/02/07 13:05:29 UTC

spark git commit: [SPARK-19118][SQL] Percentile support for frequency distribution table

Repository: spark
Updated Branches:
  refs/heads/master 3d314d08c -> e99e34d0f


[SPARK-19118][SQL] Percentile support for frequency distribution table

## What changes were proposed in this pull request?

I have a frequency distribution table with following entries
Age,    No of person
21, 10
22, 15
23, 18
..
..
30, 14
Moreover it is common to have data in frequency distribution format to further calculate Percentile, Median. With current implementation
It would be very difficult and complex to find the percentile.
Therefore i am proposing enhancement to current Percentile and Approx Percentile implementation to take frequency distribution column into consideration

## How was this patch tested?
1) Enhanced /sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala to cover the additional functionality
2) Run some performance benchmark test with 20 million row in local environment and did not see any performance degradation

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: gagan taneja <ta...@gagans-MacBook-Pro.local>

Closes #16497 from tanejagagan/branch-18940.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e99e34d0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e99e34d0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e99e34d0

Branch: refs/heads/master
Commit: e99e34d0f370211a7c7b96d144cc932b2fc71d10
Parents: 3d314d0
Author: gagan taneja <ta...@gagans-MacBook-Pro.local>
Authored: Tue Feb 7 14:05:22 2017 +0100
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Tue Feb 7 14:05:22 2017 +0100

----------------------------------------------------------------------
 .../expressions/aggregate/Percentile.scala      |  47 ++++--
 .../expressions/aggregate/PercentileSuite.scala | 149 +++++++++++++------
 2 files changed, 141 insertions(+), 55 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e99e34d0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 5b4ce47..6b7cf79 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.OpenHashMap
+import org.apache.spark.SparkException
 
 /**
  * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at
@@ -44,22 +45,30 @@ import org.apache.spark.util.collection.OpenHashMap
 @ExpressionDescription(
   usage =
     """
-      _FUNC_(col, percentage) - Returns the exact percentile value of numeric column `col` at the
-      given percentage. The value of percentage must be between 0.0 and 1.0.
+      _FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric column
+       `col` at the given percentage. The value of percentage must be between 0.0 and 1.0. The
+       value of frequency should be positive integral
 
-      _FUNC_(col, array(percentage1 [, percentage2]...)) - Returns the exact percentile value array
-      of numeric column `col` at the given percentage(s). Each value of the percentage array must
-      be between 0.0 and 1.0.
-    """)
+      _FUNC_(col, array(percentage1 [, percentage2]...) [, frequency]) - Returns the exact
+      percentile value array of numeric column `col` at the given percentage(s). Each value
+      of the percentage array must be between 0.0 and 1.0. The value of frequency should be
+      positive integral
+
+      """)
 case class Percentile(
     child: Expression,
     percentageExpression: Expression,
+    frequencyExpression : Expression,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
   extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes {
 
   def this(child: Expression, percentageExpression: Expression) = {
-    this(child, percentageExpression, 0, 0)
+    this(child, percentageExpression, Literal(1L), 0, 0)
+  }
+
+  def this(child: Expression, percentageExpression: Expression, frequency: Expression) = {
+    this(child, percentageExpression, frequency, 0, 0)
   }
 
   override def prettyName: String = "percentile"
@@ -80,7 +89,9 @@ case class Percentile(
       case arrayData: ArrayData => arrayData.toDoubleArray().toSeq
   }
 
-  override def children: Seq[Expression] = child :: percentageExpression :: Nil
+  override def children: Seq[Expression] = {
+    child  :: percentageExpression ::frequencyExpression :: Nil
+  }
 
   // Returns null for empty inputs
   override def nullable: Boolean = true
@@ -90,9 +101,12 @@ case class Percentile(
     case _ => DoubleType
   }
 
-  override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match {
-    case _: ArrayType => Seq(NumericType, ArrayType(DoubleType))
-    case _ => Seq(NumericType, DoubleType)
+  override def inputTypes: Seq[AbstractDataType] = {
+    val percentageExpType = percentageExpression.dataType match {
+      case _: ArrayType => ArrayType(DoubleType)
+      case _ => DoubleType
+    }
+    Seq(NumericType, percentageExpType, IntegralType)
   }
 
   // Check the inputTypes are valid, and the percentageExpression satisfies:
@@ -125,10 +139,17 @@ case class Percentile(
       buffer: OpenHashMap[Number, Long],
       input: InternalRow): OpenHashMap[Number, Long] = {
     val key = child.eval(input).asInstanceOf[Number]
+    val frqValue = frequencyExpression.eval(input)
 
     // Null values are ignored in counts map.
-    if (key != null) {
-      buffer.changeValue(key, 1L, _ + 1L)
+    if (key != null && frqValue != null) {
+      val frqLong = frqValue.asInstanceOf[Number].longValue()
+      // add only when frequency is positive
+      if (frqLong > 0) {
+        buffer.changeValue(key, frqLong, _ + frqLong)
+      } else if (frqLong < 0) {
+        throw new SparkException(s"Negative values found in ${frequencyExpression.sql}")
+      }
     }
     buffer
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e99e34d0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
index f060ecc..1533fe5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import org.apache.spark.SparkException
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
@@ -50,25 +51,50 @@ class PercentileSuite extends SparkFunSuite {
 
   test("class Percentile, high level interface, update, merge, eval...") {
     val count = 10000
-    val data = (1 to count)
     val percentages = Seq(0, 0.25, 0.5, 0.75, 1)
     val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000)
     val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType)
     val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_)))
     val agg = new Percentile(childExpression, percentageExpression)
 
+    // Test with rows without frequency
+    val rows = (1 to count).map( x => Seq(x))
+    runTest( agg, rows, expectedPercentiles)
+
+    // Test with row with frequency. Second and third columns are frequency in Int and Long
+    val countForFrequencyTest = 1000
+    val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong)
+    val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0)
+
+    val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false)
+    val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt)
+    runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
+
+    val frequencyExpressionLong = BoundReference(2, LongType, nullable = false)
+    val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong)
+    runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
+
+    // Run test with Flatten data
+    val flattenRows = (1 to countForFrequencyTest).flatMap( current =>
+      (1 to current).map( y => current )).map( Seq(_))
+    runTest(agg, flattenRows, expectedPercentilesWithFrquency)
+  }
+
+  private def runTest(agg: Percentile,
+        rows : Seq[Seq[Any]],
+        expectedPercentiles : Seq[Double]) {
     assert(agg.nullable)
-    val group1 = (0 until data.length / 2)
+    val group1 = (0 until rows.length / 2)
     val group1Buffer = agg.createAggregationBuffer()
     group1.foreach { index =>
-      val input = InternalRow(data(index))
+      val input = InternalRow(rows(index): _*)
       agg.update(group1Buffer, input)
     }
 
-    val group2 = (data.length / 2 until data.length)
+    val group2 = (rows.length / 2 until rows.length)
     val group2Buffer = agg.createAggregationBuffer()
     group2.foreach { index =>
-      val input = InternalRow(data(index))
+      val input = InternalRow(rows(index): _*)
       agg.update(group2Buffer, input)
     }
 
@@ -116,40 +142,6 @@ class PercentileSuite extends SparkFunSuite {
     assert(agg.eval(mutableAggBuffer).asInstanceOf[Double] == expectedPercentile)
   }
 
-  test("call from sql query") {
-    // sql, single percentile
-    assertEqual(
-      s"percentile(`a`, 0.5D)",
-      new Percentile("a".attr, Literal(0.5)).sql: String)
-
-    // sql, array of percentile
-    assertEqual(
-      s"percentile(`a`, array(0.25D, 0.5D, 0.75D))",
-      new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_)))).sql: String)
-
-    // sql(isDistinct = false), single percentile
-    assertEqual(
-      s"percentile(`a`, 0.5D)",
-      new Percentile("a".attr, Literal(0.5)).sql(isDistinct = false))
-
-    // sql(isDistinct = false), array of percentile
-    assertEqual(
-      s"percentile(`a`, array(0.25D, 0.5D, 0.75D))",
-      new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_))))
-        .sql(isDistinct = false))
-
-    // sql(isDistinct = true), single percentile
-    assertEqual(
-      s"percentile(DISTINCT `a`, 0.5D)",
-      new Percentile("a".attr, Literal(0.5)).sql(isDistinct = true))
-
-    // sql(isDistinct = true), array of percentile
-    assertEqual(
-      s"percentile(DISTINCT `a`, array(0.25D, 0.5D, 0.75D))",
-      new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_))))
-        .sql(isDistinct = true))
-  }
-
   test("fail analysis if childExpression is invalid") {
     val validDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
     val percentage = Literal(0.5)
@@ -160,6 +152,15 @@ class PercentileSuite extends SparkFunSuite {
       assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess)
     }
 
+    val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType)
+    for ( dataType <- validDataTypes;
+      frequencyType <- validFrequencyTypes)  {
+      val child = AttributeReference("a", dataType)()
+      val frq = AttributeReference("frq", frequencyType)()
+      val percentile = new Percentile(child, percentage, frq)
+      assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess)
+    }
+
     val invalidDataTypes = Seq(BooleanType, StringType, DateType, TimestampType,
       CalendarIntervalType, NullType)
 
@@ -170,6 +171,30 @@ class PercentileSuite extends SparkFunSuite {
         TypeCheckFailure(s"argument 1 requires numeric type, however, " +
             s"'`a`' is of ${dataType.simpleString} type."))
     }
+
+    val invalidFrequencyDataTypes = Seq(FloatType, DoubleType, BooleanType,
+        StringType, DateType, TimestampType,
+      CalendarIntervalType, NullType)
+
+    for( dataType <- invalidDataTypes;
+        frequencyType <- validFrequencyTypes) {
+      val child = AttributeReference("a", dataType)()
+      val frq = AttributeReference("frq", frequencyType)()
+      val percentile = new Percentile(child, percentage, frq)
+      assertEqual(percentile.checkInputDataTypes(),
+        TypeCheckFailure(s"argument 1 requires numeric type, however, " +
+            s"'`a`' is of ${dataType.simpleString} type."))
+    }
+
+    for( dataType <- validDataTypes;
+        frequencyType <- invalidFrequencyDataTypes) {
+      val child = AttributeReference("a", dataType)()
+      val frq = AttributeReference("frq", frequencyType)()
+      val percentile = new Percentile(child, percentage, frq)
+      assertEqual(percentile.checkInputDataTypes(),
+        TypeCheckFailure(s"argument 3 requires integral type, however, " +
+            s"'`frq`' is of ${frequencyType.simpleString} type."))
+    }
   }
 
   test("fails analysis if percentage(s) are invalid") {
@@ -217,19 +242,59 @@ class PercentileSuite extends SparkFunSuite {
   }
 
   test("null handling") {
+
+    // Percentile without frequency column
     val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
     val agg = new Percentile(childExpression, Literal(0.5))
     val buffer = new GenericInternalRow(new Array[Any](1))
     agg.initialize(buffer)
+
     // Empty aggregation buffer
     assert(agg.eval(buffer) == null)
+
     // Empty input row
     agg.update(buffer, InternalRow(null))
     assert(agg.eval(buffer) == null)
 
-    // Add some non-empty row
-    agg.update(buffer, InternalRow(0))
-    assert(agg.eval(buffer) != null)
+    // Percentile with Frequency column
+    val frequencyExpression = Cast(BoundReference(1, IntegerType, nullable = true), IntegerType)
+    val aggWithFrequency = new Percentile(childExpression, Literal(0.5), frequencyExpression)
+    val bufferWithFrequency = new GenericInternalRow(new Array[Any](2))
+    aggWithFrequency.initialize(bufferWithFrequency)
+
+    // Empty aggregation buffer
+    assert(aggWithFrequency.eval(bufferWithFrequency) == null)
+    // Empty input row
+    aggWithFrequency.update(bufferWithFrequency, InternalRow(null, null))
+    assert(aggWithFrequency.eval(bufferWithFrequency) == null)
+
+    // Add some non-empty row with empty frequency column
+    aggWithFrequency.update(bufferWithFrequency, InternalRow(0, null))
+    assert(aggWithFrequency.eval(bufferWithFrequency) == null)
+
+    // Add some non-empty row with zero frequency
+    aggWithFrequency.update(bufferWithFrequency, InternalRow(1, 0))
+    assert(aggWithFrequency.eval(bufferWithFrequency) == null)
+
+    // Add some non-empty row with positive frequency
+    aggWithFrequency.update(bufferWithFrequency, InternalRow(0, 1))
+    assert(aggWithFrequency.eval(bufferWithFrequency) != null)
+  }
+
+  test("negatives frequency column handling") {
+    val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+    val freqExpression = Cast(BoundReference(1, IntegerType, nullable = true), IntegerType)
+    val agg = new Percentile(childExpression, Literal(0.5), freqExpression)
+    val buffer = new GenericInternalRow(new Array[Any](2))
+    agg.initialize(buffer)
+
+    val caught =
+      intercept[SparkException]{
+        // Add some non-empty row with negative frequency
+        agg.update(buffer, InternalRow(1, -5))
+        agg.eval(buffer)
+      }
+    assert( caught.getMessage.startsWith("Negative values found in "))
   }
 
   private def compareEquals(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org