You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/04/26 19:30:28 UTC

spark git commit: [SPARK-14313][ML][SPARKR] AFTSurvivalRegression model persistence in SparkR

Repository: spark
Updated Branches:
  refs/heads/master 162cf02ef -> 92f66331b


[SPARK-14313][ML][SPARKR] AFTSurvivalRegression model persistence in SparkR

## What changes were proposed in this pull request?
```AFTSurvivalRegressionModel``` supports ```save/load``` in SparkR.

## How was this patch tested?
Unit tests.

Author: Yanbo Liang <yb...@gmail.com>

Closes #12685 from yanboliang/spark-14313.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/92f66331
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/92f66331
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/92f66331

Branch: refs/heads/master
Commit: 92f66331b4ba3634f54f57ddb5e7962b14aa4ca1
Parents: 162cf02
Author: Yanbo Liang <yb...@gmail.com>
Authored: Tue Apr 26 10:30:24 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Apr 26 10:30:24 2016 -0700

----------------------------------------------------------------------
 R/pkg/R/mllib.R                                 | 27 ++++++++++
 R/pkg/inst/tests/testthat/test_mllib.R          | 13 +++++
 .../ml/r/AFTSurvivalRegressionWrapper.scala     | 52 ++++++++++++++++++--
 .../scala/org/apache/spark/ml/r/RWrappers.scala |  2 +
 4 files changed, 91 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/92f66331/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index cda6100..4803011 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -364,6 +364,31 @@ setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"),
             invisible(callJMethod(writer, "save", path))
           })
 
+#' Save the AFT survival regression model to the input path.
+#'
+#' @param object A fitted AFT survival regression model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#'                  which means throw exception if the output path exists.
+#'
+#' @rdname ml.save
+#' @name ml.save
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' path <- "path/to/model"
+#' ml.save(model, path)
+#' }
+setMethod("ml.save", signature(object = "AFTSurvivalRegressionModel", path = "character"),
+          function(object, path, overwrite = FALSE) {
+            writer <- callJMethod(object@jobj, "write")
+            if (overwrite) {
+              writer <- callJMethod(writer, "overwrite")
+            }
+            invisible(callJMethod(writer, "save", path))
+          })
+
 #' Load a fitted MLlib model from the input path.
 #'
 #' @param path Path of the model to read.
@@ -381,6 +406,8 @@ ml.load <- function(path) {
   jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path)
   if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) {
     return(new("NaiveBayesModel", jobj = jobj))
+  } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) {
+    return(new("AFTSurvivalRegressionModel", jobj = jobj))
   } else {
     stop(paste("Unsupported model: ", jobj))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/92f66331/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 63ec84e..954abb0 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -261,6 +261,19 @@ test_that("survreg", {
   expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035,
                2.390146, 2.891269, 2.891269), tolerance = 1e-4)
 
+  # Test model save/load
+  modelPath <- tempfile(pattern = "survreg", fileext = ".tmp")
+  ml.save(model, modelPath)
+  expect_error(ml.save(model, modelPath))
+  ml.save(model, modelPath, overwrite = TRUE)
+  model2 <- ml.load(modelPath)
+  stats2 <- summary(model2)
+  coefs2 <- as.vector(stats2$coefficients[, 1])
+  expect_equal(coefs, coefs2)
+  expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients))
+
+  unlink(modelPath)
+
   # Test survival::survreg
   if (requireNamespace("survival", quietly = TRUE)) {
     rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),

http://git-wip-us.apache.org/repos/asf/spark/blob/92f66331/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index 7835468..a442469 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -17,16 +17,23 @@
 
 package org.apache.spark.ml.r
 
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.SparkException
 import org.apache.spark.ml.{Pipeline, PipelineModel}
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.feature.RFormula
 import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
+import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
 
 private[r] class AFTSurvivalRegressionWrapper private (
-    pipeline: PipelineModel,
-    features: Array[String]) {
+    val pipeline: PipelineModel,
+    val features: Array[String]) extends MLWritable {
 
   private val aftModel: AFTSurvivalRegressionModel =
     pipeline.stages(1).asInstanceOf[AFTSurvivalRegressionModel]
@@ -46,9 +53,12 @@ private[r] class AFTSurvivalRegressionWrapper private (
   def transform(dataset: Dataset[_]): DataFrame = {
     pipeline.transform(dataset).drop(aftModel.getFeaturesCol)
   }
+
+  override def write: MLWriter =
+    new AFTSurvivalRegressionWrapper.AFTSurvivalRegressionWrapperWriter(this)
 }
 
-private[r] object AFTSurvivalRegressionWrapper {
+private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalRegressionWrapper] {
 
   private def formulaRewrite(formula: String): (String, String) = {
     var rewritedFormula: String = null
@@ -96,4 +106,40 @@ private[r] object AFTSurvivalRegressionWrapper {
 
     new AFTSurvivalRegressionWrapper(pipeline, features)
   }
+
+  override def read: MLReader[AFTSurvivalRegressionWrapper] = new AFTSurvivalRegressionWrapperReader
+
+  override def load(path: String): AFTSurvivalRegressionWrapper = super.load(path)
+
+  class AFTSurvivalRegressionWrapperWriter(instance: AFTSurvivalRegressionWrapper)
+    extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("features" -> instance.features.toSeq)
+      val rMetadataJson: String = compact(render(rMetadata))
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class AFTSurvivalRegressionWrapperReader extends MLReader[AFTSurvivalRegressionWrapper] {
+
+    override def load(path: String): AFTSurvivalRegressionWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val features = (rMetadata \ "features").extract[Array[String]]
+
+      val pipeline = PipelineModel.load(pipelinePath)
+      new AFTSurvivalRegressionWrapper(pipeline, features)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/92f66331/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 7f6f147..06baedf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -38,6 +38,8 @@ private[r] object RWrappers extends MLReader[Object] {
     val className = (rMetadata \ "class").extract[String]
     className match {
       case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
+      case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" =>
+        AFTSurvivalRegressionWrapper.load(path)
       case _ =>
         throw new SparkException(s"SparkR ml.load does not support load $className")
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org