You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/03/06 02:13:13 UTC

[spark] branch branch-3.4 updated: [SPARK-42558][CONNECT] Implement `DataFrameStatFunctions` except `bloomFilter` functions

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 548baea5ee5 [SPARK-42558][CONNECT] Implement `DataFrameStatFunctions` except `bloomFilter` functions
548baea5ee5 is described below

commit 548baea5ee50d5a2b3b48a84c94e23127f5192bf
Author: yangjie01 <ya...@baidu.com>
AuthorDate: Sun Mar 5 22:12:51 2023 -0400

    [SPARK-42558][CONNECT] Implement `DataFrameStatFunctions` except `bloomFilter` functions
    
    ### What changes were proposed in this pull request?
    This pr aims partial implement DataFrameNaFunctions includes `approxQuantile`, `cov`, `corr`, `crosstab`, `freqItems`,  `sampleBy` and `countMinSketch`,  `bloomFilter` are not supported in this pr due to lack of protobuf message.
    
    ### Why are the changes needed?
    Spark connect Scala client API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    ### How was this patch tested?
    
    - Add new test.
    - checked `connect` and `connect-client-jvm` with Scala 2.13 manually
    
    Closes #40255 from LuciferYang/SPARK-42558.
    
    Authored-by: yangjie01 <ya...@baidu.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
    (cherry picked from commit 9308b0c75d9414ebf75c1dc3c576a545133eaaf0)
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../apache/spark/sql/DataFrameStatFunctions.scala  | 592 +++++++++++++++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  12 +
 .../org/apache/spark/sql/DataFrameStatSuite.scala  | 179 +++++++
 .../apache/spark/sql/PlanGenerationTestSuite.scala |  12 +
 .../CheckConnectJvmClientCompatibility.scala       |   6 +-
 .../query-tests/explain-results/crosstab.explain   |   5 +
 .../query-tests/explain-results/freqItems.explain  |   2 +
 .../query-tests/explain-results/sampleBy.explain   |   2 +
 .../resources/query-tests/queries/crosstab.json    |  17 +
 .../query-tests/queries/crosstab.proto.bin         | Bin 0 -> 55 bytes
 .../resources/query-tests/queries/freqItems.json   |  17 +
 .../query-tests/queries/freqItems.proto.bin        | Bin 0 -> 65 bytes
 .../resources/query-tests/queries/sampleBy.json    |  32 ++
 .../query-tests/queries/sampleBy.proto.bin         | Bin 0 -> 89 bytes
 14 files changed, 875 insertions(+), 1 deletion(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
new file mode 100644
index 00000000000..0d4372b8738
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -0,0 +1,592 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.{lang => jl, util => ju}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto.{Relation, StatSampleBy}
+import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder}
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.util.sketch.CountMinSketch
+
+/**
+ * Statistic functions for `DataFrame`s.
+ *
+ * @since 3.4.0
+ */
+final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, root: Relation) {
+
+  /**
+   * Calculates the approximate quantiles of a numerical column of a DataFrame.
+   *
+   * The result of this algorithm has the following deterministic bound: If the DataFrame has N
+   * elements and if we request the quantile at probability `p` up to error `err`, then the
+   * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is
+   * close to (p * N). More precisely,
+   *
+   * {{{
+   *   floor((p - err) * N) <= rank(x) <= ceil((p + err) * N)
+   * }}}
+   *
+   * This method implements a variation of the Greenwald-Khanna algorithm (with some speed
+   * optimizations). The algorithm was first present in <a
+   * href="https://doi.org/10.1145/375663.375670"> Space-efficient Online Computation of Quantile
+   * Summaries</a> by Greenwald and Khanna.
+   *
+   * @param col
+   *   the name of the numerical column
+   * @param probabilities
+   *   a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the
+   *   minimum, 0.5 is the median, 1 is the maximum.
+   * @param relativeError
+   *   The relative target precision to achieve (greater than or equal to 0). If set to zero, the
+   *   exact quantiles are computed, which could be very expensive. Note that values greater than
+   *   1 are accepted but give the same result as 1.
+   * @return
+   *   the approximate quantiles at the given probabilities
+   *
+   * @note
+   *   null and NaN values will be removed from the numerical column before calculation. If the
+   *   dataframe is empty or the column only contains null or NaN, an empty array is returned.
+   *
+   * @since 3.4.0
+   */
+  def approxQuantile(
+      col: String,
+      probabilities: Array[Double],
+      relativeError: Double): Array[Double] = {
+    approxQuantile(Array(col), probabilities, relativeError).head
+  }
+
+  /**
+   * Calculates the approximate quantiles of numerical columns of a DataFrame.
+   * @see
+   *   `approxQuantile(col:Str* approxQuantile)` for detailed description.
+   *
+   * @param cols
+   *   the names of the numerical columns
+   * @param probabilities
+   *   a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the
+   *   minimum, 0.5 is the median, 1 is the maximum.
+   * @param relativeError
+   *   The relative target precision to achieve (greater than or equal to 0). If set to zero, the
+   *   exact quantiles are computed, which could be very expensive. Note that values greater than
+   *   1 are accepted but give the same result as 1.
+   * @return
+   *   the approximate quantiles at the given probabilities of each column
+   *
+   * @note
+   *   null and NaN values will be ignored in numerical columns before calculation. For columns
+   *   only containing null or NaN values, an empty array is returned.
+   *
+   * @since 3.4.0
+   */
+  def approxQuantile(
+      cols: Array[String],
+      probabilities: Array[Double],
+      relativeError: Double): Array[Array[Double]] = {
+    require(
+      probabilities.forall(p => p >= 0.0 && p <= 1.0),
+      "percentile should be in the range [0.0, 1.0]")
+    require(relativeError >= 0, s"Relative Error must be non-negative but got $relativeError")
+    sparkSession
+      .newDataset(approxQuantileResultEncoder) { builder =>
+        val approxQuantileBuilder = builder.getApproxQuantileBuilder
+          .setInput(root)
+          .setRelativeError(relativeError)
+        cols.foreach(approxQuantileBuilder.addCols)
+        probabilities.foreach(approxQuantileBuilder.addProbabilities)
+      }
+      .head()
+  }
+
+  /**
+   * Calculate the sample covariance of two numerical columns of a DataFrame.
+   * @param col1
+   *   the name of the first column
+   * @param col2
+   *   the name of the second column
+   * @return
+   *   the covariance of the two columns.
+   *
+   * {{{
+   *    val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
+   *      .withColumn("rand2", rand(seed=27))
+   *    df.stat.cov("rand1", "rand2")
+   *    res1: Double = 0.065...
+   * }}}
+   *
+   * @since 3.4.0
+   */
+  def cov(col1: String, col2: String): Double = {
+    sparkSession
+      .newDataset(PrimitiveDoubleEncoder) { builder =>
+        builder.getCovBuilder.setInput(root).setCol1(col1).setCol2(col2)
+      }
+      .head()
+  }
+
+  /**
+   * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
+   * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
+   * MLlib's Statistics.
+   *
+   * @param col1
+   *   the name of the column
+   * @param col2
+   *   the name of the column to calculate the correlation against
+   * @return
+   *   The Pearson Correlation Coefficient as a Double.
+   *
+   * {{{
+   *    val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
+   *      .withColumn("rand2", rand(seed=27))
+   *    df.stat.corr("rand1", "rand2")
+   *    res1: Double = 0.613...
+   * }}}
+   *
+   * @since 3.4.0
+   */
+  def corr(col1: String, col2: String, method: String): Double = {
+    require(
+      method == "pearson",
+      "Currently only the calculation of the Pearson Correlation " +
+        "coefficient is supported.")
+    sparkSession
+      .newDataset(PrimitiveDoubleEncoder) { builder =>
+        builder.getCorrBuilder.setInput(root).setCol1(col1).setCol2(col2)
+      }
+      .head()
+  }
+
+  /**
+   * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame.
+   *
+   * @param col1
+   *   the name of the column
+   * @param col2
+   *   the name of the column to calculate the correlation against
+   * @return
+   *   The Pearson Correlation Coefficient as a Double.
+   *
+   * {{{
+   *    val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
+   *      .withColumn("rand2", rand(seed=27))
+   *    df.stat.corr("rand1", "rand2", "pearson")
+   *    res1: Double = 0.613...
+   * }}}
+   *
+   * @since 3.4.0
+   */
+  def corr(col1: String, col2: String): Double = {
+    corr(col1, col2, "pearson")
+  }
+
+  /**
+   * Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
+   * The first column of each row will be the distinct values of `col1` and the column names will
+   * be the distinct values of `col2`. The name of the first column will be `col1_col2`. Counts
+   * will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts.
+   * Null elements will be replaced by "null", and back ticks will be dropped from elements if
+   * they exist.
+   *
+   * @param col1
+   *   The name of the first column. Distinct items will make the first item of each row.
+   * @param col2
+   *   The name of the second column. Distinct items will make the column names of the DataFrame.
+   * @return
+   *   A DataFrame containing for the contingency table.
+   *
+   * {{{
+   *    val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3)))
+   *      .toDF("key", "value")
+   *    val ct = df.stat.crosstab("key", "value")
+   *    ct.show()
+   *    +---------+---+---+---+
+   *    |key_value|  1|  2|  3|
+   *    +---------+---+---+---+
+   *    |        2|  2|  0|  1|
+   *    |        1|  1|  1|  0|
+   *    |        3|  0|  1|  1|
+   *    +---------+---+---+---+
+   * }}}
+   *
+   * @since 3.4.0
+   */
+  def crosstab(col1: String, col2: String): DataFrame = {
+    sparkSession.newDataFrame { builder =>
+      builder.getCrosstabBuilder.setInput(root).setCol1(col1).setCol2(col2)
+    }
+  }
+
+  /**
+   * Finding frequent items for columns, possibly with false positives. Using the frequent element
+   * count algorithm described in <a href="https://doi.org/10.1145/762471.762473">here</a>,
+   * proposed by Karp, Schenker, and Papadimitriou. The `support` should be greater than 1e-4.
+   *
+   * This function is meant for exploratory data analysis, as we make no guarantee about the
+   * backward compatibility of the schema of the resulting `DataFrame`.
+   *
+   * @param cols
+   *   the names of the columns to search frequent items in.
+   * @param support
+   *   The minimum frequency for an item to be considered `frequent`. Should be greater than 1e-4.
+   * @return
+   *   A Local DataFrame with the Array of frequent items for each column.
+   *
+   * {{{
+   *    val rows = Seq.tabulate(100) { i =>
+   *      if (i % 2 == 0) (1, -1.0) else (i, i * -1.0)
+   *    }
+   *    val df = spark.createDataFrame(rows).toDF("a", "b")
+   *    // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns
+   *    // "a" and "b"
+   *    val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4)
+   *    freqSingles.show()
+   *    +-----------+-------------+
+   *    |a_freqItems|  b_freqItems|
+   *    +-----------+-------------+
+   *    |    [1, 99]|[-1.0, -99.0]|
+   *    +-----------+-------------+
+   *    // find the pair of items with a frequency greater than 0.1 in columns "a" and "b"
+   *    val pairDf = df.select(struct("a", "b").as("a-b"))
+   *    val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1)
+   *    freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show()
+   *    +----------+
+   *    |   freq_ab|
+   *    +----------+
+   *    |  [1,-1.0]|
+   *    |   ...    |
+   *    +----------+
+   * }}}
+   *
+   * @since 3.4.0
+   */
+  def freqItems(cols: Array[String], support: Double): DataFrame = {
+    sparkSession.newDataFrame { builder =>
+      val freqItemsBuilder = builder.getFreqItemsBuilder.setInput(root).setSupport(support)
+      cols.foreach(freqItemsBuilder.addCols)
+    }
+  }
+
+  /**
+   * Finding frequent items for columns, possibly with false positives. Using the frequent element
+   * count algorithm described in <a href="https://doi.org/10.1145/762471.762473">here</a>,
+   * proposed by Karp, Schenker, and Papadimitriou. Uses a `default` support of 1%.
+   *
+   * This function is meant for exploratory data analysis, as we make no guarantee about the
+   * backward compatibility of the schema of the resulting `DataFrame`.
+   *
+   * @param cols
+   *   the names of the columns to search frequent items in.
+   * @return
+   *   A Local DataFrame with the Array of frequent items for each column.
+   *
+   * @since 3.4.0
+   */
+  def freqItems(cols: Array[String]): DataFrame = {
+    freqItems(cols, 0.01)
+  }
+
+  /**
+   * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the
+   * frequent element count algorithm described in <a
+   * href="https://doi.org/10.1145/762471.762473">here</a>, proposed by Karp, Schenker, and
+   * Papadimitriou.
+   *
+   * This function is meant for exploratory data analysis, as we make no guarantee about the
+   * backward compatibility of the schema of the resulting `DataFrame`.
+   *
+   * @param cols
+   *   the names of the columns to search frequent items in.
+   * @return
+   *   A Local DataFrame with the Array of frequent items for each column.
+   *
+   * {{{
+   *    val rows = Seq.tabulate(100) { i =>
+   *      if (i % 2 == 0) (1, -1.0) else (i, i * -1.0)
+   *    }
+   *    val df = spark.createDataFrame(rows).toDF("a", "b")
+   *    // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns
+   *    // "a" and "b"
+   *    val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4)
+   *    freqSingles.show()
+   *    +-----------+-------------+
+   *    |a_freqItems|  b_freqItems|
+   *    +-----------+-------------+
+   *    |    [1, 99]|[-1.0, -99.0]|
+   *    +-----------+-------------+
+   *    // find the pair of items with a frequency greater than 0.1 in columns "a" and "b"
+   *    val pairDf = df.select(struct("a", "b").as("a-b"))
+   *    val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1)
+   *    freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show()
+   *    +----------+
+   *    |   freq_ab|
+   *    +----------+
+   *    |  [1,-1.0]|
+   *    |   ...    |
+   *    +----------+
+   * }}}
+   *
+   * @since 3.4.0
+   */
+  def freqItems(cols: Seq[String], support: Double): DataFrame = {
+    freqItems(cols.toArray, support)
+  }
+
+  /**
+   * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the
+   * frequent element count algorithm described in <a
+   * href="https://doi.org/10.1145/762471.762473">here</a>, proposed by Karp, Schenker, and
+   * Papadimitriou. Uses a `default` support of 1%.
+   *
+   * This function is meant for exploratory data analysis, as we make no guarantee about the
+   * backward compatibility of the schema of the resulting `DataFrame`.
+   *
+   * @param cols
+   *   the names of the columns to search frequent items in.
+   * @return
+   *   A Local DataFrame with the Array of frequent items for each column.
+   *
+   * @since 3.4.0
+   */
+  def freqItems(cols: Seq[String]): DataFrame = {
+    freqItems(cols.toArray, 0.01)
+  }
+
+  /**
+   * Returns a stratified sample without replacement based on the fraction given on each stratum.
+   * @param col
+   *   column that defines strata
+   * @param fractions
+   *   sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+   *   zero.
+   * @param seed
+   *   random seed
+   * @tparam T
+   *   stratum type
+   * @return
+   *   a new `DataFrame` that represents the stratified sample
+   *
+   * {{{
+   *    val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2),
+   *      (3, 3))).toDF("key", "value")
+   *    val fractions = Map(1 -> 1.0, 3 -> 0.5)
+   *    df.stat.sampleBy("key", fractions, 36L).show()
+   *    +---+-----+
+   *    |key|value|
+   *    +---+-----+
+   *    |  1|    1|
+   *    |  1|    2|
+   *    |  3|    2|
+   *    +---+-----+
+   * }}}
+   *
+   * @since 3.4.0
+   */
+  def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = {
+    sampleBy(Column(col), fractions, seed)
+  }
+
+  /**
+   * Returns a stratified sample without replacement based on the fraction given on each stratum.
+   * @param col
+   *   column that defines strata
+   * @param fractions
+   *   sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+   *   zero.
+   * @param seed
+   *   random seed
+   * @tparam T
+   *   stratum type
+   * @return
+   *   a new `DataFrame` that represents the stratified sample
+   *
+   * @since 3.4.0
+   */
+  def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
+    sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+  }
+
+  /**
+   * Returns a stratified sample without replacement based on the fraction given on each stratum.
+   * @param col
+   *   column that defines strata
+   * @param fractions
+   *   sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+   *   zero.
+   * @param seed
+   *   random seed
+   * @tparam T
+   *   stratum type
+   * @return
+   *   a new `DataFrame` that represents the stratified sample
+   *
+   * The stratified sample can be performed over multiple columns:
+   * {{{
+   *    import org.apache.spark.sql.Row
+   *    import org.apache.spark.sql.functions.struct
+   *
+   *    val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17),
+   *      ("Alice", 10))).toDF("name", "age")
+   *    val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0)
+   *    df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show()
+   *    +-----+---+
+   *    | name|age|
+   *    +-----+---+
+   *    | Nico|  8|
+   *    |Alice| 10|
+   *    +-----+---+
+   * }}}
+   *
+   * @since 3.4.0
+   */
+  def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = {
+    require(
+      fractions.values.forall(p => p >= 0.0 && p <= 1.0),
+      s"Fractions must be in [0, 1], but got $fractions.")
+    sparkSession.newDataFrame { builder =>
+      val sampleByBuilder = builder.getSampleByBuilder
+        .setInput(root)
+        .setCol(col.expr)
+        .setSeed(seed)
+      fractions.foreach { case (k, v) =>
+        sampleByBuilder.addFractions(
+          StatSampleBy.Fraction
+            .newBuilder()
+            .setStratum(lit(k).expr.getLiteral)
+            .setFraction(v))
+      }
+    }
+  }
+
+  /**
+   * (Java-specific) Returns a stratified sample without replacement based on the fraction given
+   * on each stratum.
+   * @param col
+   *   column that defines strata
+   * @param fractions
+   *   sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+   *   zero.
+   * @param seed
+   *   random seed
+   * @tparam T
+   *   stratum type
+   * @return
+   *   a new `DataFrame` that represents the stratified sample
+   *
+   * @since 3.4.0
+   */
+  def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
+    sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+  }
+
+  /**
+   * Builds a Count-min Sketch over a specified column.
+   *
+   * @param colName
+   *   name of the column over which the sketch is built
+   * @param depth
+   *   depth of the sketch
+   * @param width
+   *   width of the sketch
+   * @param seed
+   *   random seed
+   * @return
+   *   a `CountMinSketch` over column `colName`
+   * @since 3.4.0
+   */
+  def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = {
+    countMinSketch(Column(colName), depth, width, seed)
+  }
+
+  /**
+   * Builds a Count-min Sketch over a specified column.
+   *
+   * @param colName
+   *   name of the column over which the sketch is built
+   * @param eps
+   *   relative error of the sketch
+   * @param confidence
+   *   confidence of the sketch
+   * @param seed
+   *   random seed
+   * @return
+   *   a `CountMinSketch` over column `colName`
+   * @since 3.4.0
+   */
+  def countMinSketch(
+      colName: String,
+      eps: Double,
+      confidence: Double,
+      seed: Int): CountMinSketch = {
+    countMinSketch(Column(colName), eps, confidence, seed)
+  }
+
+  /**
+   * Builds a Count-min Sketch over a specified column.
+   *
+   * @param col
+   *   the column over which the sketch is built
+   * @param depth
+   *   depth of the sketch
+   * @param width
+   *   width of the sketch
+   * @param seed
+   *   random seed
+   * @return
+   *   a `CountMinSketch` over column `colName`
+   * @since 3.4.0
+   */
+  def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = {
+    countMinSketch(col, eps = 2.0 / width, confidence = 1 - 1 / Math.pow(2, depth), seed)
+  }
+
+  /**
+   * Builds a Count-min Sketch over a specified column.
+   *
+   * @param col
+   *   the column over which the sketch is built
+   * @param eps
+   *   relative error of the sketch
+   * @param confidence
+   *   confidence of the sketch
+   * @param seed
+   *   random seed
+   * @return
+   *   a `CountMinSketch` over column `colName`
+   * @since 3.4.0
+   */
+  def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
+    val agg = Column.fn("count_min_sketch", col, lit(eps), lit(confidence), lit(seed))
+    val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
+      builder.getProjectBuilder
+        .setInput(root)
+        .addExpressions(agg.expr)
+    }
+    CountMinSketch.readFrom(ds.head())
+  }
+}
+
+private object DataFrameStatFunctions {
+  private val approxQuantileResultEncoder: ArrayEncoder[Array[Double]] =
+    ArrayEncoder(ArrayEncoder(PrimitiveDoubleEncoder, containsNull = false), containsNull = false)
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index e264f1c0c0c..588e62768ac 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -534,6 +534,18 @@ class Dataset[T] private[sql] (
     }
   }
 
