You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2017/07/15 12:56:45 UTC

spark git commit: [SPARK-20307][ML][SPARKR][FOLLOW-UP] RFormula should handle invalid for both features and label column.

Repository: spark
Updated Branches:
  refs/heads/master 74ac1fb08 -> 69e5282d3


[SPARK-20307][ML][SPARKR][FOLLOW-UP] RFormula should handle invalid for both features and label column.

## What changes were proposed in this pull request?
```RFormula``` should handle invalid for both features and label column.
#18496 only handle invalid values in features column. This PR add handling invalid values for label column and test cases.

## How was this patch tested?
Add test cases.

Author: Yanbo Liang <yb...@gmail.com>

Closes #18613 from yanboliang/spark-20307.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/69e5282d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/69e5282d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/69e5282d

Branch: refs/heads/master
Commit: 69e5282d3c2998611680d3e10f2830d4e9c5f750
Parents: 74ac1fb
Author: Yanbo Liang <yb...@gmail.com>
Authored: Sat Jul 15 20:56:38 2017 +0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Sat Jul 15 20:56:38 2017 +0800

----------------------------------------------------------------------
 R/pkg/tests/fulltests/test_mllib_tree.R         |  2 +-
 .../org/apache/spark/ml/feature/RFormula.scala  |  9 ++--
 .../apache/spark/ml/feature/RFormulaSuite.scala | 49 +++++++++++++++++++-
 python/pyspark/ml/feature.py                    |  5 +-
 4 files changed, 57 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/69e5282d/R/pkg/tests/fulltests/test_mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R
index 66a0693..e31a65f 100644
--- a/R/pkg/tests/fulltests/test_mllib_tree.R
+++ b/R/pkg/tests/fulltests/test_mllib_tree.R
@@ -225,7 +225,7 @@ test_that("spark.randomForest", {
   expect_error(collect(predictions))
   model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
                              maxDepth = 10, maxBins = 10, numTrees = 10,
-                             handleInvalid = "skip")
+                             handleInvalid = "keep")
   predictions <- predict(model, testdf)
   expect_equal(class(collect(predictions)$clicked[1]), "character")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/69e5282d/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index bb7acaf..c224454 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -134,16 +134,16 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
   def getFormula: String = $(formula)
 
   /**
-   * Param for how to handle invalid data (unseen labels or NULL values).
-   * Options are 'skip' (filter out rows with invalid data),
+   * Param for how to handle invalid data (unseen or NULL values) in features and label column
+   * of string type. Options are 'skip' (filter out rows with invalid data),
    * 'error' (throw an error), or 'keep' (put invalid data in a special additional
    * bucket, at index numLabels).
    * Default: "error"
    * @group param
    */
   @Since("2.3.0")
-  override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
-    "How to handle invalid data (unseen labels or NULL values). " +
+  override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to " +
+    "handle invalid data (unseen or NULL values) in features and label column of string type. " +
     "Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
     "or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
     ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
@@ -265,6 +265,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
       encoderStages += new StringIndexer()
         .setInputCol(resolvedFormula.label)
         .setOutputCol($(labelCol))
+        .setHandleInvalid($(handleInvalid))
     }
 
     val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)

http://git-wip-us.apache.org/repos/asf/spark/blob/69e5282d/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 23570d6..5d09c90 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.ParamsSuite
@@ -501,4 +501,51 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
       assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept)
     }
   }
+
+  test("handle unseen features or labels") {
+    val df1 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b")
+    val df2 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zy")).toDF("id", "a", "b")
+
+    // Handle unseen features.
+    val formula1 = new RFormula().setFormula("id ~ a + b")
+    intercept[SparkException] {
+      formula1.fit(df1).transform(df2).collect()
+    }
+    val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2)
+    val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2)
+
+    val expected1 = Seq(
+      (1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0),
+      (2, "bar", "zq", Vectors.dense(1.0, 1.0), 2.0)
+    ).toDF("id", "a", "b", "features", "label")
+    val expected2 = Seq(
+      (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0, 0.0), 1.0),
+      (2, "bar", "zq", Vectors.dense(1.0, 0.0, 1.0, 0.0), 2.0),
+      (3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0)
+    ).toDF("id", "a", "b", "features", "label")
+
+    assert(result1.collect() === expected1.collect())
+    assert(result2.collect() === expected2.collect())
+
+    // Handle unseen labels.
+    val formula2 = new RFormula().setFormula("b ~ a + id")
+    intercept[SparkException] {
+      formula2.fit(df1).transform(df2).collect()
+    }
+    val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2)
+    val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2)
+
+    val expected3 = Seq(
+      (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0),
+      (2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0)
+    ).toDF("id", "a", "b", "features", "label")
+    val expected4 = Seq(
+      (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0),
+      (2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0),
+      (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0)
+    ).toDF("id", "a", "b", "features", "label")
+
+    assert(result3.collect() === expected3.collect())
+    assert(result4.collect() === expected4.collect())
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/69e5282d/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 7eb1b9f..54b4026 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2107,8 +2107,9 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
                             typeConverter=TypeConverters.toString)
 
     handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
-                          "labels or NULL values). Options are 'skip' (filter out rows with " +
-                          "invalid data), error (throw an error), or 'keep' (put invalid data " +
+                          "or NULL values) in features and label column of string type. " +
+                          "Options are 'skip' (filter out rows with invalid data), " +
+                          "error (throw an error), or 'keep' (put invalid data " +
                           "in a special additional bucket, at index numLabels).",
                           typeConverter=TypeConverters.toString)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org