You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/04/25 23:08:44 UTC

spark git commit: [SPARK-14312][ML][SPARKR] NaiveBayes model persistence in SparkR

Repository: spark
Updated Branches:
  refs/heads/master 0c47e274a -> 9cb3ba101


[SPARK-14312][ML][SPARKR] NaiveBayes model persistence in SparkR

## What changes were proposed in this pull request?
SparkR ```NaiveBayesModel``` supports ```save/load``` by the following API:
```
df <- createDataFrame(sqlContext, infert)
model <- naiveBayes(education ~ ., df, laplace = 0)
ml.save(model, path)
model2 <- ml.load(path)
```

## How was this patch tested?
Add unit tests.

cc mengxr

Author: Yanbo Liang <yb...@gmail.com>

Closes #12573 from yanboliang/spark-14312.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9cb3ba10
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9cb3ba10
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9cb3ba10

Branch: refs/heads/master
Commit: 9cb3ba1013a7eae11be8a00fa4a9c5308bb20195
Parents: 0c47e27
Author: Yanbo Liang <yb...@gmail.com>
Authored: Mon Apr 25 14:08:41 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Apr 25 14:08:41 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |  6 ++-
 R/pkg/R/generics.R                              |  4 ++
 R/pkg/R/mllib.R                                 | 48 ++++++++++++++++++
 R/pkg/inst/tests/testthat/test_mllib.R          | 12 +++++
 .../apache/spark/ml/r/NaiveBayesWrapper.scala   | 52 ++++++++++++++++++--
 .../scala/org/apache/spark/ml/r/RWrappers.scala | 45 +++++++++++++++++
 6 files changed, 162 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9cb3ba10/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 0f92b5e..c0a63d6 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -107,7 +107,8 @@ exportMethods("arrange",
               "write.jdbc",
               "write.json",
               "write.parquet",
-              "write.text")
+              "write.text",
+              "ml.save")
 
 exportClasses("Column")
 
@@ -299,7 +300,8 @@ export("as.DataFrame",
        "tableNames",
        "tables",
        "uncacheTable",
-       "print.summary.GeneralizedLinearRegressionModel")
+       "print.summary.GeneralizedLinearRegressionModel",
+       "ml.load")
 
 export("structField",
        "structField.jobj",

http://git-wip-us.apache.org/repos/asf/spark/blob/9cb3ba10/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 04274a1..f654d83 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1200,3 +1200,7 @@ setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBa
 #' @rdname survreg
 #' @export
 setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") })
+
+#' @rdname ml.save
+#' @export
+setGeneric("ml.save", function(object, path, ...) { standardGeneric("ml.save") })

http://git-wip-us.apache.org/repos/asf/spark/blob/9cb3ba10/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 7dd8296..cda6100 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -338,6 +338,54 @@ setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"),
             return(new("NaiveBayesModel", jobj = jobj))
           })
 
+#' Save the Bernoulli naive Bayes model to the input path.
+#'
+#' @param object A fitted Bernoulli naive Bayes model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#'                  which means throw exception if the output path exists.
+#'
+#' @rdname ml.save
+#' @name ml.save
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(sqlContext, infert)
+#' model <- naiveBayes(education ~ ., df, laplace = 0)
+#' path <- "path/to/model"
+#' ml.save(model, path)
+#' }
+setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"),
+          function(object, path, overwrite = FALSE) {
+            writer <- callJMethod(object@jobj, "write")
+            if (overwrite) {
+              writer <- callJMethod(writer, "overwrite")
+            }
+            invisible(callJMethod(writer, "save", path))
+          })
+
+#' Load a fitted MLlib model from the input path.
+#'
+#' @param path Path of the model to read.
+#' @return a fitted MLlib model
+#' @rdname ml.load
+#' @name ml.load
+#' @export
+#' @examples
+#' \dontrun{
+#' path <- "path/to/model"
+#' model <- ml.load(path)
+#' }
+ml.load <- function(path) {
+  path <- suppressWarnings(normalizePath(path))
+  jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path)
+  if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) {
+    return(new("NaiveBayesModel", jobj = jobj))
+  } else {
+    stop(paste("Unsupported model: ", jobj))
+  }
+}
+
 #' Fit an accelerated failure time (AFT) survival regression model.
 #'
 #' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().

