You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2016/09/01 08:31:14 UTC

spark git commit: [SPARK-16283][SQL] Implements percentile_approx aggregation function which supports partial aggregation.

Repository: spark
Updated Branches:
  refs/heads/master 536fa911c -> a18c169fd


[SPARK-16283][SQL] Implements percentile_approx aggregation function which supports partial aggregation.

## What changes were proposed in this pull request?

This PR implements aggregation function `percentile_approx`. Function `percentile_approx` returns the approximate percentile(s) of a column at the given percentage(s). A percentile is a watermark value below which a given percentage of the column values fall. For example, the percentile of column `col` at percentage 50% is the median value of column `col`.

### Syntax:
```
# Returns percentile at a given percentage value. The approximation error can be reduced by increasing parameter accuracy, at the cost of memory.
percentile_approx(col, percentage [, accuracy])

# Returns percentile value array at given percentage value array
percentile_approx(col, array(percentage1 [, percentage2]...) [, accuracy])
```

### Features:
1. This function supports partial aggregation.
2. The memory consumption is bounded. The larger `accuracy` parameter we choose, we smaller error we get. The default accuracy value is 10000, to match with Hive default setting. Choose a smaller value for smaller memory footprint.
3.  This function supports window function aggregation.

### Example usages:
```
## Returns the 25th percentile value, with default accuracy
SELECT percentile_approx(col, 0.25) FROM table

## Returns an array of percentile value (25th, 50th, 75th), with default accuracy
SELECT percentile_approx(col, array(0.25, 0.5, 0.75)) FROM table

## Returns 25th percentile value, with custom accuracy value 100, larger accuracy parameter yields smaller approximation error
SELECT percentile_approx(col, 0.25, 100) FROM table

## Returns the 25th, and 50th percentile values, with custom accuracy value 100
SELECT percentile_approx(col, array(0.25, 0.5), 100) FROM table
```

