You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2022/08/25 13:46:44 UTC

[spark] branch master updated: [SPARK-40192][SQL][ML] Remove redundant groupby

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dfd4fe95744 [SPARK-40192][SQL][ML] Remove redundant groupby
dfd4fe95744 is described below

commit dfd4fe957442d41f39e8b3f223ee5cc9adfa6b79
Author: Deshan Xiao <de...@microsoft.com>
AuthorDate: Thu Aug 25 08:46:24 2022 -0500

    [SPARK-40192][SQL][ML] Remove redundant groupby
    
    ### What changes were proposed in this pull request?
    Remove redundant groupby invoking in code.
    
    ### Why are the changes needed?
    For Code optimization.
    `Dataset.agg()` has invoked the function `groupBy()`. We don't need to call `groupBy` again before executing `agg()`.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing UT
    
    Closes #37628 from deshanxiao/remove-group-by.
    
    Authored-by: Deshan Xiao <de...@microsoft.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../apache/spark/ml/feature/StringIndexer.scala    |  2 +-
 .../scala/org/apache/spark/ml/stat/ANOVATest.scala |  3 +--
 .../org/apache/spark/ml/stat/ChiSquareTest.scala   |  3 +--
 .../org/apache/spark/ml/stat/FValueTest.scala      |  3 +--
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 12 +++++-----
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala |  2 +-
 .../sql/execution/WholeStageCodegenSuite.scala     |  2 +-
 .../execution/benchmark/AggregateBenchmark.scala   |  4 ++--
 .../sql/execution/metric/SQLMetricsSuite.scala     |  2 +-
 .../sql/hive/execution/AggregationQuerySuite.scala | 26 +++++++++++-----------
 10 files changed, 28 insertions(+), 31 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 98a42371d29..4f11c58a7dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -200,7 +200,7 @@ class StringIndexer @Since("1.4.0") (
     val selectedCols = getSelectedCols(dataset, inputCols)
     dataset.select(selectedCols: _*)
       .toDF
-      .groupBy().agg(aggregator.toColumn)
+      .agg(aggregator.toColumn)
       .as[Array[OpenHashMap[String, Long]]]
       .collect()(0)
   }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala
index 7a7e76c457d..d7b13f1bf25 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala
@@ -75,8 +75,7 @@ private[ml] object ANOVATest {
     if (flatten) {
       resultDF
     } else {
-      resultDF.groupBy()
-        .agg(collect_list(struct("*")))
+      resultDF.agg(collect_list(struct("*")))
         .as[Seq[(Int, Double, Long, Double)]]
         .map { seq =>
           val results = seq.toArray.sortBy(_._1)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
index a38a7c446ac..e97d007a0f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
@@ -89,8 +89,7 @@ object ChiSquareTest {
     if (flatten) {
       resultDF
     } else {
-      resultDF.groupBy()
-        .agg(collect_list(struct("*")))
+      resultDF.agg(collect_list(struct("*")))
         .as[Seq[(Int, Double, Int, Double)]]
         .map { seq =>
           val results = seq.toArray.sortBy(_._1)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala
index f315e92e86d..800c68d3b0d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala
@@ -76,8 +76,7 @@ private[ml] object FValueTest {
     if (flatten) {
       resultDF
     } else {
-      resultDF.groupBy()
-        .agg(collect_list(struct("*")))
+      resultDF.agg(collect_list(struct("*")))
         .as[Seq[(Int, Double, Long, Double)]]
         .map { seq =>
           val results = seq.toArray.sortBy(_._1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 958b3e3f53c..4ab509b5e01 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -191,11 +191,11 @@ class DataFrameAggregateSuite extends QueryTest
     )
 
     intercept[AnalysisException] {
-      courseSales.groupBy().agg(grouping("course")).explain()
+      courseSales.agg(grouping("course")).explain()
     }
 
     intercept[AnalysisException] {
-      courseSales.groupBy().agg(grouping_id("course")).explain()
+      courseSales.agg(grouping_id("course")).explain()
     }
   }
 
@@ -755,11 +755,11 @@ class DataFrameAggregateSuite extends QueryTest
     // explicit global aggregations
     val emptyAgg = Map.empty[String, String]
     checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row()))
-    checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row()))
-    checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0)))
+    checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row()))
+    checkAnswer(spark.emptyDataFrame.agg(count("*")), Seq(Row(0)))
+    checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row()))
     checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row()))
