You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/09/25 05:39:08 UTC

spark git commit: [SPARK-9852] Let reduce tasks fetch multiple map output partitions

Repository: spark
Updated Branches:
  refs/heads/master 8023242e7 -> 21fd12cb1


[SPARK-9852] Let reduce tasks fetch multiple map output partitions

This makes two changes:

- Allow reduce tasks to fetch multiple map output partitions -- this is a pretty small change to HashShuffleFetcher
- Move shuffle locality computation out of DAGScheduler and into ShuffledRDD / MapOutputTracker; this was needed because the code in DAGScheduler wouldn't work for RDDs that fetch multiple map output partitions from each reduce task

I also added an AdaptiveSchedulingSuite that creates RDDs depending on multiple map output partitions.

Author: Matei Zaharia <ma...@databricks.com>

Closes #8844 from mateiz/spark-9852.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/21fd12cb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/21fd12cb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/21fd12cb

Branch: refs/heads/master
Commit: 21fd12cb17b9e08a0cc49b4fda801af947a4183b
Parents: 8023242
Author: Matei Zaharia <ma...@databricks.com>
Authored: Thu Sep 24 23:39:04 2015 -0400
Committer: Matei Zaharia <ma...@databricks.com>
Committed: Thu Sep 24 23:39:04 2015 -0400

----------------------------------------------------------------------
 .../org/apache/spark/MapOutputTracker.scala     |  79 +++++++++--
 .../org/apache/spark/rdd/ShuffledRDD.scala      |   6 +
 .../apache/spark/scheduler/DAGScheduler.scala   |  33 +----
 .../spark/shuffle/BlockStoreShuffleReader.scala |   9 +-
 .../scheduler/AdaptiveSchedulingSuite.scala     |  47 ++++---
 .../spark/scheduler/CustomShuffledRDD.scala     | 111 +++++++++++++++
 .../spark/scheduler/DAGSchedulerSuite.scala     | 137 +++++++++++--------
 .../shuffle/BlockStoreShuffleReaderSuite.scala  |   2 +-
 .../spark/sql/execution/ShuffledRowRDD.scala    |   6 +
 9 files changed, 306 insertions(+), 124 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/21fd12cb/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 94eb8da..e4cb72e 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -134,11 +134,25 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
    */
   def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
       : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
-    logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId")
+    getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
+  }
+
+  /**
+   * Called from executors to get the server URIs and output sizes for each shuffle block that
+   * needs to be read from a given range of map output partitions (startPartition is included but
+   * endPartition is excluded from the range).
+   *
+   * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+   *         and the second item is a sequence of (shuffle block id, shuffle block size) tuples
+   *         describing the shuffle blocks that are stored at that block manager.
+   */
+  def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
     val statuses = getStatuses(shuffleId)
     // Synchronize on the returned array because, on the driver, it gets mutated in place
     statuses.synchronized {
-      return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
+      return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
     }
   }
 
@@ -262,6 +276,21 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
   /** Cache a serialized version of the output statuses for each shuffle to send them out faster */
   private var cacheEpoch = epoch
 
