You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2019/02/20 14:53:03 UTC

[spark] branch master updated: [SPARK-22798][PYTHON][ML] Add multiple column support to PySpark StringIndexer

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 74e9e1c  [SPARK-22798][PYTHON][ML] Add multiple column support to PySpark StringIndexer
74e9e1c is described below

commit 74e9e1c192f00920b69aa47813e2ac2f4e9b4325
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Wed Feb 20 08:52:46 2019 -0600

    [SPARK-22798][PYTHON][ML] Add multiple column support to PySpark StringIndexer
    
    ## What changes were proposed in this pull request?
    
    Add multiple column support to PySpark StringIndexer
    
    ## How was this patch tested?
    
    Add doctest
    
    Closes #23741 from huaxingao/spark-22798.
    
    Authored-by: Huaxin Gao <hu...@us.ibm.com>
    Signed-off-by: Sean Owen <se...@databricks.com>
---
 python/pyspark/ml/feature.py            | 60 ++++++++++++++++++++++++++++-----
 python/pyspark/ml/tests/test_wrapper.py |  8 ++++-
 python/pyspark/ml/wrapper.py            | 22 ++++++++++--
 3 files changed, 77 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 23d56c8..0d1e9bd 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2290,7 +2290,8 @@ class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
         return self._call_java("mean")
 
 
-class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol):
+class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol,
+                           HasInputCols, HasOutputCols):
     """
     Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`.
     """
@@ -2371,16 +2372,37 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
     >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]),
     ...     key=lambda x: x[0])
     [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]
+    >>> testData = sc.parallelize([Row(id=0, label1="a", label2="e"),
+    ...                            Row(id=1, label1="b", label2="f"),
+    ...                            Row(id=2, label1="c", label2="e"),
+    ...                            Row(id=3, label1="a", label2="f"),
+    ...                            Row(id=4, label1="a", label2="f"),
+    ...                            Row(id=5, label1="c", label2="f")], 3)
+    >>> multiRowDf = spark.createDataFrame(testData)
+    >>> inputs = ["label1", "label2"]
+    >>> outputs = ["index1", "index2"]
+    >>> stringIndexer = StringIndexer(inputCols=inputs, outputCols=outputs)
+    >>> model = stringIndexer.fit(multiRowDf)
+    >>> result = model.transform(multiRowDf)
+    >>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1,
+    ...     result.index2).collect()]), key=lambda x: x[0])
+    [(0, 0.0, 1.0), (1, 2.0, 0.0), (2, 1.0, 1.0), (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)]
+    >>> fromlabelsModel = StringIndexerModel.from_arrays_of_labels([["a", "b", "c"], ["e", "f"]],
+    ...     inputCols=inputs, outputCols=outputs)
+    >>> result = fromlabelsModel.transform(multiRowDf)
+    >>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1,
+    ...     result.index2).collect()]), key=lambda x: x[0])
+    [(0, 0.0, 0.0), (1, 1.0, 1.0), (2, 2.0, 0.0), (3, 0.0, 1.0), (4, 0.0, 1.0), (5, 2.0, 1.0)]
 
     .. versionadded:: 1.4.0
     """
 
     @keyword_only
-    def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
-                 stringOrderType="frequencyDesc"):
+    def __init__(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None,
+                 handleInvalid="error", stringOrderType="frequencyDesc"):
         """
-        __init__(self, inputCol=None, outputCol=None, handleInvalid="error", \
-                 stringOrderType="frequencyDesc")
+        __init__(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \
+                 handleInvalid="error", stringOrderType="frequencyDesc")
         """
         super(StringIndexer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
@@ -2389,11 +2411,11 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, inputCol=None, outputCol=None, handleInvalid="error",
-                  stringOrderType="frequencyDesc"):
+    def setParams(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None,
+                  handleInvalid="error", stringOrderType="frequencyDesc"):
         """
