You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by fe...@apache.org on 2017/08/01 03:37:10 UTC

spark git commit: [SPARK-21381][SPARKR] SparkR: pass on setHandleInvalid for classification algorithms

Repository: spark
Updated Branches:
  refs/heads/master 6b186c9d6 -> 9570e81aa


[SPARK-21381][SPARKR] SparkR: pass on setHandleInvalid for classification algorithms

## What changes were proposed in this pull request?

SPARK-20307 Added handleInvalid option to RFormula for tree-based classification algorithms. We should add this parameter for other classification algorithms in SparkR.

This is a followup PR for SPARK-20307.

## How was this patch tested?

New Unit tests are added.

Author: wangmiao1981 <wm...@hotmail.com>

Closes #18605 from wangmiao1981/class.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9570e81a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9570e81a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9570e81a

Branch: refs/heads/master
Commit: 9570e81aa949cddb30a0e94c92093cd16e34326a
Parents: 6b186c9
Author: wangmiao1981 <wm...@hotmail.com>
Authored: Mon Jul 31 20:37:06 2017 -0700
Committer: Felix Cheung <fe...@apache.org>
Committed: Mon Jul 31 20:37:06 2017 -0700

----------------------------------------------------------------------
 R/pkg/R/mllib_classification.R                  | 49 ++++++++++++++---
 R/pkg/R/mllib_tree.R                            | 33 ++++++++---
 .../tests/fulltests/test_mllib_classification.R | 58 ++++++++++++++++++++
 R/pkg/tests/fulltests/test_mllib_tree.R         | 30 ++++++++++
 .../r/DecisionTreeClassificationWrapper.scala   |  4 +-
 .../spark/ml/r/GBTClassificationWrapper.scala   |  4 +-
 .../apache/spark/ml/r/LinearSVCWrapper.scala    |  4 +-
 .../spark/ml/r/LogisticRegressionWrapper.scala  |  4 +-
 .../MultilayerPerceptronClassifierWrapper.scala |  6 +-
 .../apache/spark/ml/r/NaiveBayesWrapper.scala   |  7 ++-
 10 files changed, 175 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/R/pkg/R/mllib_classification.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index 82d2428..15af829 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -69,6 +69,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
 #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
 #'                         or the number of partitions are large, this param could be adjusted to a larger size.
 #'                         This is an expert parameter. Default value should be good for most cases.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#'                      column of string type.
+#'                      Supported options: "skip" (filter out rows with invalid data),
+#'                                         "error" (throw an error), "keep" (put invalid data in a special additional
+#'                                         bucket, at index numLabels). Default is "error".
 #' @param ... additional arguments passed to the method.
 #' @return \code{spark.svmLinear} returns a fitted linear SVM model.
 #' @rdname spark.svmLinear
@@ -98,7 +103,8 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
 #' @note spark.svmLinear since 2.2.0
 setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formula"),
           function(data, formula, regParam = 0.0, maxIter = 100, tol = 1E-6, standardization = TRUE,
-                   threshold = 0.0, weightCol = NULL, aggregationDepth = 2) {
+                   threshold = 0.0, weightCol = NULL, aggregationDepth = 2,
+                   handleInvalid = c("error", "keep", "skip")) {
             formula <- paste(deparse(formula), collapse = "")
 
             if (!is.null(weightCol) && weightCol == "") {
@@ -107,10 +113,12 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu
               weightCol <- as.character(weightCol)
             }
 
+            handleInvalid <- match.arg(handleInvalid)
+
             jobj <- callJStatic("org.apache.spark.ml.r.LinearSVCWrapper", "fit",
                                 data@sdf, formula, as.numeric(regParam), as.integer(maxIter),
                                 as.numeric(tol), as.logical(standardization), as.numeric(threshold),
-                                weightCol, as.integer(aggregationDepth))
+                                weightCol, as.integer(aggregationDepth), handleInvalid)
             new("LinearSVCModel", jobj = jobj)
           })
 
