You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2013/11/26 19:27:51 UTC

git commit: Merge pull request #201 from rxin/mappartitions

Updated Branches:
  refs/heads/branch-0.8 994956183 -> be9c176a8


Merge pull request #201 from rxin/mappartitions

Use the proper partition index in mapPartitionsWIthIndex

mapPartitionsWithIndex uses TaskContext.partitionId as the partition index. TaskContext.partitionId used to be identical to the partition index in a RDD. However, pull request #186 introduced a scenario (with partition pruning) that the two can be different. This pull request uses the right partition index in all mapPartitionsWithIndex related calls.

Also removed the extra MapPartitionsWIthContextRDD and put all the mapPartitions related functionality in MapPartitionsRDD.

(cherry picked from commit 14bb465bb3d65f5b1034ada85cfcad7460034073)
Signed-off-by: Reynold Xin <rx...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/be9c176a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/be9c176a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/be9c176a

Branch: refs/heads/branch-0.8
Commit: be9c176a81fba5975c35937ca582a3ef26bda4f5
Parents: 9949561
Author: Matei Zaharia <ma...@eecs.berkeley.edu>
Authored: Mon Nov 25 18:50:18 2013 -0800
Committer: Reynold Xin <rx...@apache.org>
Committed: Tue Nov 26 10:27:41 2013 -0800

----------------------------------------------------------------------
 .../org/apache/spark/rdd/MapPartitionsRDD.scala | 10 ++---
 .../spark/rdd/MapPartitionsWithContextRDD.scala | 41 --------------------
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 39 +++++++++----------
 .../org/apache/spark/CheckpointSuite.scala      |  2 -
 4 files changed, 22 insertions(+), 70 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/be9c176a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index 203179c..ae70d55 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -20,18 +20,16 @@ package org.apache.spark.rdd
 import org.apache.spark.{Partition, TaskContext}
 
 
-private[spark]
-class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
+private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
     prev: RDD[T],
-    f: Iterator[T] => Iterator[U],
+    f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
     preservesPartitioning: Boolean = false)
   extends RDD[U](prev) {
 
-  override val partitioner =
-    if (preservesPartitioning) firstParent[T].partitioner else None
+  override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
 
   override def getPartitions: Array[Partition] = firstParent[T].partitions
 
   override def compute(split: Partition, context: TaskContext) =
-    f(firstParent[T].iterator(split, context))
+    f(context, split.index, firstParent[T].iterator(split, context))
 }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/be9c176a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
deleted file mode 100644
index aea08ff..0000000
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.{Partition, TaskContext}
-
-
-/**
- * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
- * TaskContext, the closure can either get access to the interruptible flag or get the index
- * of the partition in the RDD.
- */
-private[spark]
-class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest](
-    prev: RDD[T],
-    f: (TaskContext, Iterator[T]) => Iterator[U],
-    preservesPartitioning: Boolean
-  ) extends RDD[U](prev) {
-
-  override def getPartitions: Array[Partition] = firstParent[T].partitions
-
-  override val partitioner = if (preservesPartitioning) prev.partitioner else None
-
-  override def compute(split: Partition, context: TaskContext) =
-    f(context, firstParent[T].iterator(split, context))
-}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/be9c176a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6e88be6..852c131 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -408,7 +408,6 @@ abstract class RDD[T: ClassManifest](
   def pipe(command: String, env: Map[String, String]): RDD[String] =
     new PipedRDD(this, command, env)
 
-
   /**
    * Return an RDD created by piping elements to a forked external process.
    * The print behavior can be customized by providing two functions.
@@ -442,7 +441,8 @@ abstract class RDD[T: ClassManifest](
    */
   def mapPartitions[U: ClassManifest](
       f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
-    new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+    val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
+    new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
   }
 
   /**
@@ -451,8 +451,8 @@ abstract class RDD[T: ClassManifest](
    */
   def mapPartitionsWithIndex[U: ClassManifest](
       f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
-    val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
-    new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+    val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
+    new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
   }
 
   /**
@@ -462,7 +462,8 @@ abstract class RDD[T: ClassManifest](
   def mapPartitionsWithContext[U: ClassManifest](
       f: (TaskContext, Iterator[T]) => Iterator[U],
       preservesPartitioning: Boolean = false): RDD[U] = {
-    new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+    val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
+    new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
   }
 
   /**
@@ -483,11 +484,10 @@ abstract class RDD[T: ClassManifest](
   def mapWith[A: ClassManifest, U: ClassManifest]
       (constructA: Int => A, preservesPartitioning: Boolean = false)
       (f: (T, A) => U): RDD[U] = {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex((index, iter) => {
+      val a = constructA(index)
       iter.map(t => f(t, a))
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+    }, preservesPartitioning)
   }
 
   /**
@@ -498,11 +498,10 @@ abstract class RDD[T: ClassManifest](
   def flatMapWith[A: ClassManifest, U: ClassManifest]
       (constructA: Int => A, preservesPartitioning: Boolean = false)
       (f: (T, A) => Seq[U]): RDD[U] = {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex((index, iter) => {
+      val a = constructA(index)
       iter.flatMap(t => f(t, a))
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+    }, preservesPartitioning)
   }
 
   /**
@@ -511,11 +510,10 @@ abstract class RDD[T: ClassManifest](
    * partition with the index of that partition.
    */
   def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex { (index, iter) =>
+      val a = constructA(index)
       iter.map(t => {f(t, a); t})
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
+    }.foreach(_ => {})
   }
 
   /**
@@ -524,11 +522,10 @@ abstract class RDD[T: ClassManifest](
    * partition with the index of that partition.
    */
   def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex((index, iter) => {
+      val a = constructA(index)
       iter.filter(t => p(t, a))
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
+    }, preservesPartitioning = true)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/be9c176a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index f26c44d..d2226aa 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -62,8 +62,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
     testCheckpointing(_.sample(false, 0.5, 0))
     testCheckpointing(_.glom())
     testCheckpointing(_.mapPartitions(_.map(_.toString)))
-    testCheckpointing(r => new MapPartitionsWithContextRDD(r,
-      (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
     testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
     testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
     testCheckpointing(_.pipe(Seq("cat")))