You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/03/09 21:04:34 UTC

spark git commit: [SPARK-13523] [SQL] Reuse exchanges in a query

Repository: spark
Updated Branches:
  refs/heads/master 0dd06485c -> 3dc9ae2e1


[SPARK-13523] [SQL] Reuse exchanges in a query

## What changes were proposed in this pull request?

It’s possible to have common parts in a query, for example, self join, it will be good to avoid the duplicated part to same CPUs and memory (Broadcast or cache).

Exchange will materialize the underlying RDD by shuffle or collect, it’s a great point to check duplicates and reuse them. Duplicated exchanges means they generate exactly the same result inside a query.

In order to find out the duplicated exchanges, we should be able to compare SparkPlan to check that they have same results or not. We already have that for LogicalPlan, so we should move that into QueryPlan to make it available for SparkPlan.

Once we can find the duplicated exchanges, we should replace all of them with same SparkPlan object (could be wrapped by ReusedExchage for explain), then the plan tree become a DAG. Since all the planner only work with tree, so this rule should be the last one for the entire planning.

After the rule, the plan will looks like:

```
WholeStageCodegen
:  +- Project [id#0L]
:     +- BroadcastHashJoin [id#0L], [id#2L], Inner, BuildRight, None
:        :- Project [id#0L]
:        :  +- BroadcastHashJoin [id#0L], [id#1L], Inner, BuildRight, None
:        :     :- Range 0, 1, 4, 1024, [id#0L]
:        :     +- INPUT
:        +- INPUT
:- BroadcastExchange HashedRelationBroadcastMode(true,List(id#1L),List(id#1L))
:  +- WholeStageCodegen
:     :  +- Range 0, 1, 4, 1024, [id#1L]
+- ReusedExchange [id#2L], BroadcastExchange HashedRelationBroadcastMode(true,List(id#1L),List(id#1L))
```

