You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/08 05:47:48 UTC

spark git commit: [SPARK-21100][SQL] Add summary method as alternative to describe that gives quartiles similar to Pandas

Repository: spark
Updated Branches:
  refs/heads/master a0fe32a21 -> e1a172c20


[SPARK-21100][SQL] Add summary method as alternative to describe that gives quartiles similar to Pandas

## What changes were proposed in this pull request?

Adds method `summary`  that allows user to specify which statistics and percentiles to calculate. By default it include the existing statistics from `describe` and quartiles (25th, 50th, and 75th percentiles) similar to Pandas. Also changes the implementation of `describe` to delegate to `summary`.

## How was this patch tested?

additional unit test

Author: Andrew Ray <ra...@gmail.com>

Closes #18307 from aray/SPARK-21100.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e1a172c2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e1a172c2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e1a172c2

Branch: refs/heads/master
Commit: e1a172c201d68406faa53b113518b10c879f1ff6
Parents: a0fe32a
Author: Andrew Ray <ra...@gmail.com>
Authored: Sat Jul 8 13:47:41 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Sat Jul 8 13:47:41 2017 +0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Dataset.scala    | 113 ++++++++++++-------
 .../sql/execution/stat/StatFunctions.scala      |  98 +++++++++++++++-
 .../org/apache/spark/sql/DataFrameSuite.scala   | 112 ++++++++++++++----
 3 files changed, 258 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e1a172c2/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b1638a2..5326b45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -38,18 +38,18 @@ import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.catalog.CatalogRelation
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.parser.ParseException
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
-import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.command._
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.python.EvaluatePython
+import org.apache.spark.sql.execution.stat.StatFunctions
 import org.apache.spark.sql.streaming.DataStreamWriter
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
@@ -224,7 +224,7 @@ class Dataset[T] private[sql](
     }
   }
 
-  private def aggregatableColumns: Seq[Expression] = {
+  private[sql] def aggregatableColumns: Seq[Expression] = {
     schema.fields
       .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType])
       .map { n =>
@@ -2161,9 +2161,9 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Computes statistics for numeric and string columns, including count, mean, stddev, min, and
-   * max. If no columns are given, this function computes statistics for all numerical or string
-   * columns.
+   * Computes basic statistics for numeric and string columns, including count, mean, stddev, min,
+   * and max. If no columns are given, this function computes statistics for all numerical or
+   * string columns.
    *
    * This function is meant for exploratory data analysis, as we make no guarantee about the
    * backward compatibility of the schema of the resulting Dataset. If you want to
@@ -2181,47 +2181,80 @@ class Dataset[T] private[sql](
    *   // max     92.0  192.0
    * }}}
    *
+   * Use [[summary]] for expanded statistics and control over which statistics to compute.
+   *
+   * @param cols Columns to compute statistics on.
+   *
    * @group action
    * @since 1.6.0
    */
   @scala.annotation.varargs