### NOTE:
1. The `percentile_approx` implementation is different from Hive, so the result returned on same query maybe slightly different with Hive. This implementation uses `QuantileSummaries` as the underlying probabilistic data structure, and mainly follows paper `Space-efficient Online Computation of Quantile Summaries` by Greenwald, Michael and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670)`
2. The current implementation of `QuantileSummaries` doesn't support automatic compression. This PR has a rule to do compression automatically at the caller side, but it may not be optimal.

## How was this patch tested?

Unit test, and Sql query test.

## Acknowledgement
1. This PR's work in based on lw-lin's PR https://github.com/apache/spark/pull/14298, with improvements like supporting partial aggregation, fixing out of memory issue.

Author: Sean Zhong <se...@databricks.com>

Closes #14868 from clockfly/appro_percentile_try_2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a18c169f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a18c169f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a18c169f

Branch: refs/heads/master
Commit: a18c169fd050e71fdb07b153ae0fa5c410d8de27
Parents: 536fa91
Author: Sean Zhong <se...@databricks.com>
Authored: Thu Sep 1 16:31:13 2016 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Sep 1 16:31:13 2016 +0800

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../aggregate/ApproximatePercentile.scala       | 321 ++++++++++++++++++
 .../aggregate/ApproximatePercentileSuite.scala  | 339 +++++++++++++++++++
 .../sql/ApproximatePercentileQuerySuite.scala   | 226 +++++++++++++
 .../spark/sql/hive/HiveSessionCatalog.scala     |   3 +-
 .../sql/catalyst/ExpressionToSQLSuite.scala     |   5 +
 6 files changed, 893 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a18c169f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 35fd800..b05f4f6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -250,6 +250,7 @@ object FunctionRegistry {
     expression[Average]("mean"),
     expression[Min]("min"),
     expression[Skewness]("skewness"),
+    expression[ApproximatePercentile]("percentile_approx"),
     expression[StddevSamp]("std"),
     expression[StddevSamp]("stddev"),
     expression[StddevPop]("stddev_pop"),

http://git-wip-us.apache.org/repos/asf/spark/blob/a18c169f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
new file mode 100644
index 0000000..f91ff87
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.nio.ByteBuffer
+
+import com.google.common.primitives.{Doubles, Ints, Longs}
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.{InternalRow}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest}
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.QuantileSummaries
+import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
+import org.apache.spark.sql.types._
+
+/**
+ * The ApproximatePercentile function returns the approximate percentile(s) of a column at the given
+ * percentage(s). A percentile is a watermark value below which a given percentage of the column
+ * values fall. For example, the percentile of column `col` at percentage 50% is the median of
+ * column `col`.
+ *
+ * This function supports partial aggregation.
+ *
+ * @param child child expression that can produce column value with `child.eval(inputRow)`
+ * @param percentageExpression Expression that represents a single percentage value or
+ *                             an array of percentage values. Each percentage value must be between
+ *                             0.0 and 1.0.
+ * @param accuracyExpression Integer literal expression of approximation accuracy. Higher value
+ *                           yields better accuracy, the default value is
+ *                           DEFAULT_PERCENTILE_ACCURACY.
+ */
+@ExpressionDescription(
+  usage =
+    """
+      _FUNC_(col, percentage [, accuracy]) - Returns the approximate percentile value of numeric
+      column `col` at the given percentage. The value of percentage must be between 0.0
+      and 1.0. The `accuracy` parameter (default: 10000) is a positive integer literal which
+      controls approximation accuracy at the cost of memory. Higher value of `accuracy` yields
+      better accuracy, `1.0/accuracy` is the relative error of the approximation.
+
+      _FUNC_(col, array(percentage1 [, percentage2]...) [, accuracy]) - Returns the approximate
+      percentile array of column `col` at the given percentage array. Each value of the
+      percentage array must be between 0.0 and 1.0. The `accuracy` parameter (default: 10000) is
+       a positive integer literal which controls approximation accuracy at the cost of memory.
+       Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is the relative error of
+       the approximation.
+    """)
+case class ApproximatePercentile(
+    child: Expression,
+    percentageExpression: Expression,
+    accuracyExpression: Expression,
+    override val mutableAggBufferOffset: Int,
+    override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] {
+
+  def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = {
+    this(child, percentageExpression, accuracyExpression, 0, 0)
+  }
+
+  def this(child: Expression, percentageExpression: Expression) = {
+    this(child, percentageExpression, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY))
+  }
+
+  // Mark as lazy so that accuracyExpression is not evaluated during tree transformation.
+  private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int]
+
+  override def inputTypes: Seq[AbstractDataType] = {
+    Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType)
+  }
+
+  // Mark as lazy so that percentageExpression is not evaluated during tree transformation.
+  private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = {
+    (percentageExpression.dataType, percentageExpression.eval()) match {
+      // Rule ImplicitTypeCasts can cast other numeric types to double
+      case (_, num: Double) => (false, Array(num))
+      case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
+         val numericArray = arrayData.toObjectArray(baseType)
+        (true, numericArray.map { x =>
+          baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])
+        })
+      case other =>
+        throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage")
+    }
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      defaultCheck
+    } else if (!percentageExpression.foldable || !accuracyExpression.foldable) {
+      TypeCheckFailure(s"The accuracy or percentage provided must be a constant literal")
+    } else if (accuracy <= 0) {
+      TypeCheckFailure(
+        s"The accuracy provided must be a positive integer literal (current value = $accuracy)")
+    } else if (percentages.exists(percentage => percentage < 0.0D || percentage > 1.0D)) {
+      TypeCheckFailure(
+        s"All percentage values must be between 0.0 and 1.0 " +
+          s"(current = ${percentages.mkString(", ")})")
+    } else {
+      TypeCheckSuccess
+    }
+  }
+
+  override def createAggregationBuffer(): PercentileDigest = {
+    val relativeError = 1.0D / accuracy
+    new PercentileDigest(relativeError)
+  }
+
+  override def update(buffer: PercentileDigest, inputRow: InternalRow): Unit = {
+    val value = child.eval(inputRow)
+    // Ignore empty rows, for example: percentile_approx(null)
+    if (value != null) {
+      buffer.add(value.asInstanceOf[Double])
+    }
+  }
+
+  override def merge(buffer: PercentileDigest, other: PercentileDigest): Unit = {
+    buffer.merge(other)
+  }
+
+  override def eval(buffer: PercentileDigest): Any = {
+    val result = buffer.getPercentiles(percentages)
+    if (result.length == 0) {
+      null
+    } else if (returnPercentileArray) {
+      new GenericArrayData(result)
+    } else {
+      result(0)
+    }
+  }
+
+  override def withNewMutableAggBufferOffset(newOffset: Int): ApproximatePercentile =
+    copy(mutableAggBufferOffset = newOffset)
+
+  override def withNewInputAggBufferOffset(newOffset: Int): ApproximatePercentile =
+    copy(inputAggBufferOffset = newOffset)
+
+  override def children: Seq[Expression] = Seq(child, percentageExpression, accuracyExpression)
+
+  // Returns null for empty inputs
+  override def nullable: Boolean = true
+
+  override def dataType: DataType = {
+    if (returnPercentileArray) ArrayType(DoubleType) else DoubleType
+  }
+
+  override def prettyName: String = "percentile_approx"
+
+  override def serialize(obj: PercentileDigest): Array[Byte] = {
+    ApproximatePercentile.serializer.serialize(obj)
+  }
+
+  override def deserialize(bytes: Array[Byte]): PercentileDigest = {
+    ApproximatePercentile.serializer.deserialize(bytes)
+  }
+}
+
+object ApproximatePercentile {
+
+  // Default accuracy of Percentile approximation. Larger value means better accuracy.
+  // The default relative error can be deduced by defaultError = 1.0 / DEFAULT_PERCENTILE_ACCURACY
+  val DEFAULT_PERCENTILE_ACCURACY: Int = 10000
+
+  /**
+   * PercentileDigest is a probabilistic data structure used for approximating percentiles
+   * with limited memory. PercentileDigest is backed by [[QuantileSummaries]].
+   *
+   * @param summaries underlying probabilistic data structure [[QuantileSummaries]].
+   * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the
+   *                   underlying quantileSummaries is compressed.
+   */
+  class PercentileDigest(
+      private var summaries: QuantileSummaries,
+      private var isCompressed: Boolean) {
+
+    // Trigger compression if the QuantileSummaries's buffer length exceeds
+    // compressThresHoldBufferLength. The buffer length can be get by
+    // quantileSummaries.sampled.length
+    private[this] final val compressThresHoldBufferLength: Int = {
+      // Max buffer length after compression.
+      val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2
+      // A safe upper bound for buffer length before compression
+      maxBufferLengthAfterCompression * 2
+    }
+
+    def this(relativeError: Double) = {
+      this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true)
+    }
+
+    /** Returns compressed object of [[QuantileSummaries]] */
+    def quantileSummaries: QuantileSummaries = {
+      if (!isCompressed) compress()
+      summaries
+    }
+
+    /** Insert an observation value into the PercentileDigest data structure. */
+    def add(value: Double): Unit = {
+      summaries = summaries.insert(value)
+      // The result of QuantileSummaries.insert is un-compressed
+      isCompressed = false
+
+      // Currently, QuantileSummaries ignores the construction parameter compressThresHold,
+      // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here
+      // to make sure QuantileSummaries doesn't occupy infinite memory.
+      // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold
+      if (summaries.sampled.length >= compressThresHoldBufferLength) compress()
+    }
+
+    /** In-place merges in another PercentileDigest. */
+    def merge(other: PercentileDigest): Unit = {
+      if (!isCompressed) compress()
+      summaries = summaries.merge(other.quantileSummaries)
+    }
+
+    /**
+     * Returns the approximate percentiles of all observation values at the given percentages.
+     * A percentile is a watermark value below which a given percentage of observation values fall.
+     * For example, the following code returns the 25th, median, and 75th percentiles of
+     * all observation values:
+     *
+     * {{{
+     *   val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75))
+     * }}}
+     */
+    def getPercentiles(percentages: Array[Double]): Array[Double] = {
+      if (!isCompressed) compress()
+      if (summaries.count == 0 || percentages.length == 0) {
+        Array.empty[Double]
+      } else {
+        val result = new Array[Double](percentages.length)
+        var i = 0
+        while (i < percentages.length) {
+          result(i) = summaries.query(percentages(i))
+          i += 1
+        }
+        result
+      }
+    }
+
+    private final def compress(): Unit = {
+      summaries = summaries.compress()
+      isCompressed = true
+    }
+  }
+
+  /**
+   * Serializer  for class [[PercentileDigest]]
+   *
+   * This class is thread safe.
+   */
+  class PercentileDigestSerializer {
+
+    private final def length(summaries: QuantileSummaries): Int = {
+      // summaries.compressThreshold, summary.relativeError, summary.count
+      Ints.BYTES + Doubles.BYTES + Longs.BYTES +
+      // length of summary.sampled
+      Ints.BYTES +
+      // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)]
+      summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES)
+    }
+
+    final def serialize(obj: PercentileDigest): Array[Byte] = {
+      val summary = obj.quantileSummaries
+      val buffer = ByteBuffer.wrap(new Array(length(summary)))
+      buffer.putInt(summary.compressThreshold)
+      buffer.putDouble(summary.relativeError)
+      buffer.putLong(summary.count)
+      buffer.putInt(summary.sampled.length)
+
+      var i = 0
+      while (i < summary.sampled.length) {
+        val stat = summary.sampled(i)
+        buffer.putDouble(stat.value)
+        buffer.putInt(stat.g)
+        buffer.putInt(stat.delta)
+        i += 1
+      }
+      buffer.array()
+    }
+
+    final def deserialize(bytes: Array[Byte]): PercentileDigest = {
+      val buffer = ByteBuffer.wrap(bytes)
+      val compressThreshold = buffer.getInt()
+      val relativeError = buffer.getDouble()
+      val count = buffer.getLong()
+      val sampledLength = buffer.getInt()
+      val sampled = new Array[Stats](sampledLength)
+
+      var i = 0
+      while (i < sampledLength) {
+        val value = buffer.getDouble()
+        val g = buffer.getInt()
+        val delta = buffer.getInt()
+        sampled(i) = Stats(value, g, delta)
+        i += 1
+      }
+      val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count)
+      new PercentileDigest(summary, isCompressed = true)
+    }
+  }
+
+  val serializer: PercentileDigestSerializer = new PercentileDigestSerializer
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a18c169f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
new file mode 100644
index 0000000..61298a1
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
@@ -0,0 +1,339 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericMutableRow, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest, PercentileDigestSerializer}
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.util.QuantileSummaries
+import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats
+import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType}
+import org.apache.spark.util.SizeEstimator
+
+class ApproximatePercentileSuite extends SparkFunSuite {
+
+  private val random = new java.util.Random()
+
+  private val data = (0 until 10000).map { _ =>
+    random.nextInt(10000)
+  }
+
+  test("serialize and de-serialize") {
+    val serializer = new PercentileDigestSerializer
+
+    // Check empty serialize and de-serialize
+    val emptyBuffer = new PercentileDigest(relativeError = 0.01)
+    assert(compareEquals(emptyBuffer, serializer.deserialize(serializer.serialize(emptyBuffer))))
+
+    val buffer = new PercentileDigest(relativeError = 0.01)
+    data.foreach { value =>
+      buffer.add(value)
+    }
+    assert(compareEquals(buffer, serializer.deserialize(serializer.serialize(buffer))))
+
+    val agg = new ApproximatePercentile(BoundReference(0, DoubleType, true), Literal(0.5))
+    assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
+  }
+
+  test("class PercentileDigest, basic operations") {
+    val valueCount = 10000
+    val percentages = Array(0.25, 0.5, 0.75)
+    Seq(0.0001, 0.001, 0.01, 0.1).foreach { relativeError =>
+      val buffer = new PercentileDigest(relativeError)
+      (1 to valueCount).grouped(10).foreach { group =>
+        val partialBuffer = new PercentileDigest(relativeError)
+        group.foreach(x => partialBuffer.add(x))
+        buffer.merge(partialBuffer)
+      }
+      val expectedPercentiles = percentages.map(_ * valueCount)
+      val approxPercentiles = buffer.getPercentiles(Array(0.25, 0.5, 0.75))
+      expectedPercentiles.zip(approxPercentiles).foreach { pair =>
+        val (expected, estimate) = pair
+        assert((estimate - expected) / valueCount <= relativeError)
+      }
+    }
+  }
+
+  test("class PercentileDigest, makes sure the memory foot print is bounded") {
+    val relativeError = 0.01
+    val memoryFootPrintUpperBound = {
+      val headBufferSize =
+        SizeEstimator.estimate(new Array[Double](QuantileSummaries.defaultHeadSize))
+      val bufferSize = SizeEstimator.estimate(new Stats(0, 0, 0)) * (1 / relativeError) * 2
+      // A safe upper bound
+      (headBufferSize + bufferSize) * 2
+    }
+
+    val sizePerInputs = Seq(100, 1000, 10000, 100000, 1000000, 10000000).map { count =>
+      val buffer = new PercentileDigest(relativeError)
+      // Worst case, data is linear sorted
+      (0 until count).foreach(buffer.add(_))
+      assert(SizeEstimator.estimate(buffer) < memoryFootPrintUpperBound)
+    }
+  }
+
+  test("class ApproximatePercentile, high level interface, update, merge, eval...") {
+    val count = 10000
+    val data = (1 until 10000).toSeq
+    val percentages = Array(0.25D, 0.5D, 0.75D)
+    val accuracy = 10000
+    val expectedPercentiles = percentages.map(count * _)
+    val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType)
+    val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_)))
+    val accuracyExpression = Literal(10000)
+    val agg = new ApproximatePercentile(childExpression, percentageExpression, accuracyExpression)
+
+    assert(agg.nullable)
+    val group1 = (0 until data.length / 2)
+    val group1Buffer = agg.createAggregationBuffer()
+    group1.foreach { index =>
+      val input = InternalRow(data(index))
+      agg.update(group1Buffer, input)
+    }
+
+    val group2 = (data.length / 2 until data.length)
+    val group2Buffer = agg.createAggregationBuffer()
+    group2.foreach { index =>
+      val input = InternalRow(data(index))
+      agg.update(group2Buffer, input)
+    }
+
+    val mergeBuffer = agg.createAggregationBuffer()
+    agg.merge(mergeBuffer, group1Buffer)
+    agg.merge(mergeBuffer, group2Buffer)
+
+    agg.eval(mergeBuffer) match {
+      case arrayData: ArrayData =>
+        val error = count / accuracy
+        val percentiles = arrayData.toDoubleArray()
+        assert(percentiles.zip(expectedPercentiles)
+          .forall(pair => Math.abs(pair._1 - pair._2) < error))
+    }
+  }
+
+  test("class ApproximatePercentile, low level interface, update, merge, eval...") {
+    val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+    val inputAggregationBufferOffset = 1
+    val mutableAggregationBufferOffset = 2
+    val percentage = 0.5D
+
+    // Phase one, partial mode aggregation
+    val agg = new ApproximatePercentile(childExpression, Literal(percentage))
+      .withNewInputAggBufferOffset(inputAggregationBufferOffset)
+      .withNewMutableAggBufferOffset(mutableAggregationBufferOffset)
+
+    val mutableAggBuffer = new GenericMutableRow(new Array[Any](mutableAggregationBufferOffset + 1))
+    agg.initialize(mutableAggBuffer)
+    val dataCount = 10
+    (1 to dataCount).foreach { data =>
+      agg.update(mutableAggBuffer, InternalRow(data))
+    }
+    agg.serializeAggregateBufferInPlace(mutableAggBuffer)
+
+    // Serialize the aggregation buffer
+    val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset)
+    val inputAggBuffer = new GenericMutableRow(Array[Any](null, serialized))
+
+    // Phase 2: final mode aggregation
+    // Re-initialize the aggregation buffer
+    agg.initialize(mutableAggBuffer)
+    agg.merge(mutableAggBuffer, inputAggBuffer)
+    val expectedPercentile = dataCount * percentage
+    assert(Math.abs(agg.eval(mutableAggBuffer).asInstanceOf[Double] - expectedPercentile) < 0.1)
+  }
+
+  test("class ApproximatePercentile, sql string") {
+    val defaultAccuracy = ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
+    // sql, single percentile
+    assertEqual(
+      s"percentile_approx(`a`, 0.5D, $defaultAccuracy)",
+      new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)).sql: String)
+
+    // sql, array of percentile
+    assertEqual(
+      s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)",
+      new ApproximatePercentile(
+        "a".attr,
+        percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_)))
+      ).sql: String)
+
+    // sql(isDistinct = false), single percentile
+    assertEqual(
+      s"percentile_approx(`a`, 0.5D, $defaultAccuracy)",
+      new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D))
+        .sql(isDistinct = false))
+
+    // sql(isDistinct = false), array of percentile
+    assertEqual(
+      s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)",
+      new ApproximatePercentile(
+        "a".attr,
+        percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_)))
+      ).sql(isDistinct = false))
+
+    // sql(isDistinct = true), single percentile
+    assertEqual(
+      s"percentile_approx(DISTINCT `a`, 0.5D, $defaultAccuracy)",
+      new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D))
+        .sql(isDistinct = true))
+
+    // sql(isDistinct = true), array of percentile
+    assertEqual(
+      s"percentile_approx(DISTINCT `a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)",
+      new ApproximatePercentile(
+        "a".attr,
+        percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_)))
+      ).sql(isDistinct = true))
+  }
+
+  test("class ApproximatePercentile, fails analysis if percentage or accuracy is not a constant") {
+    val attribute = AttributeReference("a", DoubleType)()
+    val wrongAccuracy = new ApproximatePercentile(
+      attribute,
+      percentageExpression = Literal(0.5D),
+      accuracyExpression = AttributeReference("b", IntegerType)())
+
+    assertEqual(
+      wrongAccuracy.checkInputDataTypes(),
+      TypeCheckFailure("The accuracy or percentage provided must be a constant literal")
+    )
+
+    val wrongPercentage = new ApproximatePercentile(
+      attribute,
+      percentageExpression = attribute,
+      accuracyExpression = Literal(10000))
+
+    assertEqual(
+      wrongPercentage.checkInputDataTypes(),
+      TypeCheckFailure("The accuracy or percentage provided must be a constant literal")
+    )
+  }
+
+  test("class ApproximatePercentile, fails analysis if parameters are invalid") {
+    val wrongAccuracy = new ApproximatePercentile(
+      AttributeReference("a", DoubleType)(),
+      percentageExpression = Literal(0.5D),
+      accuracyExpression = Literal(-1))
+    assertEqual(
+      wrongAccuracy.checkInputDataTypes(),
+      TypeCheckFailure(
+        "The accuracy provided must be a positive integer literal (current value = -1)"))
+
+    val correctPercentageExpresions = Seq(
+      Literal(0D),
+      Literal(1D),
+      Literal(0.5D),
+      CreateArray(Seq(0D, 1D, 0.5D).map(Literal(_)))
+    )
+    correctPercentageExpresions.foreach { percentageExpression =>
+      val correctPercentage = new ApproximatePercentile(
+        AttributeReference("a", DoubleType)(),
+        percentageExpression = percentageExpression,
+        accuracyExpression = Literal(100))
+
+      // no exception should be thrown
+      correctPercentage.checkInputDataTypes()
+    }
+
+    val wrongPercentageExpressions = Seq(
+      Literal(1.1D),
+      Literal(-0.5D),
+      CreateArray(Seq(0D, 0.5D, 1.1D).map(Literal(_)))
+    )
+
+    wrongPercentageExpressions.foreach { percentageExpression =>
+      val wrongPercentage = new ApproximatePercentile(
+        AttributeReference("a", DoubleType)(),
+        percentageExpression = percentageExpression,
+        accuracyExpression = Literal(100))
+
+      val result = wrongPercentage.checkInputDataTypes()
+      assert(
+        wrongPercentage.checkInputDataTypes() match {
+          case TypeCheckFailure(msg) if msg.contains("must be between 0.0 and 1.0") => true
+          case _ => false
+      })
+    }
+  }
+
+  test("class ApproximatePercentile, automatically add type casting for parameters") {
+    val testRelation = LocalRelation('a.int)
+    val analyzer = SimpleAnalyzer
+
+    // Compatible accuracy types: Long type and decimal type
+    val accuracyExpressions = Seq(Literal(1000L), DecimalLiteral(10000), Literal(123.0D))
+    // Compatible percentage types: float, decimal
+    val percentageExpressions = Seq(Literal(0.3f), DecimalLiteral(0.5),
+      CreateArray(Seq(Literal(0.3f), Literal(0.5D), DecimalLiteral(0.7))))
+
+    accuracyExpressions.foreach { accuracyExpression =>
+      percentageExpressions.foreach { percentageExpression =>
+        val agg = new ApproximatePercentile(
+          UnresolvedAttribute("a"),
+          percentageExpression,
+          accuracyExpression)
+        val analyzed = testRelation.select(agg).analyze.expressions.head
+        analyzed match {
+          case Alias(agg: ApproximatePercentile, _) =>
+            assert(agg.resolved)
+            assert(agg.child.dataType == DoubleType)
+            assert(agg.percentageExpression.dataType == DoubleType ||
+              agg.percentageExpression.dataType == ArrayType(DoubleType, containsNull = false))
+            assert(agg.accuracyExpression.dataType == IntegerType)
+          case _ => fail()
+        }
+      }
+    }
+  }
+
+  test("class ApproximatePercentile, null handling") {
+    val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+    val agg = new ApproximatePercentile(childExpression, Literal(0.5D))
+    val buffer = new GenericMutableRow(new Array[Any](1))
+    agg.initialize(buffer)
+    // Empty aggregation buffer
+    assert(agg.eval(buffer) == null)
+    // Empty input row
+    agg.update(buffer, InternalRow(null))
+    assert(agg.eval(buffer) == null)
+
+    // Add some non-empty row
+    agg.update(buffer, InternalRow(0))
+    assert(agg.eval(buffer) != null)
+  }
+
+  private def compareEquals(left: PercentileDigest, right: PercentileDigest): Boolean = {
+    val leftSummary = left.quantileSummaries
+    val rightSummary = right.quantileSummaries
+    leftSummary.compressThreshold == rightSummary.compressThreshold &&
+      leftSummary.relativeError == rightSummary.relativeError &&
+      leftSummary.count == rightSummary.count &&
+      leftSummary.sampled.sameElements(rightSummary.sampled)
+  }
+
+  private def assertEqual[T](left: T, right: T): Unit = {
+    assert(left == right)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a18c169f/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
new file mode 100644
index 0000000..37d7c44
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
+import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
+
+  private val table = "percentile_test"
+
+  test("percentile_approx, single percentile value") {
+    withTempView(table) {
+      (1 to 1000).toDF("col").createOrReplaceTempView(table)
+      checkAnswer(
+        spark.sql(
+          s"""
+             |SELECT
+             |  percentile_approx(col, 0.25),
+             |  percentile_approx(col, 0.5),
+             |  percentile_approx(col, 0.75d),
+             |  percentile_approx(col, 0.0),
+             |  percentile_approx(col, 1.0),
+             |  percentile_approx(col, 0),
+             |  percentile_approx(col, 1)
+             |FROM $table
+           """.stripMargin),
+        Row(250D, 500D, 750D, 1D, 1000D, 1D, 1000D)
+      )
+    }
+  }
+
+  test("percentile_approx, array of percentile value") {
+    withTempView(table) {
+      (1 to 1000).toDF("col").createOrReplaceTempView(table)
+      checkAnswer(
+        spark.sql(
+          s"""SELECT
+             |  percentile_approx(col, array(0.25, 0.5, 0.75D)),
+             |  count(col),
+             |  percentile_approx(col, array(0.0, 1.0)),
+             |  sum(col)
+             |FROM $table
+           """.stripMargin),
+        Row(Seq(250D, 500D, 750D), 1000, Seq(1D, 1000D), 500500)
+      )
+    }
+  }
+
+  test("percentile_approx, with different accuracies") {
+
+    withTempView(table) {
+      (1 to 1000).toDF("col").createOrReplaceTempView(table)
+
+      // With different accuracies
+      val expectedPercentile = 250D
+      val accuracies = Array(1, 10, 100, 1000, 10000)
+      val errors = accuracies.map { accuracy =>
+        val df = spark.sql(s"SELECT percentile_approx(col, 0.25, $accuracy) FROM $table")
+        val approximatePercentile = df.collect().head.getDouble(0)
+        val error = Math.abs(approximatePercentile - expectedPercentile)
+        error
+      }
+
+      // The larger accuracy value we use, the smaller error we get
+      assert(errors.sorted.sameElements(errors.reverse))
+    }
+  }
+
+  test("percentile_approx, supports constant folding for parameter accuracy and percentages") {
+    withTempView(table) {
+      (1 to 1000).toDF("col").createOrReplaceTempView(table)
+      checkAnswer(
+        spark.sql(s"SELECT percentile_approx(col, array(0.25 + 0.25D), 200 + 800D) FROM $table"),
+        Row(Seq(500D))
+      )
+    }
+  }
+
+  test("percentile_approx(), aggregation on empty input table, no group by") {
+    withTempView(table) {
+      Seq.empty[Int].toDF("col").createOrReplaceTempView(table)
+      checkAnswer(
+        spark.sql(s"SELECT sum(col), percentile_approx(col, 0.5) FROM $table"),
+        Row(null, null)
+      )
+    }
+  }
+
+  test("percentile_approx(), aggregation on empty input table, with group by") {
+    withTempView(table) {
+      Seq.empty[Int].toDF("col").createOrReplaceTempView(table)
+      checkAnswer(
+        spark.sql(s"SELECT sum(col), percentile_approx(col, 0.5) FROM $table GROUP BY col"),
+        Seq.empty[Row]
+      )
+    }
+  }
+
+  test("percentile_approx(null), aggregation with group by") {
+    withTempView(table) {
+      (1 to 1000).map(x => (x % 3, x)).toDF("key", "value").createOrReplaceTempView(table)
+      checkAnswer(
+        spark.sql(
+          s"""SELECT
+             |  key,
+             |  percentile_approx(null, 0.5)
+             |FROM $table
+             |GROUP BY key
+           """.stripMargin),
+        Seq(
+          Row(0, null),
+          Row(1, null),
+          Row(2, null))
+      )
+    }
+  }
+
+  test("percentile_approx(null), aggregation without group by") {
+    withTempView(table) {
+      (1 to 1000).map(x => (x % 3, x)).toDF("key", "value").createOrReplaceTempView(table)
+      checkAnswer(
+        spark.sql(
+          s"""SELECT
+              |  percentile_approx(null, 0.5),
+              |  sum(null),
+              |  percentile_approx(null, 0.5)
+              |FROM $table
+           """.stripMargin),
+         Row(null, null, null)
+      )
+    }
+  }
+
+  test("percentile_approx(col, ...), input rows contains null, with out group by") {
+    withTempView(table) {
+      (1 to 1000).map(new Integer(_)).flatMap(Seq(null: Integer, _)).toDF("col")
+        .createOrReplaceTempView(table)
+      checkAnswer(
+        spark.sql(
+          s"""SELECT
+              |  percentile_approx(col, 0.5),
+              |  sum(null),
+              |  percentile_approx(col, 0.5)
+              |FROM $table
+           """.stripMargin),
+        Row(500D, null, 500D))
+    }
+  }
+
+  test("percentile_approx(col, ...), input rows contains null, with group by") {
+    withTempView(table) {
+      val rand = new java.util.Random()
+      (1 to 1000)
+        .map(new Integer(_))
+        .map(v => (new Integer(v % 2), v))
+        // Add some nulls
+        .flatMap(Seq(_, (null: Integer, null: Integer)))
+        .toDF("key", "value").createOrReplaceTempView(table)
+      checkAnswer(
+        spark.sql(
+          s"""SELECT
+              |  percentile_approx(value, 0.5),
+              |  sum(value),
+              |  percentile_approx(value, 0.5)
+              |FROM $table
+              |GROUP BY key
+           """.stripMargin),
+        Seq(
+          Row(499.0D, 250000, 499.0D),
+          Row(500.0D, 250500, 500.0D),
+          Row(null, null, null))
+      )
+    }
+  }
+
+  test("percentile_approx(col, ...) works in window function") {
+    withTempView(table) {
+      val data = (1 to 10).map(v => (v % 2, v))
+      data.toDF("key", "value").createOrReplaceTempView(table)
+
+      val query = spark.sql(
+        s"""
+           |SElECT percentile_approx(value, 0.5)
+           |OVER
+           |  (PARTITION BY key ORDER BY value ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
+           |    AS percentile
+           |FROM $table
+           """.stripMargin)
+
+      val expected = data.groupBy(_._1).toSeq.flatMap { group =>
+        val (key, values) = group
+        val sortedValues = values.map(_._2).sorted
+
+        var outputRows = Seq.empty[Row]
+        var i = 0
+
+        val percentile = new PercentileDigest(1.0 / DEFAULT_PERCENTILE_ACCURACY)
+        sortedValues.foreach { value =>
+          percentile.add(value)
+          outputRows :+= Row(percentile.getPercentiles(Array(0.5D)).head)
+        }
+        outputRows
+      }
+
+      checkAnswer(query, expected)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a18c169f/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index bfa5899..85c5098 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -227,7 +227,6 @@ private[sql] class HiveSessionCatalog(
   private val hiveFunctions = Seq(
     "hash",
     "histogram_numeric",
-    "percentile",
-    "percentile_approx"
+    "percentile"
   )
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a18c169f/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
index b4eb50e..fdd0282 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
@@ -155,6 +155,11 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
 
   test("aggregate functions") {
     checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key")
+    checkSqlGeneration("SELECT percentile_approx(value, 0.25) FROM t1 GROUP BY key")
+    checkSqlGeneration("SELECT percentile_approx(value, array(0.25, 0.75)) FROM t1 GROUP BY key")
+    checkSqlGeneration("SELECT percentile_approx(value, 0.25, 100) FROM t1 GROUP BY key")
+    checkSqlGeneration(
+      "SELECT percentile_approx(value, array(0.25, 0.75), 100) FROM t1 GROUP BY key")
     checkSqlGeneration("SELECT avg(value) FROM t1 GROUP BY key")
     checkSqlGeneration("SELECT corr(value, key) FROM t1 GROUP BY key")
     checkSqlGeneration("SELECT count(value) FROM t1 GROUP BY key")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org