![bjoin](https://cloud.githubusercontent.com/assets/40902/13414787/209e8c5c-df0a-11e5-8a0f-edff69d89e83.png)

For three ways SortMergeJoin,
```
== Physical Plan ==
WholeStageCodegen
:  +- Project [id#0L]
:     +- SortMergeJoin [id#0L], [id#4L], None
:        :- INPUT
:        +- INPUT
:- WholeStageCodegen
:  :  +- Project [id#0L]
:  :     +- SortMergeJoin [id#0L], [id#3L], None
:  :        :- INPUT
:  :        +- INPUT
:  :- WholeStageCodegen
:  :  :  +- Sort [id#0L ASC], false, 0
:  :  :     +- INPUT
:  :  +- Exchange hashpartitioning(id#0L, 200), None
:  :     +- WholeStageCodegen
:  :        :  +- Range 0, 1, 4, 33554432, [id#0L]
:  +- WholeStageCodegen
:     :  +- Sort [id#3L ASC], false, 0
:     :     +- INPUT
:     +- ReusedExchange [id#3L], Exchange hashpartitioning(id#0L, 200), None
+- WholeStageCodegen
   :  +- Sort [id#4L ASC], false, 0
   :     +- INPUT
   +- ReusedExchange [id#4L], Exchange hashpartitioning(id#0L, 200), None
```
![sjoin](https://cloud.githubusercontent.com/assets/40902/13414790/27aea61c-df0a-11e5-8cbf-fbc985c31d95.png)

If the same ShuffleExchange or BroadcastExchange, execute()/executeBroadcast() will be called by different parents, they should cached the RDD/Broadcast, return the same one for all the parents.

## How was this patch tested?

Added some unit tests for this.  Had done some manual tests on TPCDS query Q59 and Q64, we can see some exchanges are re-used (this requires a change in PhysicalRDD to for sameResult, is be done in #11514 ).

Author: Davies Liu <da...@databricks.com>

Closes #11403 from davies/dedup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3dc9ae2e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3dc9ae2e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3dc9ae2e

Branch: refs/heads/master
Commit: 3dc9ae2e158e5b51df6f799767946fe1d190156b
Parents: 0dd0648
Author: Davies Liu <da...@databricks.com>
Authored: Wed Mar 9 12:04:29 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Mar 9 12:04:29 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/plans/QueryPlan.scala    | 63 +++++++++++++-
 .../catalyst/plans/logical/LogicalPlan.scala    | 55 +-----------
 .../catalyst/plans/physical/broadcastMode.scala |  9 ++
 .../spark/sql/execution/SparkPlanInfo.scala     | 22 ++++-
 .../execution/aggregate/TungstenAggregate.scala |  4 +
 .../spark/sql/execution/basicOperators.scala    |  3 +
 .../execution/exchange/BroadcastExchange.scala  | 10 ++-
 .../spark/sql/execution/exchange/Exchange.scala | 92 ++++++++++++++++++++
 .../execution/exchange/ShuffleExchange.scala    | 29 +++---
 .../sql/execution/joins/HashedRelation.scala    | 15 +++-
 .../spark/sql/execution/ui/SparkPlanGraph.scala | 20 +++--
 .../org/apache/spark/sql/internal/SQLConf.scala |  6 ++
 .../spark/sql/internal/SessionState.scala       |  6 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   | 38 +++++++-
 .../spark/sql/execution/ExchangeSuite.scala     | 72 ++++++++++++++-
 .../spark/sql/execution/PlannerSuite.scala      | 49 ++++++++++-
 16 files changed, 403 insertions(+), 90 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index c62d5ea..371d72e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.types.{DataType, StructType}
 
-abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] {
+abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] {
   self: PlanType =>
 
   def output: Seq[Attribute]
@@ -237,4 +237,65 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
   }
 
   override def innerChildren: Seq[PlanType] = subqueries
+
+  /**
+   * Canonicalized copy of this query plan.
+   */
+  protected lazy val canonicalized: PlanType = this
+
+  /**
+   * Returns true when the given query plan will return the same results as this query plan.
+   *
+   * Since its likely undecidable to generally determine if two given plans will produce the same
+   * results, it is okay for this function to return false, even if the results are actually
+   * the same.  Such behavior will not affect correctness, only the application of performance
+   * enhancements like caching.  However, it is not acceptable to return true if the results could
+   * possibly be different.
+   *
+   * By default this function performs a modified version of equality that is tolerant of cosmetic
+   * differences like attribute naming and or expression id differences. Operators that
+   * can do better should override this function.
+   */
+  def sameResult(plan: PlanType): Boolean = {
+    val canonicalizedLeft = this.canonicalized
+    val canonicalizedRight = plan.canonicalized
+    canonicalizedLeft.getClass == canonicalizedRight.getClass &&
+      canonicalizedLeft.children.size == canonicalizedRight.children.size &&
+      canonicalizedLeft.cleanArgs == canonicalizedRight.cleanArgs &&
+      (canonicalizedLeft.children, canonicalizedRight.children).zipped.forall(_ sameResult _)
+  }
+
+  /**
+   * All the attributes that are used for this plan.
+   */
+  lazy val allAttributes: Seq[Attribute] = children.flatMap(_.output)
+
+  private def cleanExpression(e: Expression): Expression = e match {
+    case a: Alias =>
+      // As the root of the expression, Alias will always take an arbitrary exprId, we need
+      // to erase that for equality testing.
+      val cleanedExprId =
+        Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated)
+      BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true)
+    case other =>
+      BindReferences.bindReference(other, allAttributes, allowFailures = true)
+  }
+
+  /** Args that have cleaned such that differences in expression id should not affect equality */
+  protected lazy val cleanArgs: Seq[Any] = {
+    def cleanArg(arg: Any): Any = arg match {
+      case e: Expression => cleanExpression(e).canonicalized
+      case other => other
+    }
+
+    productIterator.map {
+      // Children are checked using sameResult above.
+      case tn: TreeNode[_] if containsChild(tn) => null
+      case e: Expression => cleanArg(e)
+      case s: Option[_] => s.map(cleanArg)
+      case s: Seq[_] => s.map(cleanArg)
+      case m: Map[_, _] => m.mapValues(cleanArg)
+      case other => other
+    }.toSeq
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 31e775d..b32c7d0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -114,60 +114,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
    */
   def childrenResolved: Boolean = children.forall(_.resolved)
 
-  /**
-   * Returns true when the given logical plan will return the same results as this logical plan.
-   *
-   * Since its likely undecidable to generally determine if two given plans will produce the same
-   * results, it is okay for this function to return false, even if the results are actually
-   * the same.  Such behavior will not affect correctness, only the application of performance
-   * enhancements like caching.  However, it is not acceptable to return true if the results could
-   * possibly be different.
-   *
-   * By default this function performs a modified version of equality that is tolerant of cosmetic
-   * differences like attribute naming and or expression id differences.  Logical operators that
-   * can do better should override this function.
-   */
-  def sameResult(plan: LogicalPlan): Boolean = {
-    val cleanLeft = EliminateSubqueryAliases(this)
-    val cleanRight = EliminateSubqueryAliases(plan)
-
-    cleanLeft.getClass == cleanRight.getClass &&
-      cleanLeft.children.size == cleanRight.children.size && {
-      logDebug(
-        s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]")
-      cleanRight.cleanArgs == cleanLeft.cleanArgs
-    } &&
-    (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _)
-  }
-
-  /** Args that have cleaned such that differences in expression id should not affect equality */
-  protected lazy val cleanArgs: Seq[Any] = {
-    val input = children.flatMap(_.output)
-    def cleanExpression(e: Expression) = e match {
-      case a: Alias =>
-        // As the root of the expression, Alias will always take an arbitrary exprId, we need
-        // to erase that for equality testing.
-        val cleanedExprId =
-          Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated)
-        BindReferences.bindReference(cleanedExprId, input, allowFailures = true)
-      case other => BindReferences.bindReference(other, input, allowFailures = true)
-    }
-
-    productIterator.map {
-      // Children are checked using sameResult above.
-      case tn: TreeNode[_] if containsChild(tn) => null
-      case e: Expression => cleanExpression(e)
-      case s: Option[_] => s.map {
-        case e: Expression => cleanExpression(e)
-        case other => other
-      }
-      case s: Seq[_] => s.map {
-        case e: Expression => cleanExpression(e)
-        case other => other
-      }
-      case other => other
-    }.toSeq
-  }
+  override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this)
 
   /**
    * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
index e01f69f..9dfdf4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
@@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.InternalRow
  */
 trait BroadcastMode {
   def transform(rows: Array[InternalRow]): Any
+
+  /**
+   * Returns true iff this [[BroadcastMode]] generates the same result as `other`.
+   */
+  def compatibleWith(other: BroadcastMode): Boolean
 }
 
 /**
@@ -33,4 +38,8 @@ trait BroadcastMode {
 case object IdentityBroadcastMode extends BroadcastMode {
   // TODO: pack the UnsafeRows into single bytes array.
   override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows
+
+  override def compatibleWith(other: BroadcastMode): Boolean = {
+    this eq other
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 9019e5d..247f55d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.execution.exchange.ReusedExchange
 import org.apache.spark.sql.execution.metric.SQLMetricInfo
 import org.apache.spark.util.Utils
 
@@ -31,13 +32,28 @@ class SparkPlanInfo(
     val simpleString: String,
     val children: Seq[SparkPlanInfo],
     val metadata: Map[String, String],
-    val metrics: Seq[SQLMetricInfo])
+    val metrics: Seq[SQLMetricInfo]) {
+
+  override def hashCode(): Int = {
+    // hashCode of simpleString should be good enough to distinguish the plans from each other
+    // within a plan
+    simpleString.hashCode
+  }
+
+  override def equals(other: Any): Boolean = other match {
+    case o: SparkPlanInfo =>
+      nodeName == o.nodeName && simpleString == o.simpleString && children == o.children
+    case _ => false
+  }
+}
 
 private[sql] object SparkPlanInfo {
 
   def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
-
-    val children = plan.children ++ plan.subqueries
+    val children = plan match {
+      case ReusedExchange(_, child) => child :: Nil
+      case _ => plan.children ++ plan.subqueries
+    }
     val metrics = plan.metrics.toSeq.map { case (key, metric) =>
       new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
         Utils.getFormattedClassName(metric.param))

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index f07add8..f856634 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -46,6 +46,10 @@ case class TungstenAggregate(
 
   require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes))
 
+  override lazy val allAttributes: Seq[Attribute] =
+    child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
+      aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+
   override private[sql] lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
     "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 4a9e736..4901298 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -166,6 +166,9 @@ case class Range(
   private[sql] override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
 
+  // output attributes should not affect the results
+  override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements)
+
   override def upstreams(): Seq[RDD[InternalRow]] = {
     sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
       .map(i => InternalRow(i)) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
index 40cad4b..1a5c6a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
@@ -34,12 +34,16 @@ import org.apache.spark.util.ThreadUtils
  */
 case class BroadcastExchange(
     mode: BroadcastMode,
-    child: SparkPlan) extends UnaryNode {
-
-  override def output: Seq[Attribute] = child.output
+    child: SparkPlan) extends Exchange {
 
   override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
 
+  override def sameResult(plan: SparkPlan): Boolean = plan match {
+    case p: BroadcastExchange =>
+      mode.compatibleWith(p.mode) && child.sameResult(p.child)
+    case _ => false
+  }
+
   @transient
   private val timeout: Duration = {
     val timeoutValue = sqlContext.conf.broadcastTimeout

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
new file mode 100644
index 0000000..12513e9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.exchange
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * An interface for exchanges.
+ */
+abstract class Exchange extends UnaryNode {
+  override def output: Seq[Attribute] = child.output
+}
+
+/**
+ * A wrapper for reused exchange to have different output, because two exchanges which produce
+ * logically identical output will have distinct sets of output attribute ids, so we need to
+ * preserve the original ids because they're what downstream operators are expecting.
+ */
+case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) extends LeafNode {
+
+  override def sameResult(plan: SparkPlan): Boolean = {
+    // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here.
+    plan.sameResult(child)
+  }
+
+  def doExecute(): RDD[InternalRow] = {
+    child.execute()
+  }
+
+  override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+    child.executeBroadcast()
+  }
+
+  // Do not repeat the same tree in explain.
+  override def treeChildren: Seq[SparkPlan] = Nil
+}
+
+/**
+ * Find out duplicated exchanges in the spark plan, then use the same exchange for all the
+ * references.
+ */
+private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
+
+  def apply(plan: SparkPlan): SparkPlan = {
+    if (!sqlContext.conf.exchangeReuseEnabled) {
+      return plan
+    }
+    // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
+    val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()
+    plan.transformUp {
+      case exchange: Exchange =>
+        // the exchanges that have same results usually also have same schemas (same column names).
+        val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
+        val samePlan = sameSchema.find { e =>
+          exchange.sameResult(e)
+        }
+        if (samePlan.isDefined) {
+          // Keep the output of this exchange, the following plans require that to resolve
+          // attributes.
+          ReusedExchange(exchange.output, samePlan.get)
+        } else {
+          sameSchema += exchange
+          exchange
+        }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index de21d77..4eb4d9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -38,7 +38,7 @@ import org.apache.spark.util.MutablePair
 case class ShuffleExchange(
     var newPartitioning: Partitioning,
     child: SparkPlan,
-    @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode {
+    @transient coordinator: Option[ExchangeCoordinator]) extends Exchange {
 
   override def nodeName: String = {
     val extraInfo = coordinator match {
@@ -55,8 +55,6 @@ case class ShuffleExchange(
 
   override def outputPartitioning: Partitioning = newPartitioning
 
-  override def output: Seq[Attribute] = child.output
-
   private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
 
   override protected def doPrepare(): Unit = {
@@ -103,16 +101,25 @@ case class ShuffleExchange(
     new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
   }
 
+  /**
+   * Caches the created ShuffleRowRDD so we can reuse that.
+   */
+  private var cachedShuffleRDD: ShuffledRowRDD = null
+
   protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
-    coordinator match {
-      case Some(exchangeCoordinator) =>
-        val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
-        assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
-        shuffleRDD
-      case None =>
-        val shuffleDependency = prepareShuffleDependency()
-        preparePostShuffleRDD(shuffleDependency)
+    // Returns the same ShuffleRowRDD if this plan is used by multiple plans.
+    if (cachedShuffleRDD == null) {
+      cachedShuffleRDD = coordinator match {
+        case Some(exchangeCoordinator) =>
+          val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
+          assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
+          shuffleRDD
+        case None =>
+          val shuffleDependency = prepareShuffleDependency()
+          preparePostShuffleRDD(shuffleDependency)
+      }
     }
+    cachedShuffleRDD
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 9a3cdaf..99f8841 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -681,7 +681,7 @@ private[execution] case class HashedRelationBroadcastMode(
     keys: Seq[Expression],
     attributes: Seq[Attribute]) extends BroadcastMode {
 
-  def transform(rows: Array[InternalRow]): HashedRelation = {
+  override def transform(rows: Array[InternalRow]): HashedRelation = {
     val generator = UnsafeProjection.create(keys, attributes)
     if (canJoinKeyFitWithinLong) {
       LongHashedRelation(rows.iterator, generator, rows.length)
@@ -689,5 +689,18 @@ private[execution] case class HashedRelationBroadcastMode(
       HashedRelation(rows.iterator, generator, rows.length)
     }
   }
+
+  private lazy val canonicalizedKeys: Seq[Expression] = {
+    keys.map { e =>
+      BindReferences.bindReference(e.canonicalized, attributes)
+    }
+  }
+
+  override def compatibleWith(other: BroadcastMode): Boolean = other match {
+    case m: HashedRelationBroadcastMode =>
+      canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong &&
+        canonicalizedKeys == m.canonicalizedKeys
+    case _ => false
+  }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 83372aa..94d318e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -64,7 +64,8 @@ private[sql] object SparkPlanGraph {
     val nodeIdGenerator = new AtomicLong(0)
     val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]()
     val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]()
-    buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null)
+    val exchanges = mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]()
+    buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null, exchanges)
     new SparkPlanGraph(nodes, edges)
   }
 
@@ -74,7 +75,8 @@ private[sql] object SparkPlanGraph {
       nodes: mutable.ArrayBuffer[SparkPlanGraphNode],
       edges: mutable.ArrayBuffer[SparkPlanGraphEdge],
       parent: SparkPlanGraphNode,
-      subgraph: SparkPlanGraphCluster): Unit = {
+      subgraph: SparkPlanGraphCluster,
+      exchanges: mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]): Unit = {
     planInfo.nodeName match {
       case "WholeStageCodegen" =>
         val cluster = new SparkPlanGraphCluster(
@@ -84,13 +86,14 @@ private[sql] object SparkPlanGraph {
           mutable.ArrayBuffer[SparkPlanGraphNode]())
         nodes += cluster
         buildSparkPlanGraphNode(
-          planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster)
+          planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster, exchanges)
       case "InputAdapter" =>
-        buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null)
+        buildSparkPlanGraphNode(
+          planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
       case "Subquery" if subgraph != null =>
         // Subquery should not be included in WholeStageCodegen
-        buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null)
-      case _ =>
+        buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges)
+      case name =>
         val metrics = planInfo.metrics.map { metric =>
           SQLPlanMetric(metric.name, metric.accumulatorId,
             SQLMetrics.getMetricParam(metric.metricParam))
@@ -103,12 +106,15 @@ private[sql] object SparkPlanGraph {
         } else {
           subgraph.nodes += node
         }
+        if (name == "ShuffleExchange" || name == "BroadcastExchange") {
+          exchanges += planInfo -> node
+        }
 
         if (parent != null) {
           edges += SparkPlanGraphEdge(node.id, parent.id)
         }
         planInfo.children.foreach(
-          buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph))
+          buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph, exchanges))
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 1d1e288..384102e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -504,6 +504,10 @@ object SQLConf {
       " method",
     isPublic = false)
 
+  val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse",
+    defaultValue = Some(true),
+    doc = "When true, the planner will try to find out duplicated exchanges and re-use them",
+    isPublic = false)
 
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
@@ -564,6 +568,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
 
   def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED)
 
+  def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)
+
   def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW)
 
   def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 6f81794..98ada4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -24,10 +24,9 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, PreInsertCastAndRename, ResolveDataSource}
