You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/27 07:56:11 UTC

git commit: [SPARK-2361][MLLIB] Use broadcast instead of serializing data directly into task closure

Repository: spark
Updated Branches:
  refs/heads/master b547f69bd -> aaf2b735f


[SPARK-2361][MLLIB] Use broadcast instead of serializing data directly into task closure

We saw task serialization problems with large feature dimension, which could be avoid if we don't serialize data directly into task but use broadcast variables. This PR uses broadcast in both training and prediction and adds tests to make sure the task size is small.

Author: Xiangrui Meng <me...@databricks.com>

Closes #1427 from mengxr/broadcast-new and squashes the following commits:

b9a1228 [Xiangrui Meng] style update
b97c184 [Xiangrui Meng] minimal change to LBFGS
9ebadcc [Xiangrui Meng] add task size test to RowMatrix
9427bf0 [Xiangrui Meng] add task size tests to linear methods
e0a5cf2 [Xiangrui Meng] add task size test to GD
28a8411 [Xiangrui Meng] add test for NaiveBayes
380778c [Xiangrui Meng] update KMeans test
bccab92 [Xiangrui Meng] add task size test to LBFGS
02103ba [Xiangrui Meng] remove print
e73d68e [Xiangrui Meng] update tests for k-means
174cb15 [Xiangrui Meng] use local-cluster for test with a small akka.frameSize
1928a5a [Xiangrui Meng] add test for KMeans task size
e00c2da [Xiangrui Meng] use broadcast in GD, KMeans
010d076 [Xiangrui Meng] modify NaiveBayesModel and GLM to use broadcast


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/aaf2b735
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/aaf2b735
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/aaf2b735

Branch: refs/heads/master
Commit: aaf2b735fddbebccd28012006ee4647af3b3624f
Parents: b547f69
Author: Xiangrui Meng <me...@databricks.com>
Authored: Sat Jul 26 22:56:07 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Sat Jul 26 22:56:07 2014 -0700

----------------------------------------------------------------------
 .../spark/mllib/classification/NaiveBayes.scala |  8 ++-
 .../apache/spark/mllib/clustering/KMeans.scala  | 19 +++--
 .../spark/mllib/clustering/KMeansModel.scala    |  6 +-
 .../mllib/optimization/GradientDescent.scala    |  6 +-
 .../apache/spark/mllib/optimization/LBFGS.scala |  7 +-
 .../regression/GeneralizedLinearAlgorithm.scala |  7 +-
 .../JavaLogisticRegressionSuite.java            |  2 -
 .../LogisticRegressionSuite.scala               | 18 ++++-
 .../mllib/classification/NaiveBayesSuite.scala  | 20 +++++-
 .../spark/mllib/classification/SVMSuite.scala   | 25 +++++--
 .../spark/mllib/clustering/KMeansSuite.scala    | 75 +++++++++++++-------
 .../linalg/distributed/RowMatrixSuite.scala     | 29 +++++++-
 .../optimization/GradientDescentSuite.scala     | 34 +++++++--
 .../spark/mllib/optimization/LBFGSSuite.scala   | 30 ++++++--
 .../spark/mllib/regression/LassoSuite.scala     | 21 +++++-
 .../regression/LinearRegressionSuite.scala      | 21 +++++-
 .../mllib/regression/RidgeRegressionSuite.scala | 23 +++++-
 .../mllib/util/LocalClusterSparkContext.scala   | 42 +++++++++++
 .../spark/mllib/util/LocalSparkContext.scala    |  7 +-
 19 files changed, 330 insertions(+), 70 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index b6e0c4a..6c7be0a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -54,7 +54,13 @@ class NaiveBayesModel private[mllib] (
     }
   }
 