-  def describe(cols: String*): DataFrame = withPlan {
-
-    // The list of summary statistics to compute, in the form of expressions.
-    val statistics = List[(String, Expression => Expression)](
-      "count" -> ((child: Expression) => Count(child).toAggregateExpression()),
-      "mean" -> ((child: Expression) => Average(child).toAggregateExpression()),
-      "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()),
-      "min" -> ((child: Expression) => Min(child).toAggregateExpression()),
-      "max" -> ((child: Expression) => Max(child).toAggregateExpression()))
-
-    val outputCols =
-      (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList
-
-    val ret: Seq[Row] = if (outputCols.nonEmpty) {
-      val aggExprs = statistics.flatMap { case (_, colToAgg) =>
-        outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
-      }
-
-      val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
-
-      // Pivot the data so each summary is one row
-      row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
-        Row(statistic :: aggregation.toList: _*)
-      }
-    } else {
-      // If there are no output columns, just output a single column that contains the stats.
-      statistics.map { case (name, _) => Row(name) }
-    }
-
-    // All columns are string type
-    val schema = StructType(
-      StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes
-    // `toArray` forces materialization to make the seq serializable
-    LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)
+  def describe(cols: String*): DataFrame = {
+    val selected = if (cols.isEmpty) this else select(cols.head, cols.tail: _*)
+    selected.summary("count", "mean", "stddev", "min", "max")
   }
 
   /**
+   * Computes specified statistics for numeric and string columns. Available statistics are:
+   *
+   * - count
+   * - mean
+   * - stddev
+   * - min
+   * - max
+   * - arbitrary approximate percentiles specified as a percentage (eg, 75%)
+   *
+   * If no statistics are given, this function computes count, mean, stddev, min,
+   * approximate quartiles (percentiles at 25%, 50%, and 75%), and max.
+   *
+   * This function is meant for exploratory data analysis, as we make no guarantee about the
+   * backward compatibility of the schema of the resulting Dataset. If you want to
+   * programmatically compute summary statistics, use the `agg` function instead.
+   *
+   * {{{
+   *   ds.summary().show()
+   *
+   *   // output:
+   *   // summary age   height
+   *   // count   10.0  10.0
+   *   // mean    53.3  178.05
+   *   // stddev  11.6  15.7
+   *   // min     18.0  163.0
+   *   // 25%     24.0  176.0
+   *   // 50%     24.0  176.0
+   *   // 75%     32.0  180.0
+   *   // max     92.0  192.0
+   * }}}
+   *
+   * {{{
+   *   ds.summary("count", "min", "25%", "75%", "max").show()
+   *
+   *   // output:
+   *   // summary age   height
+   *   // count   10.0  10.0
+   *   // min     18.0  163.0
+   *   // 25%     24.0  176.0
+   *   // 75%     32.0  180.0
+   *   // max     92.0  192.0
+   * }}}
+   *
+   * To do a summary for specific columns first select them:
+   *
+   * {{{
+   *   ds.select("age", "height").summary().show()
+   * }}}
+   *
+   * See also [[describe]] for basic statistics.
+   *
+   * @param statistics Statistics from above list to be computed.
+   *
+   * @group action
+   * @since 2.3.0
+   */
+  @scala.annotation.varargs
+  def summary(statistics: String*): DataFrame = StatFunctions.summary(this, statistics.toSeq)
+
+  /**
    * Returns the first `n` rows.
    *
    * @note this method should only be used if the resulting array is expected to be small, as

http://git-wip-us.apache.org/repos/asf/spark/blob/e1a172c2/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 1debad0..436e18f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.stat
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
-import org.apache.spark.sql.catalyst.expressions.{Cast, GenericInternalRow}
+import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.catalyst.util.QuantileSummaries
+import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -220,4 +221,97 @@ object StatFunctions extends Logging {
 
     Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
   }
+
+  /** Calculate selected summary statistics for a dataset */
+  def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = {
+
+    val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
+    val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics
+
+    val hasPercentiles = selectedStatistics.exists(_.endsWith("%"))
+    val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) {
+      val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%"))
+      val percentiles = pStrings.map { p =>
+        try {
+          p.stripSuffix("%").toDouble / 100.0
+        } catch {
+          case e: NumberFormatException =>
+            throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
+        }
+      }
+      require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
+      (percentiles, pStrings, rest)
+    } else {
+      (Seq(), Seq(), selectedStatistics)
+    }
+
+
+    // The list of summary statistics to compute, in the form of expressions.
+    val availableStatistics = Map[String, Expression => Expression](
+      "count" -> ((child: Expression) => Count(child).toAggregateExpression()),
+      "mean" -> ((child: Expression) => Average(child).toAggregateExpression()),
+      "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()),
+      "min" -> ((child: Expression) => Min(child).toAggregateExpression()),
+      "max" -> ((child: Expression) => Max(child).toAggregateExpression()))
+
+    val statisticFns = remainingAggregates.map { agg =>
+      require(availableStatistics.contains(agg), s"$agg is not a recognised statistic")
+      agg -> availableStatistics(agg)
+    }
+
+    def percentileAgg(child: Expression): Expression =
+      new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_))))
+        .toAggregateExpression()
+
+    val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList
+
+    val ret: Seq[Row] = if (outputCols.nonEmpty) {
+      var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) =>
+        outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
+      }
+      if (hasPercentiles) {
+        aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs
+      }
+
+      val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
+
+      // Pivot the data so each summary is one row
+      val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq
+
+      val basicStats = if (hasPercentiles) grouped.tail else grouped
+
+      val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) =>
+        Row(statistic :: aggregation.toList: _*)
+      }
+
+      if (hasPercentiles) {
+        def nullSafeString(x: Any) = if (x == null) null else x.toString
+        val percentileRows = grouped.head
+          .map {
+            case a: Seq[Any] => a
+            case _ => Seq.fill(percentiles.length)(null: Any)
+          }
+          .transpose
+          .zip(percentileNames)
+          .map { case (values: Seq[Any], name) =>
+            Row(name :: values.map(nullSafeString).toList: _*)
+          }
+        (rows ++ percentileRows)
+          .sortWith((left, right) =>
+            selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0)))
+      } else {
+        rows
+      }
+    } else {
+      // If there are no output columns, just output a single column that contains the stats.
+      selectedStatistics.map(Row(_))
+    }
+
+    // All columns are string type
+    val schema = StructType(
+      StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes
+    // `toArray` forces materialization to make the seq serializable
+    Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq))
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e1a172c2/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 9ea9951..2c7051b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -28,8 +28,7 @@ import org.scalatest.Matchers._
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
 import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