-import org.apache.spark.sql.execution.exchange.EnsureRequirements
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
 import org.apache.spark.sql.util.ExecutionListenerManager
 
-
 /**
  * A class that holds all session-specific state in a given [[SQLContext]].
  */
@@ -94,7 +93,8 @@ private[sql] class SessionState(ctx: SQLContext) {
     override val batches: Seq[Batch] = Seq(
       Batch("Subquery", Once, PlanSubqueries(ctx)),
       Batch("Add exchange", Once, EnsureRequirements(ctx)),
-      Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx))
+      Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)),
+      Batch("Reuse duplicated exchanges", Once, ReuseExchange(ctx))
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 55153cd..26775c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -25,9 +25,9 @@ import scala.util.Random
 import org.scalatest.Matchers._
 
 import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
+import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, OneRowRelation, Union}
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
-import org.apache.spark.sql.execution.exchange.ShuffleExchange
+import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext}
@@ -1316,6 +1316,40 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
 
   }
 
+  test("reuse exchange") {
+    withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") {
+      val df = sqlContext.range(100)
+      val join = df.join(df, "id")
+      val plan = join.queryExecution.executedPlan
+      checkAnswer(join, df)
+      assert(
+        join.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1)
+      assert(join.queryExecution.executedPlan.collect { case e: ReusedExchange => true }.size === 1)
+      val broadcasted = broadcast(join)
+      val join2 = join.join(broadcasted, "id").join(broadcasted, "id")
+      checkAnswer(join2, df)
+      assert(
+        join2.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1)
+      assert(
+        join2.queryExecution.executedPlan.collect { case e: BroadcastExchange => true }.size === 1)
+      assert(
+        join2.queryExecution.executedPlan.collect { case e: ReusedExchange => true }.size === 4)
+    }
+  }
+
+  test("sameResult() on aggregate") {
+    val df = sqlContext.range(100)
+    val agg1 = df.groupBy().count()
+    val agg2 = df.groupBy().count()
+    // two aggregates with different ExprId within them should have same result
+    assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan))
+    val agg3 = df.groupBy().sum()
+    assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan))
+    val df2 = sqlContext.range(101)
+    val agg4 = df2.groupBy().count()
+    assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan))
+  }
+
   test("SPARK-12512: support `.` in column name for withColumn()") {
     val df = Seq("a" -> "b").toDF("col.a", "col.b")
     checkAnswer(df.select(df("*")), Row("a", "b"))

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index d4f22de..9f159d1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -18,8 +18,10 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
-import org.apache.spark.sql.execution.exchange.ShuffleExchange
+import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange}
+import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
 import org.apache.spark.sql.test.SharedSQLContext
 
 class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
