You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/08 11:41:51 UTC

spark git commit: [SPARK-21865][SQL] simplify the distribution semantic of Spark SQL

Repository: spark
Updated Branches:
  refs/heads/master 2c73d2a94 -> eb45b52e8


[SPARK-21865][SQL] simplify the distribution semantic of Spark SQL

## What changes were proposed in this pull request?

**The current shuffle planning logic**

1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface.
2. Each operator specifies its output partitioning, via the `Partitioning` interface.
3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`.
4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution.
5. For each operator, check if its children's output partitionings are compatible with each other, via the `Partitioning.compatibleWith`.
6. If the check in 5 failed, add a shuffle above each child.
7. try to eliminate the shuffles added in 6, via `Partitioning.guarantees`.

This design has a major problem with the definition of "compatible".

`Partitioning.compatibleWith` is not well defined, ideally a `Partitioning` can't know if it's compatible with other `Partitioning`, without more information from the operator. For example, `t1 join t2 on t1.a = t2.b`, `HashPartitioning(a, 10)` should be compatible with `HashPartitioning(b, 10)` under this case, but the partitioning itself doesn't know it.

As a result, currently `Partitioning.compatibleWith` always return false except for literals, which make it almost useless. This also means, if an operator has distribution requirements for multiple children, Spark always add shuffle nodes to all the children(although some of them can be eliminated). However, there is no guarantee that the children's output partitionings are compatible with each other after adding these shuffles, we just assume that the operator will only specify `ClusteredDistribution` for multiple children.

I think it's very hard to guarantee children co-partition for all kinds of operators, and we can not even give a clear definition about co-partition between distributions like `ClusteredDistribution(a,b)` and `ClusteredDistribution(c)`.

I think we should drop the "compatible" concept in the distribution model, and let the operator achieve the co-partition requirement by special distribution requirements.

**Proposed shuffle planning logic after this PR**
(The first 4 are same as before)
1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface.
2. Each operator specifies its output partitioning, via the `Partitioning` interface.
3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`.
4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution.
5. For each operator, check if its children's output partitionings have the same number of partitions.
6. If the check in 5 failed, pick the max number of partitions from children's output partitionings, and add shuffle to child whose number of partitions doesn't equal to the max one.

The new distribution model is very simple, we only have one kind of relationship, which is `Partitioning.satisfy`. For multiple children, Spark only guarantees they have the same number of partitions, and it's the operator's responsibility to leverage this guarantee to achieve more complicated requirements. For example, non-broadcast joins can use the newly added `HashPartitionedDistribution` to achieve co-partition.

## How was this patch tested?

existing tests.

Author: Wenchen Fan <we...@databricks.com>

Closes #19080 from cloud-fan/exchange.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eb45b52e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eb45b52e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eb45b52e

Branch: refs/heads/master
Commit: eb45b52e826ea9cea48629760db35ef87f91fea0
Parents: 2c73d2a
Author: Wenchen Fan <we...@databricks.com>
Authored: Mon Jan 8 19:41:41 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Jan 8 19:41:41 2018 +0800

----------------------------------------------------------------------
 .../catalyst/plans/physical/partitioning.scala  | 286 +++++++------------
 .../spark/sql/catalyst/PartitioningSuite.scala  |  55 ----
 .../apache/spark/sql/execution/SparkPlan.scala  |  16 +-
 .../execution/exchange/EnsureRequirements.scala | 120 +++-----
 .../execution/joins/ShuffledHashJoinExec.scala  |   2 +-
 .../sql/execution/joins/SortMergeJoinExec.scala |   2 +-
 .../apache/spark/sql/execution/objects.scala    |   2 +-
 .../spark/sql/execution/PlannerSuite.scala      |  81 ++----
 8 files changed, 194 insertions(+), 370 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index e57c842..0189bd7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -30,18 +30,43 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
  *  - Intra-partition ordering of data: In this case the distribution describes guarantees made
  *    about how tuples are distributed within a single partition.
  */
