You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/09/01 19:49:01 UTC

spark git commit: [SPARK-9679] [ML] [PYSPARK] Add Python API for Stop Words Remover

Repository: spark
Updated Branches:
  refs/heads/master 391e6be0a -> e6e483cc4


[SPARK-9679] [ML] [PYSPARK] Add Python API for Stop Words Remover

Add a python API for the Stop Words Remover.

Author: Holden Karau <ho...@pigscanfly.ca>

Closes #8118 from holdenk/SPARK-9679-python-StopWordsRemover.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e6e483cc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e6e483cc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e6e483cc

Branch: refs/heads/master
Commit: e6e483cc4de740c46398385b03ffe0e662edae39
Parents: 391e6be
Author: Holden Karau <ho...@pigscanfly.ca>
Authored: Tue Sep 1 10:48:57 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Sep 1 10:48:57 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/feature/StopWordsRemover.scala     |  6 +-
 .../ml/feature/StopWordsRemoverSuite.scala      |  2 +-
 python/pyspark/ml/feature.py                    | 73 +++++++++++++++++++-
 python/pyspark/ml/tests.py                      | 20 +++++-
 4 files changed, 93 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e6e483cc/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 5d77ea0..7da430c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -29,14 +29,14 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp
 /**
  * stop words list
  */
-private object StopWords {
+private[spark] object StopWords {
 
   /**
    * Use the same default stopwords list as scikit-learn.
    * The original list can be found from "Glasgow Information Retrieval Group"
    * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]]
    */
-  val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again",
+  val English = Array( "a", "about", "above", "across", "after", "afterwards", "again",
     "against", "all", "almost", "alone", "along", "already", "also", "although", "always",
     "am", "among", "amongst", "amoungst", "amount", "an", "and", "another",
     "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are",
@@ -121,7 +121,7 @@ class StopWordsRemover(override val uid: String)
   /** @group getParam */
   def getCaseSensitive: Boolean = $(caseSensitive)
 
-  setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false)
+  setDefault(stopWords -> StopWords.English, caseSensitive -> false)
 
   override def transform(dataset: DataFrame): DataFrame = {
     val outputSchema = transformSchema(dataset.schema)

http://git-wip-us.apache.org/repos/asf/spark/blob/e6e483cc/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index f01306f..e0d433f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -65,7 +65,7 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
   }
 
   test("StopWordsRemover with additional words") {
-    val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala")
+    val stopWords = StopWords.English ++ Array("python", "scala")
     val remover = new StopWordsRemover()
       .setInputCol("raw")
       .setOutputCol("filtered")

http://git-wip-us.apache.org/repos/asf/spark/blob/e6e483cc/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 0626281..d955307 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -22,7 +22,7 @@ if sys.version > '3':
 from pyspark.rdd import ignore_unicode_prefix
 from pyspark.ml.param.shared import *
 from pyspark.ml.util import keyword_only
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
 from pyspark.mllib.common import inherit_doc
 from pyspark.mllib.linalg import _convert_to_vector
 
@@ -30,7 +30,7 @@ __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF',
            'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer',
            'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer',
            'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec',
-           'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel']
+           'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover']
 
 
 @inherit_doc
@@ -933,6 +933,75 @@ class StringIndexerModel(JavaModel):
     """
 
 
+class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
+    """
+    .. note:: Experimental
+
+    A feature transformer that filters out stop words from input.
+    Note: null values from input array are preserved unless adding null to stopWords explicitly.
+    """
+    # a placeholder to make the stopwords show up in generated doc
+    stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out")
+    caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
+                          "comparison over the stop words")
+
+    @keyword_only
+    def __init__(self, inputCol=None, outputCol=None, stopWords=None,
+                 caseSensitive=False):
+        """
+        __init__(self, inputCol=None, outputCol=None, stopWords=None,\
+                 caseSensitive=false)
+        """
+        super(StopWordsRemover, self).__init__()
+        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
+                                            self.uid)
+        self.stopWords = Param(self, "stopWords", "The words to be filtered out")
+        self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " +
+                                   "sensitive comparison over the stop words")
+        stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
+        defaultStopWords = stopWordsObj.English()
+        self._setDefault(stopWords=defaultStopWords)
+        kwargs = self.__init__._input_kwargs
+        self.setParams(**kwargs)
+
+    @keyword_only
+    def setParams(self, inputCol=None, outputCol=None, stopWords=None,
+                  caseSensitive=False):
+        """
+        setParams(self, inputCol="input", outputCol="output", stopWords=None,\
+                  caseSensitive=false)
+        Sets params for this StopWordRemover.
+        """
+        kwargs = self.setParams._input_kwargs
+        return self._set(**kwargs)
+
+    def setStopWords(self, value):
+        """
+        Specify the stopwords to be filtered.
+        """
+        self._paramMap[self.stopWords] = value
+        return self
+
+    def getStopWords(self):
+        """
+        Get the stopwords.
+        """
+        return self.getOrDefault(self.stopWords)
+
+    def setCaseSensitive(self, value):
+        """
+        Set whether to do a case sensitive comparison over the stop words
+        """
+        self._paramMap[self.caseSensitive] = value
+        return self
+
+    def getCaseSensitive(self):
+        """
+        Get whether to do a case sensitive comparison over the stop words.
+        """
+        return self.getOrDefault(self.caseSensitive)
+
+
 @inherit_doc
 @ignore_unicode_prefix
 class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):

http://git-wip-us.apache.org/repos/asf/spark/blob/e6e483cc/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 60e4237..b892318 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -31,7 +31,7 @@ else:
     import unittest
 
 from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-from pyspark.sql import DataFrame, SQLContext
+from pyspark.sql import DataFrame, SQLContext, Row
 from pyspark.sql.functions import rand
 from pyspark.ml.evaluation import RegressionEvaluator
 from pyspark.ml.param import Param, Params
@@ -258,7 +258,7 @@ class FeatureTests(PySparkTestCase):
     def test_ngram(self):
         sqlContext = SQLContext(self.sc)
         dataset = sqlContext.createDataFrame([
-            ([["a", "b", "c", "d", "e"]])], ["input"])
+            Row(input=["a", "b", "c", "d", "e"])])
         ngram0 = NGram(n=4, inputCol="input", outputCol="output")
         self.assertEqual(ngram0.getN(), 4)
         self.assertEqual(ngram0.getInputCol(), "input")
@@ -266,6 +266,22 @@ class FeatureTests(PySparkTestCase):
         transformedDF = ngram0.transform(dataset)
         self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])
 
+    def test_stopwordsremover(self):
+        sqlContext = SQLContext(self.sc)
+        dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])])
+        stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
+        # Default
+        self.assertEquals(stopWordRemover.getInputCol(), "input")
+        transformedDF = stopWordRemover.transform(dataset)
+        self.assertEquals(transformedDF.head().output, ["panda"])
+        # Custom
+        stopwords = ["panda"]
+        stopWordRemover.setStopWords(stopwords)
+        self.assertEquals(stopWordRemover.getInputCol(), "input")
+        self.assertEquals(stopWordRemover.getStopWords(), stopwords)
+        transformedDF = stopWordRemover.transform(dataset)
+        self.assertEquals(transformedDF.head().output, ["a"])
+
 
 class HasInducedError(Params):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org