You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/09/12 19:17:21 UTC

spark git commit: [SPARK-6548] Adding stddev to DataFrame functions

Repository: spark
Updated Branches:
  refs/heads/master 22730ad54 -> f4a22808e


[SPARK-6548] Adding stddev to DataFrame functions

Adding STDDEV support for DataFrame using 1-pass online /parallel algorithm to compute variance. Please review the code change.

Author: JihongMa <li...@gmail.com>
Author: Jihong MA <li...@gmail.com>
Author: Jihong MA <ji...@jihongs-mbp.usca.ibm.com>
Author: Jihong MA <ji...@Jihongs-MacBook-Pro.local>

Closes #6297 from JihongMA/SPARK-SQL.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f4a22808
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f4a22808
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f4a22808

Branch: refs/heads/master
Commit: f4a22808e03fa12bfe1bfc82cf713cfda7e063a9
Parents: 22730ad
Author: JihongMa <li...@gmail.com>
Authored: Sat Sep 12 10:17:15 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Sat Sep 12 10:17:15 2015 -0700

----------------------------------------------------------------------
 R/pkg/inst/tests/test_sparkSQL.R                |   2 +-
 python/pyspark/sql/dataframe.py                 |  36 +--
 .../catalyst/analysis/FunctionRegistry.scala    |   3 +
 .../catalyst/analysis/HiveTypeCoercion.scala    |   3 +
 .../apache/spark/sql/catalyst/dsl/package.scala |   3 +
 .../expressions/aggregate/functions.scala       | 143 +++++++++++
 .../catalyst/expressions/aggregate/utils.scala  |  18 ++
 .../sql/catalyst/expressions/aggregates.scala   | 245 +++++++++++++++++++
 .../scala/org/apache/spark/sql/DataFrame.scala  |   6 +-
 .../org/apache/spark/sql/GroupedData.scala      |  39 +++
 .../scala/org/apache/spark/sql/functions.scala  |  27 ++
 .../apache/spark/sql/JavaDataFrameSuite.java    |   1 +
 .../spark/sql/DataFrameAggregateSuite.scala     |  33 +++
 .../org/apache/spark/sql/DataFrameSuite.scala   |   2 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  42 +++-
 .../hive/execution/AggregationQuerySuite.scala  |  35 ---
 16 files changed, 574 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/R/pkg/inst/tests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 1ccfde5..98d4402 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -1147,7 +1147,7 @@ test_that("describe() and summarize() on a DataFrame", {
   stats <- describe(df, "age")
   expect_equal(collect(stats)[1, "summary"], "count")
   expect_equal(collect(stats)[2, "age"], "24.5")
-  expect_equal(collect(stats)[3, "age"], "5.5")
+  expect_equal(collect(stats)[3, "age"], "7.7781745930520225")
   stats <- describe(df)
   expect_equal(collect(stats)[4, "name"], "Andy")
   expect_equal(collect(stats)[5, "age"], "30")

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index c5bf557..fb995fa 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -653,25 +653,25 @@ class DataFrame(object):
         guarantee about the backward compatibility of the schema of the resulting DataFrame.
 
         >>> df.describe().show()
-        +-------+---+
-        |summary|age|
-        +-------+---+
-        |  count|  2|
-        |   mean|3.5|
-        | stddev|1.5|
-        |    min|  2|
-        |    max|  5|
-        +-------+---+
+        +-------+------------------+
+        |summary|               age|
+        +-------+------------------+
+        |  count|                 2|
+        |   mean|               3.5|
+        | stddev|2.1213203435596424|
+        |    min|                 2|
+        |    max|                 5|
+        +-------+------------------+
         >>> df.describe(['age', 'name']).show()
-        +-------+---+-----+
-        |summary|age| name|
-        +-------+---+-----+
-        |  count|  2|    2|
-        |   mean|3.5| null|
-        | stddev|1.5| null|
-        |    min|  2|Alice|
-        |    max|  5|  Bob|
-        +-------+---+-----+
+        +-------+------------------+-----+
+        |summary|               age| name|
+        +-------+------------------+-----+
+        |  count|                 2|    2|
+        |   mean|               3.5| null|
+        | stddev|2.1213203435596424| null|
+        |    min|                 2|Alice|
+        |    max|                 5|  Bob|
+        +-------+------------------+-----+
         """
         if len(cols) == 1 and isinstance(cols[0], list):
             cols = cols[0]

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index cd5a90d..11b4866 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -168,6 +168,9 @@ object FunctionRegistry {
     expression[Last]("last"),
     expression[Max]("max"),
     expression[Min]("min"),
+    expression[Stddev]("stddev"),
+    expression[StddevPop]("stddev_pop"),
+    expression[StddevSamp]("stddev_samp"),
     expression[Sum]("sum"),
 
     // string functions

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 87c11ab..87a3845 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -297,6 +297,9 @@ object HiveTypeCoercion {
       case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
       case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
       case Average(e @ StringType()) => Average(Cast(e, DoubleType))
+      case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType))
+      case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
+      case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index a7e3a49..699c4cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -159,6 +159,9 @@ package object dsl {
     def lower(e: Expression): Expression = Lower(e)
     def sqrt(e: Expression): Expression = Sqrt(e)
     def abs(e: Expression): Expression = Abs(e)
+    def stddev(e: Expression): Expression = Stddev(e)
+    def stddev_pop(e: Expression): Expression = StddevPop(e)
+    def stddev_samp(e: Expression): Expression = StddevSamp(e)
 
     implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
     // TODO more implicit class for literal?

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index a73024d..02cd0ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -249,6 +249,149 @@ case class Min(child: Expression) extends AlgebraicAggregate {
   override val evaluateExpression = min
 }
 
+// Compute the sample standard deviation of a column
+case class Stddev(child: Expression) extends StddevAgg(child) {
+
+  override def isSample: Boolean = true
+  override def prettyName: String = "stddev"
+}
+
+// Compute the population standard deviation of a column
+case class StddevPop(child: Expression) extends StddevAgg(child) {
+
+  override def isSample: Boolean = false
+  override def prettyName: String = "stddev_pop"
+}
+
+// Compute the sample standard deviation of a column
+case class StddevSamp(child: Expression) extends StddevAgg(child) {
+
+  override def isSample: Boolean = true
+  override def prettyName: String = "stddev_samp"
+}
+
+// Compute standard deviation based on online algorithm specified here:
+// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+abstract class StddevAgg(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  def isSample: Boolean
+
+  // Return data type.
+  override def dataType: DataType = resultType
+
+  // Expected input data type.
+  // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
+  // new version at planning time (after analysis phase). For now, NullType is added at here
+  // to make it resolved when we have cases like `select stddev(null)`.
+  // We can use our analyzer to cast NullType to the default data type of the NumericType once
+  // we remove the old aggregate functions. Then, we will not need NullType at here.
+  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+
+  private val resultType = DoubleType
+
+  private val preCount = AttributeReference("preCount", resultType)()
+  private val currentCount = AttributeReference("currentCount", resultType)()
+  private val preAvg = AttributeReference("preAvg", resultType)()
+  private val currentAvg = AttributeReference("currentAvg", resultType)()
+  private val currentMk = AttributeReference("currentMk", resultType)()
+
+  override val bufferAttributes = preCount :: currentCount :: preAvg ::
+                                  currentAvg :: currentMk :: Nil
+
+  override val initialValues = Seq(
+    /* preCount = */ Cast(Literal(0), resultType),
+    /* currentCount = */ Cast(Literal(0), resultType),
+    /* preAvg = */ Cast(Literal(0), resultType),
+    /* currentAvg = */ Cast(Literal(0), resultType),
+    /* currentMk = */ Cast(Literal(0), resultType)
+  )
+
+  override val updateExpressions = {
+
+    // update average
+    // avg = avg + (value - avg)/count
+    def avgAdd: Expression = {
+      currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount)
+    }
+
+    // update sum of square of difference from mean
+    // Mk = Mk + (value - preAvg) * (value - updatedAvg)
+    def mkAdd: Expression = {
+      val delta1 = Cast(child, resultType) - preAvg
+      val delta2 = Cast(child, resultType) - currentAvg
+      currentMk + (delta1 * delta2)
+    }
+
+    Seq(
+      /* preCount = */ If(IsNull(child), preCount, currentCount),
+      /* currentCount = */ If(IsNull(child), currentCount,
+                           Add(currentCount, Cast(Literal(1), resultType))),
+      /* preAvg = */ If(IsNull(child), preAvg, currentAvg),
+      /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd),
+      /* currentMk = */ If(IsNull(child), currentMk, mkAdd)
+    )
+  }
+
+  override val mergeExpressions = {
+
+    // count merge
+    def countMerge: Expression = {
+      currentCount.left + currentCount.right
+    }
+
+    // average merge
+    def avgMerge: Expression = {
+      ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) /
+      (preCount + currentCount.right)
+    }
+
+    // update sum of square differences
+    def mkMerge: Expression = {
+      val avgDelta = currentAvg.right - preAvg
+      val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) /
+        (preCount + currentCount.right)
+
+      currentMk.left + currentMk.right + mkDelta
+    }
+
+    Seq(
+      /* preCount = */ If(IsNull(currentCount.left),
+                         Cast(Literal(0), resultType), currentCount.left),
+      /* currentCount = */ If(IsNull(currentCount.left), currentCount.right,
+                             If(IsNull(currentCount.right), currentCount.left, countMerge)),
+      /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left),
+      /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right,
+                           If(IsNull(currentAvg.right), currentAvg.left, avgMerge)),
+      /* currentMk = */ If(IsNull(currentMk.left), currentMk.right,
+                          If(IsNull(currentMk.right), currentMk.left, mkMerge))
+    )
+  }
+
+  override val evaluateExpression = {
+    // when currentCount == 0, return null
+    // when currentCount == 1, return 0
+    // when currentCount >1
+    // stddev_samp = sqrt (currentMk/(currentCount -1))
+    // stddev_pop = sqrt (currentMk/currentCount)
+    val varCol = {
+      if (isSample) {
+        currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType)
+      }
+      else {
+        currentMk / currentCount
+      }
+    }
+
+    If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
+      If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
+        Cast(Sqrt(varCol), resultType)))
+  }
+}
+
 case class Sum(child: Expression) extends AlgebraicAggregate {
 
   override def children: Seq[Expression] = child :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
index 4a43318..ce3ddda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
@@ -85,6 +85,24 @@ object Utils {
             mode = aggregate.Complete,
             isDistinct = false)
 
+        case expressions.Stddev(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Stddev(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.StddevPop(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.StddevPop(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.StddevSamp(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.StddevSamp(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
         case expressions.Sum(child) =>
           aggregate.AggregateExpression2(
             aggregateFunction = aggregate.Sum(child),

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 5e8298a..f1c47f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -691,3 +691,248 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag
     result
   }
 }
+
+// Compute standard deviation based on online algorithm specified here:
+// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 {
+  override def nullable: Boolean = true
+  override def dataType: DataType = DoubleType
+
+  def isSample: Boolean
+
+  override def asPartial: SplitEvaluation = {
+    val partialStd = Alias(ComputePartialStd(child), "PartialStddev")()
+    SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil)
+  }
+
+  override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample)
+
+  override def checkInputDataTypes(): TypeCheckResult =
+    TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
+
+}
+
+// Compute the sample standard deviation of a column
+case class Stddev(child: Expression) extends StddevAgg1(child) {
+
+  override def toString: String = s"STDDEV($child)"
+  override def isSample: Boolean = true
+}
+
+// Compute the population standard deviation of a column
+case class StddevPop(child: Expression) extends StddevAgg1(child) {
+
+  override def toString: String = s"STDDEV_POP($child)"
+  override def isSample: Boolean = false
+}
+
+// Compute the sample standard deviation of a column
+case class StddevSamp(child: Expression) extends StddevAgg1(child) {
+
+  override def toString: String = s"STDDEV_SAMP($child)"
+  override def isSample: Boolean = true
+}
+
+case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 {
+    def this() = this(null)
+
+    override def children: Seq[Expression] = child :: Nil
+    override def nullable: Boolean = false
+    override def dataType: DataType = ArrayType(DoubleType)
+    override def toString: String = s"computePartialStddev($child)"
+    override def newInstance(): ComputePartialStdFunction =
+      new ComputePartialStdFunction(child, this)
+}
+
+case class ComputePartialStdFunction (
+    expr: Expression,
+    base: AggregateExpression1
+) extends AggregateFunction1 {
+  def this() = this(null, null)  // Required for serialization
+
+  private val computeType = DoubleType
+  private val zero = Cast(Literal(0), computeType)
+  private var partialCount: Long = 0L
+
+  // the mean of data processed so far
+  private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType)
+
+  // update average based on this formula:
+  // avg = avg + (value - avg)/count
+  private def avgAddFunction (value: Literal): Expression = {
+    val delta = Subtract(Cast(value, computeType), partialAvg)
+    Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType)))
+  }
+
+  // the sum of squares of difference from mean
+  private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType)
+
+  // update sum of square of difference from mean based on following formula:
+  // Mk = Mk + (value - preAvg) * (value - updatedAvg)
+  private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = {
+    val delta1 = Subtract(Cast(value, computeType), prePartialAvg)
+    val delta2 = Subtract(Cast(value, computeType), partialAvg)
+    Add(partialMk, Multiply(delta1, delta2))
+  }
+
+  override def update(input: InternalRow): Unit = {
+    val evaluatedExpr = expr.eval(input)
+    if (evaluatedExpr != null) {
+      val exprValue = Literal.create(evaluatedExpr, expr.dataType)
+      val prePartialAvg = partialAvg.copy()
+      partialCount += 1
+      partialAvg.update(avgAddFunction(exprValue), input)
+      partialMk.update(mkAddFunction(exprValue, prePartialAvg), input)
+    }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null),
+        partialAvg.eval(null),
+        partialMk.eval(null)))
+  }
+}
+
+case class MergePartialStd(
+    child: Expression,
+    isSample: Boolean
+) extends UnaryExpression with AggregateExpression1 {
+  def this() = this(null, false) // required for serialization
+
+  override def children: Seq[Expression] = child:: Nil
+  override def nullable: Boolean = false
+  override def dataType: DataType = DoubleType
+  override def toString: String = s"MergePartialStd($child)"
+  override def newInstance(): MergePartialStdFunction = {
+    new MergePartialStdFunction(child, this, isSample)
+  }
+}
+
+case class MergePartialStdFunction(
+    expr: Expression,
+    base: AggregateExpression1,
+    isSample: Boolean
+) extends AggregateFunction1 {
+  def this() = this (null, null, false) // Required for serialization
+
+  private val computeType = DoubleType
+  private val zero = Cast(Literal(0), computeType)
+  private val combineCount = MutableLiteral(zero.eval(null), computeType)
+  private val combineAvg = MutableLiteral(zero.eval(null), computeType)
+  private val combineMk = MutableLiteral(zero.eval(null), computeType)
+
+  private def avgUpdateFunction(preCount: Expression,
+                                partialCount: Expression,
+                                partialAvg: Expression): Expression = {
+    Divide(Add(Multiply(combineAvg, preCount),
+               Multiply(partialAvg, partialCount)),
+           Add(preCount, partialCount))
+  }
+
+  override def update(input: InternalRow): Unit = {
+    val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData]
+
+    if (evaluatedExpr != null) {
+      val exprValue = evaluatedExpr.toArray(computeType)
+      val (partialCount, partialAvg, partialMk) =
+        (Literal.create(exprValue(0), computeType),
+         Literal.create(exprValue(1), computeType),
+         Literal.create(exprValue(2), computeType))
+
+      if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) {
+        val preCount = combineCount.copy()
+        combineCount.update(Add(combineCount, partialCount), input)
+
+        val preAvg = combineAvg.copy()
+        val avgDelta = Subtract(partialAvg, preAvg)
+        val mkDelta = Multiply(Multiply(avgDelta, avgDelta),
+                               Divide(Multiply(preCount, partialCount),
+                                      combineCount))
+
+        // update average based on following formula
+        // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount)
+        combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input)
+
+        // update sum of square differences from mean based on following formula
+        // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount)
+        combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input)
+      }
+    }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long]
+
+    if (count == 0) null
+    else if (count < 2) zero.eval(null)
+    else {
+      // when total count > 2
+      // stddev_samp = sqrt (combineMk/(combineCount -1))
+      // stddev_pop = sqrt (combineMk/combineCount)
+      val varCol = {
+        if (isSample) {
+          Divide(combineMk, Cast(Literal(count - 1), computeType))
+        }
+        else {
+          Divide(combineMk, Cast(Literal(count), computeType))
+        }
+      }
+      Sqrt(varCol).eval(null)
+    }
+  }
+}
+
+case class StddevFunction(
+    expr: Expression,
+    base: AggregateExpression1,
+    isSample: Boolean
+) extends AggregateFunction1 {
+
+  def this() = this(null, null, false) // Required for serialization
+
+  private val computeType = DoubleType
+  private var curCount: Long = 0L
+  private val zero = Cast(Literal(0), computeType)
+  private val curAvg = MutableLiteral(zero.eval(null), computeType)
+  private val curMk = MutableLiteral(zero.eval(null), computeType)
+
+  private def curAvgAddFunction(value: Literal): Expression = {
+    val delta = Subtract(Cast(value, computeType), curAvg)
+    Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType)))
+  }
+  private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = {
+    val delta1 = Subtract(Cast(value, computeType), preAvg)
+    val delta2 = Subtract(Cast(value, computeType), curAvg)
+    Add(curMk, Multiply(delta1, delta2))
+  }
+
+  override def update(input: InternalRow): Unit = {
+    val evaluatedExpr = expr.eval(input)
+    if (evaluatedExpr != null) {
+      val preAvg: MutableLiteral = curAvg.copy()
+      val exprValue = Literal.create(evaluatedExpr, expr.dataType)
+      curCount += 1L
+      curAvg.update(curAvgAddFunction(exprValue), input)
+      curMk.update(curMkAddFunction(exprValue, preAvg), input)
+    }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    if (curCount == 0) null
+    else if (curCount < 2) zero.eval(null)
+    else {
+      // when total count > 2,
+      // stddev_samp = sqrt(curMk/(curCount - 1))
+      // stddev_pop = sqrt(curMk/curCount)
+      val varCol = {
+        if (isSample) {
+          Divide(curMk, Cast(Literal(curCount - 1), computeType))
+        }
+        else {
+          Divide(curMk, Cast(Literal(curCount), computeType))
+        }
+      }
+      Sqrt(varCol).eval(null)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 791c10c..1a687b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1288,15 +1288,11 @@ class DataFrame private[sql](
   @scala.annotation.varargs
   def describe(cols: String*): DataFrame = {
 
-    // TODO: Add stddev as an expression, and remove it from here.
-    def stddevExpr(expr: Expression): Expression =
-      Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))
-
     // The list of summary statistics to compute, in the form of expressions.
     val statistics = List[(String, Expression => Expression)](
       "count" -> Count,
       "mean" -> Average,
-      "stddev" -> stddevExpr,
+      "stddev" -> Stddev,
       "min" -> Min,
       "max" -> Max)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index ee31d83..102b802 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -124,6 +124,9 @@ class GroupedData protected[sql](
       case "avg" | "average" | "mean" => Average
       case "max" => Max
       case "min" => Min
+      case "stddev" => Stddev
+      case "stddev_pop" => StddevPop
+      case "stddev_samp" => StddevSamp
       case "sum" => Sum
       case "count" | "size" =>
         // Turn count(*) into count(1)
@@ -284,6 +287,42 @@ class GroupedData protected[sql](
   }
 
   /**
+   * Compute the sample standard deviation for each numeric columns for each group.
+   * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the stddev for them.
+   *
+   * @since 1.6.0
+   */
+  @scala.annotation.varargs
+  def stddev(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames : _*)(Stddev)
+  }
+
+  /**
+   * Compute the population standard deviation for each numeric columns for each group.
+   * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the stddev for them.
+   *
+   * @since 1.6.0
+   */
+  @scala.annotation.varargs
+  def stddev_pop(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames : _*)(StddevPop)
+  }
+
+  /**
+   * Compute the sample standard deviation for each numeric columns for each group.
+   * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the stddev for them.
+   *
+   * @since 1.6.0
+   */
+  @scala.annotation.varargs
+  def stddev_samp(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames : _*)(StddevSamp)
+  }
+
+  /**
    * Compute the sum for each numeric columns for each group.
    * The resulting [[DataFrame]] will also contain the grouping columns.
    * When specified columns are given, only compute the sum for them.

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 435e631..60d9c50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -295,6 +295,33 @@ object functions {
   def min(columnName: String): Column = min(Column(columnName))
 
   /**
+   * Aggregate function: returns the unbiased sample standard deviation
+   * of the expression in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def stddev(e: Column): Column = Stddev(e.expr)
+
+  /**
+   * Aggregate function: returns the population standard deviation of
+   * the expression in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def stddev_pop(e: Column): Column = StddevPop(e.expr)
+
+  /**
+   * Aggregate function: returns the unbiased sample standard deviation of
+   * the expression in a group.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def stddev_samp(e: Column): Column = StddevSamp(e.expr)
+
+  /**
    * Aggregate function: returns the sum of all values in the expression.
    *
    * @group agg_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index d981ce9..5f9abd4 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -90,6 +90,7 @@ public class JavaDataFrameSuite {
     df.groupBy().mean("key");
     df.groupBy().max("key");
     df.groupBy().min("key");
+    df.groupBy().stddev("key");
     df.groupBy().sum("key");
 
     // Varargs in column expressions

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index c0950b0..f5ef9ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -175,6 +175,39 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
       Row(0, null))
   }
 
+  test("stddev") {
+    val testData2ADev = math.sqrt(4/5.0)
+
+    checkAnswer(
+      testData2.agg(stddev('a)),
+      Row(testData2ADev))
+
+    checkAnswer(
+      testData2.agg(stddev_pop('a)),
+      Row(math.sqrt(4/6.0)))
+
+    checkAnswer(
+      testData2.agg(stddev_samp('a)),
+      Row(testData2ADev))
+  }
+
+  test("zero stddev") {
+    val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
+    assert(emptyTableData.count() == 0)
+
+    checkAnswer(
+    emptyTableData.agg(stddev('a)),
+    Row(null))
+
+    checkAnswer(
+    emptyTableData.agg(stddev_pop('a)),
+    Row(null))
+
+    checkAnswer(
+    emptyTableData.agg(stddev_samp('a)),
+    Row(null))
+  }
+
   test("zero sum") {
     val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
     checkAnswer(

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index dbed4fc..c167999 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -436,7 +436,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val describeResult = Seq(
       Row("count", "4", "4"),
       Row("mean", "33.0", "178.0"),
-      Row("stddev", "16.583123951777", "10.0"),
+      Row("stddev", "19.148542155126762", "11.547005383792516"),
       Row("min", "16", "164"),
       Row("max", "60", "192"))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 664b7a1..962b100 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -328,6 +328,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       testCodeGen(
         "SELECT min(key) FROM testData3x",
         Row(1) :: Nil)
+      // STDDEV
+      testCodeGen(
+        "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a",
+        (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25))))
+      testCodeGen(
+        "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2",
+        Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil)
       // Some combinations.
       testCodeGen(
         """
@@ -348,8 +355,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
         Row(100, 1, 50.5, 300, 100) :: Nil)
       // Aggregate with Code generation handling all null values
       testCodeGen(
-        "SELECT  sum('a'), avg('a'), count(null) FROM testData",
-        Row(null, null, 0) :: Nil)
+        "SELECT  sum('a'), avg('a'), stddev('a'), count(null) FROM testData",
+        Row(null, null, null, 0) :: Nil)
     } finally {
       sqlContext.dropTempTable("testData3x")
       sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
@@ -515,8 +522,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
 
   test("aggregates with nulls") {
     checkAnswer(
-      sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"),
-      Row(1, 3, 2, 6, 3)
+      sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"),
+      Row(1, 3, 2, 1, 6, 3)
     )
   }
 
@@ -722,6 +729,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("stddev") {
+    checkAnswer(
+      sql("SELECT STDDEV(a) FROM testData2"),
+      Row(math.sqrt(4/5.0))
+    )
+  }
+
+  test("stddev_pop") {
+    checkAnswer(
+      sql("SELECT STDDEV_POP(a) FROM testData2"),
+      Row(math.sqrt(4/6.0))
+    )
+  }
+
+  test("stddev_samp") {
+    checkAnswer(
+      sql("SELECT STDDEV_SAMP(a) FROM testData2"),
+      Row(math.sqrt(4/5.0))
+    )
+  }
+
+  test("stddev agg") {
+    checkAnswer(
+      sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
+      (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0))))
+  }
+
   test("inner join where, one match per row") {
     checkAnswer(
       sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"),

http://git-wip-us.apache.org/repos/asf/spark/blob/f4a22808/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index b126ec4..a73b1bd 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -507,41 +507,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
       }.getMessage
       assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
     }
-
-    // TODO: once we support Hive UDAF in the new interface,
-    // we can remove the following two tests.
-    withSQLConf("spark.sql.useAggregate2" -> "true") {
-      val errorMessage = intercept[AnalysisException] {
-        sqlContext.sql(
-          """
-            |SELECT
-            |  key,
-            |  mydoublesum(value + 1.5 * key),
-            |  stddev_samp(value)
-            |FROM agg1
-            |GROUP BY key
-          """.stripMargin).collect()
-      }.getMessage
-      assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
-
-      // This will fall back to the old aggregate
-      val newAggregateOperators = sqlContext.sql(
-        """
-          |SELECT
-          |  key,
-          |  sum(value + 1.5 * key),
-          |  stddev_samp(value)
-          |FROM agg1
-          |GROUP BY key
-        """.stripMargin).queryExecution.executedPlan.collect {
-        case agg: aggregate.SortBasedAggregate => agg
-        case agg: aggregate.TungstenAggregate => agg
-      }
-      val message =
-        "We should fallback to the old aggregation code path if " +
-          "there is any aggregate function that cannot be converted to the new interface."
-      assert(newAggregateOperators.isEmpty, message)
-    }
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org