-sealed trait Distribution
+sealed trait Distribution {
+  /**
+   * The required number of partitions for this distribution. If it's None, then any number of
+   * partitions is allowed for this distribution.
+   */
+  def requiredNumPartitions: Option[Int]
+
+  /**
+   * Creates a default partitioning for this distribution, which can satisfy this distribution while
+   * matching the given number of partitions.
+   */
+  def createPartitioning(numPartitions: Int): Partitioning
+}
 
 /**
  * Represents a distribution where no promises are made about co-location of data.
  */
-case object UnspecifiedDistribution extends Distribution
+case object UnspecifiedDistribution extends Distribution {
+  override def requiredNumPartitions: Option[Int] = None
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    throw new IllegalStateException("UnspecifiedDistribution does not have default partitioning.")
+  }
+}
 
 /**
  * Represents a distribution that only has a single partition and all tuples of the dataset
  * are co-located.
  */
-case object AllTuples extends Distribution
+case object AllTuples extends Distribution {
+  override def requiredNumPartitions: Option[Int] = Some(1)
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    assert(numPartitions == 1, "The default partitioning of AllTuples can only have 1 partition.")
+    SinglePartition
+  }
+}
 
 /**
  * Represents data where tuples that share the same values for the `clustering`
@@ -51,12 +76,41 @@ case object AllTuples extends Distribution
  */
 case class ClusteredDistribution(
     clustering: Seq[Expression],
-    numPartitions: Option[Int] = None) extends Distribution {
+    requiredNumPartitions: Option[Int] = None) extends Distribution {
   require(
     clustering != Nil,
     "The clustering expressions of a ClusteredDistribution should not be Nil. " +
       "An AllTuples should be used to represent a distribution that only has " +
       "a single partition.")
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
+      s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
+        s"the actual number of partitions is $numPartitions.")
+    HashPartitioning(clustering, numPartitions)
+  }
+}
+
+/**
+ * Represents data where tuples have been clustered according to the hash of the given
+ * `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only
+ * [[HashPartitioning]] can satisfy this distribution.
+ *
+ * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the
+ * number of partitions, this distribution strictly requires which partition the tuple should be in.
+ */
+case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution {
+  require(
+    expressions != Nil,
+    "The expressions for hash of a HashPartitionedDistribution should not be Nil. " +
+      "An AllTuples should be used to represent a distribution that only has " +
+      "a single partition.")
+
+  override def requiredNumPartitions: Option[Int] = None
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    HashPartitioning(expressions, numPartitions)
+  }
 }
 
 /**
@@ -73,46 +127,31 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
       "An AllTuples should be used to represent a distribution that only has " +
       "a single partition.")
 
-  // TODO: This is not really valid...
-  def clustering: Set[Expression] = ordering.map(_.child).toSet
+  override def requiredNumPartitions: Option[Int] = None
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    RangePartitioning(ordering, numPartitions)
+  }
 }
 
 /**
  * Represents data where tuples are broadcasted to every node. It is quite common that the
  * entire set of tuples is transformed into different data structure.
  */
