You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/22 00:02:38 UTC

spark git commit: [SPARK-4531] [MLlib] cache serialized java object

Repository: spark
Updated Branches:
  refs/heads/master a81918c5a -> ce95bd8e1


[SPARK-4531] [MLlib] cache serialized java object

The Pyrolite is pretty slow (comparing to the adhoc serializer in 1.1), it cause much performance regression in 1.2, because we cache the serialized Python object in JVM, deserialize them into Java object in each step.

This PR change to cache the deserialized JavaRDD instead of PythonRDD to avoid the deserialization of Pyrolite. It should have similar memory usage as before, but much faster.

Author: Davies Liu <da...@databricks.com>

Closes #3397 from davies/cache and squashes the following commits:

7f6e6ce [Davies Liu] Update -> Updater
4b52edd [Davies Liu] using named argument
63b984e [Davies Liu] fix
7da0332 [Davies Liu] add unpersist()
dff33e1 [Davies Liu] address comments
c2bdfc2 [Davies Liu] refactor
d572f00 [Davies Liu] Merge branch 'master' into cache
f1063e1 [Davies Liu] cache serialized java object


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ce95bd8e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ce95bd8e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ce95bd8e

Branch: refs/heads/master
Commit: ce95bd8e130b2c7688b94be40683bdd90d86012d
Parents: a81918c
Author: Davies Liu <da...@databricks.com>
Authored: Fri Nov 21 15:02:31 2014 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri Nov 21 15:02:31 2014 -0800