+  /** Whether to compute locality preferences for reduce tasks */
+  private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true)
+
+  // Number of map and reduce tasks above which we do not assign preferred locations based on map
+  // output sizes. We limit the size of jobs for which assign preferred locations as computing the
+  // top locations by size becomes expensive.
+  private val SHUFFLE_PREF_MAP_THRESHOLD = 1000
+  // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that
+  private val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000
+
+  // Fraction of total map output that must be at a location for it to considered as a preferred
+  // location for a reduce task. Making this larger will focus on fewer locations where most data
+  // can be read locally, but may lead to more delay in scheduling if those locations are busy.
+  private val REDUCER_PREF_LOCS_FRACTION = 0.2
+
   /**
    * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the driver,
    * so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set).
@@ -323,6 +352,30 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
   }
 
   /**
+   * Return the preferred hosts on which to run the given map output partition in a given shuffle,
+   * i.e. the nodes that the most outputs for that partition are on.
+   *
+   * @param dep shuffle dependency object
+   * @param partitionId map output partition that we want to read
+   * @return a sequence of host names
+   */
+  def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int)
+      : Seq[String] = {
+    if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD &&
+        dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
+      val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId,
+        dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
+      if (blockManagerIds.nonEmpty) {
+        blockManagerIds.get.map(_.host)
+      } else {
+        Nil
+      }
+    } else {
+      Nil
+    }
+  }
+
+  /**
    * Return a list of locations that each have fraction of map output greater than the specified
    * threshold.
    *
@@ -460,23 +513,25 @@ private[spark] object MapOutputTracker extends Logging {
   }
 
   /**
-   * Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block
-   * manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that
-   * block manager.
+   * Given an array of map statuses and a range of map output partitions, returns a sequence that,
+   * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes
+   * stored at that block manager.
    *
    * If any of the statuses is null (indicating a missing location due to a failed mapper),
    * throws a FetchFailedException.
    *
    * @param shuffleId Identifier for the shuffle
-   * @param reduceId Identifier for the reduce task
+   * @param startPartition Start of map output partition ID range (included in range)
+   * @param endPartition End of map output partition ID range (excluded from range)
    * @param statuses List of map statuses, indexed by map ID.
    * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
-   *         and the second item is a sequence of (shuffle block id, shuffle block size) tuples
+   *         and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
    *         describing the shuffle blocks that are stored at that block manager.
    */
   private def convertMapStatuses(
       shuffleId: Int,
-      reduceId: Int,
+      startPartition: Int,
+      endPartition: Int,
       statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
     assert (statuses != null)
     val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
@@ -484,10 +539,12 @@ private[spark] object MapOutputTracker extends Logging {
       if (status == null) {
         val errorMessage = s"Missing an output location for shuffle $shuffleId"
         logError(errorMessage)
-        throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage)
+        throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
       } else {
-        splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
-          ((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId)))
+        for (part <- startPartition until endPartition) {
+          splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
+            ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))
+        }
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/21fd12cb/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index cb15d91..a013c3f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -86,6 +86,12 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
     Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
   }
 
+  override def getPreferredLocations(partition: Partition): Seq[String] = {
+    val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
+    tracker.getPreferredLocationsForShuffle(dep, partition.index)
+  }
+
   override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
     val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
     SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)

http://git-wip-us.apache.org/repos/asf/spark/blob/21fd12cb/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 394228b..ade372b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -184,22 +184,6 @@ class DAGScheduler(
   private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
   taskScheduler.setDAGScheduler(this)
 
-  // Flag to control if reduce tasks are assigned preferred locations
-  private val shuffleLocalityEnabled =
-    sc.getConf.getBoolean("spark.shuffle.reduceLocality.enabled", true)
-  // Number of map, reduce tasks above which we do not assign preferred locations
-  // based on map output sizes. We limit the size of jobs for which assign preferred locations
-  // as computing the top locations by size becomes expensive.
-  private[this] val SHUFFLE_PREF_MAP_THRESHOLD = 1000
-  // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that
-  private[this] val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000
-
-  // Fraction of total map output that must be at a location for it to considered as a preferred
-  // location for a reduce task.
-  // Making this larger will focus on fewer locations where most data can be read locally, but
-  // may lead to more delay in scheduling if those locations are busy.
-  private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2
-
   /**
    * Called by the TaskSetManager to report task's starting.
    */
@@ -1570,25 +1554,10 @@ class DAGScheduler(
             return locs
           }
         }
+
       case _ =>
     }
 