-    checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), Seq(Row()))
-    checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), Seq(Row(0)))
+    checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(count("*")), Seq(Row(0)))
 
     // global aggregation is converted to grouping aggregation:
     assert(spark.emptyDataFrame.dropDuplicates().count() == 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index cc7e51abc4e..cbd65ede054 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1923,7 +1923,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
         x
       })
       verifyCallCount(
-        df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
+        df.agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
 
       verifyCallCount(
         df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index eca22b14763..ac710c32296 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -42,7 +42,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
   }
 
   test("HashAggregate should be included in WholeStageCodegen") {
-    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
+    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
     val plan = df.queryExecution.executedPlan
     assert(plan.exists(p =>
       p.isInstanceOf[WholeStageCodegenExec] &&
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
index ae4281cd639..b2f1ee31f9f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
@@ -57,11 +57,11 @@ object AggregateBenchmark extends SqlBasedBenchmark {
       val N = 100L << 20
 
       codegenBenchmark("stddev", N) {
-        spark.range(N).groupBy().agg("id" -> "stddev").noop()
+        spark.range(N).agg("id" -> "stddev").noop()
       }
 
       codegenBenchmark("kurtosis", N) {
-        spark.range(N).groupBy().agg("id" -> "kurtosis").noop()
+        spark.range(N).agg("id" -> "kurtosis").noop()
       }
     }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 9a63572069d..f5cfbbf5a65 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -206,7 +206,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
     // Assume the execution plan is
     // ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1)
     // -> ObjectHashAggregate(nodeId = 0)
-    val df = testData2.groupBy().agg(collect_set($"a")) // 2 partitions
+    val df = testData2.agg(collect_set($"a")) // 2 partitions
     testSparkPlanMetrics(df, 1, Map(
       2L -> (("ObjectHashAggregate", Map("number of output rows" -> 2L))),
       1L -> (("Exchange", Map(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index e63cdddd81c..1966e1e64fd 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -751,9 +751,9 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
 
   test("pearson correlation") {
     val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
-    val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+    val corr1 = df.repartition(2).agg(corr("a", "b")).collect()(0).getDouble(0)
     assert(math.abs(corr1 - 1.0) < 1e-12)
-    val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
+    val corr2 = df.agg(corr("a", "c")).collect()(0).getDouble(0)
     assert(math.abs(corr2 + 1.0) < 1e-12)
     // non-trivial example. To reproduce in python, use:
     // >>> from scipy.stats import pearsonr
@@ -768,17 +768,17 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
     // > cor(a, b)
     // [1] 0.957233913947585835
     val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
-    val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+    val corr3 = df2.agg(corr("a", "b")).collect()(0).getDouble(0)
     assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
 
     val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b")
-    val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0)
+    val corr4 = df3.agg(corr("a", "b")).collect()(0)
     assert(corr4 == Row(null))
 
     val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c")
-    val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+    val corr5 = df4.repartition(2).agg(corr("a", "b")).collect()(0).getDouble(0)
     assert(math.abs(corr5 - 1.0) < 1e-12)
-    val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
+    val corr6 = df4.agg(corr("a", "c")).collect()(0).getDouble(0)
     assert(math.abs(corr6 + 1.0) < 1e-12)
 
     // Test for udaf_corr in HiveCompatibilitySuite
@@ -855,23 +855,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
     // >>> np.cov(a, b, bias = 1)[0][1]
     // 565.25
     val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
-    val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
+    val cov_samp = df.agg(covar_samp("a", "b")).collect()(0).getDouble(0)
     assert(math.abs(cov_samp - 595.0) < 1e-12)
 
-    val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
+    val cov_pop = df.agg(covar_pop("a", "b")).collect()(0).getDouble(0)
     assert(math.abs(cov_pop - 565.25) < 1e-12)
 
     val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
-    val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
+    val cov_samp2 = df2.agg(covar_samp("a", "b")).collect()(0).getDouble(0)
     assert(math.abs(cov_samp2 - 11564.0) < 1e-12)
 
-    val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
+    val cov_pop2 = df2.agg(covar_pop("a", "b")).collect()(0).getDouble(0)
     assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12)
 
     // one row test
     val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
-    checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(null))
-    checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0))
+    checkAnswer(df3.agg(covar_samp("a", "b")), Row(null))
+    checkAnswer(df3.agg(covar_pop("a", "b")), Row(0.0))
   }
 
   test("no aggregation function (SPARK-11486)") {
@@ -938,7 +938,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
           .find(r => r.getInt(0) == 50)
           .getOrElse(fail("A row with id 50 should be the expected answer."))
       checkAnswer(
-        df.groupBy().agg(udaf(allColumns: _*)),
+        df.agg(udaf(allColumns: _*)),
         // udaf returns a Row as the output value.
         Row(expectedAnswer)
       )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org