You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2017/07/02 08:17:09 UTC

spark git commit: [SPARK-19852][PYSPARK][ML] Python StringIndexer supports 'keep' to handle invalid data

Repository: spark
Updated Branches:
  refs/heads/master c605fee01 -> c19680be1


[SPARK-19852][PYSPARK][ML] Python StringIndexer supports 'keep' to handle invalid data

## What changes were proposed in this pull request?
This PR is to maintain API parity with changes made in SPARK-17498 to support a new option
'keep' in StringIndexer to handle unseen labels or NULL values with PySpark.

Note: This is updated version of #17237 , the primary author of this PR is VinceShieh .
## How was this patch tested?
Unit tests.

Author: VinceShieh <vi...@intel.com>
Author: Yanbo Liang <yb...@gmail.com>

Closes #18453 from yanboliang/spark-19852.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c19680be
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c19680be
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c19680be

Branch: refs/heads/master
Commit: c19680be1c532dded1e70edce7a981ba28af09ad
Parents: c605fee
Author: Yanbo Liang <yb...@gmail.com>
Authored: Sun Jul 2 16:17:03 2017 +0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Sun Jul 2 16:17:03 2017 +0800

----------------------------------------------------------------------
 python/pyspark/ml/feature.py |  6 ++++++
 python/pyspark/ml/tests.py   | 21 +++++++++++++++++++++
 2 files changed, 27 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c19680be/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 77de1cc..25ad06f 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2132,6 +2132,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
                             "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
                             typeConverter=TypeConverters.toString)
 
+    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
+                          "labels or NULL values). Options are 'skip' (filter out rows with " +
+                          "invalid data), error (throw an error), or 'keep' (put invalid data " +
+                          "in a special additional bucket, at index numLabels).",
+                          typeConverter=TypeConverters.toString)
+
     @keyword_only
     def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
                  stringOrderType="frequencyDesc"):

http://git-wip-us.apache.org/repos/asf/spark/blob/c19680be/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 17a3947..ffb8b0a 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -551,6 +551,27 @@ class FeatureTests(SparkSessionTestCase):
         for i in range(0, len(expected)):
             self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))
 
+    def test_string_indexer_handle_invalid(self):
+        df = self.spark.createDataFrame([
+            (0, "a"),
+            (1, "d"),
+            (2, None)], ["id", "label"])
+
+        si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep",
+                            stringOrderType="alphabetAsc")
+        model1 = si1.fit(df)
+        td1 = model1.transform(df)
+        actual1 = td1.select("id", "indexed").collect()
+        expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)]
+        self.assertEqual(actual1, expected1)
+
+        si2 = si1.setHandleInvalid("skip")
+        model2 = si2.fit(df)
+        td2 = model2.transform(df)
+        actual2 = td2.select("id", "indexed").collect()
+        expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
+        self.assertEqual(actual2, expected2)
+
 
 class HasInducedError(Params):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org