@@ -218,6 +226,11 @@ function(object, path, overwrite = FALSE) {
 #' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization.
 #'                                The bound vector size must be equal to 1 for binomial regression, or the number
 #'                                of classes for multinomial regression.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#'                      column of string type.
+#'                      Supported options: "skip" (filter out rows with invalid data),
+#'                                         "error" (throw an error), "keep" (put invalid data in a special additional
+#'                                         bucket, at index numLabels). Default is "error".
 #' @param ... additional arguments passed to the method.
 #' @return \code{spark.logit} returns a fitted logistic regression model.
 #' @rdname spark.logit
@@ -257,7 +270,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
                    tol = 1E-6, family = "auto", standardization = TRUE,
                    thresholds = 0.5, weightCol = NULL, aggregationDepth = 2,
                    lowerBoundsOnCoefficients = NULL, upperBoundsOnCoefficients = NULL,
-                   lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL) {
+                   lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL,
+                   handleInvalid = c("error", "keep", "skip")) {
             formula <- paste(deparse(formula), collapse = "")
             row <- 0
             col <- 0
@@ -304,6 +318,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
               upperBoundsOnCoefficients <- as.array(as.vector(upperBoundsOnCoefficients))
             }
 
+            handleInvalid <- match.arg(handleInvalid)
+
             jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
                                 data@sdf, formula, as.numeric(regParam),
                                 as.numeric(elasticNetParam), as.integer(maxIter),
@@ -312,7 +328,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
                                 weightCol, as.integer(aggregationDepth),
                                 as.integer(row), as.integer(col),
                                 lowerBoundsOnCoefficients, upperBoundsOnCoefficients,
-                                lowerBoundsOnIntercepts, upperBoundsOnIntercepts)
+                                lowerBoundsOnIntercepts, upperBoundsOnIntercepts,
+                                handleInvalid)
             new("LogisticRegressionModel", jobj = jobj)
           })
 
@@ -394,7 +411,12 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char
 #' @param stepSize stepSize parameter.
 #' @param seed seed parameter for weights initialization.
 #' @param initialWeights initialWeights parameter for weights initialization, it should be a
-#' numeric vector.
+#'        numeric vector.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#'                      column of string type.
+#'                      Supported options: "skip" (filter out rows with invalid data),
+#'                                         "error" (throw an error), "keep" (put invalid data in a special additional
+#'                                         bucket, at index numLabels). Default is "error".
 #' @param ... additional arguments passed to the method.
 #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model.
 #' @rdname spark.mlp
@@ -426,7 +448,8 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char
 #' @note spark.mlp since 2.1.0
 setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"),
           function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100,
-                   tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) {
+                   tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL,
+                   handleInvalid = c("error", "keep", "skip")) {
             formula <- paste(deparse(formula), collapse = "")
             if (is.null(layers)) {
               stop ("layers must be a integer vector with length > 1.")
@@ -441,10 +464,11 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"),
             if (!is.null(initialWeights)) {
               initialWeights <- as.array(as.numeric(na.omit(initialWeights)))
             }
+            handleInvalid <- match.arg(handleInvalid)
             jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper",
                                 "fit", data@sdf, formula, as.integer(blockSize), as.array(layers),
                                 as.character(solver), as.integer(maxIter), as.numeric(tol),
-                                as.numeric(stepSize), seed, initialWeights)
+                                as.numeric(stepSize), seed, initialWeights, handleInvalid)
             new("MultilayerPerceptronClassificationModel", jobj = jobj)
           })
 
@@ -514,6 +538,11 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode
 #' @param formula a symbolic description of the model to be fitted. Currently only a few formula
 #'               operators are supported, including '~', '.', ':', '+', and '-'.
 #' @param smoothing smoothing parameter.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#'                      column of string type.
+#'                      Supported options: "skip" (filter out rows with invalid data),
+#'                                         "error" (throw an error), "keep" (put invalid data in a special additional
+#'                                         bucket, at index numLabels). Default is "error".
 #' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}.
 #' @return \code{spark.naiveBayes} returns a fitted naive Bayes model.
 #' @rdname spark.naiveBayes
@@ -543,10 +572,12 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode
 #' }
 #' @note spark.naiveBayes since 2.0.0
 setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"),
