You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2016/11/28 19:06:06 UTC

spark git commit: [SPARK-16282][SQL] Implement percentile SQL function.

Repository: spark
Updated Branches:
  refs/heads/master 185642846 -> 0f5f52a3d


[SPARK-16282][SQL] Implement percentile SQL function.

## What changes were proposed in this pull request?

Implement percentile SQL function. It computes the exact percentile(s) of expr at pc with range in [0, 1].

## How was this patch tested?

Add a new testsuite `PercentileSuite` to test percentile directly.
Updated related testcases in `ExpressionToSQLSuite`.

Author: jiangxingbo <ji...@gmail.com>
Author: \u848b\u661f\u535a <ji...@meituan.com>
Author: jiangxingbo <ji...@meituan.com>

Closes #14136 from jiangxb1987/percentile.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0f5f52a3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0f5f52a3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0f5f52a3

Branch: refs/heads/master
Commit: 0f5f52a3d1e5dcf5b970c49e324e322b9deb00f3
Parents: 1856428
Author: jiangxingbo <ji...@gmail.com>
Authored: Mon Nov 28 11:05:58 2016 -0800
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Mon Nov 28 11:05:58 2016 -0800

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/aggregate/Percentile.scala      | 269 +++++++++++++++++++
 .../expressions/aggregate/PercentileSuite.scala | 245 +++++++++++++++++
 .../spark/sql/hive/HiveSessionCatalog.scala     |   3 +-
 .../sql/catalyst/ExpressionToSQLSuite.scala     |   2 +
 5 files changed, 518 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0f5f52a3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 007cdc1..2636afe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -249,6 +249,7 @@ object FunctionRegistry {
     expression[Max]("max"),
     expression[Average]("mean"),
     expression[Min]("min"),
+    expression[Percentile]("percentile"),
     expression[Skewness]("skewness"),
     expression[ApproximatePercentile]("percentile_approx"),
     expression[StddevSamp]("std"),

http://git-wip-us.apache.org/repos/asf/spark/blob/0f5f52a3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
new file mode 100644
index 0000000..356e088
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.util
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.OpenHashMap
+
+/**
+ * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at
+ * the given percentage(s) with value range in [0.0, 1.0].
+ *
+ * The operator is bound to the slower sort based aggregation path because the number of elements
+ * and their partial order cannot be determined in advance. Therefore we have to store all the
+ * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory
+ * Errors.
+ *
+ * @param child child expression that produce numeric column value with `child.eval(inputRow)`
+ * @param percentageExpression Expression that represents a single percentage value or an array of
+ *                             percentage values. Each percentage value must be in the range
+ *                             [0.0, 1.0].
+ */
+@ExpressionDescription(
+  usage =
+    """
+      _FUNC_(col, percentage) - Returns the exact percentile value of numeric column `col` at the
+      given percentage. The value of percentage must be between 0.0 and 1.0.
+
+      _FUNC_(col, array(percentage1 [, percentage2]...)) - Returns the exact percentile value array
+      of numeric column `col` at the given percentage(s). Each value of the percentage array must
+      be between 0.0 and 1.0.
+    """)
+case class Percentile(
+  child: Expression,
+  percentageExpression: Expression,
+  mutableAggBufferOffset: Int = 0,
+  inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] {
+
+  def this(child: Expression, percentageExpression: Expression) = {
+    this(child, percentageExpression, 0, 0)
+  }
+
+  override def prettyName: String = "percentile"
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Percentile =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Percentile =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  // Mark as lazy so that percentageExpression is not evaluated during tree transformation.
+  @transient
+  private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType]
+
+  @transient
+  private lazy val percentages =
+    (percentageExpression.dataType, percentageExpression.eval()) match {
+      case (_, num: Double) => Seq(num)
+      case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
+        val numericArray = arrayData.toObjectArray(baseType)
+        numericArray.map { x =>
+          baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])}.toSeq
+      case other =>
+        throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentages")
+  }
+
+  override def children: Seq[Expression] = child :: percentageExpression :: Nil
+
+  // Returns null for empty inputs
+  override def nullable: Boolean = true
+
+  override lazy val dataType: DataType = percentageExpression.dataType match {
+    case _: ArrayType => ArrayType(DoubleType, false)
+    case _ => DoubleType
+  }
+
+  override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match {
+    case _: ArrayType => Seq(NumericType, ArrayType)
+    case _ => Seq(NumericType, DoubleType)
+  }
+
+  // Check the inputTypes are valid, and the percentageExpression satisfies:
+  // 1. percentageExpression must be foldable;
+  // 2. percentages(s) must be in the range [0.0, 1.0].
+  override def checkInputDataTypes(): TypeCheckResult = {
+    // Validate the inputTypes
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      defaultCheck
+    } else if (!percentageExpression.foldable) {
+      // percentageExpression must be foldable
+      TypeCheckFailure("The percentage(s) must be a constant literal, " +
+        s"but got $percentageExpression")
+    } else if (percentages.exists(percentage => percentage < 0.0 || percentage > 1.0)) {
+      // percentages(s) must be in the range [0.0, 1.0]
+      TypeCheckFailure("Percentage(s) must be between 0.0 and 1.0, " +
+        s"but got $percentageExpression")
+    } else {
+      TypeCheckSuccess
+    }
+  }
+
+  override def createAggregationBuffer(): OpenHashMap[Number, Long] = {
+    // Initialize new counts map instance here.
+    new OpenHashMap[Number, Long]()
+  }
+
+  override def update(buffer: OpenHashMap[Number, Long], input: InternalRow): Unit = {
+    val key = child.eval(input).asInstanceOf[Number]
+
+    // Null values are ignored in counts map.
+    if (key != null) {
+      buffer.changeValue(key, 1L, _ + 1L)
+    }
+  }
+
+  override def merge(buffer: OpenHashMap[Number, Long], other: OpenHashMap[Number, Long]): Unit = {
+    other.foreach { case (key, count) =>
+      buffer.changeValue(key, count, _ + count)
+    }
+  }
+
+  override def eval(buffer: OpenHashMap[Number, Long]): Any = {
+    generateOutput(getPercentiles(buffer))
+  }
+
+  private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = {
+    if (buffer.isEmpty) {
+      return Seq.empty
+    }
+
+    val sortedCounts = buffer.toSeq.sortBy(_._1)(
+      child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]])
+    val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {
+      case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
+    }.tail
+    val maxPosition = accumlatedCounts.last._2 - 1
+
+    percentages.map { percentile =>
+      getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue()
+    }
+  }
+
+  private def generateOutput(results: Seq[Double]): Any = {
+    if (results.isEmpty) {
+      null
+    } else if (returnPercentileArray) {
+      new GenericArrayData(results)
+    } else {
+      results.head
+    }
+  }
+
+  /**
+   * Get the percentile value.
+   *
+   * This function has been based upon similar function from HIVE
+   * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
+   */
+  private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = {
+    // We may need to do linear interpolation to get the exact percentile
+    val lower = position.floor.toLong
+    val higher = position.ceil.toLong
+
+    // Use binary search to find the lower and the higher position.
+    val countsArray = aggreCounts.map(_._2).toArray[Long]
+    val lowerIndex = binarySearchCount(countsArray, 0, aggreCounts.size, lower + 1)
+    val higherIndex = binarySearchCount(countsArray, 0, aggreCounts.size, higher + 1)
+
+    val lowerKey = aggreCounts(lowerIndex)._1
+    if (higher == lower) {
+      // no interpolation needed because position does not have a fraction
+      return lowerKey
+    }
+
+    val higherKey = aggreCounts(higherIndex)._1
+    if (higherKey == lowerKey) {
+      // no interpolation needed because lower position and higher position has the same key
+      return lowerKey
+    }
+
+    // Linear interpolation to get the exact percentile
+    return (higher - position) * lowerKey.doubleValue() +
+      (position - lower) * higherKey.doubleValue()
+  }
+
+  /**
+   * use a binary search to find the index of the position closest to the current value.
+   */
+  private def binarySearchCount(
+      countsArray: Array[Long], start: Int, end: Int, value: Long): Int = {
+    util.Arrays.binarySearch(countsArray, 0, end, value) match {
+      case ix if ix < 0 => -(ix + 1)
+      case ix => ix
+    }
+  }
+
+  override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = {
+    val buffer = new Array[Byte](4 << 10)  // 4K
+    val bos = new ByteArrayOutputStream()
+    val out = new DataOutputStream(bos)
+    try {
+      val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType))
+      // Write pairs in counts map to byte buffer.
+      obj.foreach { case (key, count) =>
+        val row = InternalRow.apply(key, count)
+        val unsafeRow = projection.apply(row)
+        out.writeInt(unsafeRow.getSizeInBytes)
+        unsafeRow.writeToStream(out, buffer)
+      }
+      out.writeInt(-1)
+      out.flush()
+
+      bos.toByteArray
+    } finally {
+      out.close()
+      bos.close()
+    }
+  }
+
+  override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = {
+    val bis = new ByteArrayInputStream(bytes)
+    val ins = new DataInputStream(bis)
+    try {
+      val counts = new OpenHashMap[Number, Long]
+      // Read unsafeRow size and content in bytes.
+      var sizeOfNextRow = ins.readInt()
+      while (sizeOfNextRow >= 0) {
+        val bs = new Array[Byte](sizeOfNextRow)
+        ins.readFully(bs)
+        val row = new UnsafeRow(2)
+        row.pointTo(bs, sizeOfNextRow)
+        // Insert the pairs into counts map.
+        val key = row.get(0, child.dataType).asInstanceOf[Number]
+        val count = row.get(1, LongType).asInstanceOf[Long]
+        counts.update(key, count)
+        sizeOfNextRow = ins.readInt()
+      }
+
+      counts
+    } finally {
+      ins.close()
+      bis.close()
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0f5f52a3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
new file mode 100644
index 0000000..f060ecc
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.OpenHashMap
+
+class PercentileSuite extends SparkFunSuite {
+
+  private val random = new java.util.Random()
+
+  private val data = (0 until 10000).map { _ =>
+    random.nextInt(10000)
+  }
+
+  test("serialize and de-serialize") {
+    val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5))
+
+    // Check empty serialize and deserialize
+    val buffer = new OpenHashMap[Number, Long]()
+    assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
+
+    // Check non-empty buffer serializa and deserialize.
+    data.foreach { key =>
+      buffer.changeValue(key, 1L, _ + 1L)
+    }
+    assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
+  }
+
+  test("class Percentile, high level interface, update, merge, eval...") {
+    val count = 10000
+    val data = (1 to count)
+    val percentages = Seq(0, 0.25, 0.5, 0.75, 1)
+    val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000)
+    val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType)
+    val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_)))
+    val agg = new Percentile(childExpression, percentageExpression)
+
+    assert(agg.nullable)
+    val group1 = (0 until data.length / 2)
+    val group1Buffer = agg.createAggregationBuffer()
+    group1.foreach { index =>
+      val input = InternalRow(data(index))
+      agg.update(group1Buffer, input)
+    }
+
+    val group2 = (data.length / 2 until data.length)
+    val group2Buffer = agg.createAggregationBuffer()
+    group2.foreach { index =>
+      val input = InternalRow(data(index))
+      agg.update(group2Buffer, input)
+    }
+
+    val mergeBuffer = agg.createAggregationBuffer()
+    agg.merge(mergeBuffer, group1Buffer)
+    agg.merge(mergeBuffer, group2Buffer)
+
+    agg.eval(mergeBuffer) match {
+      case arrayData: ArrayData =>
+        val percentiles = arrayData.toDoubleArray()
+        assert(percentiles.zip(expectedPercentiles)
+          .forall(pair => pair._1 == pair._2))
+    }
+  }
+
+  test("class Percentile, low level interface, update, merge, eval...") {
+    val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+    val inputAggregationBufferOffset = 1
+    val mutableAggregationBufferOffset = 2
+    val percentage = 0.5
+
+    // Phase one, partial mode aggregation
+    val agg = new Percentile(childExpression, Literal(percentage))
+      .withNewInputAggBufferOffset(inputAggregationBufferOffset)
+      .withNewMutableAggBufferOffset(mutableAggregationBufferOffset)
+
+    val mutableAggBuffer = new GenericInternalRow(
+      new Array[Any](mutableAggregationBufferOffset + 1))
+    agg.initialize(mutableAggBuffer)
+    val dataCount = 10
+    (1 to dataCount).foreach { data =>
+      agg.update(mutableAggBuffer, InternalRow(data))
+    }
+    agg.serializeAggregateBufferInPlace(mutableAggBuffer)
+
+    // Serialize the aggregation buffer
+    val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset)
+    val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized))
+
+    // Phase 2: final mode aggregation
+    // Re-initialize the aggregation buffer
+    agg.initialize(mutableAggBuffer)
+    agg.merge(mutableAggBuffer, inputAggBuffer)
+    val expectedPercentile = 5.5
+    assert(agg.eval(mutableAggBuffer).asInstanceOf[Double] == expectedPercentile)
+  }
+
+  test("call from sql query") {
+    // sql, single percentile
+    assertEqual(
+      s"percentile(`a`, 0.5D)",
+      new Percentile("a".attr, Literal(0.5)).sql: String)
+
+    // sql, array of percentile
+    assertEqual(
+      s"percentile(`a`, array(0.25D, 0.5D, 0.75D))",
+      new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_)))).sql: String)
+
+    // sql(isDistinct = false), single percentile
+    assertEqual(
+      s"percentile(`a`, 0.5D)",
+      new Percentile("a".attr, Literal(0.5)).sql(isDistinct = false))
+
+    // sql(isDistinct = false), array of percentile
+    assertEqual(
+      s"percentile(`a`, array(0.25D, 0.5D, 0.75D))",
+      new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_))))
+        .sql(isDistinct = false))
+
+    // sql(isDistinct = true), single percentile
+    assertEqual(
+      s"percentile(DISTINCT `a`, 0.5D)",
+      new Percentile("a".attr, Literal(0.5)).sql(isDistinct = true))
+
+    // sql(isDistinct = true), array of percentile
+    assertEqual(
+      s"percentile(DISTINCT `a`, array(0.25D, 0.5D, 0.75D))",
+      new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_))))
+        .sql(isDistinct = true))
+  }
+
+  test("fail analysis if childExpression is invalid") {
+    val validDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
+    val percentage = Literal(0.5)
+
+    validDataTypes.foreach { dataType =>
+      val child = AttributeReference("a", dataType)()
+      val percentile = new Percentile(child, percentage)
+      assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess)
+    }
+
+    val invalidDataTypes = Seq(BooleanType, StringType, DateType, TimestampType,
+      CalendarIntervalType, NullType)
+
+    invalidDataTypes.foreach { dataType =>
+      val child = AttributeReference("a", dataType)()
+      val percentile = new Percentile(child, percentage)
+      assertEqual(percentile.checkInputDataTypes(),
+        TypeCheckFailure(s"argument 1 requires numeric type, however, " +
+            s"'`a`' is of ${dataType.simpleString} type."))
+    }
+  }
+
+  test("fails analysis if percentage(s) are invalid") {
+    val child = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType)
+    val input = InternalRow(1)
+
+    val validPercentages = Seq(Literal(0D), Literal(0.5), Literal(1D),
+      CreateArray(Seq(0, 0.5, 1).map(Literal(_))))
+
+    validPercentages.foreach { percentage =>
+      val percentile1 = new Percentile(child, percentage)
+      assertEqual(percentile1.checkInputDataTypes(), TypeCheckSuccess)
+    }
+
+    val invalidPercentages = Seq(Literal(-0.5), Literal(1.5), Literal(2D),
+      CreateArray(Seq(-0.5, 0, 2).map(Literal(_))))
+
+    invalidPercentages.foreach { percentage =>
+      val percentile2 = new Percentile(child, percentage)
+      assertEqual(percentile2.checkInputDataTypes(),
+        TypeCheckFailure(s"Percentage(s) must be between 0.0 and 1.0, " +
+        s"but got ${percentage.simpleString}"))
+    }
+
+    val nonFoldablePercentage = Seq(NonFoldableLiteral(0.5),
+      CreateArray(Seq(0, 0.5, 1).map(NonFoldableLiteral(_))))
+
+    nonFoldablePercentage.foreach { percentage =>
+      val percentile3 = new Percentile(child, percentage)
+      assertEqual(percentile3.checkInputDataTypes(),
+        TypeCheckFailure(s"The percentage(s) must be a constant literal, " +
+          s"but got ${percentage}"))
+    }
+
+    val invalidDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType,
+      BooleanType, StringType, DateType, TimestampType, CalendarIntervalType, NullType)
+
+    invalidDataTypes.foreach { dataType =>
+      val percentage = Literal(0.5, dataType)
+      val percentile4 = new Percentile(child, percentage)
+      assertEqual(percentile4.checkInputDataTypes(),
+        TypeCheckFailure(s"argument 2 requires double type, however, " +
+          s"'0.5' is of ${dataType.simpleString} type."))
+    }
+  }
+
+  test("null handling") {
+    val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+    val agg = new Percentile(childExpression, Literal(0.5))
+    val buffer = new GenericInternalRow(new Array[Any](1))
+    agg.initialize(buffer)
+    // Empty aggregation buffer
+    assert(agg.eval(buffer) == null)
+    // Empty input row
+    agg.update(buffer, InternalRow(null))
+    assert(agg.eval(buffer) == null)
+
+    // Add some non-empty row
+    agg.update(buffer, InternalRow(0))
+    assert(agg.eval(buffer) != null)
+  }
+
+  private def compareEquals(
+      left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = {
+    left.size == right.size && left.forall { case (key, count) =>
+      right.apply(key) == count
+    }
+  }
+
+  private def assertEqual[T](left: T, right: T): Unit = {
+    assert(left == right)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0f5f52a3/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 4a9b28a..08bf1cd 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -234,7 +234,6 @@ private[sql] class HiveSessionCatalog(
   // noopwithmapstreaming, parse_url_tuple, reflect2, windowingtablefunction.
   // Note: don't forget to update SessionCatalog.isTemporaryFunction
   private val hiveFunctions = Seq(
-    "histogram_numeric",
-    "percentile"
+    "histogram_numeric"
   )
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0f5f52a3/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
index fdd0282..27ea167 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
@@ -173,6 +173,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
     checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key")
     checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key")
     checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key")
+    checkSqlGeneration("SELECT percentile(value, 0.25) FROM t1 GROUP BY key")
+    checkSqlGeneration("SELECT percentile(value, array(0.25, 0.75)) FROM t1 GROUP BY key")
     checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key")
     checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key")
     checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org