You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by fe...@apache.org on 2016/10/26 23:12:59 UTC

spark git commit: [SPARK-17157][SPARKR] Add multiclass logistic regression SparkR Wrapper

Repository: spark
Updated Branches:
  refs/heads/master 5b7d403c1 -> 29cea8f33


[SPARK-17157][SPARKR] Add multiclass logistic regression SparkR Wrapper

## What changes were proposed in this pull request?

As we discussed in #14818, I added a separate R wrapper spark.logit for logistic regression.

This single interface supports both binary and multinomial logistic regression. It also has "predict" and "summary" for binary logistic regression.

## How was this patch tested?

New unit tests are added.

Author: wm624@hotmail.com <wm...@hotmail.com>

Closes #15365 from wangmiao1981/glm.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/29cea8f3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/29cea8f3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/29cea8f3

Branch: refs/heads/master
Commit: 29cea8f332aa3750f8ff7c3b9e705d107278da4b
Parents: 5b7d403
Author: wm624@hotmail.com <wm...@hotmail.com>
Authored: Wed Oct 26 16:12:55 2016 -0700
Committer: Felix Cheung <fe...@apache.org>
Committed: Wed Oct 26 16:12:55 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |   3 +-
 R/pkg/R/generics.R                              |   4 +
 R/pkg/R/mllib.R                                 | 192 ++++++++++++++++++-
 R/pkg/inst/tests/testthat/test_mllib.R          |  55 ++++++
 .../spark/ml/r/LogisticRegressionWrapper.scala  | 157 +++++++++++++++
 .../scala/org/apache/spark/ml/r/RWrappers.scala |   2 +
 6 files changed, 410 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/29cea8f3/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index eb314f4..7a89c01 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -43,7 +43,8 @@ exportMethods("glm",
               "spark.isoreg",
               "spark.gaussianMixture",
               "spark.als",
-              "spark.kstest")
+              "spark.kstest",
+              "spark.logit")
 
 # Job group lifecycle management methods
 export("setJobGroup",

http://git-wip-us.apache.org/repos/asf/spark/blob/29cea8f3/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 4569fe4..107e1c6 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1375,6 +1375,10 @@ setGeneric("spark.gaussianMixture",
              standardGeneric("spark.gaussianMixture")
            })
 
+#' @rdname spark.logit
+#' @export
+setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") })
+
 #' @param object a fitted ML model object.
 #' @param path the directory where the model is saved.
 #' @param ... additional argument(s) passed to the method.

http://git-wip-us.apache.org/repos/asf/spark/blob/29cea8f3/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index bf182be..e441db9 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -95,6 +95,13 @@ setClass("ALSModel", representation(jobj = "jobj"))
 #' @note KSTest since 2.1.0
 setClass("KSTest", representation(jobj = "jobj"))
 
+#' S4 class that represents an LogisticRegressionModel
+#'
+#' @param jobj a Java object reference to the backing Scala LogisticRegressionModel
+#' @export
+#' @note LogisticRegressionModel since 2.1.0
+setClass("LogisticRegressionModel", representation(jobj = "jobj"))
+
 #' Saves the MLlib model to the input path
 #'
 #' Saves the MLlib model to the input path. For more information, see the specific
@@ -105,7 +112,7 @@ setClass("KSTest", representation(jobj = "jobj"))
 #' @seealso \link{spark.glm}, \link{glm},
 #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
 #' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
-#' @seealso \link{read.ml}
+#' @seealso \link{spark.logit}, \link{read.ml}
 NULL
 
 #' Makes predictions from a MLlib model
@@ -117,7 +124,7 @@ NULL
 #' @export
 #' @seealso \link{spark.glm}, \link{glm},
 #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
-#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
+#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.logit}
 NULL
 
 write_internal <- function(object, path, overwrite = FALSE) {
@@ -647,6 +654,170 @@ setMethod("predict", signature(object = "KMeansModel"),
             predict_internal(object, newData)
           })
 
