You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/06/30 08:27:40 UTC

git commit: SPARK-897: preemptively serialize closures

Repository: spark
Updated Branches:
  refs/heads/master 66135a341 -> a484030da


SPARK-897:  preemptively serialize closures

These commits cause `ClosureCleaner.clean` to attempt to serialize the cleaned closure with the default closure serializer and throw a `SparkException` if doing so fails.  This behavior is enabled by default but can be disabled at individual callsites of `SparkContext.clean`.

Commit 98e01ae8 fixes some no-op assertions in `GraphSuite` that this work exposed; I'm happy to put that in a separate PR if that would be more appropriate.

Author: William Benton <wi...@redhat.com>

Closes #143 from willb/spark-897 and squashes the following commits:

bceab8a [William Benton] Commented DStream corner cases for serializability checking.
64d04d2 [William Benton] FailureSuite now checks both messages and causes.
3b3f74a [William Benton] Stylistic and doc cleanups.
b215dea [William Benton] Fixed spurious failures in ImplicitOrderingSuite
be1ecd6 [William Benton] Don't check serializability of DStream transforms.
abe816b [William Benton] Make proactive serializability checking optional.
5bfff24 [William Benton] Adds proactive closure-serializablilty checking
ed2ccf0 [William Benton] Test cases for SPARK-897.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a484030d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a484030d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a484030d