-          function(data, formula, smoothing = 1.0) {
+          function(data, formula, smoothing = 1.0,
+                   handleInvalid = c("error", "keep", "skip")) {
             formula <- paste(deparse(formula), collapse = "")
+            handleInvalid <- match.arg(handleInvalid)
             jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
-            formula, data@sdf, smoothing)
+                                formula, data@sdf, smoothing, handleInvalid)
             new("NaiveBayesModel", jobj = jobj)
           })
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/R/pkg/R/mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R
index 75b1a74..33c4653 100644
--- a/R/pkg/R/mllib_tree.R
+++ b/R/pkg/R/mllib_tree.R
@@ -164,6 +164,11 @@ print.summary.decisionTree <- function(x) {
 #'                     nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
 #'                     can speed up training of deeper trees. Users can set how often should the
 #'                     cache be checkpointed or disable it by setting checkpointInterval.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#'                      column of string type in classification model.
+#'                      Supported options: "skip" (filter out rows with invalid data),
+#'                                         "error" (throw an error), "keep" (put invalid data in a special additional
+#'                                         bucket, at index numLabels). Default is "error".
 #' @param ... additional arguments passed to the method.
 #' @aliases spark.gbt,SparkDataFrame,formula-method
 #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model.
@@ -205,7 +210,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
           function(data, formula, type = c("regression", "classification"),
                    maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL,
                    seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0,
-                   checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) {
+                   checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE,
+                   handleInvalid = c("error", "keep", "skip")) {
             type <- match.arg(type)
             formula <- paste(deparse(formula), collapse = "")
             if (!is.null(seed)) {
@@ -225,6 +231,7 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
                      new("GBTRegressionModel", jobj = jobj)
                    },
                    classification = {
+                     handleInvalid <- match.arg(handleInvalid)
                      if (is.null(lossType)) lossType <- "logistic"
                      lossType <- match.arg(lossType, "logistic")
                      jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper",
@@ -233,7 +240,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
                                          as.numeric(stepSize), as.integer(minInstancesPerNode),
                                          as.numeric(minInfoGain), as.integer(checkpointInterval),
                                          lossType, seed, as.numeric(subsamplingRate),
-                                         as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
+                                         as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
+                                         handleInvalid)
                      new("GBTClassificationModel", jobj = jobj)
                    }
             )
@@ -374,10 +382,11 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara
 #'                     nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
 #'                     can speed up training of deeper trees. Users can set how often should the
 #'                     cache be checkpointed or disable it by setting checkpointInterval.
-#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model.
-#'        Supported options: "skip" (filter out rows with invalid data),
-#'                           "error" (throw an error), "keep" (put invalid data in a special additional
-#'                           bucket, at index numLabels). Default is "error".
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#'                      column of string type in classification model.
+#'                      Supported options: "skip" (filter out rows with invalid data),
+#'                                         "error" (throw an error), "keep" (put invalid data in a special additional
+#'                                         bucket, at index numLabels). Default is "error".
 #' @param ... additional arguments passed to the method.
 #' @aliases spark.randomForest,SparkDataFrame,formula-method
 #' @return \code{spark.randomForest} returns a fitted Random Forest model.
@@ -583,6 +592,11 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path
 #'                     nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
 #'                     can speed up training of deeper trees. Users can set how often should the
 #'                     cache be checkpointed or disable it by setting checkpointInterval.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#'                      column of string type in classification model.
+#'                      Supported options: "skip" (filter out rows with invalid data),
+#'                                         "error" (throw an error), "keep" (put invalid data in a special additional
+#'                                         bucket, at index numLabels). Default is "error".
 #' @param ... additional arguments passed to the method.
 #' @aliases spark.decisionTree,SparkDataFrame,formula-method
 #' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
@@ -617,7 +631,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
           function(data, formula, type = c("regression", "classification"),
                    maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL,
                    minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
-                   maxMemoryInMB = 256, cacheNodeIds = FALSE) {
+                   maxMemoryInMB = 256, cacheNodeIds = FALSE,
+                   handleInvalid = c("error", "keep", "skip")) {
             type <- match.arg(type)
             formula <- paste(deparse(formula), collapse = "")
             if (!is.null(seed)) {
@@ -636,6 +651,7 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
                      new("DecisionTreeRegressionModel", jobj = jobj)
                    },
                    classification = {
+                     handleInvalid <- match.arg(handleInvalid)
                      if (is.null(impurity)) impurity <- "gini"
                      impurity <- match.arg(impurity, c("gini", "entropy"))
                      jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
@@ -643,7 +659,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
                                          as.integer(maxBins), impurity,
                                          as.integer(minInstancesPerNode), as.numeric(minInfoGain),
                                          as.integer(checkpointInterval), seed,
-                                         as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
+                                         as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
+                                         handleInvalid)
                      new("DecisionTreeClassificationModel", jobj = jobj)
                    }
             )

