You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2022/08/25 13:46:44 UTC
[spark] branch master updated: [SPARK-40192][SQL][ML] Remove redundant groupby
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dfd4fe95744 [SPARK-40192][SQL][ML] Remove redundant groupby
dfd4fe95744 is described below
commit dfd4fe957442d41f39e8b3f223ee5cc9adfa6b79
Author: Deshan Xiao <de...@microsoft.com>
AuthorDate: Thu Aug 25 08:46:24 2022 -0500
[SPARK-40192][SQL][ML] Remove redundant groupby
### What changes were proposed in this pull request?
Remove redundant groupby invoking in code.
### Why are the changes needed?
For Code optimization.
`Dataset.agg()` has invoked the function `groupBy()`. We don't need to call `groupBy` again before executing `agg()`.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing UT
Closes #37628 from deshanxiao/remove-group-by.
Authored-by: Deshan Xiao <de...@microsoft.com>
Signed-off-by: Sean Owen <sr...@gmail.com>
---
.../apache/spark/ml/feature/StringIndexer.scala | 2 +-
.../scala/org/apache/spark/ml/stat/ANOVATest.scala | 3 +--
.../org/apache/spark/ml/stat/ChiSquareTest.scala | 3 +--
.../org/apache/spark/ml/stat/FValueTest.scala | 3 +--
.../apache/spark/sql/DataFrameAggregateSuite.scala | 12 +++++-----
.../scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +-
.../sql/execution/WholeStageCodegenSuite.scala | 2 +-
.../execution/benchmark/AggregateBenchmark.scala | 4 ++--
.../sql/execution/metric/SQLMetricsSuite.scala | 2 +-
.../sql/hive/execution/AggregationQuerySuite.scala | 26 +++++++++++-----------
10 files changed, 28 insertions(+), 31 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 98a42371d29..4f11c58a7dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -200,7 +200,7 @@ class StringIndexer @Since("1.4.0") (
val selectedCols = getSelectedCols(dataset, inputCols)
dataset.select(selectedCols: _*)
.toDF
- .groupBy().agg(aggregator.toColumn)
+ .agg(aggregator.toColumn)
.as[Array[OpenHashMap[String, Long]]]
.collect()(0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala
index 7a7e76c457d..d7b13f1bf25 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala
@@ -75,8 +75,7 @@ private[ml] object ANOVATest {
if (flatten) {
resultDF
} else {
- resultDF.groupBy()
- .agg(collect_list(struct("*")))
+ resultDF.agg(collect_list(struct("*")))
.as[Seq[(Int, Double, Long, Double)]]
.map { seq =>
val results = seq.toArray.sortBy(_._1)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
index a38a7c446ac..e97d007a0f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
@@ -89,8 +89,7 @@ object ChiSquareTest {
if (flatten) {
resultDF
} else {
- resultDF.groupBy()
- .agg(collect_list(struct("*")))
+ resultDF.agg(collect_list(struct("*")))
.as[Seq[(Int, Double, Int, Double)]]
.map { seq =>
val results = seq.toArray.sortBy(_._1)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala
index f315e92e86d..800c68d3b0d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala
@@ -76,8 +76,7 @@ private[ml] object FValueTest {
if (flatten) {
resultDF
} else {
- resultDF.groupBy()
- .agg(collect_list(struct("*")))
+ resultDF.agg(collect_list(struct("*")))
.as[Seq[(Int, Double, Long, Double)]]
.map { seq =>
val results = seq.toArray.sortBy(_._1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 958b3e3f53c..4ab509b5e01 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -191,11 +191,11 @@ class DataFrameAggregateSuite extends QueryTest
)
intercept[AnalysisException] {
- courseSales.groupBy().agg(grouping("course")).explain()
+ courseSales.agg(grouping("course")).explain()
}
intercept[AnalysisException] {
- courseSales.groupBy().agg(grouping_id("course")).explain()
+ courseSales.agg(grouping_id("course")).explain()
}
}
@@ -755,11 +755,11 @@ class DataFrameAggregateSuite extends QueryTest
// explicit global aggregations
val emptyAgg = Map.empty[String, String]
checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row()))
- checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row()))
- checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0)))
+ checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row()))
+ checkAnswer(spark.emptyDataFrame.agg(count("*")), Seq(Row(0)))
+ checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row()))
checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row()))
- checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), Seq(Row()))
- checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), Seq(Row(0)))
+ checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(count("*")), Seq(Row(0)))
// global aggregation is converted to grouping aggregation:
assert(spark.emptyDataFrame.dropDuplicates().count() == 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index cc7e51abc4e..cbd65ede054 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1923,7 +1923,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
x
})
verifyCallCount(
- df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
+ df.agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
verifyCallCount(
df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index eca22b14763..ac710c32296 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -42,7 +42,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
}
test("HashAggregate should be included in WholeStageCodegen") {
- val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
+ val df = spark.range(10).agg(max(col("id")), avg(col("id")))
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
index ae4281cd639..b2f1ee31f9f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
@@ -57,11 +57,11 @@ object AggregateBenchmark extends SqlBasedBenchmark {
val N = 100L << 20
codegenBenchmark("stddev", N) {
- spark.range(N).groupBy().agg("id" -> "stddev").noop()
+ spark.range(N).agg("id" -> "stddev").noop()
}
codegenBenchmark("kurtosis", N) {
- spark.range(N).groupBy().agg("id" -> "kurtosis").noop()
+ spark.range(N).agg("id" -> "kurtosis").noop()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 9a63572069d..f5cfbbf5a65 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -206,7 +206,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
// Assume the execution plan is
// ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1)
// -> ObjectHashAggregate(nodeId = 0)
- val df = testData2.groupBy().agg(collect_set($"a")) // 2 partitions
+ val df = testData2.agg(collect_set($"a")) // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> (("ObjectHashAggregate", Map("number of output rows" -> 2L))),
1L -> (("Exchange", Map(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index e63cdddd81c..1966e1e64fd 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -751,9 +751,9 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("pearson correlation") {
val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
- val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+ val corr1 = df.repartition(2).agg(corr("a", "b")).collect()(0).getDouble(0)
assert(math.abs(corr1 - 1.0) < 1e-12)
- val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
+ val corr2 = df.agg(corr("a", "c")).collect()(0).getDouble(0)
assert(math.abs(corr2 + 1.0) < 1e-12)
// non-trivial example. To reproduce in python, use:
// >>> from scipy.stats import pearsonr
@@ -768,17 +768,17 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
// > cor(a, b)
// [1] 0.957233913947585835
val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
- val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+ val corr3 = df2.agg(corr("a", "b")).collect()(0).getDouble(0)
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b")
- val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0)
+ val corr4 = df3.agg(corr("a", "b")).collect()(0)
assert(corr4 == Row(null))
val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c")
- val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
+ val corr5 = df4.repartition(2).agg(corr("a", "b")).collect()(0).getDouble(0)
assert(math.abs(corr5 - 1.0) < 1e-12)
- val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
+ val corr6 = df4.agg(corr("a", "c")).collect()(0).getDouble(0)
assert(math.abs(corr6 + 1.0) < 1e-12)
// Test for udaf_corr in HiveCompatibilitySuite
@@ -855,23 +855,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
// >>> np.cov(a, b, bias = 1)[0][1]
// 565.25
val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
- val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
+ val cov_samp = df.agg(covar_samp("a", "b")).collect()(0).getDouble(0)
assert(math.abs(cov_samp - 595.0) < 1e-12)
- val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
+ val cov_pop = df.agg(covar_pop("a", "b")).collect()(0).getDouble(0)
assert(math.abs(cov_pop - 565.25) < 1e-12)
val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
- val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
+ val cov_samp2 = df2.agg(covar_samp("a", "b")).collect()(0).getDouble(0)
assert(math.abs(cov_samp2 - 11564.0) < 1e-12)
- val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
+ val cov_pop2 = df2.agg(covar_pop("a", "b")).collect()(0).getDouble(0)
assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12)
// one row test
val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
- checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(null))
- checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0))
+ checkAnswer(df3.agg(covar_samp("a", "b")), Row(null))
+ checkAnswer(df3.agg(covar_pop("a", "b")), Row(0.0))
}
test("no aggregation function (SPARK-11486)") {
@@ -938,7 +938,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
.find(r => r.getInt(0) == 50)
.getOrElse(fail("A row with id 50 should be the expected answer."))
checkAnswer(
- df.groupBy().agg(udaf(allColumns: _*)),
+ df.agg(udaf(allColumns: _*)),
// udaf returns a Row as the output value.
Row(expectedAnswer)
)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org