You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/31 06:30:17 UTC

git commit: [SPARK-2758] UnionRDD's UnionPartition should not reference parent RDDs

Repository: spark
Updated Branches:
  refs/heads/master e96628440 -> 894d48ffb


[SPARK-2758] UnionRDD's UnionPartition should not reference parent RDDs

Author: Reynold Xin <rx...@apache.org>

Closes #1675 from rxin/unionrdd and squashes the following commits:

941d316 [Reynold Xin] Clear RDDs for checkpointing.
c9f05f2 [Reynold Xin] [SPARK-2758] UnionRDD's UnionPartition should not reference parent RDDs


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/894d48ff
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/894d48ff
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/894d48ff

Branch: refs/heads/master
Commit: 894d48ffb8c91e347ab60c58de983e1aaf181188
Parents: e966284
Author: Reynold Xin <rx...@apache.org>
Authored: Wed Jul 30 21:30:13 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Wed Jul 30 21:30:13 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/rdd/UnionRDD.scala   | 41 ++++++++++++++------
 .../scala/org/apache/spark/rdd/RDDSuite.scala   | 12 ++++++
 2 files changed, 42 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/894d48ff/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 21c6e07..197167e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -25,21 +25,32 @@ import scala.reflect.ClassTag
 import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 
-private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int)
+/**
+ * Partition for UnionRDD.
+ *
+ * @param idx index of the partition
+ * @param rdd the parent RDD this partition refers to
+ * @param parentRddIndex index of the parent RDD this partition refers to
+ * @param parentRddPartitionIndex index of the partition within the parent RDD
+ *                                this partition refers to
+ */
+private[spark] class UnionPartition[T: ClassTag](
+    idx: Int,
+    @transient rdd: RDD[T],
+    val parentRddIndex: Int,
+    @transient parentRddPartitionIndex: Int)
   extends Partition {
 
-  var split: Partition = rdd.partitions(splitIndex)
-
-  def iterator(context: TaskContext) = rdd.iterator(split, context)
+  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)
 
-  def preferredLocations() = rdd.preferredLocations(split)
+  def preferredLocations() = rdd.preferredLocations(parentPartition)
 
   override val index: Int = idx
 
   @throws(classOf[IOException])
   private def writeObject(oos: ObjectOutputStream) {
     // Update the reference to parent split at the time of task serialization
-    split = rdd.partitions(splitIndex)
+    parentPartition = rdd.partitions(parentRddPartitionIndex)
     oos.defaultWriteObject()
   }
 }
@@ -47,14 +58,14 @@ private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitInd
 @DeveloperApi
 class UnionRDD[T: ClassTag](
     sc: SparkContext,
-    @transient var rdds: Seq[RDD[T]])
+    var rdds: Seq[RDD[T]])
   extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies
 
   override def getPartitions: Array[Partition] = {
     val array = new Array[Partition](rdds.map(_.partitions.size).sum)
     var pos = 0
-    for (rdd <- rdds; split <- rdd.partitions) {
-      array(pos) = new UnionPartition(pos, rdd, split.index)
+    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
+      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
       pos += 1
     }
     array
@@ -70,9 +81,17 @@ class UnionRDD[T: ClassTag](
     deps
   }
 
-  override def compute(s: Partition, context: TaskContext): Iterator[T] =
-    s.asInstanceOf[UnionPartition[T]].iterator(context)
+  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
+    val part = s.asInstanceOf[UnionPartition[T]]
+    val parentRdd = dependencies(part.parentRddIndex).rdd.asInstanceOf[RDD[T]]
+    parentRdd.iterator(part.parentPartition, context)
+  }
 
   override def getPreferredLocations(s: Partition): Seq[String] =
     s.asInstanceOf[UnionPartition[T]].preferredLocations()
+
+  override def clearDependencies() {
+    super.clearDependencies()
+    rdds = null
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/894d48ff/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 8966eed..ae6e525 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -121,6 +121,18 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assert(union.partitioner === nums1.partitioner)
   }
 
+  test("UnionRDD partition serialized size should be small") {
+    val largeVariable = new Array[Byte](1000 * 1000)
+    val rdd1 = sc.parallelize(1 to 10, 2).map(i => largeVariable.length)
+    val rdd2 = sc.parallelize(1 to 10, 3)
+
+    val ser = SparkEnv.get.closureSerializer.newInstance()
+    val union = rdd1.union(rdd2)
+    // The UnionRDD itself should be large, but each individual partition should be small.
+    assert(ser.serialize(union).limit() > 2000)
+    assert(ser.serialize(union.partitions.head).limit() < 2000)
+  }
+
   test("aggregate") {
     val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
     type StringMap = HashMap[String, Int]