You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/02/05 07:46:52 UTC

spark git commit: [SPARK-5596] [mllib] ML model import/export for GLMs, NaiveBayes

Repository: spark
Updated Branches:
  refs/heads/master c23ac03c8 -> 975bcef46


[SPARK-5596] [mllib] ML model import/export for GLMs, NaiveBayes

This is a PR for Parquet-based model import/export.  Please see the design doc on [the JIRA](https://issues.apache.org/jira/browse/SPARK-4587).

Note: This includes only a subset of regression and classification models:
* NaiveBayes, SVM, LogisticRegression
* LinearRegression, RidgeRegression, Lasso

Follow-up PRs will cover other models.

Sketch of current contents:
* New traits: Saveable, Loader
* Implementations for some algorithms
* Also: Added LogisticRegressionModel.getThreshold method (so that unit test could check the threshold)

CC: mengxr  selvinsource

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #4233 from jkbradley/ml-import-export and squashes the following commits:

87c4eb8 [Joseph K. Bradley] small cleanups
12d9059 [Joseph K. Bradley] Many cleanups after code review.  Major changes: Storing numFeatures, numClasses in model metadata. Improvements to unit tests
b4ee064 [Joseph K. Bradley] Reorganized save/load for regression and classification.  Renamed concepts to Saveable, Loader
a34aef5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into ml-import-export
ee99228 [Joseph K. Bradley] scala style fix
79675d5 [Joseph K. Bradley] cleanups in LogisticRegression after rebasing after multinomial PR
d1e5882 [Joseph K. Bradley] organized imports
2935963 [Joseph K. Bradley] Added save/load and tests for most classification and regression models
c495dba [Joseph K. Bradley] made version for model import/export local to each model
1496852 [Joseph K. Bradley] Added save/load for NaiveBayes
8d46386 [Joseph K. Bradley] Added save/load to NaiveBayes
1577d70 [Joseph K. Bradley] fixed issues after rebasing on master (DataFrame patch)
64914a3 [Joseph K. Bradley] added getThreshold to SVMModel
b1fc5ec [Joseph K. Bradley] small cleanups
418ba1b [Joseph K. Bradley] Added save, load to mllib.classification.LogisticRegressionModel, plus test suite


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/975bcef4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/975bcef4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/975bcef4

Branch: refs/heads/master
Commit: 975bcef467b35586e5224171071355409f451d2d
Parents: c23ac03
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Wed Feb 4 22:46:48 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Feb 4 22:46:48 2015 -0800

----------------------------------------------------------------------
 .../classification/ClassificationModel.scala    |  20 +++
 .../classification/LogisticRegression.scala     |  67 ++++++++-
 .../spark/mllib/classification/NaiveBayes.scala |  87 +++++++++++-
 .../apache/spark/mllib/classification/SVM.scala |  51 ++++++-
 .../impl/GLMClassificationModel.scala           |  95 +++++++++++++
 .../apache/spark/mllib/regression/Lasso.scala   |  33 ++++-
 .../mllib/regression/LinearRegression.scala     |  35 ++++-
 .../mllib/regression/RegressionModel.scala      |  22 ++-
 .../mllib/regression/RidgeRegression.scala      |  38 ++++-
 .../regression/impl/GLMRegressionModel.scala    |  86 ++++++++++++
 .../mllib/tree/model/DecisionTreeModel.scala    |   1 -
 .../apache/spark/mllib/util/modelSaveLoad.scala | 139 +++++++++++++++++++
 .../LogisticRegressionSuite.scala               |  70 +++++++++-
 .../mllib/classification/NaiveBayesSuite.scala  |  40 +++++-
 .../spark/mllib/classification/SVMSuite.scala   |  36 +++++
 .../spark/mllib/regression/LassoSuite.scala     |  24 ++++
 .../regression/LinearRegressionSuite.scala      |  24 ++++
 .../mllib/regression/RidgeRegressionSuite.scala |  24 ++++
 18 files changed, 863 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
index b7a1d90..348c1e8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
@@ -20,7 +20,9 @@ package org.apache.spark.mllib.classification
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
 
 /**
  * :: Experimental ::
@@ -53,3 +55,21 @@ trait ClassificationModel extends Serializable {
   def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
     predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
 }
+
+private[mllib] object ClassificationModel {
+
+  /**
+   * Helper method for loading GLM classification model metadata.
+   *
+   * @param modelClass  String name for model class (used for error messages)
+   * @return (numFeatures, numClasses)
+   */
+  def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = {
+    metadata.select("numFeatures", "numClasses").take(1)(0) match {
+      case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses)
+      case _ => throw new Exception(s"$modelClass unable to load" +
+        s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}")
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index a469315..5c9feb6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -17,14 +17,17 @@
 
 package org.apache.spark.mllib.classification
 
+import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.classification.impl.GLMClassificationModel
 import org.apache.spark.mllib.linalg.BLAS.dot
 import org.apache.spark.mllib.linalg.{DenseVector, Vector}
 import org.apache.spark.mllib.optimization._
 import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{DataValidators, MLUtils}
+import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
 import org.apache.spark.rdd.RDD
 
+
 /**
  * Classification model trained using Multinomial/Binary Logistic Regression.
  *
@@ -42,7 +45,22 @@ class LogisticRegressionModel (
     override val intercept: Double,
     val numFeatures: Int,
     val numClasses: Int)
-  extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
+  extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
+  with Saveable {
+
+  if (numClasses == 2) {
+    require(weights.size == numFeatures,
+      s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" +
+      s" numFeatures = $numFeatures, but weights.size = ${weights.size}")
+  } else {
+    val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures
+    val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1)
+    require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept,
+      s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" +
+      s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" +
+      s" or $weightsSizeWithIntercept (with intercept)," +
+      s" but was given weights of length ${weights.size}")
+  }
 
   def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)
 
@@ -62,6 +80,13 @@ class LogisticRegressionModel (
 
   /**
    * :: Experimental ::
+   * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
+   */
+  @Experimental
+  def getThreshold: Option[Double] = threshold
+
+  /**
+   * :: Experimental ::
    * Clears the threshold so that `predict` will output raw prediction scores.
    */
   @Experimental
@@ -70,7 +95,9 @@ class LogisticRegressionModel (
     this
   }
 
-  override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
+  override protected def predictPoint(
+      dataMatrix: Vector,
+      weightMatrix: Vector,
       intercept: Double) = {
     require(dataMatrix.size == numFeatures)
 
@@ -126,6 +153,40 @@ class LogisticRegressionModel (
       bestClass.toDouble
     }
   }
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
+      numFeatures, numClasses, weights, intercept, threshold)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
+
+  override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
+    val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+    // Hard-code class name string in case it changes in the future
+    val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
+    (loadedClassName, version) match {
+      case (className, "1.0") if className == classNameV1_0 =>
+        val (numFeatures, numClasses) =
+          ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+        val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
+        // numFeatures, numClasses, weights are checked in model initialization
+        val model =
+          new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses)
+        data.threshold match {
+          case Some(t) => model.setThreshold(t)
+          case None => model.clearThreshold()
+        }
+        model
+      case _ => throw new Exception(
+        s"LogisticRegressionModel.load did not recognize model with (className, format version):" +
+        s"($loadedClassName, $version).  Supported:\n" +
+        s"  ($classNameV1_0, 1.0)")
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index a967df8..4bafd49 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -19,11 +19,13 @@ package org.apache.spark.mllib.classification
 
 import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
 
-import org.apache.spark.{SparkException, Logging}
-import org.apache.spark.SparkContext._
+import org.apache.spark.{SparkContext, SparkException, Logging}
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
 import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
 
 /**
  * Model for Naive Bayes Classifiers.
@@ -36,7 +38,7 @@ import org.apache.spark.rdd.RDD
 class NaiveBayesModel private[mllib] (
     val labels: Array[Double],
     val pi: Array[Double],
-    val theta: Array[Array[Double]]) extends ClassificationModel with Serializable {
+    val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {
 
   private val brzPi = new BDV[Double](pi)
   private val brzTheta = new BDM[Double](theta.length, theta(0).length)
@@ -65,6 +67,85 @@ class NaiveBayesModel private[mllib] (
   override def predict(testData: Vector): Double = {
     labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
   }
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta)
+    NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+object NaiveBayesModel extends Loader[NaiveBayesModel] {
+
+  import Loader._
+
+  private object SaveLoadV1_0 {
+
+    def thisFormatVersion = "1.0"
+
+    /** Hard-code class name string in case it changes in the future */
+    def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"
+
+    /** Model data for model import/export */
+    case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])
+
+    def save(sc: SparkContext, path: String, data: Data): Unit = {
+      val sqlContext = new SQLContext(sc)
+      import sqlContext._
+
+      // Create JSON metadata.
+      val metadataRDD =
+        sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1)
+          .toDataFrame("class", "version", "numFeatures", "numClasses")
+      metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
+
+      // Create Parquet data.
+      val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+      dataRDD.saveAsParquetFile(dataPath(path))
+    }
+
+    def load(sc: SparkContext, path: String): NaiveBayesModel = {
+      val sqlContext = new SQLContext(sc)
+      // Load Parquet data.
+      val dataRDD = sqlContext.parquetFile(dataPath(path))
+      // Check schema explicitly since erasure makes it hard to use match-case for checking.
+      checkSchema[Data](dataRDD.schema)
+      val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
+      assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
+      val data = dataArray(0)
+      val labels = data.getAs[Seq[Double]](0).toArray
+      val pi = data.getAs[Seq[Double]](1).toArray
+      val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
+      new NaiveBayesModel(labels, pi, theta)
+    }
+  }
+
+  override def load(sc: SparkContext, path: String): NaiveBayesModel = {
+    val (loadedClassName, version, metadata) = loadMetadata(sc, path)
+    val classNameV1_0 = SaveLoadV1_0.thisClassName
+    (loadedClassName, version) match {
+      case (className, "1.0") if className == classNameV1_0 =>
+        val (numFeatures, numClasses) =
+          ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+        val model = SaveLoadV1_0.load(sc, path)
+        assert(model.pi.size == numClasses,
+          s"NaiveBayesModel.load expected $numClasses classes," +
+          s" but class priors vector pi had ${model.pi.size} elements")
+        assert(model.theta.size == numClasses,
+          s"NaiveBayesModel.load expected $numClasses classes," +
+            s" but class conditionals array theta had ${model.theta.size} elements")
+        assert(model.theta.forall(_.size == numFeatures),
+          s"NaiveBayesModel.load expected $numFeatures features," +
+          s" but class conditionals array theta had elements of size:" +
+          s" ${model.theta.map(_.size).mkString(",")}")
+        model
+      case _ => throw new Exception(
+        s"NaiveBayesModel.load did not recognize model with (className, format version):" +
+        s"($loadedClassName, $version).  Supported:\n" +
+        s"  ($classNameV1_0, 1.0)")
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index dd514ff..24d31e6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -17,13 +17,16 @@
 
 package org.apache.spark.mllib.classification
 
+import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.classification.impl.GLMClassificationModel
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.optimization._
 import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.DataValidators
+import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
 import org.apache.spark.rdd.RDD
 
+
 /**
  * Model for Support Vector Machines (SVMs).
  *
@@ -33,7 +36,8 @@ import org.apache.spark.rdd.RDD
 class SVMModel (
     override val weights: Vector,
     override val intercept: Double)
-  extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
+  extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
+  with Saveable {
 
   private var threshold: Option[Double] = Some(0.0)
 
@@ -51,6 +55,13 @@ class SVMModel (
 
   /**
    * :: Experimental ::
+   * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
+   */
+  @Experimental
+  def getThreshold: Option[Double] = threshold
+
+  /**
+   * :: Experimental ::
    * Clears the threshold so that `predict` will output raw prediction scores.
    */
   @Experimental
@@ -69,6 +80,42 @@ class SVMModel (
       case None => margin
     }
   }
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
+      numFeatures = weights.size, numClasses = 2, weights, intercept, threshold)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+object SVMModel extends Loader[SVMModel] {
+
+  override def load(sc: SparkContext, path: String): SVMModel = {
+    val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+    // Hard-code class name string in case it changes in the future
+    val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
+    (loadedClassName, version) match {
+      case (className, "1.0") if className == classNameV1_0 =>
+        val (numFeatures, numClasses) =
+          ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+        val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
+        val model = new SVMModel(data.weights, data.intercept)
+        assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" +
+          s" was given non-matching weights vector of size ${model.weights.size}")
+        assert(numClasses == 2,
+          s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes")
+        data.threshold match {
+          case Some(t) => model.setThreshold(t)
+          case None => model.clearThreshold()
+        }
+        model
+      case _ => throw new Exception(
+        s"SVMModel.load did not recognize model with (className, format version):" +
+        s"($loadedClassName, $version).  Supported:\n" +
+        s"  ($classNameV1_0, 1.0)")
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
new file mode 100644
index 0000000..b60c0cd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.classification.impl
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+/**
+ * Helper class for import/export of GLM classification models.
+ */
+private[classification] object GLMClassificationModel {
+
+  object SaveLoadV1_0 {
+
+    def thisFormatVersion = "1.0"
+
+    /** Model data for import/export */
+    case class Data(weights: Vector, intercept: Double, threshold: Option[Double])
+
+    /**
+     * Helper method for saving GLM classification model metadata and data.
+     * @param modelClass  String name for model class, to be saved with metadata
+     * @param numClasses  Number of classes label can take, to be saved with metadata
+     */
+    def save(
+        sc: SparkContext,
+        path: String,
+        modelClass: String,
+        numFeatures: Int,
+        numClasses: Int,
+        weights: Vector,
+        intercept: Double,
+        threshold: Option[Double]): Unit = {
+      val sqlContext = new SQLContext(sc)
+      import sqlContext._
+
+      // Create JSON metadata.
+      val metadataRDD =
+        sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, numClasses)), 1)
+          .toDataFrame("class", "version", "numFeatures", "numClasses")
+      metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+      // Create Parquet data.
+      val data = Data(weights, intercept, threshold)
+      val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+      // TODO: repartition with 1 partition after SPARK-5532 gets fixed
+      dataRDD.saveAsParquetFile(Loader.dataPath(path))
+    }
+
+    /**
+     * Helper method for loading GLM classification model data.
+     *
+     * NOTE: Callers of this method should check numClasses, numFeatures on their own.
+     *
+     * @param modelClass  String name for model class (used for error messages)
+     */
+    def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
+      val datapath = Loader.dataPath(path)
+      val sqlContext = new SQLContext(sc)
+      val dataRDD = sqlContext.parquetFile(datapath)
+      val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
+      assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
+      val data = dataArray(0)
+      assert(data.size == 3, s"Unable to load $modelClass data from: $datapath")
+      val (weights, intercept) = data match {
+        case Row(weights: Vector, intercept: Double, _) =>
+          (weights, intercept)
+      }
+      val threshold = if (data.isNullAt(2)) {
+        None
+      } else {
+        Some(data.getDouble(2))
+      }
+      Data(weights, intercept, threshold)
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index 8ecd5c6..1159e59 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -17,9 +17,11 @@
 
 package org.apache.spark.mllib.regression
 
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.SparkContext
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.regression.impl.GLMRegressionModel
+import org.apache.spark.mllib.util.{Saveable, Loader}
 import org.apache.spark.rdd.RDD
 
 /**
@@ -32,7 +34,7 @@ class LassoModel (
     override val weights: Vector,
     override val intercept: Double)
   extends GeneralizedLinearModel(weights, intercept)
-  with RegressionModel with Serializable {
+  with RegressionModel with Serializable with Saveable {
 
   override protected def predictPoint(
       dataMatrix: Vector,
@@ -40,12 +42,37 @@ class LassoModel (
       intercept: Double): Double = {
     weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
   }
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+object LassoModel extends Loader[LassoModel] {
+
+  override def load(sc: SparkContext, path: String): LassoModel = {
+    val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+    // Hard-code class name string in case it changes in the future
+    val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel"
+    (loadedClassName, version) match {
+      case (className, "1.0") if className == classNameV1_0 =>
+        val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+        val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
+        new LassoModel(data.weights, data.intercept)
+      case _ => throw new Exception(
+        s"LassoModel.load did not recognize model with (className, format version):" +
+        s"($loadedClassName, $version).  Supported:\n" +
+        s"  ($classNameV1_0, 1.0)")
+    }
+  }
 }
 
 /**
  * Train a regression model with L1-regularization using Stochastic Gradient Descent.
  * This solves the l1-regularized least squares regression formulation
- *          f(weights) = 1/2n ||A weights-y||^2  + regParam ||weights||_1
+ *          f(weights) = 1/2n ||A weights-y||^2^  + regParam ||weights||_1
  * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
  * its corresponding right hand side label y.
  * See also the documentation for the precise formulation.

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 81b6598..0136dcf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -17,9 +17,12 @@
 
 package org.apache.spark.mllib.regression
 
-import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.regression.impl.GLMRegressionModel
+import org.apache.spark.mllib.util.{Saveable, Loader}
+import org.apache.spark.rdd.RDD
 
 /**
  * Regression model trained using LinearRegression.
@@ -30,7 +33,8 @@ import org.apache.spark.mllib.optimization._
 class LinearRegressionModel (
     override val weights: Vector,
     override val intercept: Double)
-  extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {
+  extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable
+  with Saveable {
 
   override protected def predictPoint(
       dataMatrix: Vector,
@@ -38,12 +42,37 @@ class LinearRegressionModel (
       intercept: Double): Double = {
     weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
   }
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+object LinearRegressionModel extends Loader[LinearRegressionModel] {
+
+  override def load(sc: SparkContext, path: String): LinearRegressionModel = {
+    val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+    // Hard-code class name string in case it changes in the future
+    val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel"
+    (loadedClassName, version) match {
+      case (className, "1.0") if className == classNameV1_0 =>
+        val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+        val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
+        new LinearRegressionModel(data.weights, data.intercept)
+      case _ => throw new Exception(
+        s"LinearRegressionModel.load did not recognize model with (className, format version):" +
+        s"($loadedClassName, $version).  Supported:\n" +
+        s"  ($classNameV1_0, 1.0)")
+    }
+  }
 }
 
 /**
  * Train a linear regression model with no regularization using Stochastic Gradient Descent.
  * This solves the least squares regression formulation
- *              f(weights) = 1/n ||A weights-y||^2
+ *              f(weights) = 1/n ||A weights-y||^2^
  * (which is the mean squared error).
  * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
  * its corresponding right hand side label y.

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
index 64b02f7..843e59b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
@@ -19,8 +19,10 @@ package org.apache.spark.mllib.regression
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.rdd.RDD
 import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
 
 @Experimental
 trait RegressionModel extends Serializable {
@@ -48,3 +50,21 @@ trait RegressionModel extends Serializable {
   def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
     predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
 }
+
+private[mllib] object RegressionModel {
+
+  /**
+   * Helper method for loading GLM regression model metadata.
+   *
+   * @param modelClass  String name for model class (used for error messages)
+   * @return numFeatures
+   */
+  def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): Int = {
+    metadata.select("numFeatures").take(1)(0) match {
+      case Row(nFeatures: Int) => nFeatures
+      case _ => throw new Exception(s"$modelClass unable to load" +
+        s" numFeatures from metadata: ${Loader.metadataPath(path)}")
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index 076ba35..f2a5f1d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.mllib.regression
 
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.optimization._
+import org.apache.spark.SparkContext
 import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.regression.impl.GLMRegressionModel
+import org.apache.spark.mllib.util.{Loader, Saveable}
+import org.apache.spark.rdd.RDD
+
 
 /**
  * Regression model trained using RidgeRegression.
@@ -32,7 +35,7 @@ class RidgeRegressionModel (
     override val weights: Vector,
     override val intercept: Double)
   extends GeneralizedLinearModel(weights, intercept)
-  with RegressionModel with Serializable {
+  with RegressionModel with Serializable with Saveable {
 
   override protected def predictPoint(
       dataMatrix: Vector,
@@ -40,12 +43,37 @@ class RidgeRegressionModel (
       intercept: Double): Double = {
     weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
   }
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+object RidgeRegressionModel extends Loader[RidgeRegressionModel] {
+
+  override def load(sc: SparkContext, path: String): RidgeRegressionModel = {
+    val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+    // Hard-code class name string in case it changes in the future
+    val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel"
+    (loadedClassName, version) match {
+      case (className, "1.0") if className == classNameV1_0 =>
+        val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+        val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
+        new RidgeRegressionModel(data.weights, data.intercept)
+      case _ => throw new Exception(
+        s"RidgeRegressionModel.load did not recognize model with (className, format version):" +
+        s"($loadedClassName, $version).  Supported:\n" +
+        s"  ($classNameV1_0, 1.0)")
+    }
+  }
 }
 
 /**
  * Train a regression model with L2-regularization using Stochastic Gradient Descent.
  * This solves the l1-regularized least squares regression formulation
- *          f(weights) = 1/2n ||A weights-y||^2  + regParam/2 ||weights||^2
+ *          f(weights) = 1/2n ||A weights-y||^2^  + regParam/2 ||weights||^2^
  * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
  * its corresponding right hand side label y.
  * See also the documentation for the precise formulation.

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
new file mode 100644
index 0000000..00f25a8
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.regression.impl
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+/**
+ * Helper methods for import/export of GLM regression models.
+ */
+private[regression] object GLMRegressionModel {
+
+  object SaveLoadV1_0 {
+
+    def thisFormatVersion = "1.0"
+
+    /** Model data for model import/export */
+    case class Data(weights: Vector, intercept: Double)
+
+    /**
+     * Helper method for saving GLM regression model metadata and data.
+     * @param modelClass  String name for model class, to be saved with metadata
+     */
+    def save(
+        sc: SparkContext,
+        path: String,
+        modelClass: String,
+        weights: Vector,
+        intercept: Double): Unit = {
+      val sqlContext = new SQLContext(sc)
+      import sqlContext._
+
+      // Create JSON metadata.
+      val metadataRDD =
+        sc.parallelize(Seq((modelClass, thisFormatVersion, weights.size)), 1)
+          .toDataFrame("class", "version", "numFeatures")
+      metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+      // Create Parquet data.
+      val data = Data(weights, intercept)
+      val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+      // TODO: repartition with 1 partition after SPARK-5532 gets fixed
+      dataRDD.saveAsParquetFile(Loader.dataPath(path))
+    }
+
+    /**
+     * Helper method for loading GLM regression model data.
+     * @param modelClass  String name for model class (used for error messages)
+     * @param numFeatures  Number of features, to be checked against loaded data.
+     *                     The length of the weights vector should equal numFeatures.
+     */
+    def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
+      val datapath = Loader.dataPath(path)
+      val sqlContext = new SQLContext(sc)
+      val dataRDD = sqlContext.parquetFile(datapath)
+      val dataArray = dataRDD.select("weights", "intercept").take(1)
+      assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
+      val data = dataArray(0)
+      assert(data.size == 2, s"Unable to load $modelClass data from: $datapath")
+      data match {
+        case Row(weights: Vector, intercept: Double) =>
+          assert(weights.size == numFeatures, s"Expected $numFeatures features, but" +
+            s" found ${weights.size} features when loading $modelClass weights from $datapath")
+          Data(weights, intercept)
+      }
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index a576096..a25e625 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -53,7 +53,6 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
     features.map(x => predict(x))
   }
 
-
   /**
    * Predict values for the given data set using the model trained.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
new file mode 100644
index 0000000..56b77a7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types.{DataType, StructType, StructField}
+
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Trait for models and transformers which may be saved as files.
+ * This should be inherited by the class which implements model instances.
+ */
+@DeveloperApi
+trait Saveable {
+
+  /**
+   * Save this model to the given path.
+   *
+   * This saves:
+   *  - human-readable (JSON) model metadata to path/metadata/
+   *  - Parquet formatted data to path/data/
+   *
+   * The model may be loaded using [[Loader.load]].
+   *
+   * @param sc  Spark context used to save model data.
+   * @param path  Path specifying the directory in which to save this model.
+   *              This directory and any intermediate directory will be created if needed.
+   */
+  def save(sc: SparkContext, path: String): Unit
+
+  /** Current version of model save/load format. */
+  protected def formatVersion: String
+
+}
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Trait for classes which can load models and transformers from files.
+ * This should be inherited by an object paired with the model class.
+ */
+@DeveloperApi
+trait Loader[M <: Saveable] {
+
+  /**
+   * Load a model from the given path.
+   *
+   * The model should have been saved by [[Saveable.save]].
+   *
+   * @param sc  Spark context used for loading model files.
+   * @param path  Path specifying the directory to which the model was saved.
+   * @return  Model instance
+   */
+  def load(sc: SparkContext, path: String): M
+
+}
+
+/**
+ * Helper methods for loading models from files.
+ */
+private[mllib] object Loader {
+
+  /** Returns URI for path/data using the Hadoop filesystem */
+  def dataPath(path: String): String = new Path(path, "data").toUri.toString
+
+  /** Returns URI for path/metadata using the Hadoop filesystem */
+  def metadataPath(path: String): String = new Path(path, "metadata").toUri.toString
+
+  /**
+   * Check the schema of loaded model data.
+   *
+   * This checks every field in the expected schema to make sure that a field with the same
+   * name and DataType appears in the loaded schema.  Note that this does NOT check metadata
+   * or containsNull.
+   *
+   * @param loadedSchema  Schema for model data loaded from file.
+   * @tparam Data  Expected data type from which an expected schema can be derived.
+   */
+  def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = {
+    // Check schema explicitly since erasure makes it hard to use match-case for checking.
+    val expectedFields: Array[StructField] =
+      ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields
+    val loadedFields: Map[String, DataType] =
+      loadedSchema.map(field => field.name -> field.dataType).toMap
+    expectedFields.foreach { field =>
+      assert(loadedFields.contains(field.name), s"Unable to parse model data." +
+        s"  Expected field with name ${field.name} was missing in loaded schema:" +
+        s" ${loadedFields.mkString(", ")}")
+      assert(loadedFields(field.name) == field.dataType,
+        s"Unable to parse model data.  Expected field $field but found field" +
+          s" with different type: ${loadedFields(field.name)}")
+    }
+  }
+
+  /**
+   * Load metadata from the given path.
+   * @return (class name, version, metadata)
+   */
+  def loadMetadata(sc: SparkContext, path: String): (String, String, DataFrame) = {
+    val sqlContext = new SQLContext(sc)
+    val metadata = sqlContext.jsonFile(metadataPath(path))
+    val (clazz, version) = try {
+      val metadataArray = metadata.select("class", "version").take(1)
+      assert(metadataArray.size == 1)
+      metadataArray(0) match {
+        case Row(clazz: String, version: String) => (clazz, version)
+      }
+    } catch {
+      case e: Exception =>
+        throw new Exception(s"Unable to load model metadata from: ${metadataPath(path)}")
+    }
+    (clazz, version, metadata)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 3fb4593..d2b40f2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -17,9 +17,9 @@
 
 package org.apache.spark.mllib.classification
 
-import scala.util.control.Breaks._
-import scala.util.Random
 import scala.collection.JavaConversions._
+import scala.util.Random
+import scala.util.control.Breaks._
 
 import org.scalatest.FunSuite
 import org.scalatest.Matchers
@@ -28,6 +28,8 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression._
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
+
 
 object LogisticRegressionSuite {
 
@@ -147,8 +149,25 @@ object LogisticRegressionSuite {
     val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i)))
     testData
   }
+
+  /** Binary labels, 3 features */
+  private val binaryModel = new LogisticRegressionModel(
+    weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5, numFeatures = 3, numClasses = 2)
+
+  /** 3 classes, 2 features */
+  private val multiclassModel = new LogisticRegressionModel(
+    weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3)
+
+  private def checkModelsEqual(a: LogisticRegressionModel, b: LogisticRegressionModel): Unit = {
+    assert(a.weights == b.weights)
+    assert(a.intercept == b.intercept)
+    assert(a.numClasses == b.numClasses)
+    assert(a.numFeatures == b.numFeatures)
+    assert(a.getThreshold == b.getThreshold)
+  }
 }
 
+
 class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
   def validatePrediction(
       predictions: Seq[Double],
@@ -462,6 +481,53 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
 
   }
 
+  test("model save/load: binary classification") {
+    // NOTE: This will need to be generalized once there are multiple model format versions.
+    val model = LogisticRegressionSuite.binaryModel
+
+    model.clearThreshold()
+    assert(model.getThreshold.isEmpty)
+
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    // Save model, load it back, and compare.
+    try {
+      model.save(sc, path)
+      val sameModel = LogisticRegressionModel.load(sc, path)
+      LogisticRegressionSuite.checkModelsEqual(model, sameModel)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+
+    // Save model with threshold.
+    try {
+      model.setThreshold(0.7)
+      model.save(sc, path)
+      val sameModel = LogisticRegressionModel.load(sc, path)
+      LogisticRegressionSuite.checkModelsEqual(model, sameModel)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+
+  test("model save/load: multiclass classification") {
+    // NOTE: This will need to be generalized once there are multiple model format versions.
+    val model = LogisticRegressionSuite.multiclassModel
+
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    // Save model, load it back, and compare.
+    try {
+      model.save(sc, path)
+      val sameModel = LogisticRegressionModel.load(sc, path)
+      LogisticRegressionSuite.checkModelsEqual(model, sameModel)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+
 }
 
 class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index e68fe89..64dcc0f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -25,6 +25,8 @@ import org.apache.spark.SparkException
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
 
 object NaiveBayesSuite {
 
@@ -58,6 +60,18 @@ object NaiveBayesSuite {
       LabeledPoint(y, Vectors.dense(xi))
     }
   }
+
+  private val smallPi = Array(0.5, 0.3, 0.2).map(math.log)
+
+  private val smallTheta = Array(
+    Array(0.91, 0.03, 0.03, 0.03), // label 0
+    Array(0.03, 0.91, 0.03, 0.03), // label 1
+    Array(0.03, 0.03, 0.91, 0.03)  // label 2
+  ).map(_.map(math.log))
+
+  /** Binary labels, 3 features */
+  private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8),
+    theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)))
 }
 
 class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
@@ -74,12 +88,8 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
   test("Naive Bayes") {
     val nPoints = 10000
 
-    val pi = Array(0.5, 0.3, 0.2).map(math.log)
-    val theta = Array(
-      Array(0.91, 0.03, 0.03, 0.03), // label 0
-      Array(0.03, 0.91, 0.03, 0.03), // label 1
-      Array(0.03, 0.03, 0.91, 0.03)  // label 2
-    ).map(_.map(math.log))
+    val pi = NaiveBayesSuite.smallPi
+    val theta = NaiveBayesSuite.smallTheta
 
     val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42)
     val testRDD = sc.parallelize(testData, 2)
@@ -123,6 +133,24 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
       NaiveBayes.train(sc.makeRDD(nan, 2))
     }
   }