-case class BroadcastDistribution(mode: BroadcastMode) extends Distribution
+case class BroadcastDistribution(mode: BroadcastMode) extends Distribution {
+  override def requiredNumPartitions: Option[Int] = Some(1)
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    assert(numPartitions == 1,
+      "The default partitioning of BroadcastDistribution can only have 1 partition.")
+    BroadcastPartitioning(mode)
+  }
+}
 
 /**
- * Describes how an operator's output is split across partitions. The `compatibleWith`,
- * `guarantees`, and `satisfies` methods describe relationships between child partitionings,
- * target partitionings, and [[Distribution]]s. These relations are described more precisely in
- * their individual method docs, but at a high level:
- *
- *  - `satisfies` is a relationship between partitionings and distributions.
- *  - `compatibleWith` is relationships between an operator's child output partitionings.
- *  - `guarantees` is a relationship between a child's existing output partitioning and a target
- *     output partitioning.
- *
- *  Diagrammatically:
- *
- *            +--------------+
- *            | Distribution |
- *            +--------------+
- *                    ^
- *                    |
- *               satisfies
- *                    |
- *            +--------------+                  +--------------+
- *            |    Child     |                  |    Target    |
- *       +----| Partitioning |----guarantees--->| Partitioning |
- *       |    +--------------+                  +--------------+
- *       |            ^
- *       |            |
- *       |     compatibleWith
- *       |            |
- *       +------------+
- *
+ * Describes how an operator's output is split across partitions. It has 2 major properties:
+ *   1. number of partitions.
+ *   2. if it can satisfy a given distribution.
  */
 sealed trait Partitioning {
   /** Returns the number of partitions that the data is split across */
@@ -123,113 +162,35 @@ sealed trait Partitioning {
    * to satisfy the partitioning scheme mandated by the `required` [[Distribution]],
    * i.e. the current dataset does not need to be re-partitioned for the `required`
    * Distribution (it is possible that tuples within a partition need to be reorganized).
-   */
-  def satisfies(required: Distribution): Boolean
-
-  /**
-   * Returns true iff we can say that the partitioning scheme of this [[Partitioning]]
-   * guarantees the same partitioning scheme described by `other`.
-   *
-   * Compatibility of partitionings is only checked for operators that have multiple children
-   * and that require a specific child output [[Distribution]], such as joins.
-   *
-   * Intuitively, partitionings are compatible if they route the same partitioning key to the same
-   * partition. For instance, two hash partitionings are only compatible if they produce the same
-   * number of output partitionings and hash records according to the same hash function and
-   * same partitioning key schema.
-   *
-   * Put another way, two partitionings are compatible with each other if they satisfy all of the
-   * same distribution guarantees.
-   */
-  def compatibleWith(other: Partitioning): Boolean
-
-  /**
-   * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees
-   * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning
-   * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance
-   * optimization to allow the exchange planner to avoid redundant repartitionings. By default,
-   * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number
-   * of partitions, same strategy (range or hash), etc).
-   *
-   * In order to enable more aggressive optimization, this strict equality check can be relaxed.
-   * For example, say that the planner needs to repartition all of an operator's children so that
-   * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children
-   * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens
-   * to be hash-partitioned with a single partition then we do not need to re-shuffle this child;
-   * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees`
-   * [[SinglePartition]].
-   *
-   * The SinglePartition example given above is not particularly interesting; guarantees' real
-   * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion
-   * of null-safe partitionings, under which partitionings can specify whether rows whose
-   * partitioning keys contain null values will be grouped into the same partition or whether they
-   * will have an unknown / random distribution. If a partitioning does not require nulls to be
-   * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered
-   * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot
-   * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a
-   * symmetric relation.
    *
-   * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows
-   * produced by `A` could have also been produced by `B`.
+   * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if
+   * the [[Partitioning]] only have one partition. Implementations can overwrite this method with
+   * special logic.
    */
-  def guarantees(other: Partitioning): Boolean = this == other
-}
-
-object Partitioning {
-  def allCompatible(partitionings: Seq[Partitioning]): Boolean = {
-    // Note: this assumes transitivity
-    partitionings.sliding(2).map {
-      case Seq(a) => true
-      case Seq(a, b) =>
-        if (a.numPartitions != b.numPartitions) {
-          assert(!a.compatibleWith(b) && !b.compatibleWith(a))
-          false
-        } else {
-          a.compatibleWith(b) && b.compatibleWith(a)
-        }
-    }.forall(_ == true)
-  }
-}
-
-case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
-  override def satisfies(required: Distribution): Boolean = required match {
+  def satisfies(required: Distribution): Boolean = required match {
     case UnspecifiedDistribution => true
+    case AllTuples => numPartitions == 1
     case _ => false
   }
-
-  override def compatibleWith(other: Partitioning): Boolean = false
-
-  override def guarantees(other: Partitioning): Boolean = false
 }
 
+case class UnknownPartitioning(numPartitions: Int) extends Partitioning
+
 /**
  * Represents a partitioning where rows are distributed evenly across output partitions
  * by starting from a random target partition number and distributing rows in a round-robin
  * fashion. This partitioning is used when implementing the DataFrame.repartition() operator.
  */
-case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning {
-  override def satisfies(required: Distribution): Boolean = required match {
-    case UnspecifiedDistribution => true
-    case _ => false
-  }
-
-  override def compatibleWith(other: Partitioning): Boolean = false
-
-  override def guarantees(other: Partitioning): Boolean = false
-}
+case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning
 
 case object SinglePartition extends Partitioning {
   val numPartitions = 1
 
   override def satisfies(required: Distribution): Boolean = required match {
     case _: BroadcastDistribution => false
-    case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1)
+    case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1
     case _ => true
   }
-
-  override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1
-
-  override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1
 }
 
 /**
@@ -244,22 +205,19 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
   override def nullable: Boolean = false
   override def dataType: DataType = IntegerType
 
-  override def satisfies(required: Distribution): Boolean = required match {
-    case UnspecifiedDistribution => true
-    case ClusteredDistribution(requiredClustering, desiredPartitions) =>
-      expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
-        desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true
-    case _ => false
-  }
-
-  override def compatibleWith(other: Partitioning): Boolean = other match {
-    case o: HashPartitioning => this.semanticEquals(o)
-    case _ => false
-  }
-
-  override def guarantees(other: Partitioning): Boolean = other match {
-    case o: HashPartitioning => this.semanticEquals(o)
-    case _ => false
+  override def satisfies(required: Distribution): Boolean = {
+    super.satisfies(required) || {
+      required match {
+        case h: HashClusteredDistribution =>
+          expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
+            case (l, r) => l.semanticEquals(r)
+          }
+        case ClusteredDistribution(requiredClustering, requiredNumPartitions) =>
+          expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
+            (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions)
+        case _ => false
+      }
+    }
   }
 
   /**
@@ -288,25 +246,18 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
   override def nullable: Boolean = false
   override def dataType: DataType = IntegerType
 
-  override def satisfies(required: Distribution): Boolean = required match {
-    case UnspecifiedDistribution => true
-    case OrderedDistribution(requiredOrdering) =>
-      val minSize = Seq(requiredOrdering.size, ordering.size).min
-      requiredOrdering.take(minSize) == ordering.take(minSize)
-    case ClusteredDistribution(requiredClustering, desiredPartitions) =>
-      ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
-        desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true
-    case _ => false
-  }
-
-  override def compatibleWith(other: Partitioning): Boolean = other match {
-    case o: RangePartitioning => this.semanticEquals(o)
-    case _ => false
-  }
-
-  override def guarantees(other: Partitioning): Boolean = other match {
-    case o: RangePartitioning => this.semanticEquals(o)
-    case _ => false
+  override def satisfies(required: Distribution): Boolean = {
+    super.satisfies(required) || {
+      required match {
+        case OrderedDistribution(requiredOrdering) =>
+          val minSize = Seq(requiredOrdering.size, ordering.size).min
+          requiredOrdering.take(minSize) == ordering.take(minSize)
+        case ClusteredDistribution(requiredClustering, requiredNumPartitions) =>
+          ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
+            (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions)
+        case _ => false
+      }
+    }
   }
 }
 
@@ -347,20 +298,6 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
   override def satisfies(required: Distribution): Boolean =
     partitionings.exists(_.satisfies(required))
 
-  /**
-   * Returns true if any `partitioning` of this collection is compatible with
-   * the given [[Partitioning]].
-   */
-  override def compatibleWith(other: Partitioning): Boolean =
-    partitionings.exists(_.compatibleWith(other))
-
-  /**
-   * Returns true if any `partitioning` of this collection guarantees
-   * the given [[Partitioning]].
-   */
-  override def guarantees(other: Partitioning): Boolean =
-    partitionings.exists(_.guarantees(other))
-
   override def toString: String = {
     partitionings.map(_.toString).mkString("(", " or ", ")")
   }
