You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/09/09 07:13:10 UTC

spark git commit: [SPARK-9654] [ML] [PYSPARK] Add IndexToString to PySpark

Repository: spark
Updated Branches:
  refs/heads/master 0e2f21633 -> 2f6fd5256


[SPARK-9654] [ML] [PYSPARK] Add IndexToString to PySpark

Adds IndexToString to PySpark.

Author: Holden Karau <ho...@pigscanfly.ca>

Closes #7976 from holdenk/SPARK-9654-add-string-indexer-inverse-in-pyspark.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2f6fd525
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2f6fd525
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2f6fd525

Branch: refs/heads/master
Commit: 2f6fd5256c6650868916a3eefaa0beb091187cbb
Parents: 0e2f216
Author: Holden Karau <ho...@pigscanfly.ca>
Authored: Tue Sep 8 22:13:05 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Sep 8 22:13:05 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/StringIndexer.scala |  2 +-
 python/pyspark/ml/feature.py                    | 74 ++++++++++++++++++--
 python/pyspark/ml/wrapper.py                    |  3 +-
 3 files changed, 73 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2f6fd525/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 77aeed0..b6482ff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -102,7 +102,7 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
  * [[StringIndexerModel.transform]] would return the input dataset unmodified.
  * This is a temporary fix for the case when target labels do not exist during prediction.
  *
- * @param labels  Ordered list of labels, corresponding to indices to be assigned
+ * @param labels  Ordered list of labels, corresponding to indices to be assigned.
  */
 @Experimental
 class StringIndexerModel (

http://git-wip-us.apache.org/repos/asf/spark/blob/2f6fd525/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index a7c5b2b..8c26cfb 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -27,10 +27,11 @@ from pyspark.mllib.common import inherit_doc
 from pyspark.mllib.linalg import _convert_to_vector
 
 __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel',
-           'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer',
-           'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer',
-           'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec',
-           'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover']
+           'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion',
+           'RegexTokenizer', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel',
+           'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer',
+           'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel',
+           'StopWordsRemover']
 
 
 @inherit_doc
@@ -934,6 +935,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
     >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
     ...     key=lambda x: x[0])
     [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]
+    >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels())
+    >>> itd = inverter.transform(td)
+    >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
+    ...     key=lambda x: x[0])
+    [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
     """
 
     @keyword_only
@@ -965,6 +971,66 @@ class StringIndexerModel(JavaModel):
 
     Model fitted by StringIndexer.
     """
+    @property
+    def labels(self):
+        """
+        Ordered list of labels, corresponding to indices to be assigned.
+        """
+        return self._java_obj.labels
+
+
+@inherit_doc
+class IndexToString(JavaTransformer, HasInputCol, HasOutputCol):
+    """
+    .. note:: Experimental
+
+    A :py:class:`Transformer` that maps a column of string indices back to a new column of
+    corresponding string values using either the ML attributes of the input column, or if
+    provided using the labels supplied by the user.
+    All original columns are kept during transformation.
+    See L{StringIndexer} for converting strings into indices.
+    """
+
+    # a placeholder to make the labels show up in generated doc
+    labels = Param(Params._dummy(), "labels",
+                   "Optional array of labels to be provided by the user, if not supplied or " +
+                   "empty, column metadata is read for labels")
+
+    @keyword_only
+    def __init__(self, inputCol=None, outputCol=None, labels=None):
+        """
+        __init__(self, inputCol=None, outputCol=None, labels=None)
+        """
+        super(IndexToString, self).__init__()
+        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString",
+                                            self.uid)
+        self.labels = Param(self, "labels",
+                            "Optional array of labels to be provided by the user, if not " +
+                            "supplied or empty, column metadata is read for labels")
+        kwargs = self.__init__._input_kwargs
+        self.setParams(**kwargs)
+
+    @keyword_only
+    def setParams(self, inputCol=None, outputCol=None, labels=None):
+        """
+        setParams(self, inputCol=None, outputCol=None, labels=None)
+        Sets params for this IndexToString.
+        """
+        kwargs = self.setParams._input_kwargs
+        return self._set(**kwargs)
+
+    def setLabels(self, value):
+        """
+        Sets the value of :py:attr:`labels`.
+        """
+        self._paramMap[self.labels] = value
+        return self
+
+    def getLabels(self):
+        """
+        Gets the value of :py:attr:`labels` or its default value.
+        """
+        return self.getOrDefault(self.labels)
 
 
 class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):

http://git-wip-us.apache.org/repos/asf/spark/blob/2f6fd525/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 253705b..8218c7c 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -136,7 +136,8 @@ class JavaEstimator(Estimator, JavaWrapper):
 class JavaTransformer(Transformer, JavaWrapper):
     """
     Base class for :py:class:`Transformer`s that wrap Java/Scala
-    implementations.
+    implementations. Subclasses should ensure they have the transformer Java object
+    available as _java_obj.
     """
 
     __metaclass__ = ABCMeta


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org