----------------------------------------------------------------------
 .../spark/mllib/api/python/PythonMLLibAPI.scala | 110 +++++++++----------
 .../apache/spark/mllib/clustering/KMeans.scala  |  13 +--
 .../regression/GeneralizedLinearAlgorithm.scala |  13 +--
 python/pyspark/mllib/clustering.py              |   8 +-
 python/pyspark/mllib/common.py                  |   4 +-
 python/pyspark/mllib/recommendation.py          |   4 +-
 python/pyspark/mllib/regression.py              |   5 +-
 7 files changed, 64 insertions(+), 93 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ce95bd8e/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index b6f7618..f04df1c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -74,10 +74,28 @@ class PythonMLLibAPI extends Serializable {
       learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
       data: JavaRDD[LabeledPoint],
       initialWeights: Vector): JList[Object] = {
-    // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
-    learner.disableUncachedWarning()
-    val model = learner.run(data.rdd, initialWeights)
-    List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+    try {
+      val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights)
+      List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+    } finally {
+      data.rdd.unpersist(blocking = false)
+    }
+  }
+
+  /**
+   * Return the Updater from string
+   */
+  def getUpdaterFromString(regType: String): Updater = {
+    if (regType == "l2") {
+      new SquaredL2Updater
+    } else if (regType == "l1") {
+      new L1Updater
+    } else if (regType == null || regType == "none") {
+      new SimpleUpdater
+    } else {
+      throw new IllegalArgumentException("Invalid value for 'regType' parameter."
+        + " Can only be initialized using the following string values: ['l1', 'l2', None].")
+    }
   }
 
   /**
@@ -99,16 +117,7 @@ class PythonMLLibAPI extends Serializable {
       .setRegParam(regParam)
       .setStepSize(stepSize)
       .setMiniBatchFraction(miniBatchFraction)
-    if (regType == "l2") {
-      lrAlg.optimizer.setUpdater(new SquaredL2Updater)
-    } else if (regType == "l1") {
-      lrAlg.optimizer.setUpdater(new L1Updater)
-    } else if (regType == null) {
-      lrAlg.optimizer.setUpdater(new SimpleUpdater)
-    } else {
-        throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
-          + " Can only be initialized using the following string values: ['l1', 'l2', None].")
-    }
+    lrAlg.optimizer.setUpdater(getUpdaterFromString(regType))
     trainRegressionModel(
       lrAlg,
       data,
@@ -178,16 +187,7 @@ class PythonMLLibAPI extends Serializable {
       .setRegParam(regParam)
       .setStepSize(stepSize)
       .setMiniBatchFraction(miniBatchFraction)
-    if (regType == "l2") {
-      SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
-    } else if (regType == "l1") {
-      SVMAlg.optimizer.setUpdater(new L1Updater)
-    } else if (regType == null) {
-      SVMAlg.optimizer.setUpdater(new SimpleUpdater)
-    } else {
-      throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
-        + " Can only be initialized using the following string values: ['l1', 'l2', None].")
-    }
+    SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType))
     trainRegressionModel(
       SVMAlg,
       data,
@@ -213,16 +213,7 @@ class PythonMLLibAPI extends Serializable {
       .setRegParam(regParam)
       .setStepSize(stepSize)
       .setMiniBatchFraction(miniBatchFraction)
-    if (regType == "l2") {
-      LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
-    } else if (regType == "l1") {
-      LogRegAlg.optimizer.setUpdater(new L1Updater)
-    } else if (regType == null) {
-      LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
-    } else {
-      throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
-        + " Can only be initialized using the following string values: ['l1', 'l2', None].")
-    }
+    LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType))
     trainRegressionModel(
       LogRegAlg,
       data,
@@ -248,16 +239,7 @@ class PythonMLLibAPI extends Serializable {
       .setRegParam(regParam)
       .setNumCorrections(corrections)
       .setConvergenceTol(tolerance)
-    if (regType == "l2") {
-      LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
-    } else if (regType == "l1") {
-      LogRegAlg.optimizer.setUpdater(new L1Updater)
-    } else if (regType == null) {
-      LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
-    } else {
-      throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
-        + " Can only be initialized using the following string values: ['l1', 'l2', None].")
-    }
+    LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType))
     trainRegressionModel(
       LogRegAlg,
       data,
@@ -289,9 +271,11 @@ class PythonMLLibAPI extends Serializable {
       .setMaxIterations(maxIterations)
       .setRuns(runs)
       .setInitializationMode(initializationMode)
-      // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
-      .disableUncachedWarning()
-    kMeansAlg.run(data.rdd)
+    try {
+      kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
+    } finally {
+      data.rdd.unpersist(blocking = false)
+    }
   }
 
   /**
@@ -425,16 +409,18 @@ class PythonMLLibAPI extends Serializable {
       numPartitions: Int,
       numIterations: Int,
       seed: Long): Word2VecModelWrapper = {
-    val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)
     val word2vec = new Word2Vec()
       .setVectorSize(vectorSize)
       .setLearningRate(learningRate)
       .setNumPartitions(numPartitions)
       .setNumIterations(numIterations)
       .setSeed(seed)
-    val model = word2vec.fit(data)
-    data.unpersist()
-    new Word2VecModelWrapper(model)
+    try {
+      val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
+      new Word2VecModelWrapper(model)
+    } finally {
+      dataJRDD.rdd.unpersist(blocking = false)
+    }
   }
 
   private[python] class Word2VecModelWrapper(model: Word2VecModel) {
@@ -495,8 +481,11 @@ class PythonMLLibAPI extends Serializable {
       categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
       minInstancesPerNode = minInstancesPerNode,
       minInfoGain = minInfoGain)
-
-    DecisionTree.train(data.rdd, strategy)
+    try {
+      DecisionTree.train(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), strategy)
+    } finally {
+      data.rdd.unpersist(blocking = false)
+    }
   }
 
   /**
@@ -526,10 +515,15 @@ class PythonMLLibAPI extends Serializable {
       numClassesForClassification = numClasses,
       maxBins = maxBins,
       categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap)
-    if (algo == Algo.Classification) {
-      RandomForest.trainClassifier(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
-    } else {
-      RandomForest.trainRegressor(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
+    val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
+    try {
+      if (algo == Algo.Classification) {
+        RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed)
+      } else {
+        RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed)
+      }
+    } finally {
+      cached.unpersist(blocking = false)
     }
   }
 
@@ -711,7 +705,7 @@ private[spark] object SerDe extends Serializable {
     def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
       if (obj == this) {
         out.write(Opcodes.GLOBAL)
-        out.write((module + "\n" + name + "\n").getBytes())
+        out.write((module + "\n" + name + "\n").getBytes)
       } else {
         pickler.save(this)  // it will be memorized by Pickler
         saveState(obj, out, pickler)

http://git-wip-us.apache.org/repos/asf/spark/blob/ce95bd8e/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 7443f23..34ea0de 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -113,22 +113,13 @@ class KMeans private (
     this
   }
 
-  /** Whether a warning should be logged if the input RDD is uncached. */
-  private var warnOnUncachedInput = true
-
-  /** Disable warnings about uncached input. */
-  private[spark] def disableUncachedWarning(): this.type = {
-    warnOnUncachedInput = false
-    this
-  }  
-
   /**
    * Train a K-means model on the given set of points; `data` should be cached for high
    * performance, because this is an iterative algorithm.
    */
   def run(data: RDD[Vector]): KMeansModel = {
 
-    if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+    if (data.getStorageLevel == StorageLevel.NONE) {
       logWarning("The input data is not directly cached, which may hurt performance if its"
         + " parent RDDs are also uncached.")
     }
@@ -143,7 +134,7 @@ class KMeans private (
     norms.unpersist()
 
     // Warn at the end of the run as well, for increased visibility.
-    if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+    if (data.getStorageLevel == StorageLevel.NONE) {
       logWarning("The input data was not directly cached, which may hurt performance if its"
         + " parent RDDs are also uncached.")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/ce95bd8e/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 00dfc86..0287f04 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -136,15 +136,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
     this
   }
 
-  /** Whether a warning should be logged if the input RDD is uncached. */
-  private var warnOnUncachedInput = true
-
-  /** Disable warnings about uncached input. */
-  private[spark] def disableUncachedWarning(): this.type = {
-    warnOnUncachedInput = false
-    this
-  }
-
   /**
    * Run the algorithm with the configured parameters on an input
    * RDD of LabeledPoint entries.
@@ -161,7 +152,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
    */
   def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
 
-    if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+    if (input.getStorageLevel == StorageLevel.NONE) {
       logWarning("The input data is not directly cached, which may hurt performance if its"
         + " parent RDDs are also uncached.")
     }
@@ -241,7 +232,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
     }
 
     // Warn at the end of the run as well, for increased visibility.
-    if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+    if (input.getStorageLevel == StorageLevel.NONE) {
       logWarning("The input data was not directly cached, which may hurt performance if its"
         + " parent RDDs are also uncached.")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/ce95bd8e/python/pyspark/mllib/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index fe4c4cc..e2492ee 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -16,7 +16,7 @@
 #
 
 from pyspark import SparkContext
-from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, callJavaFunc
 from pyspark.mllib.linalg import SparseVector, _convert_to_vector
 
 __all__ = ['KMeansModel', 'KMeans']
@@ -80,10 +80,8 @@ class KMeans(object):
     @classmethod
     def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
         """Train a k-means clustering model."""
-        # cache serialized data to avoid objects over head in JVM
-        jcached = _to_java_object_rdd(rdd.map(_convert_to_vector), cache=True)
-        model = callMLlibFunc("trainKMeansModel", jcached, k, maxIterations, runs,
-                              initializationMode)
+        model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
+                              runs, initializationMode)
         centers = callJavaFunc(rdd.context, model.clusterCenters)
         return KMeansModel([c.toArray() for c in centers])
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ce95bd8e/python/pyspark/mllib/common.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index c6149fe..33c49e2 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -54,15 +54,13 @@ _picklable_classes = [
 
 
 # this will call the MLlib version of pythonToJava()
-def _to_java_object_rdd(rdd, cache=False):
+def _to_java_object_rdd(rdd):
     """ Return an JavaRDD of Object by unpickling
 
     It will convert each Python object into Java object by Pyrolite, whenever the
     RDD is serialized in batch or not.
     """
     rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
-    if cache:
-        rdd.cache()
     return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ce95bd8e/python/pyspark/mllib/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 2bcbf2a..97ec74e 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -19,7 +19,7 @@ from collections import namedtuple
 
 from pyspark import SparkContext
 from pyspark.rdd import RDD
-from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
 
 __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
 
@@ -110,7 +110,7 @@ class ALS(object):
                 ratings = ratings.map(lambda x: Rating(*x))
             else:
                 raise ValueError("rating should be RDD of Rating or tuple/list")
-        return _to_java_object_rdd(ratings, True)
+        return ratings
 
     @classmethod
     def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False,

http://git-wip-us.apache.org/repos/asf/spark/blob/ce95bd8e/python/pyspark/mllib/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index f4f5e61..2100601 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,7 +18,7 @@
 import numpy as np
 from numpy import array
 
-from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc
 from pyspark.mllib.linalg import SparseVector, _convert_to_vector
 
 __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel',
@@ -129,8 +129,7 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
     if not isinstance(first, LabeledPoint):
         raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
     initial_weights = initial_weights or [0.0] * len(data.first().features)
-    weights, intercept = train_func(_to_java_object_rdd(data, cache=True),
-                                    _convert_to_vector(initial_weights))
+    weights, intercept = train_func(data, _convert_to_vector(initial_weights))
     return modelClass(weights, intercept)
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org