http://git-wip-us.apache.org/repos/asf/spark/blob/9cb3ba10/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 1597306..63ec84e 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -204,6 +204,18 @@ test_that("naiveBayes", {
                                "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
                                "Yes", "Yes", "No", "No"))
 
+  # Test model save/load
+  modelPath <- tempfile(pattern = "naiveBayes", fileext = ".tmp")
+  ml.save(m, modelPath)
+  expect_error(ml.save(m, modelPath))
+  ml.save(m, modelPath, overwrite = TRUE)
+  m2 <- ml.load(modelPath)
+  s2 <- summary(m2)
+  expect_equal(s$apriori, s2$apriori)
+  expect_equal(s$tables, s2$tables)
+
+  unlink(modelPath)
+
   # Test e1071::naiveBayes
   if (requireNamespace("e1071", quietly = TRUE)) {
     expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))

http://git-wip-us.apache.org/repos/asf/spark/blob/9cb3ba10/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index b17207e..27c7e72 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -17,16 +17,23 @@
 
 package org.apache.spark.ml.r
 
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.ml.{Pipeline, PipelineModel}
 import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
 import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
 import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
 
 private[r] class NaiveBayesWrapper private (
-    pipeline: PipelineModel,
+    val pipeline: PipelineModel,
     val labels: Array[String],
-    val features: Array[String]) {
+    val features: Array[String]) extends MLWritable {
 
   import NaiveBayesWrapper._
 
@@ -41,9 +48,11 @@ private[r] class NaiveBayesWrapper private (
       .drop(PREDICTED_LABEL_INDEX_COL)
       .drop(naiveBayesModel.getFeaturesCol)
   }
+
+  override def write: MLWriter = new NaiveBayesWrapper.NaiveBayesWrapperWriter(this)
 }
 
-private[r] object NaiveBayesWrapper {
+private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
 
   val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
   val PREDICTED_LABEL_COL = "prediction"
@@ -74,4 +83,41 @@ private[r] object NaiveBayesWrapper {
       .fit(data)
     new NaiveBayesWrapper(pipeline, labels, features)
   }
+
+  override def read: MLReader[NaiveBayesWrapper] = new NaiveBayesWrapperReader
+
+  override def load(path: String): NaiveBayesWrapper = super.load(path)
+
+  class NaiveBayesWrapperWriter(instance: NaiveBayesWrapper) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("labels" -> instance.labels.toSeq) ~
+        ("features" -> instance.features.toSeq)
+      val rMetadataJson: String = compact(render(rMetadata))
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class NaiveBayesWrapperReader extends MLReader[NaiveBayesWrapper] {
+
+    override def load(path: String): NaiveBayesWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val labels = (rMetadata \ "labels").extract[Array[String]]
+      val features = (rMetadata \ "features").extract[Array[String]]
+
+      val pipeline = PipelineModel.load(pipelinePath)
+      new NaiveBayesWrapper(pipeline, labels, features)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9cb3ba10/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
new file mode 100644
index 0000000..7f6f147
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.util.MLReader
+
+/**
+ * This is the Scala stub of SparkR ml.load. It will dispatch the call to corresponding
+ * model wrapper loading function according the class name extracted from rMetadata of the path.
+ */
+private[r] object RWrappers extends MLReader[Object] {
+
+  override def load(path: String): Object = {
+    implicit val format = DefaultFormats
+    val rMetadataPath = new Path(path, "rMetadata").toString
+    val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+    val rMetadata = parse(rMetadataStr)
+    val className = (rMetadata \ "class").extract[String]
+    className match {
+      case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
+      case _ =>
+        throw new SparkException(s"SparkR ml.load does not support load $className")
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org