You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/07/10 05:53:32 UTC

spark git commit: [SPARK-21100][SQL][FOLLOWUP] cleanup code and add more comments for Dataset.summary

Repository: spark
Updated Branches:
  refs/heads/master 457dc9ccb -> 0e80ecae3


[SPARK-21100][SQL][FOLLOWUP] cleanup code and add more comments for Dataset.summary

## What changes were proposed in this pull request?

Some code cleanup and adding comments to make the code more readable. Changed the way to generate result rows, to be more clear.

## How was this patch tested?

existing tests

Author: Wenchen Fan <we...@databricks.com>

Closes #18570 from cloud-fan/summary.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0e80ecae
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0e80ecae
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0e80ecae

Branch: refs/heads/master
Commit: 0e80ecae300f3e2033419b2d98da8bf092c105bb
Parents: 457dc9c
Author: Wenchen Fan <we...@databricks.com>
Authored: Sun Jul 9 22:53:27 2017 -0700
Committer: Xiao Li <ga...@gmail.com>
Committed: Sun Jul 9 22:53:27 2017 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Dataset.scala    |   9 --
 .../sql/execution/stat/StatFunctions.scala      | 129 ++++++++-----------
 .../org/apache/spark/sql/DataFrameSuite.scala   |   2 +-
 3 files changed, 56 insertions(+), 84 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0e80ecae/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 5326b45..dfb5119 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -224,15 +224,6 @@ class Dataset[T] private[sql](
     }
   }
 
-  private[sql] def aggregatableColumns: Seq[Expression] = {
-    schema.fields
-      .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType])
-      .map { n =>
-        queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver)
-          .get
-      }
-  }
-
   /**
    * Compose the string representing rows for output
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/0e80ecae/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 436e18f..a75cfb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.sql.execution.stat
 
+import java.util.Locale
+
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
-import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal}
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries}
+import org.apache.spark.sql.catalyst.util.QuantileSummaries
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -228,90 +231,68 @@ object StatFunctions extends Logging {
     val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
     val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics
 
-    val hasPercentiles = selectedStatistics.exists(_.endsWith("%"))
-    val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) {
-      val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%"))
-      val percentiles = pStrings.map { p =>
-        try {
-          p.stripSuffix("%").toDouble / 100.0
-        } catch {
-          case e: NumberFormatException =>
-            throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
-        }
+    val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p =>
+      try {
+        p.stripSuffix("%").toDouble / 100.0
+      } catch {
+        case e: NumberFormatException =>
+          throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
       }
-      require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
-      (percentiles, pStrings, rest)
-    } else {
-      (Seq(), Seq(), selectedStatistics)
     }
+    require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
 
-
-    // The list of summary statistics to compute, in the form of expressions.
-    val availableStatistics = Map[String, Expression => Expression](
-      "count" -> ((child: Expression) => Count(child).toAggregateExpression()),
-      "mean" -> ((child: Expression) => Average(child).toAggregateExpression()),
-      "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()),
-      "min" -> ((child: Expression) => Min(child).toAggregateExpression()),
-      "max" -> ((child: Expression) => Max(child).toAggregateExpression()))
-
-    val statisticFns = remainingAggregates.map { agg =>
-      require(availableStatistics.contains(agg), s"$agg is not a recognised statistic")
-      agg -> availableStatistics(agg)
-    }
-
-    def percentileAgg(child: Expression): Expression =
-      new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_))))
-        .toAggregateExpression()
-
-    val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList
-
-    val ret: Seq[Row] = if (outputCols.nonEmpty) {
-      var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) =>
-        outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
-      }
-      if (hasPercentiles) {
-        aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs
+    var percentileIndex = 0
+    val statisticFns = selectedStatistics.map { stats =>
+      if (stats.endsWith("%")) {
+        val index = percentileIndex
+        percentileIndex += 1
+        (child: Expression) =>
+          GetArrayItem(
+            new ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(),
+            Literal(index))
+      } else {
+        stats.toLowerCase(Locale.ROOT) match {
+          case "count" => (child: Expression) => Count(child).toAggregateExpression()
+          case "mean" => (child: Expression) => Average(child).toAggregateExpression()
+          case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression()
+          case "min" => (child: Expression) => Min(child).toAggregateExpression()
+          case "max" => (child: Expression) => Max(child).toAggregateExpression()
+          case _ => throw new IllegalArgumentException(s"$stats is not a recognised statistic")
+        }
       }
+    }
 
-      val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
+    val selectedCols = ds.logicalPlan.output
+      .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType])
 
-      // Pivot the data so each summary is one row
-      val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq
+    val aggExprs = statisticFns.flatMap { func =>
+      selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name))
+    }
 
-      val basicStats = if (hasPercentiles) grouped.tail else grouped
+    // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val.
+    lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head
 
-      val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) =>
-        Row(statistic :: aggregation.toList: _*)
-      }
+    // We will have one row for each selected statistic in the result.
+    val result = Array.fill[InternalRow](selectedStatistics.length) {
+      // each row has the statistic name, and statistic values of each selected column.
+      new GenericInternalRow(selectedCols.length + 1)
+    }
 
-      if (hasPercentiles) {
-        def nullSafeString(x: Any) = if (x == null) null else x.toString
-        val percentileRows = grouped.head
-          .map {
-            case a: Seq[Any] => a
-            case _ => Seq.fill(percentiles.length)(null: Any)
-          }
-          .transpose
-          .zip(percentileNames)
-          .map { case (values: Seq[Any], name) =>
-            Row(name :: values.map(nullSafeString).toList: _*)
-          }
-        (rows ++ percentileRows)
-          .sortWith((left, right) =>
-            selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0)))
-      } else {
-        rows
+    var rowIndex = 0
+    while (rowIndex < result.length) {
+      val statsName = selectedStatistics(rowIndex)
+      result(rowIndex).update(0, UTF8String.fromString(statsName))
+      for (colIndex <- selectedCols.indices) {
+        val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex)
+        result(rowIndex).update(colIndex + 1, statsValue)
       }
-    } else {
-      // If there are no output columns, just output a single column that contains the stats.
-      selectedStatistics.map(Row(_))
+      rowIndex += 1
     }
 
     // All columns are string type
-    val schema = StructType(
-      StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes
-    // `toArray` forces materialization to make the seq serializable
-    Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq))
-  }
+    val output = AttributeReference("summary", StringType)() +:
+      selectedCols.map(c => AttributeReference(c.name, StringType)())
 
+    Dataset.ofRows(ds.sparkSession, LocalRelation(output, result))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0e80ecae/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 2c7051b..b2219b4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -770,7 +770,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val fooE = intercept[IllegalArgumentException] {
       person2.summary("foo")
     }
-    assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic")
+    assert(fooE.getMessage === "foo is not a recognised statistic")
 
     val parseE = intercept[IllegalArgumentException] {
       person2.summary("foo%")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org