You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/04/29 18:44:38 UTC

spark git commit: [SPARK-14314][SPARK-14315][ML][SPARKR] Model persistence in SparkR (glm & kmeans)

Repository: spark
Updated Branches:
  refs/heads/master a7d0fedc9 -> 87ac84d43


[SPARK-14314][SPARK-14315][ML][SPARKR] Model persistence in SparkR (glm & kmeans)

SparkR ```glm``` and ```kmeans``` model persistence.

Unit tests.

Author: Yanbo Liang <yb...@gmail.com>
Author: Gayathri Murali <ga...@gmail.com>

Closes #12778 from yanboliang/spark-14311.
Closes #12680
Closes #12683


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/87ac84d4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/87ac84d4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/87ac84d4

Branch: refs/heads/master
Commit: 87ac84d43729c54be100bb9ad7dc6e8fa14b8805
Parents: a7d0fed
Author: Yanbo Liang <yb...@gmail.com>
Authored: Fri Apr 29 09:42:54 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri Apr 29 09:43:04 2016 -0700

----------------------------------------------------------------------
 R/pkg/R/mllib.R                                 |  98 ++++++++--
 R/pkg/inst/tests/testthat/test_mllib.R          |  41 +++++
 .../ml/r/AFTSurvivalRegressionWrapper.scala     |   1 -
 .../r/GeneralizedLinearRegressionWrapper.scala  | 181 +++++++++++++------
 .../org/apache/spark/ml/r/KMeansWrapper.scala   |  65 ++++++-
 .../apache/spark/ml/r/NaiveBayesWrapper.scala   |   1 -
 .../scala/org/apache/spark/ml/r/RWrappers.scala |   4 +
 7 files changed, 315 insertions(+), 76 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/87ac84d4/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 4803011..c2326ea 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -99,9 +99,9 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat
 setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
           function(object, ...) {
             jobj <- object@jobj
+            is.loaded <- callJMethod(jobj, "isLoaded")
             features <- callJMethod(jobj, "rFeatures")
             coefficients <- callJMethod(jobj, "rCoefficients")
-            deviance.resid <- callJMethod(jobj, "rDevianceResiduals")
             dispersion <- callJMethod(jobj, "rDispersion")
             null.deviance <- callJMethod(jobj, "rNullDeviance")
             deviance <- callJMethod(jobj, "rDeviance")
@@ -110,15 +110,18 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
             aic <- callJMethod(jobj, "rAic")
             iter <- callJMethod(jobj, "rNumIterations")
             family <- callJMethod(jobj, "rFamily")
-
-            deviance.resid <- dataFrame(deviance.resid)
+            deviance.resid <- if (is.loaded) {
+              NULL
+            } else {
+              dataFrame(callJMethod(jobj, "rDevianceResiduals"))
+            }
             coefficients <- matrix(coefficients, ncol = 4)
             colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
             rownames(coefficients) <- unlist(features)
             ans <- list(deviance.resid = deviance.resid, coefficients = coefficients,
                         dispersion = dispersion, null.deviance = null.deviance,
                         deviance = deviance, df.null = df.null, df.residual = df.residual,
-                        aic = aic, iter = iter, family = family)
+                        aic = aic, iter = iter, family = family, is.loaded = is.loaded)
             class(ans) <- "summary.GeneralizedLinearRegressionModel"
             return(ans)
           })
@@ -129,12 +132,16 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
 #' @name print.summary.GeneralizedLinearRegressionModel
 #' @export
 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
-  x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals",
+  if (x$is.loaded) {
+    cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n")
+  } else {
+    x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals",
     c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max"))
-  x$deviance.resid <- zapsmall(x$deviance.resid, 5L)
-  cat("\nDeviance Residuals: \n")
-  cat("(Note: These are approximate quantiles with relative error <= 0.01)\n")
-  print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L)
+    x$deviance.resid <- zapsmall(x$deviance.resid, 5L)
+    cat("\nDeviance Residuals: \n")
+    cat("(Note: These are approximate quantiles with relative error <= 0.01)\n")
+    print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L)
+  }
 
   cat("\nCoefficients:\n")
   print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L)
@@ -246,6 +253,7 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
 #' Get fitted result from a k-means model
 #'
 #' Get fitted result from a k-means model, similarly to R's fitted().
+#' Note: A saved-loaded model does not support this method.
 #'
 #' @param object A fitted k-means model
 #' @return SparkDataFrame containing fitted values