+  /**
+   * Returns a [[DataFrameStatFunctions]] for working statistic functions support.
+   * {{{
+   *   // Finding frequent items in column with name 'a'.
+   *   ds.stat.freqItems(Seq("a"))
+   * }}}
+   *
+   * @group untypedrel
+   * @since 3.4.0
+   */
+  def stat: DataFrameStatFunctions = new DataFrameStatFunctions(sparkSession, plan.getRoot)
+
   private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): DataFrame = {
     sparkSession.newDataFrame { builder =>
       val joinBuilder = builder.getJoinBuilder
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
new file mode 100644
index 00000000000..aea31005f3b
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Random
+
+import io.grpc.StatusRuntimeException
+import org.scalatest.matchers.must.Matchers._
+
+import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+
+class DataFrameStatSuite extends RemoteSparkSession {
+  private def toLetter(i: Int): String = (i + 97).toChar.toString
+
+  test("approxQuantile") {
+    val session = spark
+    import session.implicits._
+
+    val n = 1000
+    val df = Seq.tabulate(n + 1)(i => (i, 2.0 * i)).toDF("singles", "doubles")
+
+    val q1 = 0.5
+    val q2 = 0.8
+    val epsilons = List(0.1, 0.05, 0.001)
+
+    for (epsilon <- epsilons) {
+      val Array(single1) = df.stat.approxQuantile("singles", Array(q1), epsilon)
+      val Array(double2) = df.stat.approxQuantile("doubles", Array(q2), epsilon)
+      // Also make sure there is no regression by computing multiple quantiles at once.
+      val Array(d1, d2) = df.stat.approxQuantile("doubles", Array(q1, q2), epsilon)
+      val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon)
+
+      val errorSingle = 1000 * epsilon
+      val errorDouble = 2.0 * errorSingle
+
+      assert(math.abs(single1 - q1 * n) <= errorSingle)
+      assert(math.abs(double2 - 2 * q2 * n) <= errorDouble)
+      assert(math.abs(s1 - q1 * n) <= errorSingle)
+      assert(math.abs(s2 - q2 * n) <= errorSingle)
+      assert(math.abs(d1 - 2 * q1 * n) <= errorDouble)
+      assert(math.abs(d2 - 2 * q2 * n) <= errorDouble)
+
+      // Multiple columns
+      val Array(Array(ms1, ms2), Array(md1, md2)) =
+        df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon)
+
+      assert(math.abs(ms1 - q1 * n) <= errorSingle)
+      assert(math.abs(ms2 - q2 * n) <= errorSingle)
+      assert(math.abs(md1 - 2 * q1 * n) <= errorDouble)
+      assert(math.abs(md2 - 2 * q2 * n) <= errorDouble)
+    }
+
+    // quantile should be in the range [0.0, 1.0]
+    val e = intercept[IllegalArgumentException] {
+      df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1), epsilons.head)
+    }
+    assert(e.getMessage.contains("percentile should be in the range [0.0, 1.0]"))
+
+    // relativeError should be non-negative
+    val e2 = intercept[IllegalArgumentException] {
+      df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), -1.0)
+    }
+    assert(e2.getMessage.contains("Relative Error must be non-negative"))
+  }
+
+  test("covariance") {
+    val session = spark
+    import session.implicits._
+
+    val df =
+      Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")
+
+    val results = df.stat.cov("singles", "doubles")
+    assert(math.abs(results - 55.0 / 3) < 1e-12)
+    intercept[StatusRuntimeException] {
+      df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
+    }
+    val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b")
+    val decimalRes = decimalData.stat.cov("a", "b")
+    assert(math.abs(decimalRes) < 1e-12)
+  }
+
+  test("correlation") {
+    val session = spark
+    import session.implicits._
+
+    val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
+    val corr1 = df.stat.corr("a", "b", "pearson")
+    assert(math.abs(corr1 - 1.0) < 1e-12)
+    val corr2 = df.stat.corr("a", "c", "pearson")
+    assert(math.abs(corr2 + 1.0) < 1e-12)
+    val df2 = Seq.tabulate(20)(x => (x, x * x - 2 * x + 3.5)).toDF("a", "b")
+    val corr3 = df2.stat.corr("a", "b", "pearson")
+    assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
+  }
+
+  test("crosstab") {
+    val session = spark
+    import session.implicits._
+
+    val rng = new Random()
+    val data = Seq.tabulate(25)(_ => (rng.nextInt(5), rng.nextInt(10)))
+    val df = data.toDF("a", "b")
+    val crosstab = df.stat.crosstab("a", "b")
+    val columnNames = crosstab.schema.fieldNames
+    assert(columnNames(0) === "a_b")
+    // reduce by key
+    val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length)
+    val rows = crosstab.collect()
+    rows.foreach { row =>
+      val i = row.getString(0).toInt
+      for (col <- 1 until columnNames.length) {
+        val j = columnNames(col).toInt
+        assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong)
+      }
+    }
+  }
+
+  test("freqItems") {
+    val session = spark
+    import session.implicits._
+
+    val rows = Seq.tabulate(1000) { i =>
+      if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
+    }
+    val df = rows.toDF("numbers", "letters", "negDoubles")
+
+    val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
+    val items = results.collect().head
+    assert(items.getSeq[Int](0).contains(1))
+    assert(items.getSeq[String](1).contains(toLetter(1)))
+
+    val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
+    val items2 = singleColResults.collect().head
+    assert(items2.getSeq[Double](0).contains(-1.0))
+  }
+
+  test("sampleBy") {
+    val session = spark
+    import session.implicits._
+    val df = Seq("Bob", "Alice", "Nico", "Bob", "Alice").toDF("name")
+    val fractions = Map("Alice" -> 0.3, "Nico" -> 1.0)
+    val sampled = df.stat.sampleBy("name", fractions, 36L)
+    val rows = sampled.groupBy("name").count().orderBy("name").collect()
+    assert(rows.length == 1)
+    val row0 = rows(0)
+    assert(row0.getString(0) == "Nico")
+    assert(row0.getLong(1) == 1L)
+  }
+
+  test("countMinSketch") {
+    val df = spark.range(1000)
+
+    val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42)
+    assert(sketch1.totalCount() === 1000)
+    assert(sketch1.depth() === 10)
+    assert(sketch1.width() === 20)
+
+    val sketch = df.stat.countMinSketch("id", eps = 0.001, confidence = 0.99, seed = 42)
+    assert(sketch.totalCount() === 1000)
+    assert(sketch.relativeError() === 0.001)
+    assert(sketch.confidence() === 0.99 +- 5e-3)
+  }
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 3b69b02df5c..b59c783aec9 100755
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -2046,4 +2046,16 @@ class PlanGenerationTestSuite
       .build()
     simple.select(Column(com.google.protobuf.Any.pack(extension)))
   }
