You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/10/29 19:58:44 UTC

spark git commit: [SPARK-10641][SQL] Add Skewness and Kurtosis Support

Repository: spark
Updated Branches:
  refs/heads/master 8185f038c -> a01cbf5da


[SPARK-10641][SQL] Add Skewness and Kurtosis Support

Implementing skewness and kurtosis support based on following algorithm:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics

Author: sethah <se...@gmail.com>

Closes #9003 from sethah/SPARK-10641.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a01cbf5d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a01cbf5d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a01cbf5d

Branch: refs/heads/master
Commit: a01cbf5daac148f39cd97299780f542abc41d1e9
Parents: 8185f03
Author: sethah <se...@gmail.com>
Authored: Thu Oct 29 11:58:39 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Oct 29 11:58:39 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   5 +
 .../catalyst/analysis/HiveTypeCoercion.scala    |   5 +
 .../apache/spark/sql/catalyst/dsl/package.scala |   5 +
 .../expressions/aggregate/functions.scala       | 329 +++++++++++++++++++
 .../catalyst/expressions/aggregate/utils.scala  |  30 ++
 .../sql/catalyst/expressions/aggregates.scala   |  95 ++++++
 .../org/apache/spark/sql/GroupedData.scala      |  65 ++++
 .../scala/org/apache/spark/sql/functions.scala  | 115 ++++++-
 .../spark/sql/DataFrameAggregateSuite.scala     |  73 ++++
 .../scala/org/apache/spark/sql/QueryTest.scala  |  48 +++
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  63 +++-
 .../hive/execution/HiveCompatibilitySuite.scala |   1 -
 12 files changed, 823 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 3dce6c1..ed9fcfe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -189,6 +189,11 @@ object FunctionRegistry {
     expression[StddevPop]("stddev_pop"),
     expression[StddevSamp]("stddev_samp"),
     expression[Sum]("sum"),
+    expression[Variance]("variance"),
+    expression[VariancePop]("var_pop"),
+    expression[VarianceSamp]("var_samp"),
+    expression[Skewness]("skewness"),
+    expression[Kurtosis]("kurtosis"),
 
     // string functions
     expression[Ascii]("ascii"),

http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 1140150..3c67567 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -300,6 +300,11 @@ object HiveTypeCoercion {
       case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType))
       case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
       case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
+      case Variance(e @ StringType()) => Variance(Cast(e, DoubleType))
+      case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
+      case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
+      case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
+      case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 27b3cd8..787f67a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -162,6 +162,11 @@ package object dsl {
     def stddev(e: Expression): Expression = Stddev(e)
     def stddev_pop(e: Expression): Expression = StddevPop(e)
     def stddev_samp(e: Expression): Expression = StddevSamp(e)
+    def variance(e: Expression): Expression = Variance(e)
+    def var_pop(e: Expression): Expression = VariancePop(e)
+    def var_samp(e: Expression): Expression = VarianceSamp(e)
+    def skewness(e: Expression): Expression = Skewness(e)
+    def kurtosis(e: Expression): Expression = Kurtosis(e)
 
     implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
     // TODO more implicit class for literal?

http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 515246d..281404f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -930,3 +930,332 @@ object HyperLogLogPlusPlus {
   )
   // scalastyle:on
 }