-    // If the RDD has shuffle dependencies and shuffle locality is enabled, pick locations that
-    // have at least REDUCER_PREF_LOCS_FRACTION of data as preferred locations
-    if (shuffleLocalityEnabled && rdd.partitions.length < SHUFFLE_PREF_REDUCE_THRESHOLD) {
-      rdd.dependencies.foreach {
-        case s: ShuffleDependency[_, _, _] =>
-          if (s.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD) {
-            // Get the preferred map output locations for this reducer
-            val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId,
-              partition, rdd.partitions.length, REDUCER_PREF_LOCS_FRACTION)
-            if (topLocsForReducer.nonEmpty) {
-              return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId))
-            }
-          }
-        case _ =>
-      }
-    }
     Nil
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/21fd12cb/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 6dc9a16..7c3e2b5 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -23,6 +23,10 @@ import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
 import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.ExternalSorter
 
+/**
+ * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by
+ * requesting them from other nodes' block stores.
+ */
 private[spark] class BlockStoreShuffleReader[K, C](
     handle: BaseShuffleHandle[K, _, C],
     startPartition: Int,
@@ -32,9 +36,6 @@ private[spark] class BlockStoreShuffleReader[K, C](
     mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
   extends ShuffleReader[K, C] with Logging {
 
-  require(endPartition == startPartition + 1,
-    "Hash shuffle currently only supports fetching one partition")
-
   private val dep = handle.dependency
 
   /** Read the combined key-values for this reduce task */
@@ -43,7 +44,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
       context,
       blockManager.shuffleClient,
       blockManager,
-      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition),
+      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
       // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
       SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/21fd12cb/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala
index 3fe2802..e0f474a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.scheduler
 
-import org.apache.spark.rdd.{ShuffledRDDPartition, RDD, ShuffledRDD}
 import org.apache.spark._
 
 object AdaptiveSchedulingSuiteState {
@@ -28,26 +27,10 @@ object AdaptiveSchedulingSuiteState {
   }
 }
 
-/** A special ShuffledRDD where we can pass a ShuffleDependency object to use */
-class CustomShuffledRDD[K, V, C](@transient dep: ShuffleDependency[K, V, C])
-  extends RDD[(K, C)](dep.rdd.context, Seq(dep)) {
-
-  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
-    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
-    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
-      .read()
-      .asInstanceOf[Iterator[(K, C)]]
-  }
-
-  override def getPartitions: Array[Partition] = {
-    Array.tabulate[Partition](dep.partitioner.numPartitions)(i => new ShuffledRDDPartition(i))
-  }
-}
-
 class AdaptiveSchedulingSuite extends SparkFunSuite with LocalSparkContext {
   test("simple use of submitMapStage") {
     try {
-      sc = new SparkContext("local[1,2]", "test")
+      sc = new SparkContext("local", "test")
       val rdd = sc.parallelize(1 to 3, 3).map { x =>
         AdaptiveSchedulingSuiteState.tasksRun += 1
         (x, x)
@@ -62,4 +45,32 @@ class AdaptiveSchedulingSuite extends SparkFunSuite with LocalSparkContext {
       AdaptiveSchedulingSuiteState.clear()
     }
   }
+
+  test("fetching multiple map output partitions per reduce") {
+    sc = new SparkContext("local", "test")
+    val rdd = sc.parallelize(0 to 2, 3).map(x => (x, x))
+    val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(3))
+    val shuffled = new CustomShuffledRDD[Int, Int, Int](dep, Array(0, 2))
+    assert(shuffled.partitions.length === 2)
+    assert(shuffled.glom().map(_.toSet).collect().toSet == Set(Set((0, 0), (1, 1)), Set((2, 2))))
+  }
+
+  test("fetching all map output partitions in one reduce") {
+    sc = new SparkContext("local", "test")
+    val rdd = sc.parallelize(0 to 2, 3).map(x => (x, x))
+    // Also create lots of hash partitions so that some of them are empty
+    val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(5))
+    val shuffled = new CustomShuffledRDD[Int, Int, Int](dep, Array(0))
+    assert(shuffled.partitions.length === 1)
+    assert(shuffled.collect().toSet == Set((0, 0), (1, 1), (2, 2)))
+  }
+
+  test("more reduce tasks than map output partitions") {
+    sc = new SparkContext("local", "test")
+    val rdd = sc.parallelize(0 to 2, 3).map(x => (x, x))
+    val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(3))
+    val shuffled = new CustomShuffledRDD[Int, Int, Int](dep, Array(0, 0, 0, 1, 1, 1, 2))
+    assert(shuffled.partitions.length === 7)
+    assert(shuffled.collect().toSet == Set((0, 0), (1, 1), (2, 2)))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/21fd12cb/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala
new file mode 100644
index 0000000..d8d818c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Arrays
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+
+/**
+ * A Partitioner that might group together one or more partitions from the parent.
+ *
+ * @param parent a parent partitioner
+ * @param partitionStartIndices indices of partitions in parent that should create new partitions
+ *   in child (this should be an array of increasing partition IDs). For example, if we have a
+ *   parent with 5 partitions, and partitionStartIndices is [0, 2, 4], we get three output
+ *   partitions, corresponding to partition ranges [0, 1], [2, 3] and [4] of the parent partitioner.
+ */
+class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: Array[Int])
+  extends Partitioner {
+
+  @transient private lazy val parentPartitionMapping: Array[Int] = {
+    val n = parent.numPartitions
+    val result = new Array[Int](n)
+    for (i <- 0 until partitionStartIndices.length) {
+      val start = partitionStartIndices(i)
+      val end = if (i < partitionStartIndices.length - 1) partitionStartIndices(i + 1) else n
+      for (j <- start until end) {
+        result(j) = i
+      }
+    }
+    result
+  }
+
+  override def numPartitions: Int = partitionStartIndices.size
+
+  override def getPartition(key: Any): Int = {
+    parentPartitionMapping(parent.getPartition(key))
+  }
+
+  override def equals(other: Any): Boolean = other match {
+    case c: CoalescedPartitioner =>
+      c.parent == parent && Arrays.equals(c.partitionStartIndices, partitionStartIndices)
+    case _ =>
+      false
+  }
+}
+
+private[spark] class CustomShuffledRDDPartition(
+    val index: Int, val startIndexInParent: Int, val endIndexInParent: Int)
+  extends Partition {
+
+  override def hashCode(): Int = index
+}
+
+/**
+ * A special ShuffledRDD that supports a ShuffleDependency object from outside and launching reduce
+ * tasks that read multiple map output partitions.
+ */
+class CustomShuffledRDD[K, V, C](
+    var dependency: ShuffleDependency[K, V, C],
+    partitionStartIndices: Array[Int])
+  extends RDD[(K, C)](dependency.rdd.context, Seq(dependency)) {
+
+  def this(dep: ShuffleDependency[K, V, C]) = {
+    this(dep, (0 until dep.partitioner.numPartitions).toArray)
+  }
+
+  override def getDependencies: Seq[Dependency[_]] = List(dependency)
+
+  override val partitioner = {
+    Some(new CoalescedPartitioner(dependency.partitioner, partitionStartIndices))
+  }
+
+  override def getPartitions: Array[Partition] = {
+    val n = dependency.partitioner.numPartitions
+    Array.tabulate[Partition](partitionStartIndices.length) { i =>
+      val startIndex = partitionStartIndices(i)
+      val endIndex = if (i < partitionStartIndices.length - 1) partitionStartIndices(i + 1) else n
+      new CustomShuffledRDDPartition(i, startIndex, endIndex)
+    }
+  }
+
+  override def compute(p: Partition, context: TaskContext): Iterator[(K, C)] = {
+    val part = p.asInstanceOf[CustomShuffledRDDPartition]
+    SparkEnv.get.shuffleManager.getReader(
+      dependency.shuffleHandle, part.startIndexInParent, part.endIndexInParent, context)
+      .read()
+      .asInstanceOf[Iterator[(K, C)]]
+  }
+
+  override def clearDependencies() {
+    super.clearDependencies()
+    dependency = null
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/21fd12cb/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 6b5bcf0..697c195 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -49,19 +49,39 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
  * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and
  * preferredLocations (if any) that are passed to them. They are deliberately not executable
  * so we can test that DAGScheduler does not try to execute RDDs locally.
+ *
+ * Optionally, one can pass in a list of locations to use as preferred locations for each task,
+ * and a MapOutputTrackerMaster to enable reduce task locality. We pass the tracker separately
+ * because, in this test suite, it won't be the same as sc.env.mapOutputTracker.
  */
 class MyRDD(
     sc: SparkContext,
     numPartitions: Int,
     dependencies: List[Dependency[_]],
-    locations: Seq[Seq[String]] = Nil) extends RDD[(Int, Int)](sc, dependencies) with Serializable {
+    locations: Seq[Seq[String]] = Nil,
+    @transient tracker: MapOutputTrackerMaster = null)
+  extends RDD[(Int, Int)](sc, dependencies) with Serializable {
+
   override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
     throw new RuntimeException("should not be reached")
+
   override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition {
     override def index: Int = i
   }).toArray
-  override def getPreferredLocations(split: Partition): Seq[String] =
-    if (locations.isDefinedAt(split.index)) locations(split.index) else Nil
+
+  override def getPreferredLocations(partition: Partition): Seq[String] = {
+    if (locations.isDefinedAt(partition.index)) {
+      locations(partition.index)
+    } else if (tracker != null && dependencies.size == 1 &&
+        dependencies(0).isInstanceOf[ShuffleDependency[_, _, _]]) {
+      // If we have only one shuffle dependency, use the same code path as ShuffledRDD for locality
+      val dep = dependencies(0).asInstanceOf[ShuffleDependency[_, _, _]]
+      tracker.getPreferredLocationsForShuffle(dep, partition.index)
+    } else {
+      Nil
+    }
+  }
+
   override def toString: String = "DAGSchedulerSuiteRDD " + id
 }
 
@@ -351,7 +371,8 @@ class DAGSchedulerSuite
    */
   test("getMissingParentStages should consider all ancestor RDDs' cache statuses") {
     val rddA = new MyRDD(sc, 1, Nil)
-    val rddB = new MyRDD(sc, 1, List(new ShuffleDependency(rddA, null)))
+    val rddB = new MyRDD(sc, 1, List(new ShuffleDependency(rddA, new HashPartitioner(1))),
+      tracker = mapOutputTracker)
     val rddC = new MyRDD(sc, 1, List(new OneToOneDependency(rddB))).cache()
     val rddD = new MyRDD(sc, 1, List(new OneToOneDependency(rddC)))
     cacheLocations(rddC.id -> 0) =
@@ -458,9 +479,9 @@ class DAGSchedulerSuite
 
   test("run trivial shuffle") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
     val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
+    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0))
     complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 1)),