@@ -260,7 +268,13 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
 setMethod("fitted", signature(object = "KMeansModel"),
           function(object, method = c("centers", "classes"), ...) {
             method <- match.arg(method)
-            return(dataFrame(callJMethod(object@jobj, "fitted", method)))
+            jobj <- object@jobj
+            is.loaded <- callJMethod(jobj, "isLoaded")
+            if (is.loaded) {
+              stop(paste("Saved-loaded k-means model does not support 'fitted' method"))
+            } else {
+              return(dataFrame(callJMethod(jobj, "fitted", method)))
+            }
           })
 
 #' Get the summary of a k-means model
@@ -280,15 +294,21 @@ setMethod("fitted", signature(object = "KMeansModel"),
 setMethod("summary", signature(object = "KMeansModel"),
           function(object, ...) {
             jobj <- object@jobj
+            is.loaded <- callJMethod(jobj, "isLoaded")
             features <- callJMethod(jobj, "features")
             coefficients <- callJMethod(jobj, "coefficients")
-            cluster <- callJMethod(jobj, "cluster")
             k <- callJMethod(jobj, "k")
             size <- callJMethod(jobj, "size")
             coefficients <- t(matrix(coefficients, ncol = k))
             colnames(coefficients) <- unlist(features)
             rownames(coefficients) <- 1:k
-            return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
+            cluster <- if (is.loaded) {
+              NULL
+            } else {
+              dataFrame(callJMethod(jobj, "cluster"))
+            }
+            return(list(coefficients = coefficients, size = size,
+                   cluster = cluster, is.loaded = is.loaded))
           })
 
 #' Make predictions from a k-means model
@@ -389,6 +409,56 @@ setMethod("ml.save", signature(object = "AFTSurvivalRegressionModel", path = "ch
             invisible(callJMethod(writer, "save", path))
           })
 
+#' Save the generalized linear model to the input path.
+#'
+#' @param object A fitted generalized linear model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#'                  which means throw exception if the output path exists.
+#'
+#' @rdname ml.save
+#' @name ml.save
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- glm(y ~ x, trainingData)
+#' path <- "path/to/model"
+#' ml.save(model, path)
+#' }
+setMethod("ml.save", signature(object = "GeneralizedLinearRegressionModel", path = "character"),
+          function(object, path, overwrite = FALSE) {
+            writer <- callJMethod(object@jobj, "write")
+            if (overwrite) {
+              writer <- callJMethod(writer, "overwrite")
+            }
+            invisible(callJMethod(writer, "save", path))
+          })
+
+#' Save the k-means model to the input path.
+#'
+#' @param object A fitted k-means model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#'                  which means throw exception if the output path exists.
+#'
+#' @rdname ml.save
+#' @name ml.save
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- kmeans(x, centers = 2, algorithm="random")
+#' path <- "path/to/model"
+#' ml.save(model, path)
+#' }
+setMethod("ml.save", signature(object = "KMeansModel", path = "character"),
+          function(object, path, overwrite = FALSE) {
+            writer <- callJMethod(object@jobj, "write")
+            if (overwrite) {
+              writer <- callJMethod(writer, "overwrite")
+            }
+            invisible(callJMethod(writer, "save", path))
+          })
+
 #' Load a fitted MLlib model from the input path.
 #'
 #' @param path Path of the model to read.
@@ -408,6 +478,10 @@ ml.load <- function(path) {
     return(new("NaiveBayesModel", jobj = jobj))
   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) {
     return(new("AFTSurvivalRegressionModel", jobj = jobj))
+  } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) {
+      return(new("GeneralizedLinearRegressionModel", jobj = jobj))
+  } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
+      return(new("KMeansModel", jobj = jobj))
   } else {
     stop(paste("Unsupported model: ", jobj))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/87ac84d4/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 954abb0..6a822be 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -126,6 +126,33 @@ test_that("glm summary", {
   expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
 })
 
+test_that("glm save/load", {
+  training <- suppressWarnings(createDataFrame(sqlContext, iris))
+  m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
+  s <- summary(m)
+
+  modelPath <- tempfile(pattern = "glm", fileext = ".tmp")
+  ml.save(m, modelPath)
+  expect_error(ml.save(m, modelPath))
+  ml.save(m, modelPath, overwrite = TRUE)
+  m2 <- ml.load(modelPath)
+  s2 <- summary(m2)
+
+  expect_equal(s$coefficients, s2$coefficients)
+  expect_equal(rownames(s$coefficients), rownames(s2$coefficients))
+  expect_equal(s$dispersion, s2$dispersion)
+  expect_equal(s$null.deviance, s2$null.deviance)
+  expect_equal(s$deviance, s2$deviance)
+  expect_equal(s$df.null, s2$df.null)
+  expect_equal(s$df.residual, s2$df.residual)
+  expect_equal(s$aic, s2$aic)
+  expect_equal(s$iter, s2$iter)
+  expect_true(!s$is.loaded)
+  expect_true(s2$is.loaded)
+
+  unlink(modelPath)
+})
+
 test_that("kmeans", {
   newIris <- iris
   newIris$Species <- NULL
@@ -150,6 +177,20 @@ test_that("kmeans", {
   summary.model <- summary(model)
   cluster <- summary.model$cluster
   expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
+
+  # Test model save/load
+  modelPath <- tempfile(pattern = "kmeans", fileext = ".tmp")
+  ml.save(model, modelPath)
+  expect_error(ml.save(model, modelPath))
+  ml.save(model, modelPath, overwrite = TRUE)
+  model2 <- ml.load(modelPath)
+  summary2 <- summary(model2)
+  expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size)))
+  expect_equal(summary.model$coefficients, summary2$coefficients)
+  expect_true(!summary.model$is.loaded)
+  expect_true(summary2$is.loaded)
+
+  unlink(modelPath)
 })
 
 test_that("naiveBayes", {

http://git-wip-us.apache.org/repos/asf/spark/blob/87ac84d4/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index a442469..5462f80 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.r
 
 import org.apache.hadoop.fs.Path
 import org.json4s._
-import org.json4s.DefaultFormats
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 

http://git-wip-us.apache.org/repos/asf/spark/blob/87ac84d4/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index f66323e..9618a34 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -17,65 +17,34 @@
 
 package org.apache.spark.ml.r
 
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.ml.{Pipeline, PipelineModel}
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.feature.RFormula
 import org.apache.spark.ml.regression._
+import org.apache.spark.ml.util._
 import org.apache.spark.sql._
 
 private[r] class GeneralizedLinearRegressionWrapper private (
-    pipeline: PipelineModel,
-    val features: Array[String]) {
+    val pipeline: PipelineModel,
+    val rFeatures: Array[String],
+    val rCoefficients: Array[Double],
+    val rDispersion: Double,
+    val rNullDeviance: Double,
+    val rDeviance: Double,
+    val rResidualDegreeOfFreedomNull: Long,
+    val rResidualDegreeOfFreedom: Long,
+    val rAic: Double,
+    val rNumIterations: Int,
+    val isLoaded: Boolean = false) extends MLWritable {
 
   private val glm: GeneralizedLinearRegressionModel =
     pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
 
-  lazy val rFeatures: Array[String] = if (glm.getFitIntercept) {
-    Array("(Intercept)") ++ features
-  } else {
-    features
-  }
-
-  lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
-    Array(glm.intercept) ++ glm.coefficients.toArray ++
-      rCoefficientStandardErrors ++ rTValues ++ rPValues
-  } else {
-    glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
-  }
-
-  private lazy val rCoefficientStandardErrors = if (glm.getFitIntercept) {
-    Array(glm.summary.coefficientStandardErrors.last) ++
-      glm.summary.coefficientStandardErrors.dropRight(1)
-  } else {
-    glm.summary.coefficientStandardErrors
-  }
-
-  private lazy val rTValues = if (glm.getFitIntercept) {
-    Array(glm.summary.tValues.last) ++ glm.summary.tValues.dropRight(1)
-  } else {
-    glm.summary.tValues
-  }
-
-  private lazy val rPValues = if (glm.getFitIntercept) {
-    Array(glm.summary.pValues.last) ++ glm.summary.pValues.dropRight(1)
-  } else {
-    glm.summary.pValues
-  }
-
-  lazy val rDispersion: Double = glm.summary.dispersion
-
-  lazy val rNullDeviance: Double = glm.summary.nullDeviance
-
-  lazy val rDeviance: Double = glm.summary.deviance
-
-  lazy val rResidualDegreeOfFreedomNull: Long = glm.summary.residualDegreeOfFreedomNull
-
-  lazy val rResidualDegreeOfFreedom: Long = glm.summary.residualDegreeOfFreedom
-
-  lazy val rAic: Double = glm.summary.aic
-
-  lazy val rNumIterations: Int = glm.summary.numIterations
-
   lazy val rDevianceResiduals: DataFrame = glm.summary.residuals()
 
   lazy val rFamily: String = glm.getFamily
@@ -85,9 +54,13 @@ private[r] class GeneralizedLinearRegressionWrapper private (
   def transform(dataset: Dataset[_]): DataFrame = {
     pipeline.transform(dataset).drop(glm.getFeaturesCol)
   }
+
+  override def write: MLWriter =
+    new GeneralizedLinearRegressionWrapper.GeneralizedLinearRegressionWrapperWriter(this)
 }
 
-private[r] object GeneralizedLinearRegressionWrapper {
+private[r] object GeneralizedLinearRegressionWrapper
+  extends MLReadable[GeneralizedLinearRegressionWrapper] {
 
   def fit(
       formula: String,
@@ -105,15 +78,119 @@ private[r] object GeneralizedLinearRegressionWrapper {
       .attributes.get
     val features = featureAttrs.map(_.name.get)
     // assemble and fit the pipeline
-    val glm = new GeneralizedLinearRegression()
+    val glr = new GeneralizedLinearRegression()
       .setFamily(family)
       .setLink(link)
       .setFitIntercept(rFormula.hasIntercept)
       .setTol(epsilon)
       .setMaxIter(maxit)
     val pipeline = new Pipeline()
-      .setStages(Array(rFormulaModel, glm))
+      .setStages(Array(rFormulaModel, glr))
       .fit(data)
-    new GeneralizedLinearRegressionWrapper(pipeline, features)
+
+    val glm: GeneralizedLinearRegressionModel =
+      pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
+    val summary = glm.summary
+
+    val rFeatures: Array[String] = if (glm.getFitIntercept) {
+      Array("(Intercept)") ++ features
+    } else {
+      features
+    }
+
+    val rCoefficientStandardErrors = if (glm.getFitIntercept) {
+      Array(summary.coefficientStandardErrors.last) ++
+        summary.coefficientStandardErrors.dropRight(1)
+    } else {
+      summary.coefficientStandardErrors
+    }
+
+    val rTValues = if (glm.getFitIntercept) {
+      Array(summary.tValues.last) ++ summary.tValues.dropRight(1)
+    } else {
+      summary.tValues
+    }
+
+    val rPValues = if (glm.getFitIntercept) {
+      Array(summary.pValues.last) ++ summary.pValues.dropRight(1)
+    } else {
+      summary.pValues
+    }
+
+    val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
+      Array(glm.intercept) ++ glm.coefficients.toArray ++
+        rCoefficientStandardErrors ++ rTValues ++ rPValues
+    } else {
+      glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
+    }
+
+    val rDispersion: Double = summary.dispersion
+    val rNullDeviance: Double = summary.nullDeviance
+    val rDeviance: Double = summary.deviance
+    val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull
+    val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom
+    val rAic: Double = summary.aic
+    val rNumIterations: Int = summary.numIterations
+
+    new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion,
+      rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom,
+      rAic, rNumIterations)
+  }
+
+  override def read: MLReader[GeneralizedLinearRegressionWrapper] =
+    new GeneralizedLinearRegressionWrapperReader
+
+  override def load(path: String): GeneralizedLinearRegressionWrapper = super.load(path)
+
+  class GeneralizedLinearRegressionWrapperWriter(instance: GeneralizedLinearRegressionWrapper)
+    extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("rFeatures" -> instance.rFeatures.toSeq) ~
+        ("rCoefficients" -> instance.rCoefficients.toSeq) ~
+        ("rDispersion" -> instance.rDispersion) ~
+        ("rNullDeviance" -> instance.rNullDeviance) ~
+        ("rDeviance" -> instance.rDeviance) ~
+        ("rResidualDegreeOfFreedomNull" -> instance.rResidualDegreeOfFreedomNull) ~
+        ("rResidualDegreeOfFreedom" -> instance.rResidualDegreeOfFreedom) ~
+        ("rAic" -> instance.rAic) ~
+        ("rNumIterations" -> instance.rNumIterations)
+      val rMetadataJson: String = compact(render(rMetadata))
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class GeneralizedLinearRegressionWrapperReader
+    extends MLReader[GeneralizedLinearRegressionWrapper] {
+
+    override def load(path: String): GeneralizedLinearRegressionWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val rFeatures = (rMetadata \ "rFeatures").extract[Array[String]]
+      val rCoefficients = (rMetadata \ "rCoefficients").extract[Array[Double]]
+      val rDispersion = (rMetadata \ "rDispersion").extract[Double]
+      val rNullDeviance = (rMetadata \ "rNullDeviance").extract[Double]
+      val rDeviance = (rMetadata \ "rDeviance").extract[Double]
+      val rResidualDegreeOfFreedomNull = (rMetadata \ "rResidualDegreeOfFreedomNull").extract[Long]
+      val rResidualDegreeOfFreedom = (rMetadata \ "rResidualDegreeOfFreedom").extract[Long]
+      val rAic = (rMetadata \ "rAic").extract[Double]
+      val rNumIterations = (rMetadata \ "rNumIterations").extract[Int]
+
+      val pipeline = PipelineModel.load(pipelinePath)
+
+      new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion,
+        rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom,
+        rAic, rNumIterations, isLoaded = true)
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/87ac84d4/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index 9e2b81e..f67760d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -17,28 +17,30 @@
 
 package org.apache.spark.ml.r
 
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.ml.{Pipeline, PipelineModel}
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
 import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
 
 private[r] class KMeansWrapper private (
-    pipeline: PipelineModel) {
+    val pipeline: PipelineModel,
+    val features: Array[String],
+    val size: Array[Long],
+    val isLoaded: Boolean = false) extends MLWritable {
 
   private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
 
   lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray)
 
-  private lazy val attrs = AttributeGroup.fromStructField(
-    kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
-
-  lazy val features: Array[String] = attrs.attributes.get.map(_.name.get)
-
   lazy val k: Int = kMeansModel.getK
 
-  lazy val size: Array[Long] = kMeansModel.summary.clusterSizes
-
   lazy val cluster: DataFrame = kMeansModel.summary.cluster
 
   def fitted(method: String): DataFrame = {
@@ -56,9 +58,10 @@ private[r] class KMeansWrapper private (
     pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
   }
 
+  override def write: MLWriter = new KMeansWrapper.KMeansWrapperWriter(this)
 }
 
-private[r] object KMeansWrapper {
+private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
 
   def fit(
       data: DataFrame,
@@ -80,6 +83,48 @@ private[r] object KMeansWrapper {
       .setStages(Array(assembler, kMeans))
       .fit(data)
 
-    new KMeansWrapper(pipeline)
+    val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
+    val attrs = AttributeGroup.fromStructField(
+      kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
+    val features: Array[String] = attrs.attributes.get.map(_.name.get)
+    val size: Array[Long] = kMeansModel.summary.clusterSizes
+
+    new KMeansWrapper(pipeline, features, size)
+  }
+
+  override def read: MLReader[KMeansWrapper] = new KMeansWrapperReader
+
+  override def load(path: String): KMeansWrapper = super.load(path)
+
+  class KMeansWrapperWriter(instance: KMeansWrapper) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("features" -> instance.features.toSeq) ~
+        ("size" -> instance.size.toSeq)
+      val rMetadataJson: String = compact(render(rMetadata))
+
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class KMeansWrapperReader extends MLReader[KMeansWrapper] {
+
+    override def load(path: String): KMeansWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+      val pipeline = PipelineModel.load(pipelinePath)
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val features = (rMetadata \ "features").extract[Array[String]]
+      val size = (rMetadata \ "size").extract[Array[Long]]
+      new KMeansWrapper(pipeline, features, size, isLoaded = true)
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/87ac84d4/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index 27c7e72..28925c7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.r
 
 import org.apache.hadoop.fs.Path
 import org.json4s._
-import org.json4s.DefaultFormats
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 

http://git-wip-us.apache.org/repos/asf/spark/blob/87ac84d4/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 06baedf..9c07579 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -40,6 +40,10 @@ private[r] object RWrappers extends MLReader[Object] {
       case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
       case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" =>
         AFTSurvivalRegressionWrapper.load(path)
+      case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" =>
+        GeneralizedLinearRegressionWrapper.load(path)
+      case "org.apache.spark.ml.r.KMeansWrapper" =>
+        KMeansWrapper.load(path)
       case _ =>
         throw new SparkException(s"SparkR ml.load does not support load $className")
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org