@@ -377,9 +314,4 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
     case BroadcastDistribution(m) if m == mode => true
     case _ => false
   }
-
-  override def compatibleWith(other: Partitioning): Boolean = other match {
-    case BroadcastPartitioning(m) if m == mode => true
-    case _ => false
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala
deleted file mode 100644
index 5b802cc..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal}
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning}
-
-class PartitioningSuite extends SparkFunSuite {
-  test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") {
-    val expressions = Seq(Literal(2), Literal(3))
-    // Consider two HashPartitionings that have the same _set_ of hash expressions but which are
-    // created with different orderings of those expressions:
-    val partitioningA = HashPartitioning(expressions, 100)
-    val partitioningB = HashPartitioning(expressions.reverse, 100)
-    // These partitionings are not considered equal:
-    assert(partitioningA != partitioningB)
-    // However, they both satisfy the same clustered distribution:
-    val distribution = ClusteredDistribution(expressions)
-    assert(partitioningA.satisfies(distribution))
-    assert(partitioningB.satisfies(distribution))
-    // These partitionings compute different hashcodes for the same input row:
-    def computeHashCode(partitioning: HashPartitioning): Int = {
-      val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty)
-      hashExprProj.apply(InternalRow.empty).hashCode()
-    }
-    assert(computeHashCode(partitioningA) != computeHashCode(partitioningB))
-    // Thus, these partitionings are incompatible:
-    assert(!partitioningA.compatibleWith(partitioningB))
-    assert(!partitioningB.compatibleWith(partitioningA))
-    assert(!partitioningA.guarantees(partitioningB))
-    assert(!partitioningB.guarantees(partitioningA))
-
-    // Just to be sure that we haven't cheated by having these methods always return false,
-    // check that identical partitionings are still compatible with and guarantee each other:
-    assert(partitioningA === partitioningA)
-    assert(partitioningA.guarantees(partitioningA))
-    assert(partitioningA.compatibleWith(partitioningA))
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 787c1cf..82300ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -94,7 +94,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
   /** Specifies how data is partitioned across different nodes in the cluster. */
   def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
 
