You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/10/16 05:38:04 UTC
spark git commit: [SPARK-22223][SQL] ObjectHashAggregate should not
introduce unnecessary shuffle
Repository: spark
Updated Branches:
refs/heads/master 13c155958 -> 0ae96495d
[SPARK-22223][SQL] ObjectHashAggregate should not introduce unnecessary shuffle
## What changes were proposed in this pull request?
`ObjectHashAggregateExec` should override `outputPartitioning` in order to avoid unnecessary shuffle.
## How was this patch tested?
Added Jenkins test.
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #19501 from viirya/SPARK-22223.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0ae96495
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0ae96495
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0ae96495
Branch: refs/heads/master
Commit: 0ae96495dedb54b3b6bae0bd55560820c5ca29a2
Parents: 13c1559
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Mon Oct 16 13:37:58 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Oct 16 13:37:58 2017 +0800
----------------------------------------------------------------------
.../aggregate/ObjectHashAggregateExec.scala | 2 ++
.../spark/sql/DataFrameAggregateSuite.scala | 30 ++++++++++++++++++++
2 files changed, 32 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/0ae96495/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
index ec3f9a0..66955b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
@@ -95,6 +95,8 @@ case class ObjectHashAggregateExec(
}
}
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numOutputRows = longMetric("numOutputRows")
val aggTime = longMetric("aggTime")
http://git-wip-us.apache.org/repos/asf/spark/blob/0ae96495/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 8549eac..06848e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -636,4 +637,33 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"),
Seq(Row(3, 4, 9)))
}
+
+ test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") {
+ withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+ val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c")
+ .repartition(col("a"))
+
+ val objHashAggDF = df
+ .withColumn("d", expr("(a, b, c)"))
+ .groupBy("a", "b").agg(collect_list("d").as("e"))
+ .withColumn("f", expr("(b, e)"))
+ .groupBy("a").agg(collect_list("f").as("g"))
+ val aggPlan = objHashAggDF.queryExecution.executedPlan
+
+ val sortAggPlans = aggPlan.collect {
+ case sortAgg: SortAggregateExec => sortAgg
+ }
+ assert(sortAggPlans.isEmpty)
+
+ val objHashAggPlans = aggPlan.collect {
+ case objHashAgg: ObjectHashAggregateExec => objHashAgg
+ }
+ assert(objHashAggPlans.nonEmpty)
+
+ val exchangePlans = aggPlan.collect {
+ case shuffle: ShuffleExchangeExec => shuffle
+ }
+ assert(exchangePlans.length == 1)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org