You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/04/30 17:38:02 UTC

spark git commit: [SPARK-15030][ML][SPARKR] Support formula in spark.kmeans in SparkR

Repository: spark
Updated Branches:
  refs/heads/master e5fb78baf -> 19a6d192d


[SPARK-15030][ML][SPARKR] Support formula in spark.kmeans in SparkR

## What changes were proposed in this pull request?
* ```RFormula``` supports empty response variable like ```~ x + y```.
* Support formula in ```spark.kmeans``` in SparkR.
* Fix some outdated docs for SparkR.

## How was this patch tested?
Unit tests.

Author: Yanbo Liang <yb...@gmail.com>

Closes #12813 from yanboliang/spark-15030.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/19a6d192
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/19a6d192
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/19a6d192

Branch: refs/heads/master
Commit: 19a6d192d53ce6dffe998ce110adab1f2efcb23e
Parents: e5fb78b
Author: Yanbo Liang <yb...@gmail.com>
Authored: Sat Apr 30 08:37:56 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sat Apr 30 08:37:56 2016 -0700

----------------------------------------------------------------------
 R/pkg/R/generics.R                              |  2 +-
 R/pkg/R/mllib.R                                 | 53 +++++++++++---------
 R/pkg/inst/tests/testthat/test_mllib.R          | 12 ++---
 .../org/apache/spark/ml/feature/RFormula.scala  |  4 +-
 .../spark/ml/feature/RFormulaParser.scala       |  9 +++-
 .../org/apache/spark/ml/r/KMeansWrapper.scala   | 32 ++++++------
 .../scala/org/apache/spark/ml/r/RWrappers.scala |  4 +-
 .../spark/mllib/stat/JavaStatisticsSuite.java   |  2 +-
 .../apache/spark/ml/feature/RFormulaSuite.scala | 19 +++++++
 9 files changed, 87 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/19a6d192/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index ab6995b..f936ea6 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1199,7 +1199,7 @@ setGeneric("rbind", signature = "...")
 
 #' @rdname spark.kmeans
 #' @export
-setGeneric("spark.kmeans", function(data, k, ...) { standardGeneric("spark.kmeans") })
+setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") })
 
 #' @rdname fitted
 #' @export

http://git-wip-us.apache.org/repos/asf/spark/blob/19a6d192/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index aee74a9..f466811 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -125,7 +125,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat
 
 #' Get the summary of a generalized linear model
 #'
-#' Returns the summary of a model produced by glm(), similarly to R's summary().
+#' Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary().
 #'
 #' @param object A fitted generalized linear model
 #' @return coefficients the model's coefficients, intercept
@@ -199,7 +199,8 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
 
 #' Make predictions from a generalized linear model
 #'
-#' Makes predictions from a generalized linear model produced by glm(), similarly to R's predict().
+#' Makes predictions from a generalized linear model produced by glm() or spark.glm(),
+#' similarly to R's predict().
 #'
 #' @param object A fitted generalized linear model
 #' @param newData SparkDataFrame for testing
@@ -219,7 +220,8 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
 
 #' Make predictions from a naive Bayes model
 #'
-#' Makes predictions from a model produced by naiveBayes(), similarly to R package e1071's predict.
+#' Makes predictions from a model produced by spark.naiveBayes(),
+#' similarly to R package e1071's predict.
 #'
 #' @param object A fitted naive Bayes model
 #' @param newData SparkDataFrame for testing
@@ -239,7 +241,8 @@ setMethod("predict", signature(object = "NaiveBayesModel"),
 
 #' Get the summary of a naive Bayes model
 #'
-#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary().
+#' Returns the summary of a naive Bayes model produced by spark.naiveBayes(),
+#' similarly to R's summary().
 #'
 #' @param object A fitted MLlib model
 #' @return a list containing 'apriori', the label distribution, and 'tables', conditional
@@ -271,22 +274,25 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
 #' Fit a k-means model, similarly to R's kmeans().
 #'
 #' @param data SparkDataFrame for training
+#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#'                operators are supported, including '~', '.', ':', '+', and '-'.
+#'                Note that the response variable of formula is empty in spark.kmeans.
 #' @param k Number of centers
 #' @param maxIter Maximum iteration number
-#' @param initializationMode Algorithm choosen to fit the model
+#' @param initMode The initialization algorithm choosen to fit the model
 #' @return A fitted k-means model
 #' @rdname spark.kmeans
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- spark.kmeans(data, k = 2, initializationMode="random")
+#' model <- spark.kmeans(data, ~ ., k=2, initMode="random")
 #' }
-setMethod("spark.kmeans", signature(data = "SparkDataFrame"),
-          function(data, k, maxIter = 10, initializationMode = c("random", "k-means||")) {
-            columnNames <- as.array(colnames(data))
-            initializationMode <- match.arg(initializationMode)
-            jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf,
-                                k, maxIter, initializationMode, columnNames)
+setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"),
+          function(data, formula, k, maxIter = 10, initMode = c("random", "k-means||")) {
+            formula <- paste(deparse(formula), collapse = "")
+            initMode <- match.arg(initMode)
+            jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula,
+                                as.integer(k), as.integer(maxIter), initMode)
             return(new("KMeansModel", jobj = jobj))
          })
 