@@ -474,9 +495,9 @@ class DAGSchedulerSuite
 
   test("run trivial shuffle with fetch failure") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0, 1))
     complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", reduceRdd.partitions.length)),
@@ -590,9 +611,8 @@ class DAGSchedulerSuite
 
     val parts = 8
     val shuffleMapRdd = new MyRDD(sc, parts, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
-    val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep))
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, (0 until parts).toArray)
 
     completeShuffleMapStageSuccessfully(0, 0, numShufflePartitions = parts)
@@ -625,9 +645,8 @@ class DAGSchedulerSuite
     setupStageAbortTest(sc)
 
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
-    val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0, 1))
 
     for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) {
@@ -668,10 +687,10 @@ class DAGSchedulerSuite
     setupStageAbortTest(sc)
 
     val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache()
-    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
-    val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache()
-    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
-    val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo))
+    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2))
+    val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne), tracker = mapOutputTracker).cache()
+    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(1))
+    val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo), tracker = mapOutputTracker)
     submit(finalRdd, Array(0))
 
     // In the first two iterations, Stage 0 succeeds and stage 1 fails. In the next two iterations,
@@ -717,10 +736,10 @@ class DAGSchedulerSuite
     setupStageAbortTest(sc)
 
     val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache()
-    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
-    val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache()
-    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
-    val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo))
+    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2))
+    val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne), tracker = mapOutputTracker).cache()
+    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(1))
+    val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo), tracker = mapOutputTracker)
     submit(finalRdd, Array(0))
 
     // First, execute stages 0 and 1, failing stage 1 up to MAX-1 times.