Branch: refs/heads/master
Commit: a484030dae9d0d7e4b97cc6307e9e928c07490dc
Parents: 66135a3
Author: William Benton <wi...@redhat.com>
Authored: Sun Jun 29 23:27:34 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Sun Jun 29 23:27:34 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/SparkContext.scala   | 12 ++-
 .../org/apache/spark/util/ClosureCleaner.scala  | 16 +++-
 .../scala/org/apache/spark/FailureSuite.scala   | 14 ++-
 .../apache/spark/ImplicitOrderingSuite.scala    | 75 +++++++++++-----
 .../ProactiveClosureSerializationSuite.scala    | 90 ++++++++++++++++++++
 .../spark/streaming/dstream/DStream.scala       | 25 ++++--
 6 files changed, 196 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a484030d/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index f9476ff..8819e73 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1203,9 +1203,17 @@ class SparkContext(config: SparkConf) extends Logging {
   /**
    * Clean a closure to make it ready to serialized and send to tasks
    * (removes unreferenced variables in $outer's, updates REPL variables)
+   * If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively 
+   * check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt> 
+   * if not.
+   * 
+   * @param f the closure to clean
+   * @param checkSerializable whether or not to immediately check <tt>f</tt> for serializability
+   * @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not
+   *   serializable
    */
-  private[spark] def clean[F <: AnyRef](f: F): F = {
-    ClosureCleaner.clean(f)
+  private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = {
+    ClosureCleaner.clean(f, checkSerializable)
     f
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a484030d/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 4916d9b..e3f52f6 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.Set
 import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
 import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
 
-import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.{Logging, SparkEnv, SparkException}
 
 private[spark] object ClosureCleaner extends Logging {
   // Get an ASM class reader for a given class from the JAR that loaded it
@@ -101,7 +101,7 @@ private[spark] object ClosureCleaner extends Logging {
     }
   }
 
-  def clean(func: AnyRef) {
+  def clean(func: AnyRef, checkSerializable: Boolean = true) {
     // TODO: cache outerClasses / innerClasses / accessedFields
     val outerClasses = getOuterClasses(func)
     val innerClasses = getInnerClasses(func)
@@ -153,6 +153,18 @@ private[spark] object ClosureCleaner extends Logging {
       field.setAccessible(true)
       field.set(func, outer)
     }
+    
+    if (checkSerializable) {
+      ensureSerializable(func)
+    }
+  }
+
+  private def ensureSerializable(func: AnyRef) {
+    try {
+      SparkEnv.get.closureSerializer.newInstance().serialize(func)
+    } catch {
+      case ex: Exception => throw new SparkException("Task not serializable", ex)
+    }
   }
 
   private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {

http://git-wip-us.apache.org/repos/asf/spark/blob/a484030d/core/src/test/scala/org/apache/spark/FailureSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index 12dbebc..e755d2e 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -22,6 +22,8 @@ import org.scalatest.FunSuite
 import org.apache.spark.SparkContext._
 import org.apache.spark.util.NonSerializable
 
+import java.io.NotSerializableException
+
 // Common state shared by FailureSuite-launched tasks. We use a global object
 // for this because any local variables used in the task closures will rightfully
 // be copied for each task, so there's no other way for them to share state.
@@ -102,7 +104,8 @@ class FailureSuite extends FunSuite with LocalSparkContext {
       results.collect()
     }
     assert(thrown.getClass === classOf[SparkException])
-    assert(thrown.getMessage.contains("NotSerializableException"))
+    assert(thrown.getMessage.contains("NotSerializableException") || 
+      thrown.getCause.getClass === classOf[NotSerializableException])
 
     FailureSuiteState.clear()
   }
@@ -116,21 +119,24 @@ class FailureSuite extends FunSuite with LocalSparkContext {
       sc.parallelize(1 to 10, 2).map(x => a).count()
     }
     assert(thrown.getClass === classOf[SparkException])
-    assert(thrown.getMessage.contains("NotSerializableException"))
+    assert(thrown.getMessage.contains("NotSerializableException") || 
+      thrown.getCause.getClass === classOf[NotSerializableException])
 
     // Non-serializable closure in an earlier stage
     val thrown1 = intercept[SparkException] {
       sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
     }
     assert(thrown1.getClass === classOf[SparkException])
-    assert(thrown1.getMessage.contains("NotSerializableException"))
+    assert(thrown1.getMessage.contains("NotSerializableException") || 
+      thrown1.getCause.getClass === classOf[NotSerializableException])
 
     // Non-serializable closure in foreach function
     val thrown2 = intercept[SparkException] {
       sc.parallelize(1 to 10, 2).foreach(x => println(a))
     }
     assert(thrown2.getClass === classOf[SparkException])
-    assert(thrown2.getMessage.contains("NotSerializableException"))
+    assert(thrown2.getMessage.contains("NotSerializableException") || 
+      thrown2.getCause.getClass === classOf[NotSerializableException])
 
     FailureSuiteState.clear()
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/a484030d/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala
index 4bd8891..8e4a9e2 100644
--- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala
@@ -19,9 +19,29 @@ package org.apache.spark
 
 import org.scalatest.FunSuite
 
+import org.apache.spark.rdd.RDD
 import org.apache.spark.SparkContext._
 
 class ImplicitOrderingSuite extends FunSuite with LocalSparkContext {
+  // Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should.
+  test("basic inference of Orderings"){
+    sc = new SparkContext("local", "test")
+    val rdd = sc.parallelize(1 to 10)
+
+    // These RDD methods are in the companion object so that the unserializable ScalaTest Engine
+    // won't be reachable from the closure object
+    
+    // Infer orderings after basic maps to particular types
+    val basicMapExpectations = ImplicitOrderingSuite.basicMapExpectations(rdd)
+    basicMapExpectations.map({case (met, explain) => assert(met, explain)})
+    
+    // Infer orderings for other RDD methods
+    val otherRDDMethodExpectations = ImplicitOrderingSuite.otherRDDMethodExpectations(rdd)
+    otherRDDMethodExpectations.map({case (met, explain) => assert(met, explain)})
+  }
+}
+
+private object ImplicitOrderingSuite {
   class NonOrderedClass {}
 
   class ComparableClass extends Comparable[ComparableClass] {
@@ -31,27 +51,36 @@ class ImplicitOrderingSuite extends FunSuite with LocalSparkContext {
   class OrderedClass extends Ordered[OrderedClass] {
     override def compare(o: OrderedClass): Int = ???
   }
-
-  // Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should.
-  test("basic inference of Orderings"){
-    sc = new SparkContext("local", "test")
-    val rdd = sc.parallelize(1 to 10)
-
-    // Infer orderings after basic maps to particular types
-    assert(rdd.map(x => (x, x)).keyOrdering.isDefined)
-    assert(rdd.map(x => (1, x)).keyOrdering.isDefined)
-    assert(rdd.map(x => (x.toString, x)).keyOrdering.isDefined)
-    assert(rdd.map(x => (null, x)).keyOrdering.isDefined)
-    assert(rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty)
-    assert(rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined)
-    assert(rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined)
-
-    // Infer orderings for other RDD methods
-    assert(rdd.groupBy(x => x).keyOrdering.isDefined)
-    assert(rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty)
-    assert(rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined)
-    assert(rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined)
-    assert(rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined)
-    assert(rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined)
+  
+  def basicMapExpectations(rdd: RDD[Int]) = {
+    List((rdd.map(x => (x, x)).keyOrdering.isDefined, 
+            "rdd.map(x => (x, x)).keyOrdering.isDefined"),
+          (rdd.map(x => (1, x)).keyOrdering.isDefined, 
+            "rdd.map(x => (1, x)).keyOrdering.isDefined"),
+          (rdd.map(x => (x.toString, x)).keyOrdering.isDefined, 
+            "rdd.map(x => (x.toString, x)).keyOrdering.isDefined"),
+          (rdd.map(x => (null, x)).keyOrdering.isDefined, 
+            "rdd.map(x => (null, x)).keyOrdering.isDefined"),
+          (rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty, 
+            "rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty"),
+          (rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined, 
+            "rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined"),
+          (rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined, 
+            "rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined"))
   }
-}
+  
+  def otherRDDMethodExpectations(rdd: RDD[Int]) = {
+    List((rdd.groupBy(x => x).keyOrdering.isDefined, 
+           "rdd.groupBy(x => x).keyOrdering.isDefined"),
+         (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, 
+           "rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty"),
+         (rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined, 
+           "rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined"),
+         (rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined,
+           "rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined"),
+         (rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined,
+           "rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined"),
+         (rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined,
+           "rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined"))
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/a484030d/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
new file mode 100644
index 0000000..5d15a68
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer;
+
+import java.io.NotSerializableException
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkException
+import org.apache.spark.SharedSparkContext
+
+/* A trivial (but unserializable) container for trivial functions */
+class UnserializableClass {
+  def op[T](x: T) = x.toString
+  
+  def pred[T](x: T) = x.toString.length % 2 == 0
+}
+
+class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext {
+
+  def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass)
+
+  test("throws expected serialization exceptions on actions") {
+    val (data, uc) = fixture
+      
+    val ex = intercept[SparkException] {
+      data.map(uc.op(_)).count
+    }
+        
+    assert(ex.getMessage.contains("Task not serializable"))
+  }
+
+  // There is probably a cleaner way to eliminate boilerplate here, but we're
+  // iterating over a map from transformation names to functions that perform that
+  // transformation on a given RDD, creating one test case for each
+  
+  for (transformation <- 
+      Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, 
+          "mapWith" -> xmapWith _, "mapPartitions" -> xmapPartitions _, 
+          "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _,
+          "mapPartitionsWithContext" -> xmapPartitionsWithContext _, 
+          "filterWith" -> xfilterWith _)) {
+    val (name, xf) = transformation
+    
+    test(s"$name transformations throw proactive serialization exceptions") {
+      val (data, uc) = fixture
+      
+      val ex = intercept[SparkException] {
+        xf(data, uc)
+      }
+
+      assert(ex.getMessage.contains("Task not serializable"), 
+        s"RDD.$name doesn't proactively throw NotSerializableException")
+    }
+  }
+  
+  private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = 
+    x.map(y=>uc.op(y))
+  private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] = 
+    x.mapWith(x => x.toString)((x,y)=>x + uc.op(y))
+  private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = 
+    x.flatMap(y=>Seq(uc.op(y)))
+  private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = 
+    x.filter(y=>uc.pred(y))
+  private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] = 
+    x.filterWith(x => x.toString)((x,y)=>uc.pred(y))
+  private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = 
+    x.mapPartitions(_.map(y=>uc.op(y)))
+  private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = 
+    x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y)))
+  private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = 
+    x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y)))
+  
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a484030d/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index 4709a62..e05db23 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -532,7 +532,10 @@ abstract class DStream[T: ClassTag] (
    * 'this' DStream will be registered as an output stream and therefore materialized.
    */
   def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) {
-    new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
+    // because the DStream is reachable from the outer object here, and because 
+    // DStreams can't be serialized with closures, we can't proactively check 
+    // it for serializability and so we pass the optional false to SparkContext.clean
+    new ForEachDStream(this, context.sparkContext.clean(foreachFunc, false)).register()
   }
 
   /**
@@ -540,7 +543,10 @@ abstract class DStream[T: ClassTag] (
    * on each RDD of 'this' DStream.
    */
   def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