@@ -301,7 +307,7 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame"),
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- spark.kmeans(trainingData, 2)
+#' model <- spark.kmeans(trainingData, ~ ., 2)
 #' fitted.model <- fitted(model)
 #' showDF(fitted.model)
 #'}
@@ -319,7 +325,7 @@ setMethod("fitted", signature(object = "KMeansModel"),
 
 #' Get the summary of a k-means model
 #'
-#' Returns the summary of a k-means model produced by kmeans(),
+#' Returns the summary of a k-means model produced by spark.kmeans(),
 #' similarly to R's summary().
 #'
 #' @param object a fitted k-means model
@@ -328,7 +334,7 @@ setMethod("fitted", signature(object = "KMeansModel"),
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- spark.kmeans(trainingData, 2)
+#' model <- spark.kmeans(trainingData, ~ ., 2)
 #' summary(model)
 #' }
 setMethod("summary", signature(object = "KMeansModel"),
@@ -353,7 +359,7 @@ setMethod("summary", signature(object = "KMeansModel"),
 
 #' Make predictions from a k-means model
 #'
-#' Make predictions from a model produced by kmeans().
+#' Make predictions from a model produced by spark.kmeans().
 #'
 #' @param object A fitted k-means model
 #' @param newData SparkDataFrame for testing
@@ -362,7 +368,7 @@ setMethod("summary", signature(object = "KMeansModel"),
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- spark.kmeans(trainingData, 2)
+#' model <- spark.kmeans(trainingData, ~ ., 2)
 #' predicted <- predict(model, testData)
 #' showDF(predicted)
 #' }
@@ -376,7 +382,7 @@ setMethod("predict", signature(object = "KMeansModel"),
 #' Fit a Bernoulli naive Bayes model on a Spark DataFrame (only categorical data is supported).
 #'
 #' @param data SparkDataFrame for training
-#' @param object A symbolic description of the model to be fitted. Currently only a few formula
+#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
 #'               operators are supported, including '~', '.', ':', '+', and '-'.
 #' @param laplace Smoothing parameter
 #' @return a fitted naive Bayes model
@@ -409,7 +415,7 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form
 #' @examples
 #' \dontrun{
 #' df <- createDataFrame(sqlContext, infert)
-#' model <- spark.naiveBayes(education ~ ., df, laplace = 0)
+#' model <- spark.naiveBayes(df, education ~ ., laplace = 0)
 #' path <- "path/to/model"
 #' write.ml(model, path)
 #' }
@@ -484,7 +490,7 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- spark.kmeans(x, k = 2, initializationMode="random")
+#' model <- spark.kmeans(trainingData, ~ ., k = 2)
 #' path <- "path/to/model"
 #' write.ml(model, path)
 #' }
@@ -540,7 +546,7 @@ read.ml <- function(path) {
 #' @examples
 #' \dontrun{
 #' df <- createDataFrame(sqlContext, ovarian)
-#' model <- spark.survreg(Surv(df, futime, fustat) ~ ecog_ps + rx)
+#' model <- spark.survreg(df, Surv(futime, fustat) ~ ecog_ps + rx)
 #' }
 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
           function(data, formula, ...) {
@@ -553,7 +559,7 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula
 
 #' Get the summary of an AFT survival regression model
 #'
-#' Returns the summary of an AFT survival regression model produced by survreg(),
+#' Returns the summary of an AFT survival regression model produced by spark.survreg(),
 #' similarly to R's summary().
 #'
 #' @param object a fitted AFT survival regression model
@@ -578,7 +584,8 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
 
 #' Make predictions from an AFT survival regression model
 #'
-#' Make predictions from a model produced by survreg(), similarly to R package survival's predict.
+#' Make predictions from a model produced by spark.survreg(),
+#' similarly to R package survival's predict.
 #'
 #' @param object A fitted AFT survival regression model
 #' @param newData SparkDataFrame for testing

http://git-wip-us.apache.org/repos/asf/spark/blob/19a6d192/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index dcd0296..37d87aa 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -132,7 +132,7 @@ test_that("spark.glm save/load", {
   m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species)
   s <- summary(m)
 
-  modelPath <- tempfile(pattern = "glm", fileext = ".tmp")
+  modelPath <- tempfile(pattern = "spark-glm", fileext = ".tmp")
   write.ml(m, modelPath)
   expect_error(write.ml(m, modelPath))
   write.ml(m, modelPath, overwrite = TRUE)
@@ -291,7 +291,7 @@ test_that("spark.kmeans", {
 
   take(training, 1)
 
-  model <- spark.kmeans(data = training, k = 2)
+  model <- spark.kmeans(data = training, ~ ., k = 2)
   sample <- take(select(predict(model, training), "prediction"), 1)
   expect_equal(typeof(sample$prediction), "integer")
   expect_equal(sample$prediction, 1)
@@ -310,7 +310,7 @@ test_that("spark.kmeans", {
   expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
 
   # Test model save/load
-  modelPath <- tempfile(pattern = "kmeans", fileext = ".tmp")
+  modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp")
   write.ml(model, modelPath)
   expect_error(write.ml(model, modelPath))
   write.ml(model, modelPath, overwrite = TRUE)
@@ -324,7 +324,7 @@ test_that("spark.kmeans", {
   unlink(modelPath)
 })
 
-test_that("naiveBayes", {
+test_that("spark.naiveBayes", {
   # R code to reproduce the result.
   # We do not support instance weights yet. So we ignore the frequencies.
   #
@@ -377,7 +377,7 @@ test_that("naiveBayes", {
                                "Yes", "Yes", "No", "No"))
 
   # Test model save/load
-  modelPath <- tempfile(pattern = "naiveBayes", fileext = ".tmp")
+  modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp")
   write.ml(m, modelPath)
   expect_error(write.ml(m, modelPath))
   write.ml(m, modelPath, overwrite = TRUE)
@@ -434,7 +434,7 @@ test_that("spark.survreg", {
                2.390146, 2.891269, 2.891269), tolerance = 1e-4)
 
   # Test model save/load
-  modelPath <- tempfile(pattern = "survreg", fileext = ".tmp")
+  modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp")
   write.ml(model, modelPath)
   expect_error(write.ml(model, modelPath))
   write.ml(model, modelPath, overwrite = TRUE)

http://git-wip-us.apache.org/repos/asf/spark/blob/19a6d192/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 3ac6c77..5219680 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -214,7 +214,7 @@ class RFormulaModel private[feature](
   override def transformSchema(schema: StructType): StructType = {
     checkCanTransform(schema)
     val withFeatures = pipelineModel.transformSchema(schema)
-    if (hasLabelCol(withFeatures)) {
+    if (resolvedFormula.label.isEmpty || hasLabelCol(withFeatures)) {
       withFeatures
     } else if (schema.exists(_.name == resolvedFormula.label)) {
       val nullable = schema(resolvedFormula.label).dataType match {
@@ -236,7 +236,7 @@ class RFormulaModel private[feature](
 
   private def transformLabel(dataset: Dataset[_]): DataFrame = {
     val labelName = resolvedFormula.label
-    if (hasLabelCol(dataset.schema)) {
+    if (labelName.isEmpty || hasLabelCol(dataset.schema)) {
       dataset.toDF
     } else if (dataset.schema.exists(_.name == labelName)) {
       dataset.schema(labelName).dataType match {

http://git-wip-us.apache.org/repos/asf/spark/blob/19a6d192/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
index 4079b38..cf52710 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
@@ -63,6 +63,9 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
     ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept)
   }
 
+  /** Whether this formula specifies fitting with response variable. */
+  def hasLabel: Boolean = label.value.nonEmpty
+
   /** Whether this formula specifies fitting with an intercept term. */
   def hasIntercept: Boolean = {
     var intercept = true
@@ -159,6 +162,10 @@ private[ml] object RFormulaParser extends RegexParsers {
   private val columnRef: Parser[ColumnRef] =
     "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }
 
+  private val empty: Parser[ColumnRef] = "" ^^ { case a => ColumnRef("") }
+
+  private val label: Parser[ColumnRef] = columnRef | empty
+
   private val dot: Parser[InteractableTerm] = "\\.".r ^^ { case _ => Dot }
 
   private val interaction: Parser[List[InteractableTerm]] = rep1sep(columnRef | dot, ":")
@@ -174,7 +181,7 @@ private[ml] object RFormulaParser extends RegexParsers {
   }
 
   private val formula: Parser[ParsedRFormula] =
-    (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
+    (label ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
 
   def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
     case Success(result, _) => result

http://git-wip-us.apache.org/repos/asf/spark/blob/19a6d192/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index f67760d..4d4c303 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -25,7 +25,7 @@ import org.json4s.jackson.JsonMethods._
 import org.apache.spark.ml.{Pipeline, PipelineModel}
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
-import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.ml.feature.RFormula
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
 
@@ -65,28 +65,32 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
 
   def fit(
       data: DataFrame,
-      k: Double,
-      maxIter: Double,
-      initMode: String,
-      columns: Array[String]): KMeansWrapper = {
+      formula: String,
+      k: Int,
+      maxIter: Int,
+      initMode: String): KMeansWrapper = {
+
+    val rFormulaModel = new RFormula()
+      .setFormula(formula)
+      .setFeaturesCol("features")
+      .fit(data)
 
-    val assembler = new VectorAssembler()
-      .setInputCols(columns)
-      .setOutputCol("features")
+    // get feature names from output schema
+    val schema = rFormulaModel.transform(data).schema
+    val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+      .attributes.get
+    val features = featureAttrs.map(_.name.get)
 
     val kMeans = new KMeans()
-      .setK(k.toInt)
-      .setMaxIter(maxIter.toInt)
+      .setK(k)
+      .setMaxIter(maxIter)
       .setInitMode(initMode)
 
     val pipeline = new Pipeline()
-      .setStages(Array(assembler, kMeans))
+      .setStages(Array(rFormulaModel, kMeans))
       .fit(data)
 
     val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
-    val attrs = AttributeGroup.fromStructField(
-      kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
-    val features: Array[String] = attrs.attributes.get.map(_.name.get)
     val size: Array[Long] = kMeansModel.summary.clusterSizes
 
     new KMeansWrapper(pipeline, features, size)

http://git-wip-us.apache.org/repos/asf/spark/blob/19a6d192/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 9c07579..568c160 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.ml.util.MLReader
 
 /**
- * This is the Scala stub of SparkR ml.load. It will dispatch the call to corresponding
+ * This is the Scala stub of SparkR read.ml. It will dispatch the call to corresponding
  * model wrapper loading function according the class name extracted from rMetadata of the path.
  */
 private[r] object RWrappers extends MLReader[Object] {
@@ -45,7 +45,7 @@ private[r] object RWrappers extends MLReader[Object] {
       case "org.apache.spark.ml.r.KMeansWrapper" =>
         KMeansWrapper.load(path)
       case _ =>
-        throw new SparkException(s"SparkR ml.load does not support load $className")
+        throw new SparkException(s"SparkR read.ml does not support load $className")
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/19a6d192/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
index 66b2cea..5f1d598 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -72,7 +72,7 @@ public class JavaStatisticsSuite implements Serializable {
     Double corr1 = Statistics.corr(x, y);
     Double corr2 = Statistics.corr(x, y, "pearson");
     // Check default method
-    assertEquals(corr1, corr2);
+    assertEquals(corr1, corr2, 1e-5);
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/spark/blob/19a6d192/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index e1b269b..f847695 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
 
 class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
   test("params") {
@@ -89,6 +90,24 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
     assert(resultSchema.toString == model.transform(original).schema.toString)
   }
 
+  test("allow empty label") {
+    val original = sqlContext.createDataFrame(
+      Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0))
+    ).toDF("id", "a", "b")
+    val formula = new RFormula().setFormula("~ a + b")
+    val model = formula.fit(original)
+    val result = model.transform(original)
+    val resultSchema = model.transformSchema(original.schema)
+    val expected = sqlContext.createDataFrame(
+      Seq(
+        (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)),
+        (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)),
+        (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)))
+      ).toDF("id", "a", "b", "features")
+    assert(result.schema.toString == resultSchema.toString)
+    assert(result.collect() === expected.collect())
+  }
+
   test("encodes string terms") {
     val formula = new RFormula().setFormula("id ~ a + b")
     val original = sqlContext.createDataFrame(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org