You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2017/07/15 12:56:45 UTC
spark git commit: [SPARK-20307][ML][SPARKR][FOLLOW-UP] RFormula
should handle invalid for both features and label column.
Repository: spark
Updated Branches:
refs/heads/master 74ac1fb08 -> 69e5282d3
[SPARK-20307][ML][SPARKR][FOLLOW-UP] RFormula should handle invalid for both features and label column.
## What changes were proposed in this pull request?
```RFormula``` should handle invalid for both features and label column.
#18496 only handle invalid values in features column. This PR add handling invalid values for label column and test cases.
## How was this patch tested?
Add test cases.
Author: Yanbo Liang <yb...@gmail.com>
Closes #18613 from yanboliang/spark-20307.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/69e5282d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/69e5282d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/69e5282d
Branch: refs/heads/master
Commit: 69e5282d3c2998611680d3e10f2830d4e9c5f750
Parents: 74ac1fb
Author: Yanbo Liang <yb...@gmail.com>
Authored: Sat Jul 15 20:56:38 2017 +0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Sat Jul 15 20:56:38 2017 +0800
----------------------------------------------------------------------
R/pkg/tests/fulltests/test_mllib_tree.R | 2 +-
.../org/apache/spark/ml/feature/RFormula.scala | 9 ++--
.../apache/spark/ml/feature/RFormulaSuite.scala | 49 +++++++++++++++++++-
python/pyspark/ml/feature.py | 5 +-
4 files changed, 57 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/69e5282d/R/pkg/tests/fulltests/test_mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R
index 66a0693..e31a65f 100644
--- a/R/pkg/tests/fulltests/test_mllib_tree.R
+++ b/R/pkg/tests/fulltests/test_mllib_tree.R
@@ -225,7 +225,7 @@ test_that("spark.randomForest", {
expect_error(collect(predictions))
model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
maxDepth = 10, maxBins = 10, numTrees = 10,
- handleInvalid = "skip")
+ handleInvalid = "keep")
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "character")
http://git-wip-us.apache.org/repos/asf/spark/blob/69e5282d/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index bb7acaf..c224454 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -134,16 +134,16 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
def getFormula: String = $(formula)
/**
- * Param for how to handle invalid data (unseen labels or NULL values).
- * Options are 'skip' (filter out rows with invalid data),
+ * Param for how to handle invalid data (unseen or NULL values) in features and label column
+ * of string type. Options are 'skip' (filter out rows with invalid data),
* 'error' (throw an error), or 'keep' (put invalid data in a special additional
* bucket, at index numLabels).
* Default: "error"
* @group param
*/
@Since("2.3.0")
- override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
- "How to handle invalid data (unseen labels or NULL values). " +
+ override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to " +
+ "handle invalid data (unseen or NULL values) in features and label column of string type. " +
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
@@ -265,6 +265,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
encoderStages += new StringIndexer()
.setInputCol(resolvedFormula.label)
.setOutputCol($(labelCol))
+ .setHandleInvalid($(handleInvalid))
}
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
http://git-wip-us.apache.org/repos/asf/spark/blob/69e5282d/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 23570d6..5d09c90 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamsSuite
@@ -501,4 +501,51 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept)
}
}
+
+ test("handle unseen features or labels") {
+ val df1 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b")
+ val df2 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zy")).toDF("id", "a", "b")
+
+ // Handle unseen features.
+ val formula1 = new RFormula().setFormula("id ~ a + b")
+ intercept[SparkException] {
+ formula1.fit(df1).transform(df2).collect()
+ }
+ val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2)
+ val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2)
+
+ val expected1 = Seq(
+ (1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0),
+ (2, "bar", "zq", Vectors.dense(1.0, 1.0), 2.0)
+ ).toDF("id", "a", "b", "features", "label")
+ val expected2 = Seq(
+ (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0, 0.0), 1.0),
+ (2, "bar", "zq", Vectors.dense(1.0, 0.0, 1.0, 0.0), 2.0),
+ (3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0)
+ ).toDF("id", "a", "b", "features", "label")
+
+ assert(result1.collect() === expected1.collect())
+ assert(result2.collect() === expected2.collect())
+
+ // Handle unseen labels.
+ val formula2 = new RFormula().setFormula("b ~ a + id")
+ intercept[SparkException] {
+ formula2.fit(df1).transform(df2).collect()
+ }
+ val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2)
+ val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2)
+
+ val expected3 = Seq(
+ (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0),
+ (2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0)
+ ).toDF("id", "a", "b", "features", "label")
+ val expected4 = Seq(
+ (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0),
+ (2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0),
+ (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0)
+ ).toDF("id", "a", "b", "features", "label")
+
+ assert(result3.collect() === expected3.collect())
+ assert(result4.collect() === expected4.collect())
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/69e5282d/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 7eb1b9f..54b4026 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2107,8 +2107,9 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
typeConverter=TypeConverters.toString)
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
- "labels or NULL values). Options are 'skip' (filter out rows with " +
- "invalid data), error (throw an error), or 'keep' (put invalid data " +
+ "or NULL values) in features and label column of string type. " +
+ "Options are 'skip' (filter out rows with invalid data), " +
+ "error (throw an error), or 'keep' (put invalid data " +
"in a special additional bucket, at index numLabels).",
typeConverter=TypeConverters.toString)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org