You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/08 05:47:48 UTC
spark git commit: [SPARK-21100][SQL] Add summary method as
alternative to describe that gives quartiles similar to Pandas
Repository: spark
Updated Branches:
refs/heads/master a0fe32a21 -> e1a172c20
[SPARK-21100][SQL] Add summary method as alternative to describe that gives quartiles similar to Pandas
## What changes were proposed in this pull request?
Adds method `summary` that allows user to specify which statistics and percentiles to calculate. By default it include the existing statistics from `describe` and quartiles (25th, 50th, and 75th percentiles) similar to Pandas. Also changes the implementation of `describe` to delegate to `summary`.
## How was this patch tested?
additional unit test
Author: Andrew Ray <ra...@gmail.com>
Closes #18307 from aray/SPARK-21100.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e1a172c2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e1a172c2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e1a172c2
Branch: refs/heads/master
Commit: e1a172c201d68406faa53b113518b10c879f1ff6
Parents: a0fe32a
Author: Andrew Ray <ra...@gmail.com>
Authored: Sat Jul 8 13:47:41 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Sat Jul 8 13:47:41 2017 +0800
----------------------------------------------------------------------
.../scala/org/apache/spark/sql/Dataset.scala | 113 ++++++++++++-------
.../sql/execution/stat/StatFunctions.scala | 98 +++++++++++++++-
.../org/apache/spark/sql/DataFrameSuite.scala | 112 ++++++++++++++----
3 files changed, 258 insertions(+), 65 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/e1a172c2/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b1638a2..5326b45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -38,18 +38,18 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.CatalogRelation
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
-import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.python.EvaluatePython
+import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
@@ -224,7 +224,7 @@ class Dataset[T] private[sql](
}
}
- private def aggregatableColumns: Seq[Expression] = {
+ private[sql] def aggregatableColumns: Seq[Expression] = {
schema.fields
.filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType])
.map { n =>
@@ -2161,9 +2161,9 @@ class Dataset[T] private[sql](
}
/**
- * Computes statistics for numeric and string columns, including count, mean, stddev, min, and
- * max. If no columns are given, this function computes statistics for all numerical or string
- * columns.
+ * Computes basic statistics for numeric and string columns, including count, mean, stddev, min,
+ * and max. If no columns are given, this function computes statistics for all numerical or
+ * string columns.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting Dataset. If you want to
@@ -2181,47 +2181,80 @@ class Dataset[T] private[sql](
* // max 92.0 192.0
* }}}
*
+ * Use [[summary]] for expanded statistics and control over which statistics to compute.
+ *
+ * @param cols Columns to compute statistics on.
+ *
* @group action
* @since 1.6.0
*/
@scala.annotation.varargs
- def describe(cols: String*): DataFrame = withPlan {
-
- // The list of summary statistics to compute, in the form of expressions.
- val statistics = List[(String, Expression => Expression)](
- "count" -> ((child: Expression) => Count(child).toAggregateExpression()),
- "mean" -> ((child: Expression) => Average(child).toAggregateExpression()),
- "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()),
- "min" -> ((child: Expression) => Min(child).toAggregateExpression()),
- "max" -> ((child: Expression) => Max(child).toAggregateExpression()))
-
- val outputCols =
- (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList
-
- val ret: Seq[Row] = if (outputCols.nonEmpty) {
- val aggExprs = statistics.flatMap { case (_, colToAgg) =>
- outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
- }
-
- val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
-
- // Pivot the data so each summary is one row
- row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
- Row(statistic :: aggregation.toList: _*)
- }
- } else {
- // If there are no output columns, just output a single column that contains the stats.
- statistics.map { case (name, _) => Row(name) }
- }
-
- // All columns are string type
- val schema = StructType(
- StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes
- // `toArray` forces materialization to make the seq serializable
- LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)
+ def describe(cols: String*): DataFrame = {
+ val selected = if (cols.isEmpty) this else select(cols.head, cols.tail: _*)
+ selected.summary("count", "mean", "stddev", "min", "max")
}
/**
+ * Computes specified statistics for numeric and string columns. Available statistics are:
+ *
+ * - count
+ * - mean
+ * - stddev
+ * - min
+ * - max
+ * - arbitrary approximate percentiles specified as a percentage (eg, 75%)
+ *
+ * If no statistics are given, this function computes count, mean, stddev, min,
+ * approximate quartiles (percentiles at 25%, 50%, and 75%), and max.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting Dataset. If you want to
+ * programmatically compute summary statistics, use the `agg` function instead.
+ *
+ * {{{
+ * ds.summary().show()
+ *
+ * // output:
+ * // summary age height
+ * // count 10.0 10.0
+ * // mean 53.3 178.05
+ * // stddev 11.6 15.7
+ * // min 18.0 163.0
+ * // 25% 24.0 176.0
+ * // 50% 24.0 176.0
+ * // 75% 32.0 180.0
+ * // max 92.0 192.0
+ * }}}
+ *
+ * {{{
+ * ds.summary("count", "min", "25%", "75%", "max").show()
+ *
+ * // output:
+ * // summary age height
+ * // count 10.0 10.0
+ * // min 18.0 163.0
+ * // 25% 24.0 176.0
+ * // 75% 32.0 180.0
+ * // max 92.0 192.0
+ * }}}
+ *
+ * To do a summary for specific columns first select them:
+ *
+ * {{{
+ * ds.select("age", "height").summary().show()
+ * }}}
+ *
+ * See also [[describe]] for basic statistics.
+ *
+ * @param statistics Statistics from above list to be computed.
+ *
+ * @group action
+ * @since 2.3.0
+ */
+ @scala.annotation.varargs
+ def summary(statistics: String*): DataFrame = StatFunctions.summary(this, statistics.toSeq)
+
+ /**
* Returns the first `n` rows.
*
* @note this method should only be used if the resulting array is expected to be small, as
http://git-wip-us.apache.org/repos/asf/spark/blob/e1a172c2/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 1debad0..436e18f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.stat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
-import org.apache.spark.sql.catalyst.expressions.{Cast, GenericInternalRow}
+import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.catalyst.util.QuantileSummaries
+import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -220,4 +221,97 @@ object StatFunctions extends Logging {
Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
}
+
+ /** Calculate selected summary statistics for a dataset */
+ def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = {
+
+ val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
+ val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics
+
+ val hasPercentiles = selectedStatistics.exists(_.endsWith("%"))
+ val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) {
+ val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%"))
+ val percentiles = pStrings.map { p =>
+ try {
+ p.stripSuffix("%").toDouble / 100.0
+ } catch {
+ case e: NumberFormatException =>
+ throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
+ }
+ }
+ require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
+ (percentiles, pStrings, rest)
+ } else {
+ (Seq(), Seq(), selectedStatistics)
+ }
+
+
+ // The list of summary statistics to compute, in the form of expressions.
+ val availableStatistics = Map[String, Expression => Expression](
+ "count" -> ((child: Expression) => Count(child).toAggregateExpression()),
+ "mean" -> ((child: Expression) => Average(child).toAggregateExpression()),
+ "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()),
+ "min" -> ((child: Expression) => Min(child).toAggregateExpression()),
+ "max" -> ((child: Expression) => Max(child).toAggregateExpression()))
+
+ val statisticFns = remainingAggregates.map { agg =>
+ require(availableStatistics.contains(agg), s"$agg is not a recognised statistic")
+ agg -> availableStatistics(agg)
+ }
+
+ def percentileAgg(child: Expression): Expression =
+ new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_))))
+ .toAggregateExpression()
+
+ val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList
+
+ val ret: Seq[Row] = if (outputCols.nonEmpty) {
+ var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) =>
+ outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
+ }
+ if (hasPercentiles) {
+ aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs
+ }
+
+ val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
+
+ // Pivot the data so each summary is one row
+ val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq
+
+ val basicStats = if (hasPercentiles) grouped.tail else grouped
+
+ val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) =>
+ Row(statistic :: aggregation.toList: _*)
+ }
+
+ if (hasPercentiles) {
+ def nullSafeString(x: Any) = if (x == null) null else x.toString
+ val percentileRows = grouped.head
+ .map {
+ case a: Seq[Any] => a
+ case _ => Seq.fill(percentiles.length)(null: Any)
+ }
+ .transpose
+ .zip(percentileNames)
+ .map { case (values: Seq[Any], name) =>
+ Row(name :: values.map(nullSafeString).toList: _*)
+ }
+ (rows ++ percentileRows)
+ .sortWith((left, right) =>
+ selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0)))
+ } else {
+ rows
+ }
+ } else {
+ // If there are no output columns, just output a single column that contains the stats.
+ selectedStatistics.map(Row(_))
+ }
+
+ // All columns are string type
+ val schema = StructType(
+ StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes
+ // `toArray` forces materialization to make the seq serializable
+ Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq))
+ }
+
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e1a172c2/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 9ea9951..2c7051b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -28,8 +28,7 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
@@ -663,13 +662,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}
- test("describe") {
- val describeTestData = Seq(
- ("Bob", 16, 176),
- ("Alice", 32, 164),
- ("David", 60, 192),
- ("Amy", 24, 180)).toDF("name", "age", "height")
+ private lazy val person2: DataFrame = Seq(
+ ("Bob", 16, 176),
+ ("Alice", 32, 164),
+ ("David", 60, 192),
+ ("Amy", 24, 180)).toDF("name", "age", "height")
+ test("describe") {
val describeResult = Seq(
Row("count", "4", "4", "4"),
Row("mean", null, "33.0", "178.0"),
@@ -686,32 +685,99 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
- val describeTwoCols = describeTestData.describe("name", "age", "height")
- assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height"))
- checkAnswer(describeTwoCols, describeResult)
- // All aggregate value should have been cast to string
- describeTwoCols.collect().foreach { row =>
- assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass)
- assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass)
- }
-
- val describeAllCols = describeTestData.describe()
+ val describeAllCols = person2.describe()
assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height"))
checkAnswer(describeAllCols, describeResult)
+ // All aggregate value should have been cast to string
+ describeAllCols.collect().foreach { row =>
+ row.toSeq.foreach { value =>
+ if (value != null) {
+ assert(value.isInstanceOf[String], "expected string but found " + value.getClass)
+ }
+ }
+ }
- val describeOneCol = describeTestData.describe("age")
+ val describeOneCol = person2.describe("age")
assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} )
- val describeNoCol = describeTestData.select("name").describe()
- assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name"))
- checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} )
+ val describeNoCol = person2.select().describe()
+ assert(getSchemaAsSeq(describeNoCol) === Seq("summary"))
+ checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s)} )
- val emptyDescription = describeTestData.limit(0).describe()
+ val emptyDescription = person2.limit(0).describe()
assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
checkAnswer(emptyDescription, emptyDescribeResult)
}
+ test("summary") {
+ val summaryResult = Seq(
+ Row("count", "4", "4", "4"),
+ Row("mean", null, "33.0", "178.0"),
+ Row("stddev", null, "19.148542155126762", "11.547005383792516"),
+ Row("min", "Alice", "16", "164"),
+ Row("25%", null, "24.0", "176.0"),
+ Row("50%", null, "24.0", "176.0"),
+ Row("75%", null, "32.0", "180.0"),
+ Row("max", "David", "60", "192"))
+
+ val emptySummaryResult = Seq(
+ Row("count", "0", "0", "0"),
+ Row("mean", null, null, null),
+ Row("stddev", null, null, null),
+ Row("min", null, null, null),
+ Row("25%", null, null, null),
+ Row("50%", null, null, null),
+ Row("75%", null, null, null),
+ Row("max", null, null, null))
+
+ def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
+
+ val summaryAllCols = person2.summary()
+
+ assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height"))
+ checkAnswer(summaryAllCols, summaryResult)
+ // All aggregate value should have been cast to string
+ summaryAllCols.collect().foreach { row =>
+ row.toSeq.foreach { value =>
+ if (value != null) {
+ assert(value.isInstanceOf[String], "expected string but found " + value.getClass)
+ }
+ }
+ }
+
+ val summaryOneCol = person2.select("age").summary()
+ assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age"))
+ checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d)} )
+
+ val summaryNoCol = person2.select().summary()
+ assert(getSchemaAsSeq(summaryNoCol) === Seq("summary"))
+ checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s)} )
+
+ val emptyDescription = person2.limit(0).summary()
+ assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
+ checkAnswer(emptyDescription, emptySummaryResult)
+ }
+
+ test("summary advanced") {
+ val stats = Array("count", "50.01%", "max", "mean", "min", "25%")
+ val orderMatters = person2.summary(stats: _*)
+ assert(orderMatters.collect().map(_.getString(0)) === stats)
+
+ val onlyPercentiles = person2.summary("0.1%", "99.9%")
+ assert(onlyPercentiles.count() === 2)
+
+ val fooE = intercept[IllegalArgumentException] {
+ person2.summary("foo")
+ }
+ assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic")
+
+ val parseE = intercept[IllegalArgumentException] {
+ person2.summary("foo%")
+ }
+ assert(parseE.getMessage === "Unable to parse foo% as a percentile")
+ }
+
test("apply on query results (SPARK-5462)") {
val df = testData.sparkSession.sql("select key from testData")
checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org