+
+/**
+ * A central moment is the expected value of a specified power of the deviation of a random
+ * variable from the mean. Central moments are often used to characterize the properties of about
+ * the shape of a distribution.
+ *
+ * This class implements online, one-pass algorithms for computing the central moments of a set of
+ * points.
+ *
+ * Behavior:
+ *  - null values are ignored
+ *  - returns `Double.NaN` when the column contains `Double.NaN` values
+ *
+ * References:
+ *  - Xiangrui Meng.  "Simpler Online Updates for Arbitrary-Order Central Moments."
+ *      2015. http://arxiv.org/abs/1510.04923
+ *
+ * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ *     Algorithms for calculating variance (Wikipedia)]]
+ *
+ * @param child to compute central moments of.
+ */
+abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable {
+
+  /**
+   * The central moment order to be computed.
+   */
+  protected def momentOrder: Int
+
+  override def children: Seq[Expression] = Seq(child)
+
+  override def nullable: Boolean = false
+
+  override def dataType: DataType = DoubleType
+
+  // Expected input data type.
+  // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
+  // new version at planning time (after analysis phase). For now, NullType is added at here
+  // to make it resolved when we have cases like `select avg(null)`.
+  // We can use our analyzer to cast NullType to the default data type of the NumericType once
+  // we remove the old aggregate functions. Then, we will not need NullType at here.
+  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+
+  override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+
+  /**
+   * Size of aggregation buffer.
+   */
+  private[this] val bufferSize = 5
+
+  override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i =>
+    AttributeReference(s"M$i", DoubleType)()
+  }
+
+  // Note: although this simply copies aggBufferAttributes, this common code can not be placed
+  // in the superclass because that will lead to initialization ordering issues.
+  override val inputAggBufferAttributes: Seq[AttributeReference] =
+    aggBufferAttributes.map(_.newInstance())
+
+  // buffer offsets
+  private[this] val nOffset = mutableAggBufferOffset
+  private[this] val meanOffset = mutableAggBufferOffset + 1
+  private[this] val secondMomentOffset = mutableAggBufferOffset + 2
+  private[this] val thirdMomentOffset = mutableAggBufferOffset + 3
+  private[this] val fourthMomentOffset = mutableAggBufferOffset + 4
+
+  // frequently used values for online updates
+  private[this] var delta = 0.0
+  private[this] var deltaN = 0.0
+  private[this] var delta2 = 0.0
+  private[this] var deltaN2 = 0.0
+  private[this] var n = 0.0
+  private[this] var mean = 0.0
+  private[this] var m2 = 0.0
+  private[this] var m3 = 0.0
+  private[this] var m4 = 0.0
+
+  /**
+   * Initialize all moments to zero.
+   */
+  override def initialize(buffer: MutableRow): Unit = {
+    for (aggIndex <- 0 until bufferSize) {
+      buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0)
+    }
+  }
+
+  /**
+   * Update the central moments buffer.
+   */
+  override def update(buffer: MutableRow, input: InternalRow): Unit = {
+    val v = Cast(child, DoubleType).eval(input)
+    if (v != null) {
+      val updateValue = v match {
+        case d: Double => d
+      }
+
+      n = buffer.getDouble(nOffset)
+      mean = buffer.getDouble(meanOffset)
+
+      n += 1.0
+      buffer.setDouble(nOffset, n)
+      delta = updateValue - mean
+      deltaN = delta / n
+      mean += deltaN
+      buffer.setDouble(meanOffset, mean)
+
+      if (momentOrder >= 2) {
+        m2 = buffer.getDouble(secondMomentOffset)
+        m2 += delta * (delta - deltaN)
+        buffer.setDouble(secondMomentOffset, m2)
+      }
+
+      if (momentOrder >= 3) {
+        delta2 = delta * delta
+        deltaN2 = deltaN * deltaN
+        m3 = buffer.getDouble(thirdMomentOffset)
+        m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2)
+        buffer.setDouble(thirdMomentOffset, m3)
+      }
+
+      if (momentOrder >= 4) {
+        m4 = buffer.getDouble(fourthMomentOffset)
+        m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 +
+          delta * (delta * delta2 - deltaN * deltaN2)
+        buffer.setDouble(fourthMomentOffset, m4)
+      }
+    }
+  }
+
+  /**
+   * Merge two central moment buffers.
+   */
+  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+    val n1 = buffer1.getDouble(nOffset)
+    val n2 = buffer2.getDouble(inputAggBufferOffset)
+    val mean1 = buffer1.getDouble(meanOffset)
+    val mean2 = buffer2.getDouble(inputAggBufferOffset + 1)
+
+    var secondMoment1 = 0.0
+    var secondMoment2 = 0.0
+
+    var thirdMoment1 = 0.0
+    var thirdMoment2 = 0.0
+
+    var fourthMoment1 = 0.0
+    var fourthMoment2 = 0.0
+
+    n = n1 + n2
+    buffer1.setDouble(nOffset, n)
+    delta = mean2 - mean1
+    deltaN = if (n == 0.0) 0.0 else delta / n
+    mean = mean1 + deltaN * n2
+    buffer1.setDouble(mutableAggBufferOffset + 1, mean)
+
+    // higher order moments computed according to:
+    // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
+    if (momentOrder >= 2) {
+      secondMoment1 = buffer1.getDouble(secondMomentOffset)
+      secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2)
+      m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2
+      buffer1.setDouble(secondMomentOffset, m2)
+    }
+
+    if (momentOrder >= 3) {
+      thirdMoment1 = buffer1.getDouble(thirdMomentOffset)
+      thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3)
+      m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 *
+        (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1)
+      buffer1.setDouble(thirdMomentOffset, m3)
+    }
+
+    if (momentOrder >= 4) {
+      fourthMoment1 = buffer1.getDouble(fourthMomentOffset)
+      fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4)
+      m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 *
+        n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 *
+        (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) +
+        4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1)
+      buffer1.setDouble(fourthMomentOffset, m4)
+    }
+  }
+
+  /**
+   * Compute aggregate statistic from sufficient moments.
+   * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized)
+   *                       needed to compute the aggregate stat.
+   */
+  def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double
+
+  override final def eval(buffer: InternalRow): Any = {
+    val n = buffer.getDouble(nOffset)
+    val mean = buffer.getDouble(meanOffset)
+    val moments = Array.ofDim[Double](momentOrder + 1)
+    moments(0) = 1.0
+    moments(1) = 0.0
+    if (momentOrder >= 2) {
+      moments(2) = buffer.getDouble(secondMomentOffset)
+    }
+    if (momentOrder >= 3) {
+      moments(3) = buffer.getDouble(thirdMomentOffset)
+    }
+    if (momentOrder >= 4) {
+      moments(4) = buffer.getDouble(fourthMomentOffset)
+    }
+
+    getStatistic(n, mean, moments)
+  }
+}
+
+case class Variance(child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def prettyName: String = "variance"
+
+  override protected val momentOrder = 2
+
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+    require(moments.length == momentOrder + 1,
+      s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
+
+    if (n == 0.0) Double.NaN else moments(2) / n
+  }
+}
+
+case class VarianceSamp(child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def prettyName: String = "variance_samp"
+
+  override protected val momentOrder = 2
+
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+    require(moments.length == momentOrder + 1,
+      s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
+
+    if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0)
+  }
+}
+
+case class VariancePop(child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def prettyName: String = "variance_pop"
+
+  override protected val momentOrder = 2
+
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+    require(moments.length == momentOrder + 1,
+      s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
+
+    if (n == 0.0) Double.NaN else moments(2) / n
+  }
+}
+
+case class Skewness(child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def prettyName: String = "skewness"
+
+  override protected val momentOrder = 3
+
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+    require(moments.length == momentOrder + 1,
+      s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
+    val m2 = moments(2)
+    val m3 = moments(3)
+    if (n == 0.0 || m2 == 0.0) {
+      Double.NaN
+    } else {
+      math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
+    }
+  }
+}
+
+case class Kurtosis(child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def prettyName: String = "kurtosis"
+
+  override protected val momentOrder = 4
+
+  // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
+    require(moments.length == momentOrder + 1,
+      s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
+    val m2 = moments(2)
+    val m4 = moments(4)
+    if (n == 0.0 || m2 == 0.0) {
+      Double.NaN
+    } else {
+      n * m4 / (m2 * m2) - 3.0
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
index 12bdab0..c911ec5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
@@ -67,6 +67,12 @@ object Utils {
             mode = aggregate.Complete,
             isDistinct = false)
 
+        case expressions.Kurtosis(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Kurtosis(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
         case expressions.Last(child, ignoreNulls) =>
           aggregate.AggregateExpression2(
             aggregateFunction = aggregate.Last(child, ignoreNulls),
@@ -85,6 +91,12 @@ object Utils {
             mode = aggregate.Complete,
             isDistinct = false)
 
+        case expressions.Skewness(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Skewness(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
         case expressions.Stddev(child) =>
           aggregate.AggregateExpression2(
             aggregateFunction = aggregate.Stddev(child),
@@ -120,6 +132,24 @@ object Utils {
             aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd),
             mode = aggregate.Complete,
             isDistinct = false)
+
+        case expressions.Variance(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Variance(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.VariancePop(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.VariancePop(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.VarianceSamp(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.VarianceSamp(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
       }
       // Check if there is any expressions.AggregateExpression1 left.
       // If so, we cannot convert this plan.

http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 70819be..c1bab6d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -991,3 +991,98 @@ case class StddevFunction(
     }
   }
 }
+
+// placeholder
+case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression1 {
+
+  override def newInstance(): AggregateFunction1 = {
+    throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
+      "please set spark.sql.useAggregate2 = true")
+  }
+
+  override def nullable: Boolean = false
+
+  override def dataType: DoubleType.type = DoubleType
+
+  override def foldable: Boolean = false
+
+  override def prettyName: String = "kurtosis"
+
+  override def toString: String = s"KURTOSIS($child)"
+}
+
+// placeholder
+case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression1 {
+
+  override def newInstance(): AggregateFunction1 = {
+    throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
+      "please set spark.sql.useAggregate2 = true")
+  }
+
+  override def nullable: Boolean = false
+
+  override def dataType: DoubleType.type = DoubleType
+
+  override def foldable: Boolean = false
+
+  override def prettyName: String = "skewness"
+
+  override def toString: String = s"SKEWNESS($child)"
+}
+
+// placeholder
+case class Variance(child: Expression) extends UnaryExpression with AggregateExpression1 {
+
+  override def newInstance(): AggregateFunction1 = {
+    throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
+      "please set spark.sql.useAggregate2 = true")
+  }
+
+  override def nullable: Boolean = false
+
+  override def dataType: DoubleType.type = DoubleType
+
+  override def foldable: Boolean = false
+
+  override def prettyName: String = "variance"
+
+  override def toString: String = s"VARIANCE($child)"
+}
+
+// placeholder
+case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 {
+
+  override def newInstance(): AggregateFunction1 = {
+    throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
+      "please set spark.sql.useAggregate2 = true")
+  }
+
+  override def nullable: Boolean = false
+
+  override def dataType: DoubleType.type = DoubleType
+
+  override def foldable: Boolean = false
+
+  override def prettyName: String = "variance_pop"
+
+  override def toString: String = s"VAR_POP($child)"
+}
+
+// placeholder
+case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression1 {
+
+  override def newInstance(): AggregateFunction1 = {
+    throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
+      "please set spark.sql.useAggregate2 = true")
+  }
+
+  override def nullable: Boolean = false
+
+  override def dataType: DoubleType.type = DoubleType
+
+  override def foldable: Boolean = false
+
+  override def prettyName: String = "variance_samp"
+
+  override def toString: String = s"VAR_SAMP($child)"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 102b802..dc96384 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -127,7 +127,12 @@ class GroupedData protected[sql](
       case "stddev" => Stddev
       case "stddev_pop" => StddevPop
       case "stddev_samp" => StddevSamp
+      case "variance" => Variance
+      case "var_pop" => VariancePop
+      case "var_samp" => VarianceSamp
       case "sum" => Sum
+      case "skewness" => Skewness
+      case "kurtosis" => Kurtosis
       case "count" | "size" =>
         // Turn count(*) into count(1)
         (inputExpr: Expression) => inputExpr match {
@@ -251,6 +256,30 @@ class GroupedData protected[sql](
   }
 
   /**
+   * Compute the skewness for each numeric columns for each group.
+   * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the skewness values for them.
+   *
+   * @since 1.6.0
+   */
+  @scala.annotation.varargs
+  def skewness(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames : _*)(Skewness)
+  }
+
+  /**
+   * Compute the kurtosis for each numeric columns for each group.
+   * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the kurtosis values for them.
+   *
+   * @since 1.6.0
+   */
+  @scala.annotation.varargs
+  def kurtosis(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames : _*)(Kurtosis)
+  }
+
+  /**
    * Compute the max value for each numeric columns for each group.
    * The resulting [[DataFrame]] will also contain the grouping columns.
    * When specified columns are given, only compute the max values for them.
@@ -333,4 +362,40 @@ class GroupedData protected[sql](
   def sum(colNames: String*): DataFrame = {
     aggregateNumericColumns(colNames : _*)(Sum)
   }
+
+  /**
+   * Compute the sample variance for each numeric columns for each group.
+   * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the variance for them.
+   *
+   * @since 1.6.0
+   */
+  @scala.annotation.varargs
+  def variance(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames : _*)(Variance)
+  }
+
+  /**
+   * Compute the population variance for each numeric columns for each group.
+   * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the variance for them.
+   *
+   * @since 1.6.0
+   */
+  @scala.annotation.varargs
+  def var_pop(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames : _*)(VariancePop)
+  }
+
+  /**
+   * Compute the sample variance for each numeric columns for each group.
+   * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the variance for them.
+   *
+   * @since 1.6.0
+   */
+  @scala.annotation.varargs
+  def var_samp(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames : _*)(VarianceSamp)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 15c864a..c1737b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -229,6 +229,22 @@ object functions {
   def first(columnName: String): Column = first(Column(columnName))
 
   /**
+   * Aggregate function: returns the kurtosis of the values in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def kurtosis(e: Column): Column = Kurtosis(e.expr)
+
+  /**
+   * Aggregate function: returns the kurtosis of the values in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def kurtosis(columnName: String): Column = kurtosis(Column(columnName))
+
+  /**
    * Aggregate function: returns the last value in a group.
    *
    * @group agg_funcs
@@ -295,8 +311,24 @@ object functions {
   def min(columnName: String): Column = min(Column(columnName))
 
   /**
-   * Aggregate function: returns the unbiased sample standard deviation
-   * of the expression in a group.
+   * Aggregate function: returns the skewness of the values in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def skewness(e: Column): Column = Skewness(e.expr)
+
+  /**
+   * Aggregate function: returns the skewness of the values in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def skewness(columnName: String): Column = skewness(Column(columnName))
+
+  /**
+   * Aggregate function: returns the unbiased sample standard deviation of
+   * the expression in a group.
    *
    * @group agg_funcs
    * @since 1.6.0
@@ -304,13 +336,13 @@ object functions {
   def stddev(e: Column): Column = Stddev(e.expr)
 
   /**
-   * Aggregate function: returns the population standard deviation of
+   * Aggregate function: returns the unbiased sample standard deviation of
    * the expression in a group.
    *
    * @group agg_funcs
    * @since 1.6.0
    */
-  def stddev_pop(e: Column): Column = StddevPop(e.expr)
+  def stddev(columnName: String): Column = stddev(Column(columnName))
 
   /**
    * Aggregate function: returns the unbiased sample standard deviation of
@@ -322,6 +354,33 @@ object functions {
   def stddev_samp(e: Column): Column = StddevSamp(e.expr)
 
   /**
+   * Aggregate function: returns the unbiased sample standard deviation of
+   * the expression in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName))
+
+  /**
+   * Aggregate function: returns the population standard deviation of
+   * the expression in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def stddev_pop(e: Column): Column = StddevPop(e.expr)
+
+  /**
+   * Aggregate function: returns the population standard deviation of
+   * the expression in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName))
+
+  /**
    * Aggregate function: returns the sum of all values in the expression.
    *
    * @group agg_funcs
@@ -353,6 +412,54 @@ object functions {
    */
   def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))
 
+  /**
+   * Aggregate function: returns the population variance of the values in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def variance(e: Column): Column = Variance(e.expr)
+
+  /**
+   * Aggregate function: returns the population variance of the values in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def variance(columnName: String): Column = variance(Column(columnName))
+
+  /**
+   * Aggregate function: returns the unbiased variance of the values in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def var_samp(e: Column): Column = VarianceSamp(e.expr)
+
+  /**
+   * Aggregate function: returns the unbiased variance of the values in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def var_samp(columnName: String): Column = var_samp(Column(columnName))
+
+  /**
+   * Aggregate function: returns the population variance of the values in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def var_pop(e: Column): Column = VariancePop(e.expr)
+
+  /**
+   * Aggregate function: returns the population variance of the values in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def var_pop(columnName: String): Column = var_pop(Column(columnName))
+
   //////////////////////////////////////////////////////////////////////////////////////////////
   // Window functions
   //////////////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index f5ef9ff..9b23977 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -221,4 +221,77 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
       emptyTableData.agg(sumDistinct('a)),
       Row(null))
   }
+
+  test("moments") {
+    val absTol = 1e-8
+
+    val sparkVariance = testData2.agg(variance('a))
+    val expectedVariance = Row(4.0 / 6.0)
+    checkAggregatesWithTol(sparkVariance, expectedVariance, absTol)
+    val sparkVariancePop = testData2.agg(var_pop('a))
+    checkAggregatesWithTol(sparkVariancePop, expectedVariance, absTol)
+
+    val sparkVarianceSamp = testData2.agg(var_samp('a))
+    val expectedVarianceSamp = Row(4.0 / 5.0)
+    checkAggregatesWithTol(sparkVarianceSamp, expectedVarianceSamp, absTol)
+
+    val sparkSkewness = testData2.agg(skewness('a))
+    val expectedSkewness = Row(0.0)
+    checkAggregatesWithTol(sparkSkewness, expectedSkewness, absTol)
+
+    val sparkKurtosis = testData2.agg(kurtosis('a))
+    val expectedKurtosis = Row(-1.5)
+    checkAggregatesWithTol(sparkKurtosis, expectedKurtosis, absTol)
+
+  }
+
+  test("zero moments") {
+    val emptyTableData = Seq((1, 2)).toDF("a", "b")
+    assert(emptyTableData.count() === 1)
+
+    checkAnswer(
+      emptyTableData.agg(variance('a)),
+      Row(0.0))
+
+    checkAnswer(
+      emptyTableData.agg(var_samp('a)),
+      Row(Double.NaN))
+
+    checkAnswer(
+      emptyTableData.agg(var_pop('a)),
+      Row(0.0))
+
+    checkAnswer(
+      emptyTableData.agg(skewness('a)),
+      Row(Double.NaN))
+
+    checkAnswer(
+      emptyTableData.agg(kurtosis('a)),
+      Row(Double.NaN))
+  }
+
+  test("null moments") {
+    val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
+    assert(emptyTableData.count() === 0)
+
+    checkAnswer(
+      emptyTableData.agg(variance('a)),
+      Row(Double.NaN))
+
+    checkAnswer(
+      emptyTableData.agg(var_samp('a)),
+      Row(Double.NaN))
+
+    checkAnswer(
+      emptyTableData.agg(var_pop('a)),
+      Row(Double.NaN))
+
+    checkAnswer(
+      emptyTableData.agg(skewness('a)),
+      Row(Double.NaN))
+
+    checkAnswer(
+      emptyTableData.agg(kurtosis('a)),
+      Row(Double.NaN))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 73e02eb..3c174ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -135,6 +135,32 @@ abstract class QueryTest extends PlanTest {
   }
 
   /**
+   * Runs the plan and makes sure the answer is within absTol of the expected result.
+   * @param dataFrame the [[DataFrame]] to be executed
+   * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+   * @param absTol the absolute tolerance between actual and expected answers.
+   */
+  protected def checkAggregatesWithTol(dataFrame: DataFrame,
+      expectedAnswer: Seq[Row],
+      absTol: Double): Unit = {
+    // TODO: catch exceptions in data frame execution
+    val actualAnswer = dataFrame.collect()
+    require(actualAnswer.length == expectedAnswer.length,
+      s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}")
+
+    actualAnswer.zip(expectedAnswer).foreach {
+      case (actualRow, expectedRow) =>
+        QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol)
+    }
+  }
+
+  protected def checkAggregatesWithTol(dataFrame: DataFrame,
+      expectedAnswer: Row,
+      absTol: Double): Unit = {
+    checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol)
+  }
+
+  /**
    * Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
    */
   def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
@@ -214,6 +240,28 @@ object QueryTest {
     return None
   }
 
+  /**
+   * Runs the plan and makes sure the answer is within absTol of the expected result.
+   * @param actualAnswer the actual result in a [[Row]].
+   * @param expectedAnswer the expected result in a[[Row]].
+   * @param absTol the absolute tolerance between actual and expected answers.
+   */
+  protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = {
+    require(actualAnswer.length == expectedAnswer.length,
+      s"actual answer length ${actualAnswer.length} != " +
+        s"expected answer length ${expectedAnswer.length}")
+
+    // TODO: support other numeric types besides Double
+    // TODO: support struct types?
+    actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach {
+      case (actual: Double, expected: Double) =>
+        assert(math.abs(actual - expected) < absTol,
+          s"actual answer $actual not within $absTol of correct answer $expected")
+      case (actual, expected) =>
+        assert(actual == expected, s"$actual did not equal $expected")
+    }
+  }
+
   def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = {
     checkAnswer(df, expectedAnswer.asScala) match {
       case Some(errorMessage) => errorMessage

http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index f5ae3ae..5a616fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -523,8 +523,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
 
   test("aggregates with nulls") {
     checkAnswer(
-      sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"),
-      Row(1, 3, 2, 1, 6, 3)
+      sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
+        "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"),
+      Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, 1, 6, 3)
     )
   }
 
@@ -717,14 +718,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
   test("stddev") {
     checkAnswer(
       sql("SELECT STDDEV(a) FROM testData2"),
-      Row(math.sqrt(4/5.0))
+      Row(math.sqrt(4.0 / 5.0))
     )
   }
 
   test("stddev_pop") {
     checkAnswer(
       sql("SELECT STDDEV_POP(a) FROM testData2"),
-      Row(math.sqrt(4/6.0))
+      Row(math.sqrt(4.0 / 6.0))
     )
   }
 
@@ -735,10 +736,60 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     )
   }
 