-  /** Specifies any partition requirements on the input data for this operator. */
+  /**
+   * Specifies the data distribution requirements of all the children for this operator. By default
+   * it's [[UnspecifiedDistribution]] for each child, which means each child can have any
+   * distribution.
+   *
+   * If an operator overwrites this method, and specifies distribution requirements(excluding
+   * [[UnspecifiedDistribution]] and [[BroadcastDistribution]]) for more than one child, Spark
+   * guarantees that the outputs of these children will have same number of partitions, so that the
+   * operator can safely zip partitions of these children's result RDDs. Some operators can leverage
+   * this guarantee to satisfy some interesting requirement, e.g., non-broadcast joins can specify
+   * HashClusteredDistribution(a,b) for its left child, and specify HashClusteredDistribution(c,d)
+   * for its right child, then it's guaranteed that left and right child are co-partitioned by
+   * a,b/c,d, which means tuples of same value are in the partitions of same index, e.g.,
+   * (a=1,b=2) and (c=1,d=2) are both in the second partition of left and right child.
+   */
   def requiredChildDistribution: Seq[Distribution] =
     Seq.fill(children.size)(UnspecifiedDistribution)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index c8e236b..e3d2838 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -47,23 +47,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
   }
 
   /**
-   * Given a required distribution, returns a partitioning that satisfies that distribution.
-   * @param requiredDistribution The distribution that is required by the operator
-   * @param numPartitions Used when the distribution doesn't require a specific number of partitions
-   */
-  private def createPartitioning(
-      requiredDistribution: Distribution,
-      numPartitions: Int): Partitioning = {
-    requiredDistribution match {
-      case AllTuples => SinglePartition
-      case ClusteredDistribution(clustering, desiredPartitions) =>
-        HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions))
-      case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
-      case dist => sys.error(s"Do not know how to satisfy distribution $dist")
-    }
-  }
-
-  /**
    * Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled
    * and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]].
    */