-  override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict)
+  override def predict(testData: RDD[Vector]): RDD[Double] = {
+    val bcModel = testData.context.broadcast(this)
+    testData.mapPartitions { iter =>
+      val model = bcModel.value
+      iter.map(model.predict)
+    }
+  }
 
   override def predict(testData: Vector): Double = {
     labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index de22fbb..db425d8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -165,18 +165,21 @@ class KMeans private (
       val activeCenters = activeRuns.map(r => centers(r)).toArray
       val costAccums = activeRuns.map(_ => sc.accumulator(0.0))
 
+      val bcActiveCenters = sc.broadcast(activeCenters)
+
       // Find the sum and count of points mapping to each center
       val totalContribs = data.mapPartitions { points =>
-        val runs = activeCenters.length
-        val k = activeCenters(0).length
-        val dims = activeCenters(0)(0).vector.length
+        val thisActiveCenters = bcActiveCenters.value
+        val runs = thisActiveCenters.length
+        val k = thisActiveCenters(0).length
+        val dims = thisActiveCenters(0)(0).vector.length
 
         val sums = Array.fill(runs, k)(BDV.zeros[Double](dims).asInstanceOf[BV[Double]])
         val counts = Array.fill(runs, k)(0L)
 
         points.foreach { point =>
           (0 until runs).foreach { i =>
-            val (bestCenter, cost) = KMeans.findClosest(activeCenters(i), point)
+            val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
             costAccums(i) += cost
             sums(i)(bestCenter) += point.vector
             counts(i)(bestCenter) += 1
@@ -264,16 +267,17 @@ class KMeans private (
     // to their squared distance from that run's current centers
     var step = 0
     while (step < initializationSteps) {
+      val bcCenters = data.context.broadcast(centers)
       val sumCosts = data.flatMap { point =>
         (0 until runs).map { r =>
-          (r, KMeans.pointCost(centers(r), point))
+          (r, KMeans.pointCost(bcCenters.value(r), point))
         }
       }.reduceByKey(_ + _).collectAsMap()
       val chosen = data.mapPartitionsWithIndex { (index, points) =>
         val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
         points.flatMap { p =>
           (0 until runs).filter { r =>
-            rand.nextDouble() < 2.0 * KMeans.pointCost(centers(r), p) * k / sumCosts(r)
+            rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
           }.map((_, p))
         }
       }.collect()
@@ -286,9 +290,10 @@ class KMeans private (
     // Finally, we might have a set of more than k candidate centers for each run; weigh each
     // candidate by the number of points in the dataset mapping to it and run a local k-means++
     // on the weighted centers to pick just k of them
+    val bcCenters = data.context.broadcast(centers)
     val weightMap = data.flatMap { p =>
       (0 until runs).map { r =>
-        ((r, KMeans.findClosest(centers(r), p)._1), 1.0)
+        ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
       }
     }.reduceByKey(_ + _).collectAsMap()
     val finalCenters = (0 until runs).map { r =>

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index fba21ae..5823cb6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -38,7 +38,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
   /** Maps given points to their cluster indices. */
   def predict(points: RDD[Vector]): RDD[Int] = {
     val centersWithNorm = clusterCentersWithNorm
-    points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1)
+    val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
+    points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))._1)
   }
 
   /** Maps given points to their cluster indices. */
@@ -51,7 +52,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
    */
   def computeCost(data: RDD[Vector]): Double = {
     val centersWithNorm = clusterCentersWithNorm
-    data.map(p => KMeans.pointCost(centersWithNorm, new BreezeVectorWithNorm(p))).sum()
+    val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
+    data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))).sum()
   }
 
   private def clusterCentersWithNorm: Iterable[BreezeVectorWithNorm] =

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 7030eea..9fd760b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -163,6 +163,7 @@ object GradientDescent extends Logging {
 
     // Initialize weights as a column vector
     var weights = Vectors.dense(initialWeights.toArray)
+    val n = weights.size
 
     /**
      * For the first iteration, the regVal will be initialized as sum of weight squares
@@ -172,12 +173,13 @@ object GradientDescent extends Logging {
       weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
 
     for (i <- 1 to numIterations) {
+      val bcWeights = data.context.broadcast(weights)
       // Sample a subset (fraction miniBatchFraction) of the total data
       // compute and sum up the subgradients on this subset (this is one map-reduce)
       val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
-        .aggregate((BDV.zeros[Double](weights.size), 0.0))(
+        .aggregate((BDV.zeros[Double](n), 0.0))(
           seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
-            val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad))
+            val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
             (grad, loss + l)
           },
           combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index 7bbed9c..179cd4a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -195,13 +195,14 @@ object LBFGS extends Logging {
 
     override def calculate(weights: BDV[Double]) = {
       // Have a local copy to avoid the serialization of CostFun object which is not serializable.
-      val localData = data
       val localGradient = gradient
+      val n = weights.length
+      val bcWeights = data.context.broadcast(weights)
 
-      val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
+      val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))(
           seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
             val l = localGradient.compute(
-              features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
+              features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad))
             (grad, loss + l)
           },
           combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index fe41863..5485425 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -56,9 +56,12 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
     // A small optimization to avoid serializing the entire model. Only the weightsMatrix
     // and intercept is needed.
     val localWeights = weights
+    val bcWeights = testData.context.broadcast(localWeights)
     val localIntercept = intercept
-
-    testData.map(v => predictPoint(v, localWeights, localIntercept))
+    testData.mapPartitions { iter =>
+      val w = bcWeights.value
+      iter.map(v => predictPoint(v, w, localIntercept))
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index faa675b..862221d 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -92,8 +92,6 @@ public class JavaLogisticRegressionSuite implements Serializable {
         testRDD.rdd(), 100, 1.0, 1.0);
 
     int numAccurate = validatePrediction(validationData, model);
-      System.out.println(numAccurate);
     Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 44b757b..3f6ff85 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.Matchers
 
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
 
 object LogisticRegressionSuite {
 
@@ -126,3 +126,19 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
     validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
   }
 }
+
+class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+  test("task size should be small in both training and prediction") {
+    val m = 4
+    val n = 200000
+    val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+      val random = new Random(idx)
+      iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+    }.cache()
+    // If we serialize data directly in the task closure, the size of the serialized task would be
+    // greater than 1MB and hence Spark would throw an error.
+    val model = LogisticRegressionWithSGD.train(points, 2)
+    val predictions = model.predict(points.map(_.features))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 516895d..06cdd04 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
 
 object NaiveBayesSuite {
 
@@ -96,3 +96,21 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext {
     validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
   }
 }
+
+class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+  test("task size should be small in both training and prediction") {
+    val m = 10
+    val n = 200000
+    val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+      val random = new Random(idx)
+      iter.map { i =>
+        LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble())))
+      }
+    }
+    // If we serialize data directly in the task closure, the size of the serialized task would be
+    // greater than 1MB and hence Spark would throw an error.
+    val model = NaiveBayes.train(examples)
+    val predictions = model.predict(examples.map(_.features))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 886c71d..65e5df5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -17,17 +17,16 @@
 
 package org.apache.spark.mllib.classification
 
-import scala.util.Random
 import scala.collection.JavaConversions._
-
-import org.scalatest.FunSuite
+import scala.util.Random
 
 import org.jblas.DoubleMatrix
+import org.scalatest.FunSuite
 
 import org.apache.spark.SparkException
-import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
 import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
 
 object SVMSuite {
 
@@ -193,3 +192,19 @@ class SVMSuite extends FunSuite with LocalSparkContext {
     new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
   }
 }
+
+class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+  test("task size should be small in both training and prediction") {
+    val m = 4
+    val n = 200000
+    val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+      val random = new Random(idx)
+      iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+    }.cache()
+    // If we serialize data directly in the task closure, the size of the serialized task would be
+    // greater than 1MB and hence Spark would throw an error.
+    val model = SVMWithSGD.train(points, 2)
+    val predictions = model.predict(points.map(_.features))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 76a3bdf..34bc453 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -17,14 +17,16 @@
 
 package org.apache.spark.mllib.clustering
 
+import scala.util.Random
+
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.util.LocalSparkContext
 import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
 
 class KMeansSuite extends FunSuite with LocalSparkContext {
 
-  import KMeans.{RANDOM, K_MEANS_PARALLEL}
+  import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM}
 
   test("single cluster") {
     val data = sc.parallelize(Array(
@@ -38,26 +40,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
     // No matter how many runs or iterations we use, we should get one cluster,
     // centered at the mean of the points
 
-    var model = KMeans.train(data, k=1, maxIterations=1)
+    var model = KMeans.train(data, k = 1, maxIterations = 1)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=2)
+    model = KMeans.train(data, k = 1, maxIterations = 2)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=5)
+    model = KMeans.train(data, k = 1, maxIterations = 5)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+    model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+    model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+    model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
     assert(model.clusterCenters.head === center)
 
     model = KMeans.train(
-      data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+      data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL)
     assert(model.clusterCenters.head === center)
   }
 
@@ -100,26 +102,27 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
 
     val center = Vectors.dense(1.0, 3.0, 4.0)
 
-    var model = KMeans.train(data, k=1, maxIterations=1)
+    var model = KMeans.train(data, k = 1, maxIterations = 1)
     assert(model.clusterCenters.size === 1)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=2)
+    model = KMeans.train(data, k = 1, maxIterations = 2)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=5)
+    model = KMeans.train(data, k = 1, maxIterations = 5)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+    model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+    model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+    model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+    model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
+      initializationMode = K_MEANS_PARALLEL)
     assert(model.clusterCenters.head === center)
   }
 
@@ -145,25 +148,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
 
     val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0)))
 
-    var model = KMeans.train(data, k=1, maxIterations=1)
+    var model = KMeans.train(data, k = 1, maxIterations = 1)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=2)
+    model = KMeans.train(data, k = 1, maxIterations = 2)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=5)
+    model = KMeans.train(data, k = 1, maxIterations = 5)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+    model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+    model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+    model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
     assert(model.clusterCenters.head === center)
 
-    model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+    model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
+      initializationMode = K_MEANS_PARALLEL)
     assert(model.clusterCenters.head === center)
 
     data.unpersist()
@@ -183,15 +187,15 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
     // it will make at least five passes, and it will give non-zero probability to each
     // unselected point as long as it hasn't yet selected all of them
 
-    var model = KMeans.train(rdd, k=5, maxIterations=1)
+    var model = KMeans.train(rdd, k = 5, maxIterations = 1)
     assert(Set(model.clusterCenters: _*) === Set(points: _*))
 
     // Iterations of Lloyd's should not change the answer either
-    model = KMeans.train(rdd, k=5, maxIterations=10)
+    model = KMeans.train(rdd, k = 5, maxIterations = 10)
     assert(Set(model.clusterCenters: _*) === Set(points: _*))
 
     // Neither should more runs
-    model = KMeans.train(rdd, k=5, maxIterations=10, runs=5)
+    model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5)
     assert(Set(model.clusterCenters: _*) === Set(points: _*))
   }
 
@@ -220,3 +224,22 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
     }
   }
 }
+
+class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+  test("task size should be small in both training and prediction") {
+    val m = 4
+    val n = 200000
+    val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+      val random = new Random(idx)
+      iter.map(i => Vectors.dense(Array.fill(n)(random.nextDouble)))
+    }.cache()
+    for (initMode <- Seq(KMeans.RANDOM, KMeans.K_MEANS_PARALLEL)) {
+      // If we serialize data directly in the task closure, the size of the serialized task would be
+      // greater than 1MB and hence Spark would throw an error.
+      val model = KMeans.train(points, 2, 2, 1, initMode)
+      val predictions = model.predict(points).collect()
+      val cost = model.computeCost(points)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index a961f89..325b817 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -17,12 +17,13 @@
 
 package org.apache.spark.mllib.linalg.distributed
 
-import org.scalatest.FunSuite
+import scala.util.Random
 
 import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd}
+import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.util.LocalSparkContext
 import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
 
 class RowMatrixSuite extends FunSuite with LocalSparkContext {
 
@@ -193,3 +194,27 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
     }
   }
 }
+
+class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+  var mat: RowMatrix = _
+
+  override def beforeAll() {
+    super.beforeAll()
+    val m = 4
+    val n = 200000
+    val rows = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+      val random = new Random(idx)
+      iter.map(i => Vectors.dense(Array.fill(n)(random.nextDouble())))
+    }
+    mat = new RowMatrix(rows)
+  }
+
+  test("task size should be small in svd") {
+    val svd = mat.computeSVD(1, computeU = true)
+  }
+
+  test("task size should be small in summarize") {
+    val summary = mat.computeColumnSummaryStatistics()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index 951b4f7..dfb2eb7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -17,15 +17,14 @@
 
 package org.apache.spark.mllib.optimization
 
-import scala.util.Random
 import scala.collection.JavaConversions._
+import scala.util.Random
 
-import org.scalatest.FunSuite
-import org.scalatest.Matchers
+import org.scalatest.{FunSuite, Matchers}
 
-import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
 import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
 
 object GradientDescentSuite {
 
@@ -46,7 +45,7 @@ object GradientDescentSuite {
     val rnd = new Random(seed)
     val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
 
-    val unifRand = new scala.util.Random(45)
+    val unifRand = new Random(45)
     val rLogis = (0 until nPoints).map { i =>
       val u = unifRand.nextDouble()
       math.log(u) - math.log(1.0-u)
@@ -144,3 +143,26 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers
         "should be initialWeightsWithIntercept.")
   }
 }
+
+class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+  test("task size should be small") {
+    val m = 4
+    val n = 200000
+    val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+      val random = new Random(idx)
+      iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+    }.cache()
+    // If we serialize data directly in the task closure, the size of the serialized task would be
+    // greater than 1MB and hence Spark would throw an error.
+    val (weights, loss) = GradientDescent.runMiniBatchSGD(
+      points,
+      new LogisticGradient,
+      new SquaredL2Updater,
+      0.1,
+      2,
+      1.0,
+      1.0,
+      Vectors.dense(new Array[Double](n)))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index fe7a903..ff41474 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -17,12 +17,13 @@
 
 package org.apache.spark.mllib.optimization
 
-import org.scalatest.FunSuite
-import org.scalatest.Matchers
+import scala.util.Random
+
+import org.scalatest.{FunSuite, Matchers}
 
-import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
 
 class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
 
@@ -230,3 +231,24 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
       "The weight differences between LBFGS and GD should be within 2%.")
   }
 }
+
+class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+  test("task size should be small") {
+    val m = 10
+    val n = 200000
+    val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+      val random = new Random(idx)
+      iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble))))
+    }.cache()
+    val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater)
+      .setNumCorrections(1)
+      .setConvergenceTol(1e-12)
+      .setMaxNumIterations(1)
+      .setRegParam(1.0)
+    val random = new Random(0)
+    // If we serialize data directly in the task closure, the size of the serialized task would be
+    // greater than 1MB and hence Spark would throw an error.
+    val weights = lbfgs.optimize(examples, Vectors.dense(Array.fill(n)(random.nextDouble)))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index bfa4295..7aa9642 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.mllib.regression
 
+import scala.util.Random
+
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+  LocalSparkContext}
 
 class LassoSuite extends FunSuite with LocalSparkContext {
 
@@ -113,3 +116,19 @@ class LassoSuite extends FunSuite with LocalSparkContext {
     validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
   }
 }
+
+class LassoClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+  test("task size should be small in both training and prediction") {
+    val m = 4
+    val n = 200000
+    val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+      val random = new Random(idx)
+      iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+    }.cache()
+    // If we serialize data directly in the task closure, the size of the serialized task would be
+    // greater than 1MB and hence Spark would throw an error.
+    val model = LassoWithSGD.train(points, 2)
+    val predictions = model.predict(points.map(_.features))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 7aaad7d..4f89112 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.mllib.regression
 
+import scala.util.Random
+
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+  LocalSparkContext}
 
 class LinearRegressionSuite extends FunSuite with LocalSparkContext {
 
@@ -122,3 +125,19 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
       sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
   }
 }
+
+class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+  test("task size should be small in both training and prediction") {
+    val m = 4
+    val n = 200000
+    val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+      val random = new Random(idx)
+      iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+    }.cache()
+    // If we serialize data directly in the task closure, the size of the serialized task would be
+    // greater than 1MB and hence Spark would throw an error.
+    val model = LinearRegressionWithSGD.train(points, 2)
+    val predictions = model.predict(points.map(_.features))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 67768e1..727bbd0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -17,11 +17,14 @@
 
 package org.apache.spark.mllib.regression
 
-import org.scalatest.FunSuite
+import scala.util.Random
 
 import org.jblas.DoubleMatrix
+import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+  LocalSparkContext}
 
 class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
 
@@ -73,3 +76,19 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
       "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
   }
 }
+
+class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+  test("task size should be small in both training and prediction") {
+    val m = 4
+    val n = 200000
+    val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+      val random = new Random(idx)
+      iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+    }.cache()
+    // If we serialize data directly in the task closure, the size of the serialized task would be
+    // greater than 1MB and hence Spark would throw an error.
+    val model = RidgeRegressionWithSGD.train(points, 2)
+    val predictions = model.predict(points.map(_.features))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
new file mode 100644
index 0000000..5e9101c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.scalatest.{Suite, BeforeAndAfterAll}
+
+import org.apache.spark.{SparkConf, SparkContext}
+
+trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite =>
+  @transient var sc: SparkContext = _
+
+  override def beforeAll() {
+    val conf = new SparkConf()
+      .setMaster("local-cluster[2, 1, 512]")
+      .setAppName("test-cluster")
+      .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data
+    sc = new SparkContext(conf)
+    super.beforeAll()
+  }
+
+  override def afterAll() {
+    if (sc != null) {
+      sc.stop()
+    }
+    super.afterAll()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf2b735/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
index 0d4868f..7857d9e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
@@ -20,13 +20,16 @@ package org.apache.spark.mllib.util
 import org.scalatest.Suite
 import org.scalatest.BeforeAndAfterAll
 
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkConf, SparkContext}
 
 trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
   @transient var sc: SparkContext = _
 
   override def beforeAll() {
-    sc = new SparkContext("local", "test")
+    val conf = new SparkConf()
+      .setMaster("local")
+      .setAppName("test")
+    sc = new SparkContext(conf)
     super.beforeAll()
   }