You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by fe...@apache.org on 2017/07/08 06:51:35 UTC

spark git commit: [SPARK-20307][SPARKR] SparkR: pass on setHandleInvalid to spark.mllib functions that use StringIndexer

Repository: spark
Updated Branches:
  refs/heads/master d0bfc6733 -> a7b46c627


[SPARK-20307][SPARKR] SparkR: pass on setHandleInvalid to spark.mllib functions that use StringIndexer

## What changes were proposed in this pull request?

For randomForest classifier, if test data contains unseen labels, it will throw an error. The StringIndexer already has the handleInvalid logic. The patch add a new method to set the underlying StringIndexer handleInvalid logic.

This patch should also apply to other classifiers. This PR focuses on the main logic and randomForest classifier. I will do follow-up PR for other classifiers.

## How was this patch tested?

Add a new unit test based on the error case in the JIRA.

Author: wangmiao1981 <wm...@hotmail.com>

Closes #18496 from wangmiao1981/handle.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a7b46c62
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a7b46c62
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a7b46c62

Branch: refs/heads/master
Commit: a7b46c627b5d2461257f337139a29f23350e0c77
Parents: d0bfc67
Author: wangmiao1981 <wm...@hotmail.com>
Authored: Fri Jul 7 23:51:32 2017 -0700
Committer: Felix Cheung <fe...@apache.org>
Committed: Fri Jul 7 23:51:32 2017 -0700

----------------------------------------------------------------------
 R/pkg/R/mllib_tree.R                            | 11 +++++++--
 R/pkg/tests/fulltests/test_mllib_tree.R         | 17 +++++++++++++
 .../org/apache/spark/ml/feature/RFormula.scala  | 25 ++++++++++++++++++++
 .../r/RandomForestClassificationWrapper.scala   |  4 +++-
 .../spark/ml/feature/StringIndexerSuite.scala   |  2 +-
 5 files changed, 55 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a7b46c62/R/pkg/R/mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R
index 2f1220a..75b1a74 100644
--- a/R/pkg/R/mllib_tree.R
+++ b/R/pkg/R/mllib_tree.R
@@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara
 #'                     nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
 #'                     can speed up training of deeper trees. Users can set how often should the
 #'                     cache be checkpointed or disable it by setting checkpointInterval.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model.
+#'        Supported options: "skip" (filter out rows with invalid data),
+#'                           "error" (throw an error), "keep" (put invalid data in a special additional
+#'                           bucket, at index numLabels). Default is "error".
 #' @param ... additional arguments passed to the method.
 #' @aliases spark.randomForest,SparkDataFrame,formula-method
 #' @return \code{spark.randomForest} returns a fitted Random Forest model.
@@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
                    maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
                    featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
                    minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
-                   maxMemoryInMB = 256, cacheNodeIds = FALSE) {
+                   maxMemoryInMB = 256, cacheNodeIds = FALSE,
+                   handleInvalid = c("error", "keep", "skip")) {
             type <- match.arg(type)
             formula <- paste(deparse(formula), collapse = "")
             if (!is.null(seed)) {
@@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
                      new("RandomForestRegressionModel", jobj = jobj)
                    },
                    classification = {
+                     handleInvalid <- match.arg(handleInvalid)
                      if (is.null(impurity)) impurity <- "gini"
                      impurity <- match.arg(impurity, c("gini", "entropy"))
                      jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper",
@@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
                                          as.numeric(minInfoGain), as.integer(checkpointInterval),
                                          as.character(featureSubsetStrategy), seed,
                                          as.numeric(subsamplingRate),
-                                         as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
+                                         as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
+                                         handleInvalid)
                      new("RandomForestClassificationModel", jobj = jobj)
                    }
             )

http://git-wip-us.apache.org/repos/asf/spark/blob/a7b46c62/R/pkg/tests/fulltests/test_mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R
index 9b3fc8d..66a0693 100644
--- a/R/pkg/tests/fulltests/test_mllib_tree.R
+++ b/R/pkg/tests/fulltests/test_mllib_tree.R
@@ -212,6 +212,23 @@ test_that("spark.randomForest", {
   expect_equal(length(grep("1.0", predictions)), 50)
   expect_equal(length(grep("2.0", predictions)), 50)
 
+  # Test unseen labels
+  data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+                    someString = base::sample(c("this", "that"), 10, replace = TRUE),
+                    stringsAsFactors = FALSE)
+  trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+  traindf <- as.DataFrame(data[trainidxs, ])
+  testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+  model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
+                          maxDepth = 10, maxBins = 10, numTrees = 10)
+  predictions <- predict(model, testdf)
+  expect_error(collect(predictions))
+  model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
+                             maxDepth = 10, maxBins = 10, numTrees = 10,
+                             handleInvalid = "skip")
+  predictions <- predict(model, testdf)
+  expect_equal(class(collect(predictions)$clicked[1]), "character")
+
   # spark.randomForest classification can work on libsvm data
   if (windows_with_hadoop()) {
     data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),

http://git-wip-us.apache.org/repos/asf/spark/blob/a7b46c62/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 4b44878..61aa646 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -132,6 +132,30 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
   @Since("1.5.0")
   def getFormula: String = $(formula)
 
+  /**
+   * Param for how to handle invalid data (unseen labels or NULL values).
+   * Options are 'skip' (filter out rows with invalid data),
+   * 'error' (throw an error), or 'keep' (put invalid data in a special additional
+   * bucket, at index numLabels).
+   * Default: "error"
+   * @group param
+   */
+  @Since("2.3.0")
+  val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " +
+    "invalid data (unseen labels or NULL values). " +
+    "Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
+    "or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
+    ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
+  setDefault(handleInvalid, StringIndexer.ERROR_INVALID)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
+  /** @group getParam */
+  @Since("2.3.0")
+  def getHandleInvalid: String = $(handleInvalid)
+
   /** @group setParam */
   @Since("1.5.0")
   def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@@ -197,6 +221,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
             .setInputCol(term)
             .setOutputCol(indexCol)
             .setStringOrderType($(stringIndexerOrderType))
+            .setHandleInvalid($(handleInvalid))
           prefixesToRewrite(indexCol + "_") = term + "_"
           (term, indexCol)
         case _ =>

http://git-wip-us.apache.org/repos/asf/spark/blob/a7b46c62/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 8a83d4e..132345f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -78,11 +78,13 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
       seed: String,
       subsamplingRate: Double,
       maxMemoryInMB: Int,
-      cacheNodeIds: Boolean): RandomForestClassifierWrapper = {
+      cacheNodeIds: Boolean,
+      handleInvalid: String): RandomForestClassifierWrapper = {
 
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
+      .setHandleInvalid(handleInvalid)
     checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a7b46c62/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 806a927..027b1fb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.functions.col


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org