@@ -777,9 +796,9 @@ class DAGSchedulerSuite
 
   test("trivial shuffle with multiple fetch failures") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0, 1))
     complete(taskSets(0), Seq(
       (Success, makeMapStatus("hostA", reduceRdd.partitions.length)),
@@ -818,9 +837,9 @@ class DAGSchedulerSuite
    */
   test("late fetch failures don't cause multiple concurrent attempts for the same map stage") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0, 1))
 
     val mapStageId = 0
@@ -886,9 +905,9 @@ class DAGSchedulerSuite
   test("extremely late fetch failures don't cause multiple concurrent attempts for " +
       "the same stage") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0, 1))
 
     def countSubmittedReduceStageAttempts(): Int = {
@@ -949,9 +968,9 @@ class DAGSchedulerSuite
 
   test("ignore late map task completions") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0, 1))
 
     // pretend we were told hostA went away
@@ -1018,8 +1037,8 @@ class DAGSchedulerSuite
 
   test("run shuffle with map stage failure") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
-    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0, 1))
 
     // Fail the map stage.  This should cause the entire job to fail.
@@ -1221,12 +1240,12 @@ class DAGSchedulerSuite
    */
   test("failure of stage used by two jobs") {
     val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
-    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, null)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2))
     val shuffleMapRdd2 = new MyRDD(sc, 2, Nil)