@@ -33,4 +35,70 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
       input.map(Row.fromTuple)
     )
   }
+
+  test("compatible BroadcastMode") {
+    val mode1 = IdentityBroadcastMode
+    val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq())
+    val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq())
+
+    assert(mode1.compatibleWith(mode1))
+    assert(!mode1.compatibleWith(mode2))
+    assert(!mode2.compatibleWith(mode1))
+    assert(mode2.compatibleWith(mode2))
+    assert(!mode2.compatibleWith(mode3))
+    assert(mode3.compatibleWith(mode3))
+  }
+
+  test("BroadcastExchange same result") {
+    val df = sqlContext.range(10)
+    val plan = df.queryExecution.executedPlan
+    val output = plan.output
+    assert(plan sameResult plan)
+
+    val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan)
+    val hashMode = HashedRelationBroadcastMode(true, output, plan.output)
+    val exchange2 = BroadcastExchange(hashMode, plan)
+    val hashMode2 =
+      HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output)
+    val exchange3 = BroadcastExchange(hashMode2, plan)
+    val exchange4 = ReusedExchange(output, exchange3)
+
+    assert(exchange1 sameResult exchange1)
+    assert(exchange2 sameResult exchange2)
+    assert(exchange3 sameResult exchange3)
+    assert(exchange4 sameResult exchange4)
+
+    assert(!exchange1.sameResult(exchange2))
+    assert(!exchange2.sameResult(exchange3))
+    assert(!exchange3.sameResult(exchange4))
+    assert(exchange4 sameResult exchange3)
+  }
+
+  test("ShuffleExchange same result") {
+    val df = sqlContext.range(10)
+    val plan = df.queryExecution.executedPlan
+    val output = plan.output
+    assert(plan sameResult plan)
+
+    val part1 = HashPartitioning(output, 1)
+    val exchange1 = ShuffleExchange(part1, plan)
+    val exchange2 = ShuffleExchange(part1, plan)
+    val part2 = HashPartitioning(output, 2)
+    val exchange3 = ShuffleExchange(part2, plan)
+    val part3 = HashPartitioning(output ++ output, 2)
+    val exchange4 = ShuffleExchange(part3, plan)
+    val exchange5 = ReusedExchange(output, exchange4)
+
+    assert(exchange1 sameResult exchange1)
+    assert(exchange2 sameResult exchange2)
+    assert(exchange3 sameResult exchange3)
+    assert(exchange4 sameResult exchange4)
+    assert(exchange5 sameResult exchange5)
+
+    assert(exchange1 sameResult exchange2)
+    assert(!exchange2.sameResult(exchange3))
+    assert(!exchange3.sameResult(exchange4))
+    assert(!exchange4.sameResult(exchange5))
+    assert(exchange5 sameResult exchange4)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3dc9ae2e/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index a733237..ab0a7ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -23,15 +23,14 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
-import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchange}
+import org.apache.spark.sql.execution.columnar.InMemoryRelation
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchange, ReuseExchange, ShuffleExchange}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
-
 class PlannerSuite extends SharedSQLContext {
   import testImplicits._
 
@@ -472,6 +471,50 @@ class PlannerSuite extends SharedSQLContext {
   }
 
   // ---------------------------------------------------------------------------------------------
+
+  test("Reuse exchanges") {
+    val distribution = ClusteredDistribution(Literal(1) :: Nil)
+    val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
+    val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
+    assert(!childPartitioning.satisfies(distribution))
+    val shuffle = ShuffleExchange(finalPartitioning,
+      DummySparkPlan(
+        children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
+        requiredChildDistribution = Seq(distribution),
+        requiredChildOrdering = Seq(Seq.empty)),
+      None)
+
+    val inputPlan = SortMergeJoin(
+        Literal(1) :: Nil,
+        Literal(1) :: Nil,
+        None,
+        shuffle,
+        shuffle)
+
+    val outputPlan = ReuseExchange(sqlContext).apply(inputPlan)
+    if (outputPlan.collect { case e: ReusedExchange => true }.size != 1) {
+      fail(s"Should re-use the shuffle:\n$outputPlan")
+    }
+    if (outputPlan.collect { case e: ShuffleExchange => true }.size != 1) {
+      fail(s"Should have only one shuffle:\n$outputPlan")
+    }
+
+    // nested exchanges
+    val inputPlan2 = SortMergeJoin(
+      Literal(1) :: Nil,
+      Literal(1) :: Nil,
+      None,
+      ShuffleExchange(finalPartitioning, inputPlan),
+      ShuffleExchange(finalPartitioning, inputPlan))
+
+    val outputPlan2 = ReuseExchange(sqlContext).apply(inputPlan2)
+    if (outputPlan2.collect { case e: ReusedExchange => true }.size != 2) {
+      fail(s"Should re-use the two shuffles:\n$outputPlan2")
+    }
+    if (outputPlan2.collect { case e: ShuffleExchange => true }.size != 2) {
+      fail(s"Should have only two shuffles:\n$outputPlan")
+    }
+  }
 }
 
 // Used for unit-testing EnsureRequirements


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org