@@ -88,8 +71,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
         // shuffle data when we have more than one children because data generated by
         // these children may not be partitioned in the same way.
         // Please see the comment in withCoordinator for more details.
-        val supportsDistribution =
-          requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution])
+        val supportsDistribution = requiredChildDistributions.forall { dist =>
+          dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution]
+        }
         children.length > 1 && supportsDistribution
       }
 
@@ -142,8 +126,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
             //
             // It will be great to introduce a new Partitioning to represent the post-shuffle
             // partitions when one post-shuffle partition includes multiple pre-shuffle partitions.
-            val targetPartitioning =
-              createPartitioning(distribution, defaultNumPreShufflePartitions)
+            val targetPartitioning = distribution.createPartitioning(defaultNumPreShufflePartitions)
             assert(targetPartitioning.isInstanceOf[HashPartitioning])
             ShuffleExchangeExec(targetPartitioning, child, Some(coordinator))
         }
@@ -162,71 +145,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
     assert(requiredChildDistributions.length == children.length)
     assert(requiredChildOrderings.length == children.length)
 
-    // Ensure that the operator's children satisfy their output distribution requirements:
+    // Ensure that the operator's children satisfy their output distribution requirements.
     children = children.zip(requiredChildDistributions).map {
       case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
         child
       case (child, BroadcastDistribution(mode)) =>
         BroadcastExchangeExec(mode, child)
       case (child, distribution) =>
-        ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
+        val numPartitions = distribution.requiredNumPartitions
+          .getOrElse(defaultNumPreShufflePartitions)
+        ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
     }
 
-    // If the operator has multiple children and specifies child output distributions (e.g. join),
-    // then the children's output partitionings must be compatible:
-    def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match {
-      case UnspecifiedDistribution => false
-      case BroadcastDistribution(_) => false
+    // Get the indexes of children which have specified distribution requirements and need to have
+    // same number of partitions.
+    val childrenIndexes = requiredChildDistributions.zipWithIndex.filter {
+      case (UnspecifiedDistribution, _) => false
+      case (_: BroadcastDistribution, _) => false
       case _ => true
-    }
-    if (children.length > 1
-        && requiredChildDistributions.exists(requireCompatiblePartitioning)
-        && !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
+    }.map(_._2)
 
-      // First check if the existing partitions of the children all match. This means they are
-      // partitioned by the same partitioning into the same number of partitions. In that case,
-      // don't try to make them match `defaultPartitions`, just use the existing partitioning.
-      val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max
-      val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
-        case (child, distribution) =>
-          child.outputPartitioning.guarantees(
-            createPartitioning(distribution, maxChildrenNumPartitions))
+    val childrenNumPartitions =
+      childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet
+
+    if (childrenNumPartitions.size > 1) {
+      // Get the number of partitions which is explicitly required by the distributions.
+      val requiredNumPartitions = {
+        val numPartitionsSet = childrenIndexes.flatMap {
+          index => requiredChildDistributions(index).requiredNumPartitions
+        }.toSet
+        assert(numPartitionsSet.size <= 1,
+          s"$operator have incompatible requirements of the number of partitions for its children")
+        numPartitionsSet.headOption
       }
 
-      children = if (useExistingPartitioning) {
-        // We do not need to shuffle any child's output.
-        children
-      } else {
-        // We need to shuffle at least one child's output.
-        // Now, we will determine the number of partitions that will be used by created
-        // partitioning schemes.
-        val numPartitions = {
-          // Let's see if we need to shuffle all child's outputs when we use
-          // maxChildrenNumPartitions.
-          val shufflesAllChildren = children.zip(requiredChildDistributions).forall {
-            case (child, distribution) =>
-              !child.outputPartitioning.guarantees(
-                createPartitioning(distribution, maxChildrenNumPartitions))
-          }
-          // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the
-          // number of partitions. Otherwise, we use maxChildrenNumPartitions.
-          if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions
-        }
+      val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max)
 
-        children.zip(requiredChildDistributions).map {
-          case (child, distribution) =>
-            val targetPartitioning = createPartitioning(distribution, numPartitions)
-            if (child.outputPartitioning.guarantees(targetPartitioning)) {
-              child
-            } else {
-              child match {
-                // If child is an exchange, we replace it with
-                // a new one having targetPartitioning.
-                case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c)
-                case _ => ShuffleExchangeExec(targetPartitioning, child)
-              }
+      children = children.zip(requiredChildDistributions).zipWithIndex.map {
+        case ((child, distribution), index) if childrenIndexes.contains(index) =>
+          if (child.outputPartitioning.numPartitions == targetNumPartitions) {
+            child
+          } else {
+            val defaultPartitioning = distribution.createPartitioning(targetNumPartitions)
+            child match {
+              // If child is an exchange, we replace it with a new one having defaultPartitioning.
+              case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c)
+              case _ => ShuffleExchangeExec(defaultPartitioning, child)
+            }
           }
-        }
+
+        case ((child, _), _) => child
       }
     }
 