+
+  test("model save/load") {
+    val model = NaiveBayesSuite.binaryModel
+
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    // Save model, load it back, and compare.
+    try {
+      model.save(sc, path)
+      val sameModel = NaiveBayesModel.load(sc, path)
+      assert(model.labels === sameModel.labels)
+      assert(model.pi === sameModel.pi)
+      assert(model.theta === sameModel.theta)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
 }
 
 class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index a2de7fb..6de098b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression._
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
+import org.apache.spark.util.Utils
 
 object SVMSuite {
 
@@ -56,6 +57,9 @@ object SVMSuite {
     y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
   }
 
+  /** Binary labels, 3 features */
+  private val binaryModel = new SVMModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+
 }
 
 class SVMSuite extends FunSuite with MLlibTestSparkContext {
@@ -191,6 +195,38 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
     // Turning off data validation should not throw an exception
     new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
   }
+
+  test("model save/load") {
+    // NOTE: This will need to be generalized once there are multiple model format versions.
+    val model = SVMSuite.binaryModel
+
+    model.clearThreshold()
+    assert(model.getThreshold.isEmpty)
+
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    // Save model, load it back, and compare.
+    try {
+      model.save(sc, path)
+      val sameModel = SVMModel.load(sc, path)
+      assert(model.weights == sameModel.weights)
+      assert(model.intercept == sameModel.intercept)
+      assert(sameModel.getThreshold.isEmpty)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+
+    // Save model with threshold.
+    try {
+      model.setThreshold(0.7)
+      model.save(sc, path)
+      val sameModel2 = SVMModel.load(sc, path)
+      assert(model.getThreshold.get == sameModel2.getThreshold.get)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
 }
 
 class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 2668dcc..c9f5dc0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -24,6 +24,13 @@ import org.scalatest.FunSuite
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
   MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
+private object LassoSuite {
+
+  /** 3 features */
+  val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+}
 
 class LassoSuite extends FunSuite with MLlibTestSparkContext {
 
@@ -115,6 +122,23 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
     // Test prediction on Array.
     validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
   }
+
+  test("model save/load") {
+    val model = LassoSuite.model
+
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    // Save model, load it back, and compare.
+    try {
+      model.save(sc, path)
+      val sameModel = LassoModel.load(sc, path)
+      assert(model.weights == sameModel.weights)
+      assert(model.intercept == sameModel.intercept)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
 }
 
 class LassoClusterSuite extends FunSuite with LocalClusterSparkContext {

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 864622a..3781931 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -24,6 +24,13 @@ import org.scalatest.FunSuite
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
   MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
+private object LinearRegressionSuite {
+
+  /** 3 features */
+  val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+}
 
 class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
 
@@ -124,6 +131,23 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
     validatePrediction(
       sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
   }
+
+  test("model save/load") {
+    val model = LinearRegressionSuite.model
+
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    // Save model, load it back, and compare.
+    try {
+      model.save(sc, path)
+      val sameModel = LinearRegressionModel.load(sc, path)
+      assert(model.weights == sameModel.weights)
+      assert(model.intercept == sameModel.intercept)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
 }
 
 class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {

http://git-wip-us.apache.org/repos/asf/spark/blob/975bcef4/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 18d3bf5..43d6115 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -25,6 +25,13 @@ import org.scalatest.FunSuite
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
   MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
+private object RidgeRegressionSuite {
+
+  /** 3 features */
+  val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+}
 
 class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
 
@@ -75,6 +82,23 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
     assert(ridgeErr < linearErr,
       "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
   }
+
+  test("model save/load") {
+    val model = RidgeRegressionSuite.model
+
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    // Save model, load it back, and compare.
+    try {
+      model.save(sc, path)
+      val sameModel = RidgeRegressionModel.load(sc, path)
+      assert(model.weights == sameModel.weights)
+      assert(model.intercept == sameModel.intercept)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
 }
 
 class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org