-        setParams(self, inputCol=None, outputCol=None, handleInvalid="error", \
-                  stringOrderType="frequencyDesc")
+        setParams(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \
+                  handleInvalid="error", stringOrderType="frequencyDesc")
         Sets params for this StringIndexer.
         """
         kwargs = self._input_kwargs
@@ -2436,6 +2458,26 @@ class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaML
             model.setHandleInvalid(handleInvalid)
         return model
 
+    @classmethod
+    @since("3.0.0")
+    def from_arrays_of_labels(cls, arrayOfLabels, inputCols, outputCols=None,
+                              handleInvalid=None):
+        """
+        Construct the model directly from an array of array of label strings,
+        requires an active SparkContext.
+        """
+        sc = SparkContext._active_spark_context
+        java_class = sc._gateway.jvm.java.lang.String
+        jlabels = StringIndexerModel._new_java_array(arrayOfLabels, java_class)
+        model = StringIndexerModel._create_from_java_class(
+            "org.apache.spark.ml.feature.StringIndexerModel", jlabels)
+        model.setInputCols(inputCols)
+        if outputCols is not None:
+            model.setOutputCols(outputCols)
+        if handleInvalid is not None:
+            model.setHandleInvalid(handleInvalid)
+        return model
+
     @property
     @since("1.5.0")
     def labels(self):
diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py
index ae672a0..c01521e 100644
--- a/python/pyspark/ml/tests/test_wrapper.py
+++ b/python/pyspark/ml/tests/test_wrapper.py
@@ -99,7 +99,13 @@ class WrapperTests(MLlibTestCase):
         java_class = self.sc._gateway.jvm.java.lang.Integer
         java_array = JavaWrapper._new_java_array([], java_class)
         self.assertEqual(_java2py(self.sc, java_array), [])
-
+        # test array of array of strings
+        str_list = [["a", "b", "c"], ["d", "e"], ["f", "g", "h", "i"], []]
+        expected_str_list = [("a", "b", "c", None), ("d", "e", None, None), ("f", "g", "h", "i"),
+                             (None, None, None, None)]
+        java_class = self.sc._gateway.jvm.java.lang.String
+        java_array = JavaWrapper._new_java_array(str_list, java_class)
+        self.assertEqual(_java2py(self.sc, java_array), expected_str_list)
 
 if __name__ == "__main__":
     from pyspark.ml.tests.test_wrapper import *
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index d325633..9bb1262a 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -71,6 +71,10 @@ class JavaWrapper(object):
         """
         Create a Java array of given java_class type. Useful for
         calling a method with a Scala Array from Python with Py4J.
+        If the param pylist is a 2D array, then a 2D java array will be returned.
+        The returned 2D java array is a square, non-jagged 2D array that is big
+        enough for all elements. The empty slots in the inner Java arrays will
+        be filled with null to make the non-jagged 2D array.
 
         :param pylist:
           Python list to convert to a Java Array.
@@ -87,9 +91,21 @@ class JavaWrapper(object):
           - bool -> sc._gateway.jvm.java.lang.Boolean
         """
         sc = SparkContext._active_spark_context
-        java_array = sc._gateway.new_array(java_class, len(pylist))
-        for i in xrange(len(pylist)):
-            java_array[i] = pylist[i]
+        java_array = None
+        if len(pylist) > 0 and isinstance(pylist[0], list):
+            # If pylist is a 2D array, then a 2D java array will be created.
+            # The 2D array is a square, non-jagged 2D array that is big enough for all elements.
+            inner_array_length = 0
+            for i in xrange(len(pylist)):
+                inner_array_length = max(inner_array_length, len(pylist[i]))
+            java_array = sc._gateway.new_array(java_class, len(pylist), inner_array_length)
+            for i in xrange(len(pylist)):
+                for j in xrange(len(pylist[i])):
+                    java_array[i][j] = pylist[i][j]
+        else:
+            java_array = sc._gateway.new_array(java_class, len(pylist))
+            for i in xrange(len(pylist)):
+                java_array[i] = pylist[i]
         return java_array
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org