You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/10/27 11:48:05 UTC

[spark] branch master updated: [SPARK-16280][SPARK-37082][SQL] Implements histogram_numeric aggregation function which supports partial aggregation

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new de0d7fb  [SPARK-16280][SPARK-37082][SQL] Implements histogram_numeric aggregation function which supports partial aggregation
de0d7fb is described below

commit de0d7fbb4f010bec8e457d0dc00b5618e7a43750
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Wed Oct 27 19:47:17 2021 +0800

    [SPARK-16280][SPARK-37082][SQL] Implements histogram_numeric aggregation function which supports partial aggregation
    
    ### What changes were proposed in this pull request?
    This PR implements aggregation function `histogram_numeric`. Function `histogram_numeric` returns an approximate histogram of a numerical column using a user-specified number of bins. For example, the histogram of column `col` when split to 3 bins.
    
    Syntax:
    #### an approximate histogram of a numerical column using a user-specified number of bins.
    histogram_numebric(col, nBins)
    
    ###### Returns an approximate histogram of a column `col` into 3 bins.
    SELECT histogram_numebric(col, 3) FROM table
    
    ##### Returns an approximate histogram of a column `col` into 5 bins.
    SELECT histogram_numebric(col, 5) FROM table
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    No change from user side
    
    ### How was this patch tested?
    Added UT
    
    Closes #34380 from AngersZhuuuu/SPARK-37082.
    
    Authored-by: Angerszhuuuu <an...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../apache/spark/sql/util/NumericHistogram.java    | 286 +++++++++++++++++++++
 .../sql/catalyst/analysis/FunctionRegistry.scala   |   1 +
 .../sql/catalyst/catalog/SessionCatalog.scala      |   2 +-
 .../expressions/aggregate/HistogramNumeric.scala   | 207 +++++++++++++++
 .../aggregate/HistogramNumericSuite.scala          | 166 ++++++++++++
 .../sql-functions/sql-expression-schema.md         |   3 +-
 .../test/resources/sql-tests/inputs/group-by.sql   |  12 +
 .../resources/sql-tests/results/group-by.sql.out   |  20 +-
 .../apache/spark/sql/hive/HiveSessionCatalog.scala |   4 +-
 9 files changed, 695 insertions(+), 6 deletions(-)

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java b/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java
new file mode 100644
index 0000000..987c18e
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Random;
+
+
+/**
+ * A generic, re-usable histogram class that supports partial aggregations.
+ * The algorithm is a heuristic adapted from the following paper:
+ * Yael Ben-Haim and Elad Tom-Tov, "A streaming parallel decision tree algorithm",
+ * J. Machine Learning Research 11 (2010), pp. 849--872. Although there are no approximation
+ * guarantees, it appears to work well with adequate data and a large (e.g., 20-80) number
+ * of histogram bins.
+ *
+ * Adapted from Hive's NumericHistogram. Can refer to
+ * https://github.com/apache/hive/blob/master/ql/src/
+ * java/org/apache/hadoop/hive/ql/udf/generic/NumericHistogram.java
+ *
+ * Differences:
+ *   1. Declaring [[Coord]] and it's variables as public types for
+ *      easy access in the HistogramNumeric class.
+ *   2. Add method [[getNumBins()]] for serialize [[NumericHistogram]]
+ *      in [[NumericHistogramSerializer]].
+ *   3. Add method [[addBin()]] for deserialize [[NumericHistogram]]
+ *      in [[NumericHistogramSerializer]].
+ *   4. In Hive's code, the method [[merge()] pass a serialized histogram,
+ *      in Spark, this method pass a deserialized histogram.
+ *      Here we change the code about merge bins.
+ */
+public class NumericHistogram {
+  /**
+   * The Coord class defines a histogram bin, which is just an (x,y) pair.
+   */
+  public static class Coord implements Comparable {
+    public double x;
+    public double y;
+
+    public int compareTo(Object other) {
+      return Double.compare(x, ((Coord) other).x);
+    }
+  }
+
+  // Class variables
+  private int nbins;
+  private int nusedbins;
+  private ArrayList<Coord> bins;
+  private Random prng;
+
+  /**
+   * Creates a new histogram object. Note that the allocate() or merge()
+   * method must be called before the histogram can be used.
+   */
+  public NumericHistogram() {
+    nbins = 0;
+    nusedbins = 0;
+    bins = null;
+
+    // init the RNG for breaking ties in histogram merging. A fixed seed is specified here
+    // to aid testing, but can be eliminated to use a time-based seed (which would
+    // make the algorithm non-deterministic).
+    prng = new Random(31183);
+  }
+
+  /**
+   * Resets a histogram object to its initial state. allocate() or merge() must be
+   * called again before use.
+   */
+  public void reset() {
+    bins = null;
+    nbins = nusedbins = 0;
+  }
+
+  /**
+   * Returns the number of bins.
+   */
+  public int getNumBins() {
+    return nbins;
+  }
+
+  /**
+   * Returns the number of bins currently being used by the histogram.
+   */
+  public int getUsedBins() {
+    return nusedbins;
+  }
+
+  /**
+   * Set the number of bins currently being used by the histogram.
+   */
+  public void setUsedBins(int nusedBins) {
+    this.nusedbins = nusedBins;
+  }
+
+  /**
+   * Returns true if this histogram object has been initialized by calling merge()
+   * or allocate().
+   */
+  public boolean isReady() {
+    return nbins != 0;
+  }
+
+  /**
+   * Returns a particular histogram bin.
+   */
+  public Coord getBin(int b) {
+    return bins.get(b);
+  }
+
+  /**
+   * Set a particular histogram bin with index.
+   */
+  public void addBin(double x, double y, int b) {
+    Coord coord = new Coord();
+    coord.x = x;
+    coord.y = y;
+    bins.add(b, coord);
+  }
+
+  /**
+   * Sets the number of histogram bins to use for approximating data.
+   *
+   * @param num_bins Number of non-uniform-width histogram bins to use
+   */
+  public void allocate(int num_bins) {
+    nbins = num_bins;
+    bins = new ArrayList<Coord>();
+    nusedbins = 0;
+  }
+
+  /**
+   * Takes a histogram and merges it with the current histogram object.
+   */
+  public void merge(NumericHistogram other) {
+    if (other == null) {
+      return;
+    }
+
+    if (nbins == 0 || nusedbins == 0) {
+      // Our aggregation buffer has nothing in it, so just copy over 'other'
+      // by deserializing the ArrayList of (x,y) pairs into an array of Coord objects
+      nbins = other.nbins;
+      nusedbins = other.nusedbins;
+      bins = new ArrayList<Coord>(nusedbins);
+      for (int i = 0; i < other.nusedbins; i += 1) {
+        Coord bin = new Coord();
+        bin.x = other.getBin(i).x;
+        bin.y = other.getBin(i).y;
+        bins.add(bin);
+      }
+    } else {
+      // The aggregation buffer already contains a partial histogram. Therefore, we need
+      // to merge histograms using Algorithm #2 from the Ben-Haim and Tom-Tov paper.
+
+      ArrayList<Coord> tmp_bins = new ArrayList<Coord>(nusedbins + other.nusedbins);
+      // Copy all the histogram bins from us and 'other' into an overstuffed histogram
+      for (int i = 0; i < nusedbins; i++) {
+        Coord bin = new Coord();
+        bin.x = bins.get(i).x;
+        bin.y = bins.get(i).y;
+        tmp_bins.add(bin);
+      }
+      for (int j = 0; j < other.nusedbins; j += 1) {
+        Coord bin = new Coord();
+        bin.x = other.getBin(j).x;
+        bin.y = other.getBin(j).y;
+        tmp_bins.add(bin);
+      }
+      Collections.sort(tmp_bins);
+
+      // Now trim the overstuffed histogram down to the correct number of bins
+      bins = tmp_bins;
+      nusedbins += other.nusedbins;
+      trim();
+    }
+  }
+
+
+  /**
+   * Adds a new data point to the histogram approximation. Make sure you have
+   * called either allocate() or merge() first. This method implements Algorithm #1
+   * from Ben-Haim and Tom-Tov, "A Streaming Parallel Decision Tree Algorithm", JMLR 2010.
+   *
+   * @param v The data point to add to the histogram approximation.
+   */
+  public void add(double v) {
+    // Binary search to find the closest bucket that v should go into.
+    // 'bin' should be interpreted as the bin to shift right in order to accomodate
+    // v. As a result, bin is in the range [0,N], where N means that the value v is
+    // greater than all the N bins currently in the histogram. It is also possible that
+    // a bucket centered at 'v' already exists, so this must be checked in the next step.
+    int bin = 0;
+    for (int l = 0, r = nusedbins; l < r; ) {
+      bin = (l + r) / 2;
+      if (bins.get(bin).x > v) {
+        r = bin;
+      } else {
+        if (bins.get(bin).x < v) {
+          l = ++bin;
+        } else {
+          break; // break loop on equal comparator
+        }
+      }
+    }
+
+    // If we found an exact bin match for value v, then just increment that bin's count.
+    // Otherwise, we need to insert a new bin and trim the resulting histogram back to size.
+    // A possible optimization here might be to set some threshold under which 'v' is just
+    // assumed to be equal to the closest bin -- if fabs(v-bins[bin].x) < THRESHOLD, then
+    // just increment 'bin'. This is not done now because we don't want to make any
+    // assumptions about the range of numeric data being analyzed.
+    if (bin < nusedbins && bins.get(bin).x == v) {
+      bins.get(bin).y++;
+    } else {
+      Coord newBin = new Coord();
+      newBin.x = v;
+      newBin.y = 1;
+      bins.add(bin, newBin);
+
+      // Trim the bins down to the correct number of bins.
+      if (++nusedbins > nbins) {
+        trim();
+      }
+    }
+
+  }
+
+  /**
+   * Trims a histogram down to 'nbins' bins by iteratively merging the closest bins.
+   * If two pairs of bins are equally close to each other, decide uniformly at random which
+   * pair to merge, based on a PRNG.
+   */
+  private void trim() {
+    while (nusedbins > nbins) {
+      // Find the closest pair of bins in terms of x coordinates. Break ties randomly.
+      double smallestdiff = bins.get(1).x - bins.get(0).x;
+      int smallestdiffloc = 0, smallestdiffcount = 1;
+      for (int i = 1; i < nusedbins - 1; i++) {
+        double diff = bins.get(i + 1).x - bins.get(i).x;
+        if (diff < smallestdiff) {
+          smallestdiff = diff;
+          smallestdiffloc = i;
+          smallestdiffcount = 1;
+        } else {
+          if (diff == smallestdiff && prng.nextDouble() <= (1.0 / ++smallestdiffcount)) {
+            smallestdiffloc = i;
+          }
+        }
+      }
+
+      // Merge the two closest bins into their average x location, weighted by their heights.
+      // The height of the new bin is the sum of the heights of the old bins.
+      // double d = bins[smallestdiffloc].y + bins[smallestdiffloc+1].y;
+      // bins[smallestdiffloc].x *= bins[smallestdiffloc].y / d;
+      // bins[smallestdiffloc].x += bins[smallestdiffloc+1].x / d *
+      //   bins[smallestdiffloc+1].y;
+      // bins[smallestdiffloc].y = d;
+
+      double d = bins.get(smallestdiffloc).y + bins.get(smallestdiffloc + 1).y;
+      Coord smallestdiffbin = bins.get(smallestdiffloc);
+      smallestdiffbin.x *= smallestdiffbin.y / d;
+      smallestdiffbin.x += bins.get(smallestdiffloc + 1).x / d * bins.get(smallestdiffloc + 1).y;
+      smallestdiffbin.y = d;
+      // Shift the remaining bins left one position
+      bins.remove(smallestdiffloc + 1);
+      nusedbins--;
+    }
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f53c829..4d316ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -433,6 +433,7 @@ object FunctionRegistry {
     expression[Skewness]("skewness"),
     expression[ApproximatePercentile]("percentile_approx"),
     expression[ApproximatePercentile]("approx_percentile", true),
+    expression[HistogramNumeric]("histogram_numeric"),
     expression[StddevSamp]("std", true),
     expression[StddevSamp]("stddev", true),
     expression[StddevPop]("stddev_pop"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index c3cc78e..141de75 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -1508,7 +1508,7 @@ class SessionCatalog(
    */
   def isTemporaryFunction(name: FunctionIdentifier): Boolean = {
     // copied from HiveSessionCatalog
-    val hiveFunctions = Seq("histogram_numeric")
+    val hiveFunctions = Seq()
 
     // A temporary function is a function that has been registered in functionRegistry
     // without a database name, and is neither a built-in function nor a Hive function
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala
new file mode 100644
index 0000000..09408e6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.nio.ByteBuffer
+
+import com.google.common.primitives.{Doubles, Ints}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes}
+import org.apache.spark.sql.catalyst.trees.BinaryLike
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, DateType, DayTimeIntervalType, DoubleType, IntegerType, NumericType, StructField, StructType, TimestampNTZType, TimestampType, TypeCollection, YearMonthIntervalType}
+import org.apache.spark.sql.util.NumericHistogram
+
+/**
+ * Computes an approximate histogram of a numerical column using a user-specified number of bins.
+ *
+ * The output is an array of (x,y) pairs as struct objects that represents the histogram's
+ * bin centers and heights.
+ */
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr, nb) - Computes a histogram on numeric 'expr' using nb bins.
+      The return value is an array of (x,y) pairs representing the centers of the
+      histogram's bins. As the value of 'nb' is increased, the histogram approximation
+      gets finer-grained, but may yield artifacts around outliers. In practice, 20-40
+      histogram bins appear to work well, with more bins being required for skewed or
+      smaller datasets. Note that this function creates a histogram with non-uniform
+      bin widths. It offers no guarantees in terms of the mean-squared-error of the
+      histogram, but in practice is comparable to the histograms produced by the R/S-Plus
+      statistical computing packages.
+    """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col);
+       [{"x":0.0,"y":1.0},{"x":1.0,"y":1.0},{"x":2.0,"y":1.0},{"x":10.0,"y":1.0}]
+  """,
+  group = "agg_funcs",
+  since = "3.3.0")
+case class HistogramNumeric(
+    child: Expression,
+    nBins: Expression,
+    override val mutableAggBufferOffset: Int,
+    override val inputAggBufferOffset: Int)
+  extends TypedImperativeAggregate[NumericHistogram] with ImplicitCastInputTypes
+  with BinaryLike[Expression] {
+
+  def this(child: Expression, nBins: Expression) = {
+    this(child, nBins, 0, 0)
+  }
+
+  private lazy val nb = nBins.eval() match {
+    case null => null
+    case n: Int => n
+  }
+
+  override def inputTypes: Seq[AbstractDataType] = {
+    // Support NumericType, DateType, TimestampType and TimestampNTZType, YearMonthIntervalType,
+    // DayTimeIntervalType since their internal types are all numeric,
+    // and can be easily cast to double for processing.
+    Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType,
+      YearMonthIntervalType, DayTimeIntervalType), IntegerType)
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      defaultCheck
+    } else if (!nBins.foldable) {
+      TypeCheckFailure(s"${this.prettyName} needs the nBins provided must be a constant literal.")
+    } else if (nb == null) {
+      TypeCheckFailure(s"${this.prettyName} needs nBins value must not be null.")
+    } else if (nb.asInstanceOf[Int] < 2) {
+      TypeCheckFailure(s"${this.prettyName} needs nBins to be at least 2, but you supplied $nb.")
+    } else {
+      TypeCheckSuccess
+    }
+  }
+
+  override def createAggregationBuffer(): NumericHistogram = {
+    val buffer = new NumericHistogram()
+    buffer.allocate(nb.asInstanceOf[Int])
+    buffer
+  }
+
+  override def update(buffer: NumericHistogram, inputRow: InternalRow): NumericHistogram = {
+    val value = child.eval(inputRow)
+    // Ignore empty rows, for example: histogram_numeric(null)
+    if (value != null) {
+      // Convert the value to a double value
+      val doubleValue = value.asInstanceOf[Number].doubleValue
+      buffer.add(doubleValue)
+    }
+    buffer
+  }
+
+  override def merge(
+      buffer: NumericHistogram,
+      other: NumericHistogram): NumericHistogram = {
+    buffer.merge(other)
+    buffer
+  }
+
+  override def eval(buffer: NumericHistogram): Any = {
+    if (buffer.getUsedBins < 1) {
+      null
+    } else {
+      val result = (0 until buffer.getUsedBins).map { index =>
+        val coord = buffer.getBin(index)
+        InternalRow.apply(coord.x, coord.y)
+      }
+      new GenericArrayData(result)
+    }
+  }
+
+  override def serialize(obj: NumericHistogram): Array[Byte] = {
+    NumericHistogramSerializer.serialize(obj)
+  }
+
+  override def deserialize(bytes: Array[Byte]): NumericHistogram = {
+    NumericHistogramSerializer.deserialize(bytes)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = nBins
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression,
+      newRight: Expression): HistogramNumeric = {
+    copy(child = newLeft, nBins = newRight)
+  }
+
+  override def withNewMutableAggBufferOffset(newOffset: Int): HistogramNumeric =
+    copy(mutableAggBufferOffset = newOffset)
+
+  override def withNewInputAggBufferOffset(newOffset: Int): HistogramNumeric =
+    copy(inputAggBufferOffset = newOffset)
+
+  override def nullable: Boolean = true
+
+  override def dataType: DataType =
+    ArrayType(new StructType(Array(
+      StructField("x", DoubleType, true),
+      StructField("y", DoubleType, true))), true)
+
+  override def prettyName: String = "histogram_numeric"
+}
+
+object NumericHistogramSerializer {
+    private final def length(histogram: NumericHistogram): Int = {
+      // histogram.nBins, histogram.nUsedBins
+      Ints.BYTES + Ints.BYTES +
+        //  histogram.bins, Array[Coord(x: Double, y: Double)]
+        histogram.getUsedBins * (Doubles.BYTES + Doubles.BYTES)
+    }
+
+    def serialize(histogram: NumericHistogram): Array[Byte] = {
+      val buffer = ByteBuffer.wrap(new Array(length(histogram)))
+      buffer.putInt(histogram.getNumBins)
+      buffer.putInt(histogram.getUsedBins)
+
+      var i = 0
+      while (i < histogram.getUsedBins) {
+        val coord = histogram.getBin(i)
+        buffer.putDouble(coord.x)
+        buffer.putDouble(coord.y)
+        i += 1
+      }
+      buffer.array()
+    }
+
+    def deserialize(bytes: Array[Byte]): NumericHistogram = {
+      val buffer = ByteBuffer.wrap(bytes)
+      val nBins = buffer.getInt()
+      val nUsedBins = buffer.getInt()
+      val histogram = new NumericHistogram()
+      histogram.allocate(nBins)
+      histogram.setUsedBins(nUsedBins)
+      var i: Int = 0
+      while (i < nUsedBins) {
+        val x = buffer.getDouble()
+        val y = buffer.getDouble()
+        histogram.addBin(x, y, i)
+        i += 1
+      }
+      histogram
+    }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala
new file mode 100644
index 0000000..60b53c6
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.dsl.expressions.{DslString, DslSymbol}
+import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, GenericInternalRow, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.types.{DoubleType, IntegerType}
+import org.apache.spark.sql.util.NumericHistogram
+
+class HistogramNumericSuite extends SparkFunSuite {
+
+  private val random = new java.util.Random()
+
+  private val data = (0 until 10000).map { _ =>
+    random.nextInt(10000)
+  }
+
+  test("serialize and de-serialize") {
+
+    // Check empty serialize and de-serialize
+    val emptyBuffer = new NumericHistogram()
+    emptyBuffer.allocate(5)
+    assert(compareEquals(emptyBuffer,
+      NumericHistogramSerializer.deserialize(NumericHistogramSerializer.serialize(emptyBuffer))))
+
+    val buffer = new NumericHistogram()
+    buffer.allocate(data.size / 3)
+    data.foreach { value =>
+      buffer.add(value)
+    }
+    assert(compareEquals(buffer,
+      NumericHistogramSerializer.deserialize(NumericHistogramSerializer.serialize(buffer))))
+
+    val agg = new HistogramNumeric(BoundReference(0, DoubleType, true), Literal(5))
+    assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
+  }
+
+  test("class NumericHistogram, basic operations") {
+    val valueCount = 5
+    Seq(3, 5).foreach { nBins: Int =>
+      val buffer = new NumericHistogram()
+      buffer.allocate(nBins)
+      (1 to valueCount).grouped(nBins).foreach { group =>
+        val partialBuffer = new NumericHistogram()
+        partialBuffer.allocate(nBins)
+        group.foreach(x => partialBuffer.add(x))
+        buffer.merge(partialBuffer)
+      }
+      val sum = (0 until buffer.getUsedBins).map { i =>
+        val coord = buffer.getBin(i)
+        coord.x * coord.y
+      }.sum
+      assert(sum <= (1 to valueCount).sum)
+    }
+  }
+
+  test("class HistogramNumeric, sql string") {
+    val defaultAccuracy = ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
+    assertEqual(s"histogram_numeric(a, 3)",
+      new HistogramNumeric("a".attr, Literal(3)).sql: String)
+
+    // sql(isDistinct = true), array of percentile
+    assertEqual(s"histogram_numeric(DISTINCT a, 3)",
+      new HistogramNumeric("a".attr, Literal(3)).sql(isDistinct = true))
+  }
+
+  test("class HistogramNumeric, fails analysis if nBins is not a constant") {
+    val attribute = AttributeReference("a", IntegerType)()
+    val wrongNB = new HistogramNumeric(attribute, nBins = AttributeReference("b", IntegerType)())
+
+    assertEqual(
+      wrongNB.checkInputDataTypes(),
+      TypeCheckFailure("histogram_numeric needs the nBins provided must be a constant literal.")
+    )
+  }
+
+  test("class HistogramNumeric, fails analysis if nBins is invalid") {
+    val attribute = AttributeReference("a", IntegerType)()
+    val wrongNB = new HistogramNumeric(attribute, nBins = Literal(1))
+
+    assertEqual(
+      wrongNB.checkInputDataTypes(),
+      TypeCheckFailure("histogram_numeric needs nBins to be at least 2, but you supplied 1.")
+    )
+  }
+
+  test("class HistogramNumeric, automatically add type casting for parameters") {
+    val testRelation = LocalRelation('a.int)
+
+    // accuracy types must be integral, no type casting
+    val nBinsExpressions = Seq(
+      Literal(2.toByte),
+      Literal(100.toShort),
+      Literal(100),
+      Literal(1000L))
+
+    nBinsExpressions.foreach { nBins =>
+      val agg = new HistogramNumeric(UnresolvedAttribute("a"), nBins)
+      val analyzed = testRelation.select(agg).analyze.expressions.head
+      analyzed match {
+        case Alias(agg: HistogramNumeric, _) =>
+          assert(agg.resolved)
+          assert(agg.child.dataType == IntegerType)
+          assert(agg.nBins.dataType == IntegerType)
+        case _ => fail()
+      }
+    }
+  }
+
+  test("HistogramNumeric: nulls in nBins expression") {
+    assert(new HistogramNumeric(
+      AttributeReference("a", DoubleType)(),
+      Literal(null, IntegerType)).checkInputDataTypes() ===
+      TypeCheckFailure("histogram_numeric needs nBins value must not be null."))
+  }
+
+  test("class HistogramNumeric, null handling") {
+    val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+    val agg = new HistogramNumeric(childExpression, Literal(5))
+    val buffer = new GenericInternalRow(new Array[Any](1))
+    agg.initialize(buffer)
+    // Empty aggregation buffer
+    assert(agg.eval(buffer) == null)
+    // Empty input row
+    agg.update(buffer, InternalRow(null))
+    assert(agg.eval(buffer) == null)
+
+    // Add some non-empty row
+    agg.update(buffer, InternalRow(0))
+    assert(agg.eval(buffer) != null)
+  }
+
+  private def compareEquals(left: NumericHistogram, right: NumericHistogram): Boolean = {
+    left.getNumBins == right.getNumBins && left.getUsedBins == right.getUsedBins &&
+      (0 until left.getUsedBins).forall { i =>
+        val leftCoord = left.getBin(i)
+        val rightCoord = right.getBin(i)
+        leftCoord.x == rightCoord.x && leftCoord.y == rightCoord.y
+      }
+  }
+
+  private def assertEqual[T](left: T, right: T): Unit = {
+    assert(left == right)
+  }
+}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 958b961..9192ac4 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -1,6 +1,6 @@
 <!-- Automatically generated by ExpressionsSchemaSuite -->
 ## Summary
-  - Number of queries: 366
+  - Number of queries: 367
   - Number of expressions that missing example: 12
   - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
 ## Schema of Built-in Functions
@@ -345,6 +345,7 @@
 | org.apache.spark.sql.catalyst.expressions.aggregate.CovSample | covar_samp | SELECT covar_samp(c1, c2) FROM VALUES (1,1), (2,2), (3,3) AS tab(c1, c2) | struct<covar_samp(c1, c2):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.First | first | SELECT first(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first(col):int> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.First | first_value | SELECT first_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first_value(col):int> |
+| org.apache.spark.sql.catalyst.expressions.aggregate.HistogramNumeric | histogram_numeric | SELECT histogram_numeric(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct<histogram_numeric(col, 5):array<struct<x:double,y:double>>> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus | approx_count_distinct | SELECT approx_count_distinct(col1) FROM VALUES (1), (1), (2), (2), (3) tab(col1) | struct<approx_count_distinct(col1):bigint> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.Kurtosis | kurtosis | SELECT kurtosis(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct<kurtosis(col):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<last(col):int> |
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index 039373b..4e6d2d2 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -192,3 +192,15 @@ SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
 FROM testData
 GROUP BY a IS NULL;
 
+
+SELECT
+  histogram_numeric(col, 2) as histogram_2,
+  histogram_numeric(col, 3) as histogram_3,
+  histogram_numeric(col, 5) as histogram_5,
+  histogram_numeric(col, 10) as histogram_10
+FROM VALUES
+ (1), (2), (3), (4), (5), (6), (7), (8), (9), (10),
+ (11), (12), (13), (14), (15), (16), (17), (18), (19), (20),
+ (21), (22), (23), (24), (25), (26), (27), (28), (29), (30),
+ (31), (32), (33), (34), (35), (3), (37), (38), (39), (40),
+ (41), (42), (43), (44), (45), (46), (47), (48), (49), (50) AS tab(col);
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index f598f49..5cd5a37 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 65
+-- Number of queries: 66
 
 
 -- !query
@@ -673,3 +673,21 @@ struct<(IF((NOT (a IS NULL)), rand(0), 1)):double,c:bigint>
 -- !query output
 0.7604953758285915	7
 1.0	2
+
+
+-- !query
+SELECT
+  histogram_numeric(col, 2) as histogram_2,
+  histogram_numeric(col, 3) as histogram_3,
+  histogram_numeric(col, 5) as histogram_5,
+  histogram_numeric(col, 10) as histogram_10
+FROM VALUES
+ (1), (2), (3), (4), (5), (6), (7), (8), (9), (10),
+ (11), (12), (13), (14), (15), (16), (17), (18), (19), (20),
+ (21), (22), (23), (24), (25), (26), (27), (28), (29), (30),
+ (31), (32), (33), (34), (35), (3), (37), (38), (39), (40),
+ (41), (42), (43), (44), (45), (46), (47), (48), (49), (50) AS tab(col)
+-- !query schema
+struct<histogram_2:array<struct<x:double,y:double>>,histogram_3:array<struct<x:double,y:double>>,histogram_5:array<struct<x:double,y:double>>,histogram_10:array<struct<x:double,y:double>>>
+-- !query output
+[{"x":12.615384615384613,"y":26.0},{"x":38.083333333333336,"y":24.0}]	[{"x":9.649999999999999,"y":20.0},{"x":25.0,"y":11.0},{"x":40.736842105263165,"y":19.0}]	[{"x":5.272727272727273,"y":11.0},{"x":14.5,"y":8.0},{"x":22.0,"y":7.0},{"x":30.499999999999996,"y":10.0},{"x":43.5,"y":14.0}]	[{"x":3.0,"y":6.0},{"x":8.5,"y":6.0},{"x":13.5,"y":4.0},{"x":17.0,"y":3.0},{"x":20.5,"y":4.0},{"x":25.5,"y":6.0},{"x":31.999999999999996,"y":7.0},{"x":39.0,"y":5.0},{"x":43.5,"y":4.0},{"x":48.0,"y":5.0}]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 488890a..b11774b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -128,7 +128,5 @@ private[sql] class HiveSessionCatalog(
   // in_file, index, matchpath, ngrams, noop, noopstreaming, noopwithmap,
   // noopwithmapstreaming, parse_url_tuple, reflect2, windowingtablefunction.
   // Note: don't forget to update SessionCatalog.isTemporaryFunction
-  private val hiveFunctions = Seq(
-    "histogram_numeric"
-  )
+  private val hiveFunctions = Seq()
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org