You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/10/27 11:48:05 UTC
[spark] branch master updated: [SPARK-16280][SPARK-37082][SQL]
Implements histogram_numeric aggregation function which supports partial
aggregation
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new de0d7fb [SPARK-16280][SPARK-37082][SQL] Implements histogram_numeric aggregation function which supports partial aggregation
de0d7fb is described below
commit de0d7fbb4f010bec8e457d0dc00b5618e7a43750
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Wed Oct 27 19:47:17 2021 +0800
[SPARK-16280][SPARK-37082][SQL] Implements histogram_numeric aggregation function which supports partial aggregation
### What changes were proposed in this pull request?
This PR implements aggregation function `histogram_numeric`. Function `histogram_numeric` returns an approximate histogram of a numerical column using a user-specified number of bins. For example, the histogram of column `col` when split to 3 bins.
Syntax:
#### an approximate histogram of a numerical column using a user-specified number of bins.
histogram_numebric(col, nBins)
###### Returns an approximate histogram of a column `col` into 3 bins.
SELECT histogram_numebric(col, 3) FROM table
##### Returns an approximate histogram of a column `col` into 5 bins.
SELECT histogram_numebric(col, 5) FROM table
### Why are the changes needed?
### Does this PR introduce _any_ user-facing change?
No change from user side
### How was this patch tested?
Added UT
Closes #34380 from AngersZhuuuu/SPARK-37082.
Authored-by: Angerszhuuuu <an...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../apache/spark/sql/util/NumericHistogram.java | 286 +++++++++++++++++++++
.../sql/catalyst/analysis/FunctionRegistry.scala | 1 +
.../sql/catalyst/catalog/SessionCatalog.scala | 2 +-
.../expressions/aggregate/HistogramNumeric.scala | 207 +++++++++++++++
.../aggregate/HistogramNumericSuite.scala | 166 ++++++++++++
.../sql-functions/sql-expression-schema.md | 3 +-
.../test/resources/sql-tests/inputs/group-by.sql | 12 +
.../resources/sql-tests/results/group-by.sql.out | 20 +-
.../apache/spark/sql/hive/HiveSessionCatalog.scala | 4 +-
9 files changed, 695 insertions(+), 6 deletions(-)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java b/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java
new file mode 100644
index 0000000..987c18e
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Random;
+
+
+/**
+ * A generic, re-usable histogram class that supports partial aggregations.
+ * The algorithm is a heuristic adapted from the following paper:
+ * Yael Ben-Haim and Elad Tom-Tov, "A streaming parallel decision tree algorithm",
+ * J. Machine Learning Research 11 (2010), pp. 849--872. Although there are no approximation
+ * guarantees, it appears to work well with adequate data and a large (e.g., 20-80) number
+ * of histogram bins.
+ *
+ * Adapted from Hive's NumericHistogram. Can refer to
+ * https://github.com/apache/hive/blob/master/ql/src/
+ * java/org/apache/hadoop/hive/ql/udf/generic/NumericHistogram.java
+ *
+ * Differences:
+ * 1. Declaring [[Coord]] and it's variables as public types for
+ * easy access in the HistogramNumeric class.
+ * 2. Add method [[getNumBins()]] for serialize [[NumericHistogram]]
+ * in [[NumericHistogramSerializer]].
+ * 3. Add method [[addBin()]] for deserialize [[NumericHistogram]]
+ * in [[NumericHistogramSerializer]].
+ * 4. In Hive's code, the method [[merge()] pass a serialized histogram,
+ * in Spark, this method pass a deserialized histogram.
+ * Here we change the code about merge bins.
+ */
+public class NumericHistogram {
+ /**
+ * The Coord class defines a histogram bin, which is just an (x,y) pair.
+ */
+ public static class Coord implements Comparable {
+ public double x;
+ public double y;
+
+ public int compareTo(Object other) {
+ return Double.compare(x, ((Coord) other).x);
+ }
+ }
+
+ // Class variables
+ private int nbins;
+ private int nusedbins;
+ private ArrayList<Coord> bins;
+ private Random prng;
+
+ /**
+ * Creates a new histogram object. Note that the allocate() or merge()
+ * method must be called before the histogram can be used.
+ */
+ public NumericHistogram() {
+ nbins = 0;
+ nusedbins = 0;
+ bins = null;
+
+ // init the RNG for breaking ties in histogram merging. A fixed seed is specified here
+ // to aid testing, but can be eliminated to use a time-based seed (which would
+ // make the algorithm non-deterministic).
+ prng = new Random(31183);
+ }
+
+ /**
+ * Resets a histogram object to its initial state. allocate() or merge() must be
+ * called again before use.
+ */
+ public void reset() {
+ bins = null;
+ nbins = nusedbins = 0;
+ }
+
+ /**
+ * Returns the number of bins.
+ */
+ public int getNumBins() {
+ return nbins;
+ }
+
+ /**
+ * Returns the number of bins currently being used by the histogram.
+ */
+ public int getUsedBins() {
+ return nusedbins;
+ }
+
+ /**
+ * Set the number of bins currently being used by the histogram.
+ */
+ public void setUsedBins(int nusedBins) {
+ this.nusedbins = nusedBins;
+ }
+
+ /**
+ * Returns true if this histogram object has been initialized by calling merge()
+ * or allocate().
+ */
+ public boolean isReady() {
+ return nbins != 0;
+ }
+
+ /**
+ * Returns a particular histogram bin.
+ */
+ public Coord getBin(int b) {
+ return bins.get(b);
+ }
+
+ /**
+ * Set a particular histogram bin with index.
+ */
+ public void addBin(double x, double y, int b) {
+ Coord coord = new Coord();
+ coord.x = x;
+ coord.y = y;
+ bins.add(b, coord);
+ }
+
+ /**
+ * Sets the number of histogram bins to use for approximating data.
+ *
+ * @param num_bins Number of non-uniform-width histogram bins to use
+ */
+ public void allocate(int num_bins) {
+ nbins = num_bins;
+ bins = new ArrayList<Coord>();
+ nusedbins = 0;
+ }
+
+ /**
+ * Takes a histogram and merges it with the current histogram object.
+ */
+ public void merge(NumericHistogram other) {
+ if (other == null) {
+ return;
+ }
+
+ if (nbins == 0 || nusedbins == 0) {
+ // Our aggregation buffer has nothing in it, so just copy over 'other'
+ // by deserializing the ArrayList of (x,y) pairs into an array of Coord objects
+ nbins = other.nbins;
+ nusedbins = other.nusedbins;
+ bins = new ArrayList<Coord>(nusedbins);
+ for (int i = 0; i < other.nusedbins; i += 1) {
+ Coord bin = new Coord();
+ bin.x = other.getBin(i).x;
+ bin.y = other.getBin(i).y;
+ bins.add(bin);
+ }
+ } else {
+ // The aggregation buffer already contains a partial histogram. Therefore, we need
+ // to merge histograms using Algorithm #2 from the Ben-Haim and Tom-Tov paper.
+
+ ArrayList<Coord> tmp_bins = new ArrayList<Coord>(nusedbins + other.nusedbins);
+ // Copy all the histogram bins from us and 'other' into an overstuffed histogram
+ for (int i = 0; i < nusedbins; i++) {
+ Coord bin = new Coord();
+ bin.x = bins.get(i).x;
+ bin.y = bins.get(i).y;
+ tmp_bins.add(bin);
+ }
+ for (int j = 0; j < other.nusedbins; j += 1) {
+ Coord bin = new Coord();
+ bin.x = other.getBin(j).x;
+ bin.y = other.getBin(j).y;
+ tmp_bins.add(bin);
+ }
+ Collections.sort(tmp_bins);
+
+ // Now trim the overstuffed histogram down to the correct number of bins
+ bins = tmp_bins;
+ nusedbins += other.nusedbins;
+ trim();
+ }
+ }
+
+
+ /**
+ * Adds a new data point to the histogram approximation. Make sure you have
+ * called either allocate() or merge() first. This method implements Algorithm #1
+ * from Ben-Haim and Tom-Tov, "A Streaming Parallel Decision Tree Algorithm", JMLR 2010.
+ *
+ * @param v The data point to add to the histogram approximation.
+ */
+ public void add(double v) {
+ // Binary search to find the closest bucket that v should go into.
+ // 'bin' should be interpreted as the bin to shift right in order to accomodate
+ // v. As a result, bin is in the range [0,N], where N means that the value v is
+ // greater than all the N bins currently in the histogram. It is also possible that
+ // a bucket centered at 'v' already exists, so this must be checked in the next step.
+ int bin = 0;
+ for (int l = 0, r = nusedbins; l < r; ) {
+ bin = (l + r) / 2;
+ if (bins.get(bin).x > v) {
+ r = bin;
+ } else {
+ if (bins.get(bin).x < v) {
+ l = ++bin;
+ } else {
+ break; // break loop on equal comparator
+ }
+ }
+ }
+
+ // If we found an exact bin match for value v, then just increment that bin's count.
+ // Otherwise, we need to insert a new bin and trim the resulting histogram back to size.
+ // A possible optimization here might be to set some threshold under which 'v' is just
+ // assumed to be equal to the closest bin -- if fabs(v-bins[bin].x) < THRESHOLD, then
+ // just increment 'bin'. This is not done now because we don't want to make any
+ // assumptions about the range of numeric data being analyzed.
+ if (bin < nusedbins && bins.get(bin).x == v) {
+ bins.get(bin).y++;
+ } else {
+ Coord newBin = new Coord();
+ newBin.x = v;
+ newBin.y = 1;
+ bins.add(bin, newBin);
+
+ // Trim the bins down to the correct number of bins.
+ if (++nusedbins > nbins) {
+ trim();
+ }
+ }
+
+ }
+
+ /**
+ * Trims a histogram down to 'nbins' bins by iteratively merging the closest bins.
+ * If two pairs of bins are equally close to each other, decide uniformly at random which
+ * pair to merge, based on a PRNG.
+ */
+ private void trim() {
+ while (nusedbins > nbins) {
+ // Find the closest pair of bins in terms of x coordinates. Break ties randomly.
+ double smallestdiff = bins.get(1).x - bins.get(0).x;
+ int smallestdiffloc = 0, smallestdiffcount = 1;
+ for (int i = 1; i < nusedbins - 1; i++) {
+ double diff = bins.get(i + 1).x - bins.get(i).x;
+ if (diff < smallestdiff) {
+ smallestdiff = diff;
+ smallestdiffloc = i;
+ smallestdiffcount = 1;
+ } else {
+ if (diff == smallestdiff && prng.nextDouble() <= (1.0 / ++smallestdiffcount)) {
+ smallestdiffloc = i;
+ }
+ }
+ }
+
+ // Merge the two closest bins into their average x location, weighted by their heights.
+ // The height of the new bin is the sum of the heights of the old bins.
+ // double d = bins[smallestdiffloc].y + bins[smallestdiffloc+1].y;
+ // bins[smallestdiffloc].x *= bins[smallestdiffloc].y / d;
+ // bins[smallestdiffloc].x += bins[smallestdiffloc+1].x / d *
+ // bins[smallestdiffloc+1].y;
+ // bins[smallestdiffloc].y = d;
+
+ double d = bins.get(smallestdiffloc).y + bins.get(smallestdiffloc + 1).y;
+ Coord smallestdiffbin = bins.get(smallestdiffloc);
+ smallestdiffbin.x *= smallestdiffbin.y / d;
+ smallestdiffbin.x += bins.get(smallestdiffloc + 1).x / d * bins.get(smallestdiffloc + 1).y;
+ smallestdiffbin.y = d;
+ // Shift the remaining bins left one position
+ bins.remove(smallestdiffloc + 1);
+ nusedbins--;
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f53c829..4d316ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -433,6 +433,7 @@ object FunctionRegistry {
expression[Skewness]("skewness"),
expression[ApproximatePercentile]("percentile_approx"),
expression[ApproximatePercentile]("approx_percentile", true),
+ expression[HistogramNumeric]("histogram_numeric"),
expression[StddevSamp]("std", true),
expression[StddevSamp]("stddev", true),
expression[StddevPop]("stddev_pop"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index c3cc78e..141de75 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -1508,7 +1508,7 @@ class SessionCatalog(
*/
def isTemporaryFunction(name: FunctionIdentifier): Boolean = {
// copied from HiveSessionCatalog
- val hiveFunctions = Seq("histogram_numeric")
+ val hiveFunctions = Seq()
// A temporary function is a function that has been registered in functionRegistry
// without a database name, and is neither a built-in function nor a Hive function
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala
new file mode 100644
index 0000000..09408e6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.nio.ByteBuffer
+
+import com.google.common.primitives.{Doubles, Ints}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes}
+import org.apache.spark.sql.catalyst.trees.BinaryLike
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, DateType, DayTimeIntervalType, DoubleType, IntegerType, NumericType, StructField, StructType, TimestampNTZType, TimestampType, TypeCollection, YearMonthIntervalType}
+import org.apache.spark.sql.util.NumericHistogram
+
+/**
+ * Computes an approximate histogram of a numerical column using a user-specified number of bins.
+ *
+ * The output is an array of (x,y) pairs as struct objects that represents the histogram's
+ * bin centers and heights.
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr, nb) - Computes a histogram on numeric 'expr' using nb bins.
+ The return value is an array of (x,y) pairs representing the centers of the
+ histogram's bins. As the value of 'nb' is increased, the histogram approximation
+ gets finer-grained, but may yield artifacts around outliers. In practice, 20-40
+ histogram bins appear to work well, with more bins being required for skewed or
+ smaller datasets. Note that this function creates a histogram with non-uniform
+ bin widths. It offers no guarantees in terms of the mean-squared-error of the
+ histogram, but in practice is comparable to the histograms produced by the R/S-Plus
+ statistical computing packages.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col);
+ [{"x":0.0,"y":1.0},{"x":1.0,"y":1.0},{"x":2.0,"y":1.0},{"x":10.0,"y":1.0}]
+ """,
+ group = "agg_funcs",
+ since = "3.3.0")
+case class HistogramNumeric(
+ child: Expression,
+ nBins: Expression,
+ override val mutableAggBufferOffset: Int,
+ override val inputAggBufferOffset: Int)
+ extends TypedImperativeAggregate[NumericHistogram] with ImplicitCastInputTypes
+ with BinaryLike[Expression] {
+
+ def this(child: Expression, nBins: Expression) = {
+ this(child, nBins, 0, 0)
+ }
+
+ private lazy val nb = nBins.eval() match {
+ case null => null
+ case n: Int => n
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ // Support NumericType, DateType, TimestampType and TimestampNTZType, YearMonthIntervalType,
+ // DayTimeIntervalType since their internal types are all numeric,
+ // and can be easily cast to double for processing.
+ Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType,
+ YearMonthIntervalType, DayTimeIntervalType), IntegerType)
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!nBins.foldable) {
+ TypeCheckFailure(s"${this.prettyName} needs the nBins provided must be a constant literal.")
+ } else if (nb == null) {
+ TypeCheckFailure(s"${this.prettyName} needs nBins value must not be null.")
+ } else if (nb.asInstanceOf[Int] < 2) {
+ TypeCheckFailure(s"${this.prettyName} needs nBins to be at least 2, but you supplied $nb.")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+
+ override def createAggregationBuffer(): NumericHistogram = {
+ val buffer = new NumericHistogram()
+ buffer.allocate(nb.asInstanceOf[Int])
+ buffer
+ }
+
+ override def update(buffer: NumericHistogram, inputRow: InternalRow): NumericHistogram = {
+ val value = child.eval(inputRow)
+ // Ignore empty rows, for example: histogram_numeric(null)
+ if (value != null) {
+ // Convert the value to a double value
+ val doubleValue = value.asInstanceOf[Number].doubleValue
+ buffer.add(doubleValue)
+ }
+ buffer
+ }
+
+ override def merge(
+ buffer: NumericHistogram,
+ other: NumericHistogram): NumericHistogram = {
+ buffer.merge(other)
+ buffer
+ }
+
+ override def eval(buffer: NumericHistogram): Any = {
+ if (buffer.getUsedBins < 1) {
+ null
+ } else {
+ val result = (0 until buffer.getUsedBins).map { index =>
+ val coord = buffer.getBin(index)
+ InternalRow.apply(coord.x, coord.y)
+ }
+ new GenericArrayData(result)
+ }
+ }
+
+ override def serialize(obj: NumericHistogram): Array[Byte] = {
+ NumericHistogramSerializer.serialize(obj)
+ }
+
+ override def deserialize(bytes: Array[Byte]): NumericHistogram = {
+ NumericHistogramSerializer.deserialize(bytes)
+ }
+
+ override def left: Expression = child
+
+ override def right: Expression = nBins
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression,
+ newRight: Expression): HistogramNumeric = {
+ copy(child = newLeft, nBins = newRight)
+ }
+
+ override def withNewMutableAggBufferOffset(newOffset: Int): HistogramNumeric =
+ copy(mutableAggBufferOffset = newOffset)
+
+ override def withNewInputAggBufferOffset(newOffset: Int): HistogramNumeric =
+ copy(inputAggBufferOffset = newOffset)
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType =
+ ArrayType(new StructType(Array(
+ StructField("x", DoubleType, true),
+ StructField("y", DoubleType, true))), true)
+
+ override def prettyName: String = "histogram_numeric"
+}
+
+object NumericHistogramSerializer {
+ private final def length(histogram: NumericHistogram): Int = {
+ // histogram.nBins, histogram.nUsedBins
+ Ints.BYTES + Ints.BYTES +
+ // histogram.bins, Array[Coord(x: Double, y: Double)]
+ histogram.getUsedBins * (Doubles.BYTES + Doubles.BYTES)
+ }
+
+ def serialize(histogram: NumericHistogram): Array[Byte] = {
+ val buffer = ByteBuffer.wrap(new Array(length(histogram)))
+ buffer.putInt(histogram.getNumBins)
+ buffer.putInt(histogram.getUsedBins)
+
+ var i = 0
+ while (i < histogram.getUsedBins) {
+ val coord = histogram.getBin(i)
+ buffer.putDouble(coord.x)
+ buffer.putDouble(coord.y)
+ i += 1
+ }
+ buffer.array()
+ }
+
+ def deserialize(bytes: Array[Byte]): NumericHistogram = {
+ val buffer = ByteBuffer.wrap(bytes)
+ val nBins = buffer.getInt()
+ val nUsedBins = buffer.getInt()
+ val histogram = new NumericHistogram()
+ histogram.allocate(nBins)
+ histogram.setUsedBins(nUsedBins)
+ var i: Int = 0
+ while (i < nUsedBins) {
+ val x = buffer.getDouble()
+ val y = buffer.getDouble()
+ histogram.addBin(x, y, i)
+ i += 1
+ }
+ histogram
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala
new file mode 100644
index 0000000..60b53c6
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.dsl.expressions.{DslString, DslSymbol}
+import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, GenericInternalRow, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.types.{DoubleType, IntegerType}
+import org.apache.spark.sql.util.NumericHistogram
+
+class HistogramNumericSuite extends SparkFunSuite {
+
+ private val random = new java.util.Random()
+
+ private val data = (0 until 10000).map { _ =>
+ random.nextInt(10000)
+ }
+
+ test("serialize and de-serialize") {
+
+ // Check empty serialize and de-serialize
+ val emptyBuffer = new NumericHistogram()
+ emptyBuffer.allocate(5)
+ assert(compareEquals(emptyBuffer,
+ NumericHistogramSerializer.deserialize(NumericHistogramSerializer.serialize(emptyBuffer))))
+
+ val buffer = new NumericHistogram()
+ buffer.allocate(data.size / 3)
+ data.foreach { value =>
+ buffer.add(value)
+ }
+ assert(compareEquals(buffer,
+ NumericHistogramSerializer.deserialize(NumericHistogramSerializer.serialize(buffer))))
+
+ val agg = new HistogramNumeric(BoundReference(0, DoubleType, true), Literal(5))
+ assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
+ }
+
+ test("class NumericHistogram, basic operations") {
+ val valueCount = 5
+ Seq(3, 5).foreach { nBins: Int =>
+ val buffer = new NumericHistogram()
+ buffer.allocate(nBins)
+ (1 to valueCount).grouped(nBins).foreach { group =>
+ val partialBuffer = new NumericHistogram()
+ partialBuffer.allocate(nBins)
+ group.foreach(x => partialBuffer.add(x))
+ buffer.merge(partialBuffer)
+ }
+ val sum = (0 until buffer.getUsedBins).map { i =>
+ val coord = buffer.getBin(i)
+ coord.x * coord.y
+ }.sum
+ assert(sum <= (1 to valueCount).sum)
+ }
+ }
+
+ test("class HistogramNumeric, sql string") {
+ val defaultAccuracy = ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
+ assertEqual(s"histogram_numeric(a, 3)",
+ new HistogramNumeric("a".attr, Literal(3)).sql: String)
+
+ // sql(isDistinct = true), array of percentile
+ assertEqual(s"histogram_numeric(DISTINCT a, 3)",
+ new HistogramNumeric("a".attr, Literal(3)).sql(isDistinct = true))
+ }
+
+ test("class HistogramNumeric, fails analysis if nBins is not a constant") {
+ val attribute = AttributeReference("a", IntegerType)()
+ val wrongNB = new HistogramNumeric(attribute, nBins = AttributeReference("b", IntegerType)())
+
+ assertEqual(
+ wrongNB.checkInputDataTypes(),
+ TypeCheckFailure("histogram_numeric needs the nBins provided must be a constant literal.")
+ )
+ }
+
+ test("class HistogramNumeric, fails analysis if nBins is invalid") {
+ val attribute = AttributeReference("a", IntegerType)()
+ val wrongNB = new HistogramNumeric(attribute, nBins = Literal(1))
+
+ assertEqual(
+ wrongNB.checkInputDataTypes(),
+ TypeCheckFailure("histogram_numeric needs nBins to be at least 2, but you supplied 1.")
+ )
+ }
+
+ test("class HistogramNumeric, automatically add type casting for parameters") {
+ val testRelation = LocalRelation('a.int)
+
+ // accuracy types must be integral, no type casting
+ val nBinsExpressions = Seq(
+ Literal(2.toByte),
+ Literal(100.toShort),
+ Literal(100),
+ Literal(1000L))
+
+ nBinsExpressions.foreach { nBins =>
+ val agg = new HistogramNumeric(UnresolvedAttribute("a"), nBins)
+ val analyzed = testRelation.select(agg).analyze.expressions.head
+ analyzed match {
+ case Alias(agg: HistogramNumeric, _) =>
+ assert(agg.resolved)
+ assert(agg.child.dataType == IntegerType)
+ assert(agg.nBins.dataType == IntegerType)
+ case _ => fail()
+ }
+ }
+ }
+
+ test("HistogramNumeric: nulls in nBins expression") {
+ assert(new HistogramNumeric(
+ AttributeReference("a", DoubleType)(),
+ Literal(null, IntegerType)).checkInputDataTypes() ===
+ TypeCheckFailure("histogram_numeric needs nBins value must not be null."))
+ }
+
+ test("class HistogramNumeric, null handling") {
+ val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+ val agg = new HistogramNumeric(childExpression, Literal(5))
+ val buffer = new GenericInternalRow(new Array[Any](1))
+ agg.initialize(buffer)
+ // Empty aggregation buffer
+ assert(agg.eval(buffer) == null)
+ // Empty input row
+ agg.update(buffer, InternalRow(null))
+ assert(agg.eval(buffer) == null)
+
+ // Add some non-empty row
+ agg.update(buffer, InternalRow(0))
+ assert(agg.eval(buffer) != null)
+ }
+
+ private def compareEquals(left: NumericHistogram, right: NumericHistogram): Boolean = {
+ left.getNumBins == right.getNumBins && left.getUsedBins == right.getUsedBins &&
+ (0 until left.getUsedBins).forall { i =>
+ val leftCoord = left.getBin(i)
+ val rightCoord = right.getBin(i)
+ leftCoord.x == rightCoord.x && leftCoord.y == rightCoord.y
+ }
+ }
+
+ private def assertEqual[T](left: T, right: T): Unit = {
+ assert(left == right)
+ }
+}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 958b961..9192ac4 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -1,6 +1,6 @@
<!-- Automatically generated by ExpressionsSchemaSuite -->
## Summary
- - Number of queries: 366
+ - Number of queries: 367
- Number of expressions that missing example: 12
- Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
## Schema of Built-in Functions
@@ -345,6 +345,7 @@
| org.apache.spark.sql.catalyst.expressions.aggregate.CovSample | covar_samp | SELECT covar_samp(c1, c2) FROM VALUES (1,1), (2,2), (3,3) AS tab(c1, c2) | struct<covar_samp(c1, c2):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.First | first | SELECT first(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first(col):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.First | first_value | SELECT first_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first_value(col):int> |
+| org.apache.spark.sql.catalyst.expressions.aggregate.HistogramNumeric | histogram_numeric | SELECT histogram_numeric(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct<histogram_numeric(col, 5):array<struct<x:double,y:double>>> |
| org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus | approx_count_distinct | SELECT approx_count_distinct(col1) FROM VALUES (1), (1), (2), (2), (3) tab(col1) | struct<approx_count_distinct(col1):bigint> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Kurtosis | kurtosis | SELECT kurtosis(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct<kurtosis(col):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<last(col):int> |
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index 039373b..4e6d2d2 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -192,3 +192,15 @@ SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
FROM testData
GROUP BY a IS NULL;
+
+SELECT
+ histogram_numeric(col, 2) as histogram_2,
+ histogram_numeric(col, 3) as histogram_3,
+ histogram_numeric(col, 5) as histogram_5,
+ histogram_numeric(col, 10) as histogram_10
+FROM VALUES
+ (1), (2), (3), (4), (5), (6), (7), (8), (9), (10),
+ (11), (12), (13), (14), (15), (16), (17), (18), (19), (20),
+ (21), (22), (23), (24), (25), (26), (27), (28), (29), (30),
+ (31), (32), (33), (34), (35), (3), (37), (38), (39), (40),
+ (41), (42), (43), (44), (45), (46), (47), (48), (49), (50) AS tab(col);
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index f598f49..5cd5a37 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 65
+-- Number of queries: 66
-- !query
@@ -673,3 +673,21 @@ struct<(IF((NOT (a IS NULL)), rand(0), 1)):double,c:bigint>
-- !query output
0.7604953758285915 7
1.0 2
+
+
+-- !query
+SELECT
+ histogram_numeric(col, 2) as histogram_2,
+ histogram_numeric(col, 3) as histogram_3,
+ histogram_numeric(col, 5) as histogram_5,
+ histogram_numeric(col, 10) as histogram_10
+FROM VALUES
+ (1), (2), (3), (4), (5), (6), (7), (8), (9), (10),
+ (11), (12), (13), (14), (15), (16), (17), (18), (19), (20),
+ (21), (22), (23), (24), (25), (26), (27), (28), (29), (30),
+ (31), (32), (33), (34), (35), (3), (37), (38), (39), (40),
+ (41), (42), (43), (44), (45), (46), (47), (48), (49), (50) AS tab(col)
+-- !query schema
+struct<histogram_2:array<struct<x:double,y:double>>,histogram_3:array<struct<x:double,y:double>>,histogram_5:array<struct<x:double,y:double>>,histogram_10:array<struct<x:double,y:double>>>
+-- !query output
+[{"x":12.615384615384613,"y":26.0},{"x":38.083333333333336,"y":24.0}] [{"x":9.649999999999999,"y":20.0},{"x":25.0,"y":11.0},{"x":40.736842105263165,"y":19.0}] [{"x":5.272727272727273,"y":11.0},{"x":14.5,"y":8.0},{"x":22.0,"y":7.0},{"x":30.499999999999996,"y":10.0},{"x":43.5,"y":14.0}] [{"x":3.0,"y":6.0},{"x":8.5,"y":6.0},{"x":13.5,"y":4.0},{"x":17.0,"y":3.0},{"x":20.5,"y":4.0},{"x":25.5,"y":6.0},{"x":31.999999999999996,"y":7.0},{"x":39.0,"y":5.0},{"x":43.5,"y":4.0},{"x":48.0,"y":5.0}]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 488890a..b11774b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -128,7 +128,5 @@ private[sql] class HiveSessionCatalog(
// in_file, index, matchpath, ngrams, noop, noopstreaming, noopwithmap,
// noopwithmapstreaming, parse_url_tuple, reflect2, windowingtablefunction.
// Note: don't forget to update SessionCatalog.isTemporaryFunction
- private val hiveFunctions = Seq(
- "histogram_numeric"
- )
+ private val hiveFunctions = Seq()
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org