http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/R/pkg/tests/fulltests/test_mllib_classification.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R
index 3d75f4c..a4d0397 100644
--- a/R/pkg/tests/fulltests/test_mllib_classification.R
+++ b/R/pkg/tests/fulltests/test_mllib_classification.R
@@ -70,6 +70,20 @@ test_that("spark.svmLinear", {
   prediction <- collect(select(predict(model, df), "prediction"))
   expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))
 
+  # Test unseen labels
+  data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+  someString = base::sample(c("this", "that"), 10, replace = TRUE),
+                            stringsAsFactors = FALSE)
+  trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+  traindf <- as.DataFrame(data[trainidxs, ])
+  testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+  model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1)
+  predictions <- predict(model, testdf)
+  expect_error(collect(predictions))
+  model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, handleInvalid = "skip")
+  predictions <- predict(model, testdf)
+  expect_equal(class(collect(predictions)$clicked[1]), "list")
+
 })
 
 test_that("spark.logit", {
@@ -263,6 +277,21 @@ test_that("spark.logit", {
   virginicaCoefs <- summary$coefficients[, "virginica"]
   expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1))
   expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1))
+
+  # Test unseen labels
+  data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+  someString = base::sample(c("this", "that"), 10, replace = TRUE),
+                            stringsAsFactors = FALSE)
+  trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+  traindf <- as.DataFrame(data[trainidxs, ])
+  testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+  model <- spark.logit(traindf, clicked ~ ., regParam = 0.5)
+  predictions <- predict(model, testdf)
+  expect_error(collect(predictions))
+  model <- spark.logit(traindf, clicked ~ ., regParam = 0.5, handleInvalid = "keep")
+  predictions <- predict(model, testdf)
+  expect_equal(class(collect(predictions)$clicked[1]), "character")
+
 })
 
 test_that("spark.mlp", {
@@ -344,6 +373,21 @@ test_that("spark.mlp", {
   expect_equal(summary$numOfOutputs, 3)
   expect_equal(summary$layers, c(4, 3))
   expect_equal(length(summary$weights), 15)
+
+  # Test unseen labels
+  data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+  someString = base::sample(c("this", "that"), 10, replace = TRUE),
+                            stringsAsFactors = FALSE)
+  trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+  traindf <- as.DataFrame(data[trainidxs, ])
+  testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+  model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3))
+  predictions <- predict(model, testdf)
+  expect_error(collect(predictions))
+  model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3), handleInvalid = "skip")
+  predictions <- predict(model, testdf)
+  expect_equal(class(collect(predictions)$clicked[1]), "list")
+
 })
 
 test_that("spark.naiveBayes", {
@@ -427,6 +471,20 @@ test_that("spark.naiveBayes", {
   expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6)
   expect_equal(sum(s$apriori), 1)
   expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6)
+
+  # Test unseen labels
+  data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+  someString = base::sample(c("this", "that"), 10, replace = TRUE),
+                            stringsAsFactors = FALSE)
+  trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+  traindf <- as.DataFrame(data[trainidxs, ])
+  testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+  model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0)
+  predictions <- predict(model, testdf)
+  expect_error(collect(predictions))
+  model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0, handleInvalid = "keep")
+  predictions <- predict(model, testdf)
+  expect_equal(class(collect(predictions)$clicked[1]), "character")
 })
 
 sparkR.session.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/R/pkg/tests/fulltests/test_mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R