@@ -663,13 +662,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
   }
 
-  test("describe") {
-    val describeTestData = Seq(
-      ("Bob", 16, 176),
-      ("Alice", 32, 164),
-      ("David", 60, 192),
-      ("Amy", 24, 180)).toDF("name", "age", "height")
+  private lazy val person2: DataFrame = Seq(
+    ("Bob", 16, 176),
+    ("Alice", 32, 164),
+    ("David", 60, 192),
+    ("Amy", 24, 180)).toDF("name", "age", "height")
 
+  test("describe") {
     val describeResult = Seq(
       Row("count", "4", "4", "4"),
       Row("mean", null, "33.0", "178.0"),
@@ -686,32 +685,99 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
 
     def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
 
-    val describeTwoCols = describeTestData.describe("name", "age", "height")
-    assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height"))
-    checkAnswer(describeTwoCols, describeResult)
-    // All aggregate value should have been cast to string
-    describeTwoCols.collect().foreach { row =>
-      assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass)
-      assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass)
-    }
-
-    val describeAllCols = describeTestData.describe()
+    val describeAllCols = person2.describe()
     assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height"))
     checkAnswer(describeAllCols, describeResult)
+    // All aggregate value should have been cast to string
+    describeAllCols.collect().foreach { row =>
+      row.toSeq.foreach { value =>
+        if (value != null) {
+          assert(value.isInstanceOf[String], "expected string but found " + value.getClass)
+        }
+      }
+    }
 
-    val describeOneCol = describeTestData.describe("age")
+    val describeOneCol = person2.describe("age")
     assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
     checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} )
 
-    val describeNoCol = describeTestData.select("name").describe()
-    assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name"))
-    checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} )
+    val describeNoCol = person2.select().describe()
+    assert(getSchemaAsSeq(describeNoCol) === Seq("summary"))
+    checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s)} )
 
-    val emptyDescription = describeTestData.limit(0).describe()
+    val emptyDescription = person2.limit(0).describe()
     assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
     checkAnswer(emptyDescription, emptyDescribeResult)
   }
 
+  test("summary") {
+    val summaryResult = Seq(
+      Row("count", "4", "4", "4"),
+      Row("mean", null, "33.0", "178.0"),
+      Row("stddev", null, "19.148542155126762", "11.547005383792516"),
+      Row("min", "Alice", "16", "164"),
+      Row("25%", null, "24.0", "176.0"),
+      Row("50%", null, "24.0", "176.0"),
+      Row("75%", null, "32.0", "180.0"),
+      Row("max", "David", "60", "192"))
+
+    val emptySummaryResult = Seq(
+      Row("count", "0", "0", "0"),
+      Row("mean", null, null, null),
+      Row("stddev", null, null, null),
+      Row("min", null, null, null),
+      Row("25%", null, null, null),
+      Row("50%", null, null, null),
+      Row("75%", null, null, null),
+      Row("max", null, null, null))
+
+    def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
+
+    val summaryAllCols = person2.summary()
+
+    assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height"))
+    checkAnswer(summaryAllCols, summaryResult)
+    // All aggregate value should have been cast to string
+    summaryAllCols.collect().foreach { row =>
+      row.toSeq.foreach { value =>
+        if (value != null) {
+          assert(value.isInstanceOf[String], "expected string but found " + value.getClass)
+        }
+      }
+    }
+
+    val summaryOneCol = person2.select("age").summary()
+    assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age"))
+    checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d)} )
+
+    val summaryNoCol = person2.select().summary()
+    assert(getSchemaAsSeq(summaryNoCol) === Seq("summary"))
+    checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s)} )
+
+    val emptyDescription = person2.limit(0).summary()
+    assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
+    checkAnswer(emptyDescription, emptySummaryResult)
+  }
+
+  test("summary advanced") {
+    val stats = Array("count", "50.01%", "max", "mean", "min", "25%")
+    val orderMatters = person2.summary(stats: _*)
+    assert(orderMatters.collect().map(_.getString(0)) === stats)
+
+    val onlyPercentiles = person2.summary("0.1%", "99.9%")
+    assert(onlyPercentiles.count() === 2)
+
+    val fooE = intercept[IllegalArgumentException] {
+      person2.summary("foo")
+    }
+    assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic")
+
+    val parseE = intercept[IllegalArgumentException] {
+      person2.summary("foo%")
+    }
+    assert(parseE.getMessage === "Unable to parse foo% as a percentile")
+  }
+
   test("apply on query results (SPARK-5462)") {
     val df = testData.sparkSession.sql("select key from testData")
     checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org