+
+  test("crosstab") {
+    simple.stat.crosstab("a", "b")
+  }
+
+  test("freqItems") {
+    simple.stat.freqItems(Array("id", "a"), 0.1)
+  }
+
+  test("sampleBy") {
+    simple.stat.sampleBy("id", Map(0 -> 0.1, 1 -> 0.2), 0L)
+  }
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 548608f50b5..b684721243a 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -109,6 +109,7 @@ object CheckConnectJvmClientCompatibility {
       IncludeByName("org.apache.spark.sql.ColumnName.*"),
       IncludeByName("org.apache.spark.sql.DataFrame.*"),
       IncludeByName("org.apache.spark.sql.DataFrameReader.*"),
+      IncludeByName("org.apache.spark.sql.DataFrameStatFunctions.*"),
       IncludeByName("org.apache.spark.sql.DataFrameWriter.*"),
       IncludeByName("org.apache.spark.sql.DataFrameWriterV2.*"),
       IncludeByName("org.apache.spark.sql.Dataset.*"),
@@ -133,6 +134,10 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.jdbc"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameWriter.jdbc"),
 
+      // DataFrameStatFunctions
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.this"),
+
       // Dataset
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.ofRows"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.DATASET_ID_TAG"),
@@ -144,7 +149,6 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.encoder"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.na"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.stat"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.joinWith"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.select"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"),
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/crosstab.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/crosstab.explain
new file mode 100644
index 00000000000..a30cd136e8d
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/crosstab.explain
@@ -0,0 +1,5 @@
+Project [a_b#0]
++- Project [a_b#0]
+   +- Aggregate [a_b#0], [a_b#0, pivotfirst(__pivot_col#0, count(1) AS count#0L, 0, 0) AS __pivot_count(1) AS count AS `count(1) AS count`#0]
+      +- Aggregate [CASE WHEN isnull(a#0) THEN null ELSE cast(a#0 as string) END, CASE WHEN isnull(b#0) THEN null ELSE regexp_replace(cast(b#0 as string), `, , 1) END], [CASE WHEN isnull(a#0) THEN null ELSE cast(a#0 as string) END AS a_b#0, CASE WHEN isnull(b#0) THEN null ELSE regexp_replace(cast(b#0 as string), `, , 1) END AS __pivot_col#0, count(1) AS count(1) AS count#0L]
+         +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/freqItems.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/freqItems.explain
new file mode 100644
index 00000000000..31ef46e2424
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/freqItems.explain
@@ -0,0 +1,2 @@
+Aggregate [collect_frequent_items(id#0L, 10, 0, 0) AS id_freqItems#0, collect_frequent_items(a#0, 10, 0, 0) AS a_freqItems#0]
++- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/sampleBy.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/sampleBy.explain
new file mode 100644
index 00000000000..64abbcf1b53
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/sampleBy.explain
@@ -0,0 +1,2 @@
+Filter UDF(id#0L, rand(0))
++- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/crosstab.json b/connector/connect/common/src/test/resources/query-tests/queries/crosstab.json
new file mode 100644
index 00000000000..755a6fa4dd2
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/crosstab.json
@@ -0,0 +1,17 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "crosstab": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "col1": "a",
+    "col2": "b"
+  }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/crosstab.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/crosstab.proto.bin
new file mode 100644
index 00000000000..c664cedb01c
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/crosstab.proto.bin differ
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/freqItems.json b/connector/connect/common/src/test/resources/query-tests/queries/freqItems.json
new file mode 100644
index 00000000000..8734722b354
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/freqItems.json
@@ -0,0 +1,17 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "freqItems": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "cols": ["id", "a"],
+    "support": 0.1
+  }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/freqItems.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/freqItems.proto.bin
new file mode 100644
index 00000000000..717f3d61ae9
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/freqItems.proto.bin differ
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.json b/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.json
new file mode 100644
index 00000000000..03fdd100753
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.json
@@ -0,0 +1,32 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "sampleBy": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "col": {
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "id"
+      }
+    },
+    "fractions": [{
+      "stratum": {
+        "integer": 0
+      },
+      "fraction": 0.1
+    }, {
+      "stratum": {
+        "integer": 1
+      },
+      "fraction": 0.2
+    }],
+    "seed": "0"
+  }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.proto.bin
new file mode 100644
index 00000000000..29773f18e0e
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/sampleBy.proto.bin differ


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org