You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/10/16 05:38:04 UTC

spark git commit: [SPARK-22223][SQL] ObjectHashAggregate should not introduce unnecessary shuffle

Repository: spark
Updated Branches:
  refs/heads/master 13c155958 -> 0ae96495d


[SPARK-22223][SQL] ObjectHashAggregate should not introduce unnecessary shuffle

## What changes were proposed in this pull request?

`ObjectHashAggregateExec` should override `outputPartitioning` in order to avoid unnecessary shuffle.

## How was this patch tested?

Added Jenkins test.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #19501 from viirya/SPARK-22223.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0ae96495
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0ae96495
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0ae96495

Branch: refs/heads/master
Commit: 0ae96495dedb54b3b6bae0bd55560820c5ca29a2
Parents: 13c1559
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Mon Oct 16 13:37:58 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Oct 16 13:37:58 2017 +0800

----------------------------------------------------------------------
 .../aggregate/ObjectHashAggregateExec.scala     |  2 ++
 .../spark/sql/DataFrameAggregateSuite.scala     | 30 ++++++++++++++++++++
 2 files changed, 32 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0ae96495/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
index ec3f9a0..66955b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
@@ -95,6 +95,8 @@ case class ObjectHashAggregateExec(
     }
   }
 
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
   protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
     val numOutputRows = longMetric("numOutputRows")
     val aggTime = longMetric("aggTime")

http://git-wip-us.apache.org/repos/asf/spark/blob/0ae96495/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 8549eac..06848e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
 
 import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -636,4 +637,33 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
       spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"),
       Seq(Row(3, 4, 9)))
   }
+
+  test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") {
+    withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+      val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c")
+        .repartition(col("a"))
+
+      val objHashAggDF = df
+        .withColumn("d", expr("(a, b, c)"))
+        .groupBy("a", "b").agg(collect_list("d").as("e"))
+        .withColumn("f", expr("(b, e)"))
+        .groupBy("a").agg(collect_list("f").as("g"))
+      val aggPlan = objHashAggDF.queryExecution.executedPlan
+
+      val sortAggPlans = aggPlan.collect {
+        case sortAgg: SortAggregateExec => sortAgg
+      }
+      assert(sortAggPlans.isEmpty)
+
+      val objHashAggPlans = aggPlan.collect {
+        case objHashAgg: ObjectHashAggregateExec => objHashAgg
+      }
+      assert(objHashAggPlans.nonEmpty)
+
+      val exchangePlans = aggPlan.collect {
+        case shuffle: ShuffleExchangeExec => shuffle
+      }
+      assert(exchangePlans.length == 1)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org