@@ -324,10 +292,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
   }
 
   def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
-    case operator @ ShuffleExchangeExec(partitioning, child, _) =>
-      child.children match {
-        case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil =>
-          if (childPartitioning.guarantees(partitioning)) child else operator
+    // TODO: remove this after we create a physical operator for `RepartitionByExpression`.
+    case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
+      child.outputPartitioning match {
+        case lower: HashPartitioning if upper.semanticEquals(lower) => child
         case _ => operator
       }
     case operator: SparkPlan =>

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 66e8031..897a4da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -46,7 +46,7 @@ case class ShuffledHashJoinExec(
     "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
 
   override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+    HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
 
   private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
     val buildDataSize = longMetric("buildDataSize")

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 9440541..2de2f30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -78,7 +78,7 @@ case class SortMergeJoinExec(
   }
 
   override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+    HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
 
   override def outputOrdering: Seq[SortOrder] = joinType match {
     // For inner join, orders of both sides keys should be kept.

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index d1bd8a7..03d1bbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -456,7 +456,7 @@ case class CoGroupExec(
     right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
 
   override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
+    HashClusteredDistribution(leftGroup) :: HashClusteredDistribution(rightGroup) :: Nil
 
   override def requiredChildOrdering: Seq[Seq[SortOrder]] =
     leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index b50642d..f8b26f5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -260,11 +260,16 @@ class PlannerSuite extends SharedSQLContext {
   // do they satisfy the distribution requirements? As a result, we need at least four test cases.
 
   private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = {
-    if (outputPlan.children.length > 1
-        && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) {
-      val childPartitionings = outputPlan.children.map(_.outputPartitioning)
-      if (!Partitioning.allCompatible(childPartitionings)) {
-        fail(s"Partitionings are not compatible: $childPartitionings")
+    if (outputPlan.children.length > 1) {
+      val childPartitionings = outputPlan.children.zip(outputPlan.requiredChildDistribution)
+        .filter {
+          case (_, UnspecifiedDistribution) => false
+          case (_, _: BroadcastDistribution) => false
+          case _ => true
+        }.map(_._1.outputPartitioning)
+
+      if (childPartitionings.map(_.numPartitions).toSet.size > 1) {
+        fail(s"Partitionings doesn't have same number of partitions: $childPartitionings")
       }
     }
     outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach {
@@ -274,40 +279,7 @@ class PlannerSuite extends SharedSQLContext {
     }
   }
 
-  test("EnsureRequirements with incompatible child partitionings which satisfy distribution") {
-    // Consider an operator that requires inputs that are clustered by two expressions (e.g.
-    // sort merge join where there are multiple columns in the equi-join condition)
-    val clusteringA = Literal(1) :: Nil
-    val clusteringB = Literal(2) :: Nil
-    val distribution = ClusteredDistribution(clusteringA ++ clusteringB)
-    // Say that the left and right inputs are each partitioned by _one_ of the two join columns:
-    val leftPartitioning = HashPartitioning(clusteringA, 1)
-    val rightPartitioning = HashPartitioning(clusteringB, 1)
-    // Individually, each input's partitioning satisfies the clustering distribution:
-    assert(leftPartitioning.satisfies(distribution))
-    assert(rightPartitioning.satisfies(distribution))
-    // However, these partitionings are not compatible with each other, so we still need to
-    // repartition both inputs prior to performing the join:
-    assert(!leftPartitioning.compatibleWith(rightPartitioning))
-    assert(!rightPartitioning.compatibleWith(leftPartitioning))
-    val inputPlan = DummySparkPlan(
-      children = Seq(
-        DummySparkPlan(outputPartitioning = leftPartitioning),
-        DummySparkPlan(outputPartitioning = rightPartitioning)
-      ),
-      requiredChildDistribution = Seq(distribution, distribution),
-      requiredChildOrdering = Seq(Seq.empty, Seq.empty)
-    )
-    val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
-    assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) {
-      fail(s"Exchange should have been added:\n$outputPlan")
-    }
-  }
-
   test("EnsureRequirements with child partitionings with different numbers of output partitions") {
-    // This is similar to the previous test, except it checks that partitionings are not compatible
-    // unless they produce the same number of partitions.
     val clustering = Literal(1) :: Nil
     val distribution = ClusteredDistribution(clustering)
     val inputPlan = DummySparkPlan(
@@ -386,18 +358,15 @@ class PlannerSuite extends SharedSQLContext {
     }
   }
 
-  test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") {
+  test("EnsureRequirements eliminates Exchange if child has same partitioning") {
     val distribution = ClusteredDistribution(Literal(1) :: Nil)
-    val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
-    val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
-    assert(!childPartitioning.satisfies(distribution))
-    val inputPlan = ShuffleExchangeExec(finalPartitioning,
-      DummySparkPlan(
-        children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
-        requiredChildDistribution = Seq(distribution),
-        requiredChildOrdering = Seq(Seq.empty)),
-      None)
+    val partitioning = HashPartitioning(Literal(1) :: Nil, 5)
+    assert(partitioning.satisfies(distribution))
 
+    val inputPlan = ShuffleExchangeExec(
+      partitioning,
+      DummySparkPlan(outputPartitioning = partitioning),
+      None)
     val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
     if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 2) {
@@ -407,17 +376,13 @@ class PlannerSuite extends SharedSQLContext {
 
   test("EnsureRequirements does not eliminate Exchange with different partitioning") {
     val distribution = ClusteredDistribution(Literal(1) :: Nil)
-    // Number of partitions differ
-    val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8)
-    val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
-    assert(!childPartitioning.satisfies(distribution))
-    val inputPlan = ShuffleExchangeExec(finalPartitioning,
-      DummySparkPlan(
-        children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
-        requiredChildDistribution = Seq(distribution),
-        requiredChildOrdering = Seq(Seq.empty)),
-      None)
+    val partitioning = HashPartitioning(Literal(2) :: Nil, 5)
+    assert(!partitioning.satisfies(distribution))
 
+    val inputPlan = ShuffleExchangeExec(
+      partitioning,
+      DummySparkPlan(outputPartitioning = partitioning),
+      None)
     val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
     if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org