+#' Logistic Regression Model
+#'
+#' Fits an logistic regression model against a Spark DataFrame. It supports "binomial": Binary logistic regression
+#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet.
+#' Users can print, make predictions on the produced model and save the model to the input path.
+#'
+#' @param data SparkDataFrame for training
+#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#'                operators are supported, including '~', '.', ':', '+', and '-'.
+#' @param regParam the regularization parameter. Default is 0.0.
+#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty.
+#'                        For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination
+#'                        of L1 and L2. Default is 0.0 which is an L2 penalty.
+#' @param maxIter maximum iteration number.
+#' @param tol convergence tolerance of iterations.
+#' @param fitIntercept whether to fit an intercept term. Default is TRUE.
+#' @param family the name of family which is a description of the label distribution to be used in the model.
+#'               Supported options:
+#'                 \itemize{
+#'                   \item{"auto": Automatically select the family based on the number of classes:
+#'                           If number of classes == 1 || number of classes == 2, set to "binomial".
+#'                           Else, set to "multinomial".}
+#'                   \item{"binomial": Binary logistic regression with pivoting.}
+#'                   \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.
+#'                           Default is "auto".}
+#'                 }
+#' @param standardization whether to standardize the training features before fitting the model. The coefficients
+#'                        of models will be always returned on the original scale, so it will be transparent for
+#'                        users. Note that with/without standardization, the models should be always converged
+#'                        to the same solution when no regularization is applied. Default is TRUE, same as glmnet.
+#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1
+#'                  is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0
+#'                  more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with
+#'                  threshold p is equivalent to setting thresholds c(1-p, p). When threshold is set, any user-set
+#'                  value for thresholds will be cleared. If both threshold and thresholds are set, then they must be
+#'                  equivalent. In multiclass (or binary) classification to adjust the probability of
+#'                  predicting each class. Array must have length equal to the number of classes, with values > 0,
+#'                  excepting that at most one value may be 0. The class with largest value p/t is predicted, where p
+#'                  is the original probability of that class and t is the class's threshold. Note: When thresholds
+#'                  is set, any user-set value for threshold will be cleared. If both threshold and thresholds are
+#'                  set, then they must be equivalent. Default is 0.5.
+#' @param weightCol The weight column name.
+#' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions
+#'                         are large, this param could be adjusted to a larger size. Default is 2.
+#' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability".
+#' @param ... additional arguments passed to the method.
+#' @return \code{spark.logit} returns a fitted logistic regression model
+#' @rdname spark.logit
+#' @aliases spark.logit,SparkDataFrame,formula-method
+#' @name spark.logit
+#' @export
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' # binary logistic regression
+#' label <- c(1.0, 1.0, 1.0, 0.0, 0.0)
+#' feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776)
+#' binary_data <- as.data.frame(cbind(label, feature))
+#' binary_df <- createDataFrame(binary_data)
+#' blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0)
+#' blr_predict <- collect(select(predict(blr_model, binary_df), "prediction"))
+#'
+#' # summary of binary logistic regression
+#' blr_summary <- summary(blr_model)
+#' blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure"))
+#' # save fitted model to input path
+#' path <- "path/to/model"
+#' write.ml(blr_model, path)
+#'
+#' # can also read back the saved model and predict
+#' Note that summary deos not work on loaded model
+#' savedModel <- read.ml(path)
+#' blr_predict2 <- collect(select(predict(savedModel, binary_df), "prediction"))
+#'
+#' # multinomial logistic regression
+#'
+#' label <- c(0.0, 1.0, 2.0, 0.0, 0.0)
+#' feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667)
+#' feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987)
+#' feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130)
+#' feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842)
+#' data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4))
+#' df <- createDataFrame(data)
+#'
+#' Note that summary of multinomial logistic regression is not implemented yet
+#' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds=c(0, 1, 1))
+#' predict1 <- collect(select(predict(model, df), "prediction"))
+#' }
+#' @note spark.logit since 2.1.0
+setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"),
+          function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100,
+                   tol = 1E-6, fitIntercept = TRUE, family = "auto", standardization = TRUE,
+                   thresholds = 0.5, weightCol = NULL, aggregationDepth = 2,
+                   probabilityCol = "probability") {
+            formula <- paste0(deparse(formula), collapse = "")
+
+            if (is.null(weightCol)) {
+              weightCol <- ""
+            }
+
+            jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
+                                data@sdf, formula, as.numeric(regParam),
+                                as.numeric(elasticNetParam), as.integer(maxIter),
+                                as.numeric(tol), as.logical(fitIntercept),
+                                as.character(family), as.logical(standardization),
+                                as.array(thresholds), as.character(weightCol),
+                                as.integer(aggregationDepth), as.character(probabilityCol))
+            new("LogisticRegressionModel", jobj = jobj)
+          })
+
+#  Predicted values based on an LogisticRegressionModel model
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns the predicted values based on an LogisticRegressionModel.
+#' @rdname spark.logit
+#' @aliases predict,LogisticRegressionModel,SparkDataFrame-method
+#' @export
+#' @note predict(LogisticRegressionModel) since 2.1.0
+setMethod("predict", signature(object = "LogisticRegressionModel"),
+          function(object, newData) {
+            predict_internal(object, newData)
+          })
+
+#  Get the summary of an LogisticRegressionModel
+
+#' @param object an LogisticRegressionModel fitted by \code{spark.logit}
+#' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that
+#'                        Multinomial logistic regression summary is not available now.
+#' @rdname spark.logit
+#' @aliases summary,LogisticRegressionModel-method
+#' @export
+#' @note summary(LogisticRegressionModel) since 2.1.0
+setMethod("summary", signature(object = "LogisticRegressionModel"),
+          function(object) {
+            jobj <- object@jobj
+            is.loaded <- callJMethod(jobj, "isLoaded")
+
+            if (is.loaded) {
+              stop("Loaded model doesn't have training summary.")
+            }
+
+            roc <- dataFrame(callJMethod(jobj, "roc"))
+
+            areaUnderROC <- callJMethod(jobj, "areaUnderROC")
+
+            pr <- dataFrame(callJMethod(jobj, "pr"))
+
+            fMeasureByThreshold <- dataFrame(callJMethod(jobj, "fMeasureByThreshold"))
+
+            precisionByThreshold <- dataFrame(callJMethod(jobj, "precisionByThreshold"))
+
+            recallByThreshold <- dataFrame(callJMethod(jobj, "recallByThreshold"))
+
+            totalIterations <- callJMethod(jobj, "totalIterations")
+
+            objectiveHistory <- callJMethod(jobj, "objectiveHistory")
+
+            list(roc = roc, areaUnderROC = areaUnderROC, pr = pr,
+                 fMeasureByThreshold = fMeasureByThreshold,
+                 precisionByThreshold = precisionByThreshold,
+                 recallByThreshold = recallByThreshold,
+                 totalIterations = totalIterations, objectiveHistory = objectiveHistory)
+          })
+
 #' Multilayer Perceptron Classification Model
 #'
 #' \code{spark.mlp} fits a multi-layer perceptron neural network model against a SparkDataFrame.
