You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/03/06 02:13:13 UTC
[spark] branch branch-3.4 updated: [SPARK-42558][CONNECT] Implement `DataFrameStatFunctions` except `bloomFilter` functions
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 548baea5ee5 [SPARK-42558][CONNECT] Implement `DataFrameStatFunctions` except `bloomFilter` functions
548baea5ee5 is described below
commit 548baea5ee50d5a2b3b48a84c94e23127f5192bf
Author: yangjie01 <ya...@baidu.com>
AuthorDate: Sun Mar 5 22:12:51 2023 -0400
[SPARK-42558][CONNECT] Implement `DataFrameStatFunctions` except `bloomFilter` functions
### What changes were proposed in this pull request?
This pr aims partial implement DataFrameNaFunctions includes `approxQuantile`, `cov`, `corr`, `crosstab`, `freqItems`, `sampleBy` and `countMinSketch`, `bloomFilter` are not supported in this pr due to lack of protobuf message.
### Why are the changes needed?
Spark connect Scala client API coverage
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
- Add new test.
- checked `connect` and `connect-client-jvm` with Scala 2.13 manually
Closes #40255 from LuciferYang/SPARK-42558.
Authored-by: yangjie01 <ya...@baidu.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
(cherry picked from commit 9308b0c75d9414ebf75c1dc3c576a545133eaaf0)
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../apache/spark/sql/DataFrameStatFunctions.scala | 592 +++++++++++++++++++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 12 +
.../org/apache/spark/sql/DataFrameStatSuite.scala | 179 +++++++
.../apache/spark/sql/PlanGenerationTestSuite.scala | 12 +
.../CheckConnectJvmClientCompatibility.scala | 6 +-
.../query-tests/explain-results/crosstab.explain | 5 +
.../query-tests/explain-results/freqItems.explain | 2 +
.../query-tests/explain-results/sampleBy.explain | 2 +
.../resources/query-tests/queries/crosstab.json | 17 +
.../query-tests/queries/crosstab.proto.bin | Bin 0 -> 55 bytes
.../resources/query-tests/queries/freqItems.json | 17 +
.../query-tests/queries/freqItems.proto.bin | Bin 0 -> 65 bytes
.../resources/query-tests/queries/sampleBy.json | 32 ++
.../query-tests/queries/sampleBy.proto.bin | Bin 0 -> 89 bytes
14 files changed, 875 insertions(+), 1 deletion(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
new file mode 100644
index 00000000000..0d4372b8738
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -0,0 +1,592 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.{lang => jl, util => ju}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto.{Relation, StatSampleBy}
+import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder}
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.util.sketch.CountMinSketch
+
+/**
+ * Statistic functions for `DataFrame`s.
+ *
+ * @since 3.4.0
+ */
+final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, root: Relation) {
+
+ /**
+ * Calculates the approximate quantiles of a numerical column of a DataFrame.
+ *
+ * The result of this algorithm has the following deterministic bound: If the DataFrame has N
+ * elements and if we request the quantile at probability `p` up to error `err`, then the
+ * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is
+ * close to (p * N). More precisely,
+ *
+ * {{{
+ * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N)
+ * }}}
+ *
+ * This method implements a variation of the Greenwald-Khanna algorithm (with some speed
+ * optimizations). The algorithm was first present in <a
+ * href="https://doi.org/10.1145/375663.375670"> Space-efficient Online Computation of Quantile
+ * Summaries</a> by Greenwald and Khanna.
+ *
+ * @param col
+ * the name of the numerical column
+ * @param probabilities
+ * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the
+ * minimum, 0.5 is the median, 1 is the maximum.
+ * @param relativeError
+ * The relative target precision to achieve (greater than or equal to 0). If set to zero, the
+ * exact quantiles are computed, which could be very expensive. Note that values greater than
+ * 1 are accepted but give the same result as 1.
+ * @return
+ * the approximate quantiles at the given probabilities
+ *
+ * @note
+ * null and NaN values will be removed from the numerical column before calculation. If the
+ * dataframe is empty or the column only contains null or NaN, an empty array is returned.
+ *
+ * @since 3.4.0
+ */
+ def approxQuantile(
+ col: String,
+ probabilities: Array[Double],
+ relativeError: Double): Array[Double] = {
+ approxQuantile(Array(col), probabilities, relativeError).head
+ }
+
+ /**
+ * Calculates the approximate quantiles of numerical columns of a DataFrame.
+ * @see
+ * `approxQuantile(col:Str* approxQuantile)` for detailed description.
+ *
+ * @param cols
+ * the names of the numerical columns
+ * @param probabilities
+ * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the
+ * minimum, 0.5 is the median, 1 is the maximum.
+ * @param relativeError
+ * The relative target precision to achieve (greater than or equal to 0). If set to zero, the
+ * exact quantiles are computed, which could be very expensive. Note that values greater than
+ * 1 are accepted but give the same result as 1.
+ * @return
+ * the approximate quantiles at the given probabilities of each column
+ *
+ * @note
+ * null and NaN values will be ignored in numerical columns before calculation. For columns
+ * only containing null or NaN values, an empty array is returned.
+ *
+ * @since 3.4.0
+ */
+ def approxQuantile(
+ cols: Array[String],
+ probabilities: Array[Double],
+ relativeError: Double): Array[Array[Double]] = {
+ require(
+ probabilities.forall(p => p >= 0.0 && p <= 1.0),
+ "percentile should be in the range [0.0, 1.0]")
+ require(relativeError >= 0, s"Relative Error must be non-negative but got $relativeError")
+ sparkSession
+ .newDataset(approxQuantileResultEncoder) { builder =>
+ val approxQuantileBuilder = builder.getApproxQuantileBuilder
+ .setInput(root)
+ .setRelativeError(relativeError)
+ cols.foreach(approxQuantileBuilder.addCols)
+ probabilities.foreach(approxQuantileBuilder.addProbabilities)
+ }
+ .head()
+ }
+
+ /**
+ * Calculate the sample covariance of two numerical columns of a DataFrame.
+ * @param col1
+ * the name of the first column
+ * @param col2
+ * the name of the second column
+ * @return
+ * the covariance of the two columns.
+ *
+ * {{{
+ * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
+ * .withColumn("rand2", rand(seed=27))
+ * df.stat.cov("rand1", "rand2")
+ * res1: Double = 0.065...
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def cov(col1: String, col2: String): Double = {
+ sparkSession
+ .newDataset(PrimitiveDoubleEncoder) { builder =>
+ builder.getCovBuilder.setInput(root).setCol1(col1).setCol2(col2)
+ }
+ .head()
+ }
+
+ /**
+ * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
+ * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
+ * MLlib's Statistics.
+ *
+ * @param col1
+ * the name of the column
+ * @param col2
+ * the name of the column to calculate the correlation against
+ * @return
+ * The Pearson Correlation Coefficient as a Double.
+ *
+ * {{{
+ * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
+ * .withColumn("rand2", rand(seed=27))
+ * df.stat.corr("rand1", "rand2")
+ * res1: Double = 0.613...
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def corr(col1: String, col2: String, method: String): Double = {
+ require(
+ method == "pearson",
+ "Currently only the calculation of the Pearson Correlation " +
+ "coefficient is supported.")
+ sparkSession
+ .newDataset(PrimitiveDoubleEncoder) { builder =>
+ builder.getCorrBuilder.setInput(root).setCol1(col1).setCol2(col2)
+ }
+ .head()
+ }
+
+ /**
+ * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame.
+ *
+ * @param col1
+ * the name of the column
+ * @param col2
+ * the name of the column to calculate the correlation against
+ * @return
+ * The Pearson Correlation Coefficient as a Double.
+ *
+ * {{{
+ * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
+ * .withColumn("rand2", rand(seed=27))
+ * df.stat.corr("rand1", "rand2", "pearson")
+ * res1: Double = 0.613...
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def corr(col1: String, col2: String): Double = {
+ corr(col1, col2, "pearson")
+ }
+
+ /**
+ * Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
+ * The first column of each row will be the distinct values of `col1` and the column names will
+ * be the distinct values of `col2`. The name of the first column will be `col1_col2`. Counts
+ * will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts.
+ * Null elements will be replaced by "null", and back ticks will be dropped from elements if
+ * they exist.
+ *
+ * @param col1
+ * The name of the first column. Distinct items will make the first item of each row.
+ * @param col2
+ * The name of the second column. Distinct items will make the column names of the DataFrame.
+ * @return
+ * A DataFrame containing for the contingency table.
+ *
+ * {{{
+ * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3)))
+ * .toDF("key", "value")
+ * val ct = df.stat.crosstab("key", "value")
+ * ct.show()
+ * +---------+---+---+---+
+ * |key_value| 1| 2| 3|
+ * +---------+---+---+---+
+ * | 2| 2| 0| 1|
+ * | 1| 1| 1| 0|
+ * | 3| 0| 1| 1|
+ * +---------+---+---+---+
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def crosstab(col1: String, col2: String): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ builder.getCrosstabBuilder.setInput(root).setCol1(col1).setCol2(col2)
+ }
+ }
+
+ /**
+ * Finding frequent items for columns, possibly with false positives. Using the frequent element
+ * count algorithm described in <a href="https://doi.org/10.1145/762471.762473">here</a>,
+ * proposed by Karp, Schenker, and Papadimitriou. The `support` should be greater than 1e-4.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting `DataFrame`.
+ *
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @param support
+ * The minimum frequency for an item to be considered `frequent`. Should be greater than 1e-4.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
+ *
+ * {{{
+ * val rows = Seq.tabulate(100) { i =>
+ * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0)
+ * }
+ * val df = spark.createDataFrame(rows).toDF("a", "b")
+ * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns
+ * // "a" and "b"
+ * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4)
+ * freqSingles.show()
+ * +-----------+-------------+
+ * |a_freqItems| b_freqItems|
+ * +-----------+-------------+
+ * | [1, 99]|[-1.0, -99.0]|
+ * +-----------+-------------+
+ * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b"
+ * val pairDf = df.select(struct("a", "b").as("a-b"))
+ * val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1)
+ * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show()
+ * +----------+
+ * | freq_ab|
+ * +----------+
+ * | [1,-1.0]|
+ * | ... |
+ * +----------+
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def freqItems(cols: Array[String], support: Double): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ val freqItemsBuilder = builder.getFreqItemsBuilder.setInput(root).setSupport(support)
+ cols.foreach(freqItemsBuilder.addCols)
+ }
+ }
+
+ /**
+ * Finding frequent items for columns, possibly with false positives. Using the frequent element
+ * count algorithm described in <a href="https://doi.org/10.1145/762471.762473">here</a>,
+ * proposed by Karp, Schenker, and Papadimitriou. Uses a `default` support of 1%.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting `DataFrame`.
+ *
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
+ *
+ * @since 3.4.0
+ */
+ def freqItems(cols: Array[String]): DataFrame = {
+ freqItems(cols, 0.01)
+ }
+
+ /**
+ * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the
+ * frequent element count algorithm described in <a
+ * href="https://doi.org/10.1145/762471.762473">here</a>, proposed by Karp, Schenker, and
+ * Papadimitriou.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting `DataFrame`.
+ *
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
+ *
+ * {{{
+ * val rows = Seq.tabulate(100) { i =>
+ * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0)
+ * }
+ * val df = spark.createDataFrame(rows).toDF("a", "b")
+ * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns
+ * // "a" and "b"
+ * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4)
+ * freqSingles.show()
+ * +-----------+-------------+
+ * |a_freqItems| b_freqItems|
+ * +-----------+-------------+
+ * | [1, 99]|[-1.0, -99.0]|
+ * +-----------+-------------+
+ * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b"
+ * val pairDf = df.select(struct("a", "b").as("a-b"))
+ * val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1)
+ * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show()
+ * +----------+
+ * | freq_ab|
+ * +----------+
+ * | [1,-1.0]|
+ * | ... |
+ * +----------+
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def freqItems(cols: Seq[String], support: Double): DataFrame = {
+ freqItems(cols.toArray, support)
+ }
+
+ /**
+ * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the
+ * frequent element count algorithm described in <a
+ * href="https://doi.org/10.1145/762471.762473">here</a>, proposed by Karp, Schenker, and
+ * Papadimitriou. Uses a `default` support of 1%.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting `DataFrame`.
+ *
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
+ *
+ * @since 3.4.0
+ */
+ def freqItems(cols: Seq[String]): DataFrame = {
+ freqItems(cols.toArray, 0.01)
+ }
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
+ *
+ * {{{
+ * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2),
+ * (3, 3))).toDF("key", "value")
+ * val fractions = Map(1 -> 1.0, 3 -> 0.5)
+ * df.stat.sampleBy("key", fractions, 36L).show()
+ * +---+-----+
+ * |key|value|
+ * +---+-----+
+ * | 1| 1|
+ * | 1| 2|
+ * | 3| 2|
+ * +---+-----+
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = {
+ sampleBy(Column(col), fractions, seed)
+ }
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
+ *
+ * @since 3.4.0
+ */
+ def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
+ sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+ }
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
+ *
+ * The stratified sample can be performed over multiple columns:
+ * {{{
+ * import org.apache.spark.sql.Row
+ * import org.apache.spark.sql.functions.struct
+ *
+ * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17),
+ * ("Alice", 10))).toDF("name", "age")
+ * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0)
+ * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show()
+ * +-----+---+
+ * | name|age|
+ * +-----+---+
+ * | Nico| 8|
+ * |Alice| 10|
+ * +-----+---+
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = {
+ require(
+ fractions.values.forall(p => p >= 0.0 && p <= 1.0),
+ s"Fractions must be in [0, 1], but got $fractions.")
+ sparkSession.newDataFrame { builder =>
+ val sampleByBuilder = builder.getSampleByBuilder
+ .setInput(root)
+ .setCol(col.expr)
+ .setSeed(seed)
+ fractions.foreach { case (k, v) =>
+ sampleByBuilder.addFractions(
+ StatSampleBy.Fraction
+ .newBuilder()
+ .setStratum(lit(k).expr.getLiteral)
+ .setFraction(v))
+ }
+ }
+ }
+
+ /**
+ * (Java-specific) Returns a stratified sample without replacement based on the fraction given
+ * on each stratum.
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
+ *
+ * @since 3.4.0
+ */
+ def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
+ sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+ }
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param colName
+ * name of the column over which the sketch is built
+ * @param depth
+ * depth of the sketch
+ * @param width
+ * width of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
+ * @since 3.4.0
+ */
+ def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = {
+ countMinSketch(Column(colName), depth, width, seed)
+ }
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param colName
+ * name of the column over which the sketch is built
+ * @param eps
+ * relative error of the sketch
+ * @param confidence
+ * confidence of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
+ * @since 3.4.0
+ */
+ def countMinSketch(
+ colName: String,
+ eps: Double,
+ confidence: Double,
+ seed: Int): CountMinSketch = {
+ countMinSketch(Column(colName), eps, confidence, seed)
+ }
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param col
+ * the column over which the sketch is built
+ * @param depth
+ * depth of the sketch
+ * @param width
+ * width of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
+ * @since 3.4.0
+ */
+ def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = {
+ countMinSketch(col, eps = 2.0 / width, confidence = 1 - 1 / Math.pow(2, depth), seed)
+ }
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param col
+ * the column over which the sketch is built
+ * @param eps
+ * relative error of the sketch
+ * @param confidence
+ * confidence of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
+ * @since 3.4.0
+ */
+ def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
+ val agg = Column.fn("count_min_sketch", col, lit(eps), lit(confidence), lit(seed))
+ val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
+ builder.getProjectBuilder
+ .setInput(root)
+ .addExpressions(agg.expr)
+ }
+ CountMinSketch.readFrom(ds.head())
+ }
+}
+
+private object DataFrameStatFunctions {
+ private val approxQuantileResultEncoder: ArrayEncoder[Array[Double]] =
+ ArrayEncoder(ArrayEncoder(PrimitiveDoubleEncoder, containsNull = false), containsNull = false)
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index e264f1c0c0c..588e62768ac 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -534,6 +534,18 @@ class Dataset[T] private[sql] (
}
}
+ /**
+ * Returns a [[DataFrameStatFunctions]] for working statistic functions support.
+ * {{{
+ * // Finding frequent items in column with name 'a'.
+ * ds.stat.freqItems(Seq("a"))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def stat: DataFrameStatFunctions = new DataFrameStatFunctions(sparkSession, plan.getRoot)
+
private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): DataFrame = {
sparkSession.newDataFrame { builder =>
val joinBuilder = builder.getJoinBuilder
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
new file mode 100644
index 00000000000..aea31005f3b
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Random
+
+import io.grpc.StatusRuntimeException
+import org.scalatest.matchers.must.Matchers._
+
+import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+
+class DataFrameStatSuite extends RemoteSparkSession {
+ private def toLetter(i: Int): String = (i + 97).toChar.toString
+
+ test("approxQuantile") {
+ val session = spark
+ import session.implicits._
+
+ val n = 1000
+ val df = Seq.tabulate(n + 1)(i => (i, 2.0 * i)).toDF("singles", "doubles")
+
+ val q1 = 0.5
+ val q2 = 0.8
+ val epsilons = List(0.1, 0.05, 0.001)
+
+ for (epsilon <- epsilons) {
+ val Array(single1) = df.stat.approxQuantile("singles", Array(q1), epsilon)
+ val Array(double2) = df.stat.approxQuantile("doubles", Array(q2), epsilon)
+ // Also make sure there is no regression by computing multiple quantiles at once.
+ val Array(d1, d2) = df.stat.approxQuantile("doubles", Array(q1, q2), epsilon)
+ val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon)
+
+ val errorSingle = 1000 * epsilon
+ val errorDouble = 2.0 * errorSingle
+
+ assert(math.abs(single1 - q1 * n) <= errorSingle)
+ assert(math.abs(double2 - 2 * q2 * n) <= errorDouble)
+ assert(math.abs(s1 - q1 * n) <= errorSingle)
+ assert(math.abs(s2 - q2 * n) <= errorSingle)
+ assert(math.abs(d1 - 2 * q1 * n) <= errorDouble)
+ assert(math.abs(d2 - 2 * q2 * n) <= errorDouble)
+
+ // Multiple columns
+ val Array(Array(ms1, ms2), Array(md1, md2)) =
+ df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon)
+
+ assert(math.abs(ms1 - q1 * n) <= errorSingle)
+ assert(math.abs(ms2 - q2 * n) <= errorSingle)
+ assert(math.abs(md1 - 2 * q1 * n) <= errorDouble)
+ assert(math.abs(md2 - 2 * q2 * n) <= errorDouble)
+ }
+
+ // quantile should be in the range [0.0, 1.0]
+ val e = intercept[IllegalArgumentException] {
+ df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1), epsilons.head)
+ }
+ assert(e.getMessage.contains("percentile should be in the range [0.0, 1.0]"))
+
+ // relativeError should be non-negative
+ val e2 = intercept[IllegalArgumentException] {
+ df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), -1.0)
+ }
+ assert(e2.getMessage.contains("Relative Error must be non-negative"))
+ }
+
+ test("covariance") {
+ val session = spark
+ import session.implicits._
+
+ val df =
+ Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")
+
+ val results = df.stat.cov("singles", "doubles")
+ assert(math.abs(results - 55.0 / 3) < 1e-12)
+ intercept[StatusRuntimeException] {
+ df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
+ }
+ val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b")
+ val decimalRes = decimalData.stat.cov("a", "b")
+ assert(math.abs(decimalRes) < 1e-12)
+ }
+
+ test("correlation") {
+ val session = spark
+ import session.implicits._
+
+ val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
+ val corr1 = df.stat.corr("a", "b", "pearson")
+ assert(math.abs(corr1 - 1.0) < 1e-12)
+ val corr2 = df.stat.corr("a", "c", "pearson")
+ assert(math.abs(corr2 + 1.0) < 1e-12)
+ val df2 = Seq.tabulate(20)(x => (x, x * x - 2 * x + 3.5)).toDF("a", "b")
+ val corr3 = df2.stat.corr("a", "b", "pearson")
+ assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
+ }
+
+ test("crosstab") {
+ val session = spark
+ import session.implicits._
+
+ val rng = new Random()
+ val data = Seq.tabulate(25)(_ => (rng.nextInt(5), rng.nextInt(10)))
+ val df = data.toDF("a", "b")
+ val crosstab = df.stat.crosstab("a", "b")
+ val columnNames = crosstab.schema.fieldNames
+ assert(columnNames(0) === "a_b")
+ // reduce by key
+ val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length)
+ val rows = crosstab.collect()
+ rows.foreach { row =>
+ val i = row.getString(0).toInt
+ for (col <- 1 until columnNames.length) {
+ val j = columnNames(col).toInt
+ assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong)
+ }
+ }
+ }
+
+ test("freqItems") {
+ val session = spark
+ import session.implicits._
+
+ val rows = Seq.tabulate(1000) { i =>
+ if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
+ }
+ val df = rows.toDF("numbers", "letters", "negDoubles")
+
+ val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
+ val items = results.collect().head
+ assert(items.getSeq[Int](0).contains(1))
+ assert(items.getSeq[String](1).contains(toLetter(1)))
+
+ val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
+ val items2 = singleColResults.collect().head
+ assert(items2.getSeq[Double](0).contains(-1.0))
+ }
+
+ test("sampleBy") {
+ val session = spark
+ import session.implicits._
+ val df = Seq("Bob", "Alice", "Nico", "Bob", "Alice").toDF("name")
+ val fractions = Map("Alice" -> 0.3, "Nico" -> 1.0)
+ val sampled = df.stat.sampleBy("name", fractions, 36L)
+ val rows = sampled.groupBy("name").count().orderBy("name").collect()
+ assert(rows.length == 1)
+ val row0 = rows(0)
+ assert(row0.getString(0) == "Nico")
+ assert(row0.getLong(1) == 1L)
+ }
+
+ test("countMinSketch") {
+ val df = spark.range(1000)
+
+ val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42)
+ assert(sketch1.totalCount() === 1000)
+ assert(sketch1.depth() === 10)
+ assert(sketch1.width() === 20)
+
+ val sketch = df.stat.countMinSketch("id", eps = 0.001, confidence = 0.99, seed = 42)
+ assert(sketch.totalCount() === 1000)
+ assert(sketch.relativeError() === 0.001)
+ assert(sketch.confidence() === 0.99 +- 5e-3)
+ }
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 3b69b02df5c..b59c783aec9 100755
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -2046,4 +2046,16 @@ class PlanGenerationTestSuite
.build()
simple.select(Column(com.google.protobuf.Any.pack(extension)))
}
+
+ test("crosstab") {
+ simple.stat.crosstab("a", "b")
+ }
+
+ test("freqItems") {
+ simple.stat.freqItems(Array("id", "a"), 0.1)
+ }
+
+ test("sampleBy") {
+ simple.stat.sampleBy("id", Map(0 -> 0.1, 1 -> 0.2), 0L)
+ }
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 548608f50b5..b684721243a 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -109,6 +109,7 @@ object CheckConnectJvmClientCompatibility {
IncludeByName("org.apache.spark.sql.ColumnName.*"),
IncludeByName("org.apache.spark.sql.DataFrame.*"),
IncludeByName("org.apache.spark.sql.DataFrameReader.*"),
+ IncludeByName("org.apache.spark.sql.DataFrameStatFunctions.*"),
IncludeByName("org.apache.spark.sql.DataFrameWriter.*"),
IncludeByName("org.apache.spark.sql.DataFrameWriterV2.*"),
IncludeByName("org.apache.spark.sql.Dataset.*"),
@@ -133,6 +134,10 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.jdbc"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameWriter.jdbc"),
+ // DataFrameStatFunctions
+ ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"),
+ ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.this"),
+
// Dataset
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.ofRows"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.DATASET_ID_TAG"),
@@ -144,7 +149,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.encoder"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.na"),
- ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.stat"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.joinWith"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.select"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"),
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/crosstab.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/crosstab.explain
new file mode 100644
index 00000000000..a30cd136e8d
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/crosstab.explain
@@ -0,0 +1,5 @@
+Project [a_b#0]
++- Project [a_b#0]
+ +- Aggregate [a_b#0], [a_b#0, pivotfirst(__pivot_col#0, count(1) AS count#0L, 0, 0) AS __pivot_count(1) AS count AS `count(1) AS count`#0]
+ +- Aggregate [CASE WHEN isnull(a#0) THEN null ELSE cast(a#0 as string) END, CASE WHEN isnull(b#0) THEN null ELSE regexp_replace(cast(b#0 as string), `, , 1) END], [CASE WHEN isnull(a#0) THEN null ELSE cast(a#0 as string) END AS a_b#0, CASE WHEN isnull(b#0) THEN null ELSE regexp_replace(cast(b#0 as string), `, , 1) END AS __pivot_col#0, count(1) AS count(1) AS count#0L]
+ +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/freqItems.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/freqItems.explain
new file mode 100644
index 00000000000..31ef46e2424
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/freqItems.explain
@@ -0,0 +1,2 @@
+Aggregate [collect_frequent_items(id#0L, 10, 0, 0) AS id_freqItems#0, collect_frequent_items(a#0, 10, 0, 0) AS a_freqItems#0]
++- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/sampleBy.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/sampleBy.explain
new file mode 100644
index 00000000000..64abbcf1b53
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/sampleBy.explain
@@ -0,0 +1,2 @@
+Filter UDF(id#0L, rand(0))
++- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/crosstab.json b/connector/connect/common/src/test/resources/query-tests/queries/crosstab.json
new file mode 100644
index 00000000000..755a6fa4dd2
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/crosstab.json
@@ -0,0 +1,17 @@
+{
+ "common": {
+ "planId": "1"
+ },
+ "crosstab": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+ }
+ },
+ "col1": "a",
+ "col2": "b"
+ }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/crosstab.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/crosstab.proto.bin
new file mode 100644
index 00000000000..c664cedb01c
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/crosstab.proto.bin differ
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/freqItems.json b/connector/connect/common/src/test/resources/query-tests/queries/freqItems.json
new file mode 100644
index 00000000000..8734722b354
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/freqItems.json
@@ -0,0 +1,17 @@
+{
+ "common": {
+ "planId": "1"
+ },
+ "freqItems": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+ }
+ },
+ "cols": ["id", "a"],
+ "support": 0.1
+ }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/freqItems.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/freqItems.proto.bin
new file mode 100644
index 00000000000..717f3d61ae9
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/freqItems.proto.bin differ
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.json b/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.json
new file mode 100644
index 00000000000..03fdd100753
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.json
@@ -0,0 +1,32 @@
+{
+ "common": {
+ "planId": "1"
+ },
+ "sampleBy": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+ }
+ },
+ "col": {
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "id"
+ }
+ },
+ "fractions": [{
+ "stratum": {
+ "integer": 0
+ },
+ "fraction": 0.1
+ }, {
+ "stratum": {
+ "integer": 1
+ },
+ "fraction": 0.2
+ }],
+ "seed": "0"
+ }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.proto.bin
new file mode 100644
index 00000000000..29773f18e0e
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.proto.bin differ
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org