-    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, null)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2))
 
-    val reduceRdd1 = new MyRDD(sc, 2, List(shuffleDep1))
-    val reduceRdd2 = new MyRDD(sc, 2, List(shuffleDep1, shuffleDep2))
+    val reduceRdd1 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+    val reduceRdd2 = new MyRDD(sc, 2, List(shuffleDep1, shuffleDep2), tracker = mapOutputTracker)
 
     // We need to make our own listeners for this test, since by default submit uses the same
     // listener for all jobs, and here we want to capture the failure for each job separately.
@@ -1258,9 +1277,9 @@ class DAGSchedulerSuite
 
   test("run trivial shuffle with out-of-band failure and retry") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
+    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0))
     // blockManagerMaster.removeExecutor("exec-hostA")
     // pretend we were told hostA went away
@@ -1281,10 +1300,10 @@ class DAGSchedulerSuite
 
   test("recursive shuffle failures") {
     val shuffleOneRdd = new MyRDD(sc, 2, Nil)
-    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
-    val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne))
-    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
-    val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo))
+    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2))
+    val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne), tracker = mapOutputTracker)
+    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(1))
+    val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo), tracker = mapOutputTracker)
     submit(finalRdd, Array(0))
     // have the first stage complete normally
     complete(taskSets(0), Seq(
@@ -1310,10 +1329,10 @@ class DAGSchedulerSuite
 
   test("cached post-shuffle") {
     val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache()
-    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
-    val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache()
-    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
-    val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo))
+    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2))
+    val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne), tracker = mapOutputTracker).cache()
+    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(1))
+    val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo), tracker = mapOutputTracker)
     submit(finalRdd, Array(0))
     cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
     cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
@@ -1419,9 +1438,9 @@ class DAGSchedulerSuite
   test("reduce tasks should be placed locally with map output") {
     // Create an shuffleMapRdd with 1 partition
     val shuffleMapRdd = new MyRDD(sc, 1, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
+    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0))
     complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 1))))