@@ -888,6 +1059,21 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
             write_internal(object, path, overwrite)
           })
 
+#  Save fitted LogisticRegressionModel to the input path
+
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#'                  which means throw exception if the output path exists.
+#'
+#' @rdname spark.logit
+#' @aliases write.ml,LogisticRegressionModel,character-method
+#' @export
+#' @note write.ml(LogisticRegression, character) since 2.1.0
+setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"),
+          function(object, path, overwrite = FALSE) {
+            write_internal(object, path, overwrite)
+          })
+
 #  Save fitted MLlib model to the input path
 
 #' @param path the directory where the model is saved.
@@ -938,6 +1124,8 @@ read.ml <- function(path) {
     new("GaussianMixtureModel", jobj = jobj)
   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
     new("ALSModel", jobj = jobj)
+  } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) {
+    new("LogisticRegressionModel", jobj = jobj)
   } else {
     stop("Unsupported model: ", jobj)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/29cea8f3/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 33cc069..6d1fccc 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -602,6 +602,61 @@ test_that("spark.isotonicRegression", {
   unlink(modelPath)
 })
 
+test_that("spark.logit", {
+  # test binary logistic regression
+  label <- c(1.0, 1.0, 1.0, 0.0, 0.0)
+  feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776)
+  binary_data <- as.data.frame(cbind(label, feature))
+  binary_df <- createDataFrame(binary_data)
+
+  blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0)
+  blr_predict <- collect(select(predict(blr_model, binary_df), "prediction"))
+  expect_equal(blr_predict$prediction, c(0, 0, 0, 0, 0))
+  blr_model1 <- spark.logit(binary_df, label ~ feature, thresholds = 0.0)
+  blr_predict1 <- collect(select(predict(blr_model1, binary_df), "prediction"))
+  expect_equal(blr_predict1$prediction, c(1, 1, 1, 1, 1))
+
+  # test summary of binary logistic regression
+  blr_summary <- summary(blr_model)
+  blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure"))
+  expect_equal(blr_fmeasure$threshold, c(0.8221347, 0.7884005, 0.6674709, 0.3785437, 0.3434487),
+               tolerance = 1e-4)
+  expect_equal(blr_fmeasure$"F-Measure", c(0.5000000, 0.8000000, 0.6666667, 0.8571429, 0.7500000),
+               tolerance = 1e-4)
+  blr_precision <- collect(select(blr_summary$precisionByThreshold, "threshold", "precision"))
+  expect_equal(blr_precision$precision, c(1.0000000, 1.0000000, 0.6666667, 0.7500000, 0.6000000),
+               tolerance = 1e-4)
+  blr_recall <- collect(select(blr_summary$recallByThreshold, "threshold", "recall"))
+  expect_equal(blr_recall$recall, c(0.3333333, 0.6666667, 0.6666667, 1.0000000, 1.0000000),
+               tolerance = 1e-4)
+
+  # test model save and read
+  modelPath <- tempfile(pattern = "spark-logisticRegression", fileext = ".tmp")
+  write.ml(blr_model, modelPath)
+  expect_error(write.ml(blr_model, modelPath))
+  write.ml(blr_model, modelPath, overwrite = TRUE)
+  blr_model2 <- read.ml(modelPath)
+  blr_predict2 <- collect(select(predict(blr_model2, binary_df), "prediction"))
+  expect_equal(blr_predict$prediction, blr_predict2$prediction)
+  expect_error(summary(blr_model2))
+  unlink(modelPath)
+
+  # test multinomial logistic regression
+  label <- c(0.0, 1.0, 2.0, 0.0, 0.0)
+  feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667)
+  feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987)
+  feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130)
+  feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842)
+  data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4))
+  df <- createDataFrame(data)
+
+  model <- spark.logit(df, label ~., family = "multinomial", thresholds = c(0, 1, 1))
+  predict1 <- collect(select(predict(model, df), "prediction"))
+  expect_equal(predict1$prediction, c(0, 0, 0, 0, 0))
+  # Summary of multinomial logistic regression is not implemented yet
+  expect_error(summary(model))
+})
+
 test_that("spark.gaussianMixture", {
   # R code to reproduce the result.
   # nolint start

http://git-wip-us.apache.org/repos/asf/spark/blob/29cea8f3/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
new file mode 100644
index 0000000..9b352c9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class LogisticRegressionWrapper private (
+    val pipeline: PipelineModel,
+    val features: Array[String],
+    val isLoaded: Boolean = false) extends MLWritable {
+
+  private val logisticRegressionModel: LogisticRegressionModel =
+    pipeline.stages(1).asInstanceOf[LogisticRegressionModel]
+
+  lazy val totalIterations: Int = logisticRegressionModel.summary.totalIterations
+
+  lazy val objectiveHistory: Array[Double] = logisticRegressionModel.summary.objectiveHistory
+
+  lazy val blrSummary =
+    logisticRegressionModel.summary.asInstanceOf[BinaryLogisticRegressionSummary]
+
+  lazy val roc: DataFrame = blrSummary.roc
+
+  lazy val areaUnderROC: Double = blrSummary.areaUnderROC
+
+  lazy val pr: DataFrame = blrSummary.pr
+
+  lazy val fMeasureByThreshold: DataFrame = blrSummary.fMeasureByThreshold
+
+  lazy val precisionByThreshold: DataFrame = blrSummary.precisionByThreshold
+
+  lazy val recallByThreshold: DataFrame = blrSummary.recallByThreshold
+
+  def transform(dataset: Dataset[_]): DataFrame = {
+    pipeline.transform(dataset).drop(logisticRegressionModel.getFeaturesCol)
+  }
+
+  override def write: MLWriter = new LogisticRegressionWrapper.LogisticRegressionWrapperWriter(this)
+}
+
+private[r] object LogisticRegressionWrapper
+    extends MLReadable[LogisticRegressionWrapper] {
+
+  def fit( // scalastyle:ignore
+      data: DataFrame,
+      formula: String,
+      regParam: Double,
+      elasticNetParam: Double,
+      maxIter: Int,
+      tol: Double,
+      fitIntercept: Boolean,
+      family: String,
+      standardization: Boolean,
+      thresholds: Array[Double],
+      weightCol: String,
+      aggregationDepth: Int,
+      probability: String
+      ): LogisticRegressionWrapper = {
+
+    val rFormula = new RFormula()
+      .setFormula(formula)
+    RWrapperUtils.checkDataColumns(rFormula, data)
+    val rFormulaModel = rFormula.fit(data)
+
+    // get feature names from output schema
+    val schema = rFormulaModel.transform(data).schema
+    val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+      .attributes.get
+    val features = featureAttrs.map(_.name.get)
+
+    // assemble and fit the pipeline
+    val logisticRegression = new LogisticRegression()
+      .setRegParam(regParam)
+      .setElasticNetParam(elasticNetParam)
+      .setMaxIter(maxIter)
+      .setTol(tol)
+      .setFitIntercept(fitIntercept)
+      .setFamily(family)
+      .setStandardization(standardization)
+      .setWeightCol(weightCol)
+      .setAggregationDepth(aggregationDepth)
+      .setFeaturesCol(rFormula.getFeaturesCol)
+      .setProbabilityCol(probability)
+
+    if (thresholds.length > 1) {
+      logisticRegression.setThresholds(thresholds)
+    } else {
+      logisticRegression.setThreshold(thresholds(0))
+    }
+
+    val pipeline = new Pipeline()
+      .setStages(Array(rFormulaModel, logisticRegression))
+      .fit(data)
+
+    new LogisticRegressionWrapper(pipeline, features)
+  }
+
+  override def read: MLReader[LogisticRegressionWrapper] = new LogisticRegressionWrapperReader
+
+  override def load(path: String): LogisticRegressionWrapper = super.load(path)
+
+  class LogisticRegressionWrapperWriter(instance: LogisticRegressionWrapper) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("features" -> instance.features.toSeq)
+      val rMetadataJson: String = compact(render(rMetadata))
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class LogisticRegressionWrapperReader extends MLReader[LogisticRegressionWrapper] {
+
+    override def load(path: String): LogisticRegressionWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val features = (rMetadata \ "features").extract[Array[String]]
+
+      val pipeline = PipelineModel.load(pipelinePath)
+      new LogisticRegressionWrapper(pipeline, features, isLoaded = true)
+    }
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/29cea8f3/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index d64de1b..1df3662 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -54,6 +54,8 @@ private[r] object RWrappers extends MLReader[Object] {
         GaussianMixtureWrapper.load(path)
       case "org.apache.spark.ml.r.ALSWrapper" =>
         ALSWrapper.load(path)
+      case "org.apache.spark.ml.r.LogisticRegressionWrapper" =>
+        LogisticRegressionWrapper.load(path)
       case _ =>
         throw new SparkException(s"SparkR read.ml does not support load $className")
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org