You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2014/07/31 07:41:15 UTC

git commit: [SPARK-2737] Add retag() method for changing RDDs' ClassTags.

Repository: spark
Updated Branches:
  refs/heads/master a7c305b86 -> 4fb259353


[SPARK-2737] Add retag() method for changing RDDs' ClassTags.

The Java API's use of fake ClassTags doesn't seem to cause any problems for Java users, but it can lead to issues when passing JavaRDDs' underlying RDDs to Scala code (e.g. in the MLlib Java API wrapper code). If we call collect() on a Scala RDD with an incorrect ClassTag, this causes ClassCastExceptions when we try to allocate an array of the wrong type (for example, see SPARK-2197).

There are a few possible fixes here. An API-breaking fix would be to completely remove the fake ClassTags and require Java API users to pass java.lang.Class instances to all parallelize() calls and add returnClass fields to all Function implementations. This would be extremely verbose.

Instead, this patch adds internal APIs to "repair" a Scala RDD with an incorrect ClassTag by wrapping it and overriding its ClassTag. This should be okay for cases where the Scala code that calls collect() knows what type of array should be allocated, which is the case in the MLlib wrappers.

Author: Josh Rosen <jo...@apache.org>

Closes #1639 from JoshRosen/SPARK-2737 and squashes the following commits:

572b4c8 [Josh Rosen] Replace newRDD[T] with mapPartitions().
469d941 [Josh Rosen] Preserve partitioner in retag().
af78816 [Josh Rosen] Allow retag() to get classTag implicitly.
d1d54e6 [Josh Rosen] [SPARK-2737] Add retag() method for changing RDDs' ClassTags.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4fb25935
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4fb25935
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4fb25935

Branch: refs/heads/master
Commit: 4fb259353f616822c32537e3f031944a6d2a09a8
Parents: a7c305b
Author: Josh Rosen <jo...@apache.org>
Authored: Wed Jul 30 22:40:57 2014 -0700
Committer: Josh Rosen <jo...@apache.org>
Committed: Wed Jul 30 22:40:57 2014 -0700

----------------------------------------------------------------------
 core/src/main/scala/org/apache/spark/rdd/RDD.scala | 17 +++++++++++++++++
 .../test/java/org/apache/spark/JavaAPISuite.java   | 17 +++++++++++++++++
 .../test/scala/org/apache/spark/rdd/RDDSuite.scala |  8 ++++++++
 3 files changed, 42 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4fb25935/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 74ac970..e1c49e3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1236,6 +1236,23 @@ abstract class RDD[T: ClassTag](
   /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */
   def context = sc
 
+  /**
+   * Private API for changing an RDD's ClassTag.
+   * Used for internal Java <-> Scala API compatibility.
+   */
+  private[spark] def retag(cls: Class[T]): RDD[T] = {
+    val classTag: ClassTag[T] = ClassTag.apply(cls)
+    this.retag(classTag)
+  }
+
+  /**
+   * Private API for changing an RDD's ClassTag.
+   * Used for internal Java <-> Scala API compatibility.
+   */
+  private[spark] def retag(implicit classTag: ClassTag[T]): RDD[T] = {
+    this.mapPartitions(identity, preservesPartitioning = true)(classTag)
+  }
+
   // Avoid handling doCheckpoint multiple times to prevent excessive recursion
   @transient private var doCheckpointCalled = false
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4fb25935/core/src/test/java/org/apache/spark/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e8bd65f..fab64a5 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1245,4 +1245,21 @@ public class JavaAPISuite implements Serializable {
     Assert.assertTrue(worExactCounts.get(0) == 2);
     Assert.assertTrue(worExactCounts.get(1) == 4);
   }
+
+  private static class SomeCustomClass implements Serializable {
+    public SomeCustomClass() {
+      // Intentionally left blank
+    }
+  }
+
+  @Test
+  public void collectUnderlyingScalaRDD() {
+    List<SomeCustomClass> data = new ArrayList<SomeCustomClass>();
+    for (int i = 0; i < 100; i++) {
+      data.add(new SomeCustomClass());
+    }
+    JavaRDD<SomeCustomClass> rdd = sc.parallelize(data);
+    SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect();
+    Assert.assertEquals(data.size(), collected.length);
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4fb25935/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index ae6e525..b31e3a0 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.rdd
 
 import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.JavaConverters._
 import scala.reflect.ClassTag
 
 import org.scalatest.FunSuite
@@ -26,6 +27,7 @@ import org.apache.spark._
 import org.apache.spark.SparkContext._
 import org.apache.spark.util.Utils
 
+import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.rdd.RDDSuiteUtils._
 
 class RDDSuite extends FunSuite with SharedSparkContext {
@@ -718,6 +720,12 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assert(ids.length === n)
   }
 
+  test("retag with implicit ClassTag") {
+    val jsc: JavaSparkContext = new JavaSparkContext(sc)
+    val jrdd: JavaRDD[String] = jsc.parallelize(Seq("A", "B", "C").asJava)
+    jrdd.rdd.retag.collect()
+  }
+
   test("getNarrowAncestors") {
     val rdd1 = sc.parallelize(1 to 100, 4)
     val rdd2 = rdd1.filter(_ % 2 == 0).map(_ + 1)