@@ -1440,9 +1459,9 @@ class DAGSchedulerSuite
     val numMapTasks = 4
     // Create an shuffleMapRdd with more partitions
     val shuffleMapRdd = new MyRDD(sc, numMapTasks, Nil)
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
     val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
+    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0))
 
     val statuses = (1 to numMapTasks).map { i =>
@@ -1464,10 +1483,10 @@ class DAGSchedulerSuite
     // Create an RDD that has both a shuffle dependency and a narrow dependency (e.g. for a join)
     val rdd1 = new MyRDD(sc, 1, Nil)
     val rdd2 = new MyRDD(sc, 1, Nil, locations = Seq(Seq("hostB")))
-    val shuffleDep = new ShuffleDependency(rdd1, null)
+    val shuffleDep = new ShuffleDependency(rdd1, new HashPartitioner(1))
     val narrowDep = new OneToOneDependency(rdd2)
     val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep, narrowDep))
+    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep, narrowDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0))
     complete(taskSets(0), Seq(
       (Success, makeMapStatus("hostA", 1))))
@@ -1500,7 +1519,8 @@ class DAGSchedulerSuite
   test("simple map stage submission") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
-    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
 
     // Submit a map stage by itself
     submitMapStage(shuffleDep)
@@ -1526,7 +1546,8 @@ class DAGSchedulerSuite
   test("map stage submission with reduce stage also depending on the data") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
-    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
 
     // Submit the map stage by itself
     submitMapStage(shuffleDep)
@@ -1555,7 +1576,7 @@ class DAGSchedulerSuite
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val shuffleId = shuffleDep.shuffleId
-    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
 
     // Submit a map stage by itself
     submitMapStage(shuffleDep)
@@ -1604,9 +1625,9 @@ class DAGSchedulerSuite
   test("map stage submission with multiple shared stages and failures") {
     val rdd1 = new MyRDD(sc, 2, Nil)
     val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2))
-    val rdd2 = new MyRDD(sc, 2, List(dep1))
+    val rdd2 = new MyRDD(sc, 2, List(dep1), tracker = mapOutputTracker)
     val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2))
-    val rdd3 = new MyRDD(sc, 2, List(dep2))
+    val rdd3 = new MyRDD(sc, 2, List(dep2), tracker = mapOutputTracker)
 
     val listener1 = new SimpleListener
     val listener2 = new SimpleListener
@@ -1712,7 +1733,7 @@ class DAGSchedulerSuite
     assertDataStructuresEmpty()
 
     // Also test that a reduce stage using this shuffled data can immediately run
-    val reduceRDD = new MyRDD(sc, 2, List(shuffleDep))
+    val reduceRDD = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
     results.clear()
     submit(reduceRDD, Array(0, 1))
     complete(taskSets(2), Seq((Success, 42), (Success, 43)))

http://git-wip-us.apache.org/repos/asf/spark/blob/21fd12cb/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
index a5eafb1..26a372d 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
@@ -114,7 +114,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
     // Make a mocked MapOutputTracker for the shuffle reader to use to determine what
     // shuffle data to read.
     val mapOutputTracker = mock(classOf[MapOutputTracker])
-    when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn {
+    when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn {
       // Test a scenario where all data is local, to avoid creating a bunch of additional mocks
       // for the code to read data over the network.
       val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>

http://git-wip-us.apache.org/repos/asf/spark/blob/21fd12cb/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
index 88f5b13..743c99a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
@@ -65,6 +65,12 @@ class ShuffledRowRDD(
     Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRowRDDPartition(i))
   }
 
+  override def getPreferredLocations(partition: Partition): Seq[String] = {
+    val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+    val dep = dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+    tracker.getPreferredLocationsForShuffle(dep, partition.index)
+  }
+
   override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
     val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, InternalRow, InternalRow]]
     SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org