-    transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r)))
+    // because the DStream is reachable from the outer object here, and because 
+    // DStreams can't be serialized with closures, we can't proactively check 
+    // it for serializability and so we pass the optional false to SparkContext.clean
+    transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false))
   }
 
   /**
@@ -548,7 +554,10 @@ abstract class DStream[T: ClassTag] (
    * on each RDD of 'this' DStream.
    */
   def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
-    val cleanedF = context.sparkContext.clean(transformFunc)
+    // because the DStream is reachable from the outer object here, and because 
+    // DStreams can't be serialized with closures, we can't proactively check 
+    // it for serializability and so we pass the optional false to SparkContext.clean
+    val cleanedF = context.sparkContext.clean(transformFunc, false)
     val realTransformFunc =  (rdds: Seq[RDD[_]], time: Time) => {
       assert(rdds.length == 1)
       cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
@@ -563,7 +572,10 @@ abstract class DStream[T: ClassTag] (
   def transformWith[U: ClassTag, V: ClassTag](
       other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V]
     ): DStream[V] = {
-    val cleanedF = ssc.sparkContext.clean(transformFunc)
+    // because the DStream is reachable from the outer object here, and because 
+    // DStreams can't be serialized with closures, we can't proactively check 
+    // it for serializability and so we pass the optional false to SparkContext.clean
+    val cleanedF = ssc.sparkContext.clean(transformFunc, false)
     transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2))
   }
 
@@ -574,7 +586,10 @@ abstract class DStream[T: ClassTag] (
   def transformWith[U: ClassTag, V: ClassTag](
       other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V]
     ): DStream[V] = {
-    val cleanedF = ssc.sparkContext.clean(transformFunc)
+    // because the DStream is reachable from the outer object here, and because 
+    // DStreams can't be serialized with closures, we can't proactively check 
+    // it for serializability and so we pass the optional false to SparkContext.clean
+    val cleanedF = ssc.sparkContext.clean(transformFunc, false)
     val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
       assert(rdds.length == 2)
       val rdd1 = rdds(0).asInstanceOf[RDD[T]]