+  test("var_samp") {
+    val absTol = 1e-8
+    val sparkAnswer = sql("SELECT VAR_SAMP(a) FROM testData2")
+    val expectedAnswer = Row(4.0 / 5.0)
+    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+  }
+
+  test("variance") {
+    val absTol = 1e-8
+    val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2")
+    val expectedAnswer = Row(4.0 / 6.0)
+    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+  }
+
+  test("var_pop") {
+    val absTol = 1e-8
+    val sparkAnswer = sql("SELECT VAR_POP(a) FROM testData2")
+    val expectedAnswer = Row(4.0 / 6.0)
+    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+  }
+
+  test("skewness") {
+    val absTol = 1e-8
+    val sparkAnswer = sql("SELECT skewness(a) FROM testData2")
+    val expectedAnswer = Row(0.0)
+    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+  }
+
+  test("kurtosis") {
+    val absTol = 1e-8
+    val sparkAnswer = sql("SELECT kurtosis(a) FROM testData2")
+    val expectedAnswer = Row(-1.5)
+    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+  }
+
   test("stddev agg") {
     checkAnswer(
-      sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
-      (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0))))
+        sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
+      (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0))))
+  }
+
+  test("variance agg") {
+    val absTol = 1e-8
+    val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b)" +
+      "FROM testData2 GROUP BY a")
+    val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 4.0, 1.0 / 2.0, 1.0 / 4.0))
+    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+  }
+
+  test("skewness and kurtosis agg") {
+    val absTol = 1e-8
+    val sparkAnswer = sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a")
+    val expectedAnswer = (1 to 3).map(i => Row(i, 0.0, -2.0))
+    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
   }
 
   test("inner join where, one match per row") {

http://git-wip-us.apache.org/repos/asf/spark/blob/a01cbf5d/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index eed9e43..9e357bf 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -467,7 +467,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
     "escape_orderby1",
     "escape_sortby1",
     "explain_rearrange",
-    "fetch_aggregation",
     "fileformat_mix",
     "fileformat_sequencefile",
     "fileformat_text",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org