index e31a65f..799f944 100644
--- a/R/pkg/tests/fulltests/test_mllib_tree.R
+++ b/R/pkg/tests/fulltests/test_mllib_tree.R
@@ -109,6 +109,20 @@ test_that("spark.gbt", {
     model <- spark.gbt(data, label ~ features, "classification")
     expect_equal(summary(model)$numFeatures, 692)
   }
+
+  # Test unseen labels
+  data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+  someString = base::sample(c("this", "that"), 10, replace = TRUE),
+                            stringsAsFactors = FALSE)
+  trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+  traindf <- as.DataFrame(data[trainidxs, ])
+  testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+  model <- spark.gbt(traindf, clicked ~ ., type = "classification")
+  predictions <- predict(model, testdf)
+  expect_error(collect(predictions))
+  model <- spark.gbt(traindf, clicked ~ ., type = "classification", handleInvalid = "keep")
+  predictions <- predict(model, testdf)
+  expect_equal(class(collect(predictions)$clicked[1]), "character")
 })
 
 test_that("spark.randomForest", {
@@ -328,6 +342,22 @@ test_that("spark.decisionTree", {
     model <- spark.decisionTree(data, label ~ features, "classification")
     expect_equal(summary(model)$numFeatures, 4)
   }
+
+  # Test unseen labels
+  data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+  someString = base::sample(c("this", "that"), 10, replace = TRUE),
+                            stringsAsFactors = FALSE)
+  trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+  traindf <- as.DataFrame(data[trainidxs, ])
+  testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+  model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
+                              maxDepth = 5, maxBins = 16)
+  predictions <- predict(model, testdf)
+  expect_error(collect(predictions))
+  model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
+                              maxDepth = 5, maxBins = 16, handleInvalid = "keep")
+  predictions <- predict(model, testdf)
+  expect_equal(class(collect(predictions)$clicked[1]), "character")
 })
 
 sparkR.session.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
index 7f59825..a90cae5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
@@ -73,11 +73,13 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC
       checkpointInterval: Int,
       seed: String,
       maxMemoryInMB: Int,
-      cacheNodeIds: Boolean): DecisionTreeClassifierWrapper = {
+      cacheNodeIds: Boolean,
+      handleInvalid: String): DecisionTreeClassifierWrapper = {
 
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
+      .setHandleInvalid(handleInvalid)
     checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
index c07eadb..ecaeec5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
@@ -78,11 +78,13 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
       seed: String,
       subsamplingRate: Double,
       maxMemoryInMB: Int,
-      cacheNodeIds: Boolean): GBTClassifierWrapper = {
+      cacheNodeIds: Boolean,
+      handleInvalid: String): GBTClassifierWrapper = {
 
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
+      .setHandleInvalid(handleInvalid)
     checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
index 0dd1f11..7a22a71 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
@@ -79,12 +79,14 @@ private[r] object LinearSVCWrapper
       standardization: Boolean,
       threshold: Double,
       weightCol: String,
-      aggregationDepth: Int
+      aggregationDepth: Int,
+      handleInvalid: String
       ): LinearSVCWrapper = {
 
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
+      .setHandleInvalid(handleInvalid)
     checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
index b96481a..18acf7d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
@@ -103,12 +103,14 @@ private[r] object LogisticRegressionWrapper
       lowerBoundsOnCoefficients: Array[Double],
       upperBoundsOnCoefficients: Array[Double],
       lowerBoundsOnIntercepts: Array[Double],
-      upperBoundsOnIntercepts: Array[Double]
+      upperBoundsOnIntercepts: Array[Double],
+      handleInvalid: String
       ): LogisticRegressionWrapper = {
 
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
+      .setHandleInvalid(handleInvalid)
     checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
index 48c8774..62f6421 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
@@ -62,7 +62,7 @@ private[r] object MultilayerPerceptronClassifierWrapper
   val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
   val PREDICTED_LABEL_COL = "prediction"
 
-  def fit(
+  def fit(  // scalastyle:ignore
       data: DataFrame,
       formula: String,
       blockSize: Int,
@@ -72,11 +72,13 @@ private[r] object MultilayerPerceptronClassifierWrapper
       tol: Double,
       stepSize: Double,
       seed: String,
-      initialWeights: Array[Double]
+      initialWeights: Array[Double],
+      handleInvalid: String
      ): MultilayerPerceptronClassifierWrapper = {
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
+      .setHandleInvalid(handleInvalid)
     checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
     // get labels and feature names from output schema

http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index 0afea4b..fbf9f46 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -57,10 +57,15 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
   val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
   val PREDICTED_LABEL_COL = "prediction"
 
-  def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = {
+  def fit(
+      formula: String,
+      data: DataFrame,
+      smoothing: Double,
+      handleInvalid: String): NaiveBayesWrapper = {
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
+      .setHandleInvalid(handleInvalid)
     checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
     // get labels and feature names from output schema


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org