You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2019/02/20 14:53:03 UTC
[spark] branch master updated: [SPARK-22798][PYTHON][ML] Add
multiple column support to PySpark StringIndexer
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 74e9e1c [SPARK-22798][PYTHON][ML] Add multiple column support to PySpark StringIndexer
74e9e1c is described below
commit 74e9e1c192f00920b69aa47813e2ac2f4e9b4325
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Wed Feb 20 08:52:46 2019 -0600
[SPARK-22798][PYTHON][ML] Add multiple column support to PySpark StringIndexer
## What changes were proposed in this pull request?
Add multiple column support to PySpark StringIndexer
## How was this patch tested?
Add doctest
Closes #23741 from huaxingao/spark-22798.
Authored-by: Huaxin Gao <hu...@us.ibm.com>
Signed-off-by: Sean Owen <se...@databricks.com>
---
python/pyspark/ml/feature.py | 60 ++++++++++++++++++++++++++++-----
python/pyspark/ml/tests/test_wrapper.py | 8 ++++-
python/pyspark/ml/wrapper.py | 22 ++++++++++--
3 files changed, 77 insertions(+), 13 deletions(-)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 23d56c8..0d1e9bd 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2290,7 +2290,8 @@ class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
return self._call_java("mean")
-class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol):
+class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol,
+ HasInputCols, HasOutputCols):
"""
Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`.
"""
@@ -2371,16 +2372,37 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
>>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]),
... key=lambda x: x[0])
[(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]
+ >>> testData = sc.parallelize([Row(id=0, label1="a", label2="e"),
+ ... Row(id=1, label1="b", label2="f"),
+ ... Row(id=2, label1="c", label2="e"),
+ ... Row(id=3, label1="a", label2="f"),
+ ... Row(id=4, label1="a", label2="f"),
+ ... Row(id=5, label1="c", label2="f")], 3)
+ >>> multiRowDf = spark.createDataFrame(testData)
+ >>> inputs = ["label1", "label2"]
+ >>> outputs = ["index1", "index2"]
+ >>> stringIndexer = StringIndexer(inputCols=inputs, outputCols=outputs)
+ >>> model = stringIndexer.fit(multiRowDf)
+ >>> result = model.transform(multiRowDf)
+ >>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1,
+ ... result.index2).collect()]), key=lambda x: x[0])
+ [(0, 0.0, 1.0), (1, 2.0, 0.0), (2, 1.0, 1.0), (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)]
+ >>> fromlabelsModel = StringIndexerModel.from_arrays_of_labels([["a", "b", "c"], ["e", "f"]],
+ ... inputCols=inputs, outputCols=outputs)
+ >>> result = fromlabelsModel.transform(multiRowDf)
+ >>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1,
+ ... result.index2).collect()]), key=lambda x: x[0])
+ [(0, 0.0, 0.0), (1, 1.0, 1.0), (2, 2.0, 0.0), (3, 0.0, 1.0), (4, 0.0, 1.0), (5, 2.0, 1.0)]
.. versionadded:: 1.4.0
"""
@keyword_only
- def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
- stringOrderType="frequencyDesc"):
+ def __init__(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None,
+ handleInvalid="error", stringOrderType="frequencyDesc"):
"""
- __init__(self, inputCol=None, outputCol=None, handleInvalid="error", \
- stringOrderType="frequencyDesc")
+ __init__(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \
+ handleInvalid="error", stringOrderType="frequencyDesc")
"""
super(StringIndexer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
@@ -2389,11 +2411,11 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
@keyword_only
@since("1.4.0")
- def setParams(self, inputCol=None, outputCol=None, handleInvalid="error",
- stringOrderType="frequencyDesc"):
+ def setParams(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None,
+ handleInvalid="error", stringOrderType="frequencyDesc"):
"""
- setParams(self, inputCol=None, outputCol=None, handleInvalid="error", \
- stringOrderType="frequencyDesc")
+ setParams(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \
+ handleInvalid="error", stringOrderType="frequencyDesc")
Sets params for this StringIndexer.
"""
kwargs = self._input_kwargs
@@ -2436,6 +2458,26 @@ class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaML
model.setHandleInvalid(handleInvalid)
return model
+ @classmethod
+ @since("3.0.0")
+ def from_arrays_of_labels(cls, arrayOfLabels, inputCols, outputCols=None,
+ handleInvalid=None):
+ """
+ Construct the model directly from an array of array of label strings,
+ requires an active SparkContext.
+ """
+ sc = SparkContext._active_spark_context
+ java_class = sc._gateway.jvm.java.lang.String
+ jlabels = StringIndexerModel._new_java_array(arrayOfLabels, java_class)
+ model = StringIndexerModel._create_from_java_class(
+ "org.apache.spark.ml.feature.StringIndexerModel", jlabels)
+ model.setInputCols(inputCols)
+ if outputCols is not None:
+ model.setOutputCols(outputCols)
+ if handleInvalid is not None:
+ model.setHandleInvalid(handleInvalid)
+ return model
+
@property
@since("1.5.0")
def labels(self):
diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py
index ae672a0..c01521e 100644
--- a/python/pyspark/ml/tests/test_wrapper.py
+++ b/python/pyspark/ml/tests/test_wrapper.py
@@ -99,7 +99,13 @@ class WrapperTests(MLlibTestCase):
java_class = self.sc._gateway.jvm.java.lang.Integer
java_array = JavaWrapper._new_java_array([], java_class)
self.assertEqual(_java2py(self.sc, java_array), [])
-
+ # test array of array of strings
+ str_list = [["a", "b", "c"], ["d", "e"], ["f", "g", "h", "i"], []]
+ expected_str_list = [("a", "b", "c", None), ("d", "e", None, None), ("f", "g", "h", "i"),
+ (None, None, None, None)]
+ java_class = self.sc._gateway.jvm.java.lang.String
+ java_array = JavaWrapper._new_java_array(str_list, java_class)
+ self.assertEqual(_java2py(self.sc, java_array), expected_str_list)
if __name__ == "__main__":
from pyspark.ml.tests.test_wrapper import *
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index d325633..9bb1262a 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -71,6 +71,10 @@ class JavaWrapper(object):
"""
Create a Java array of given java_class type. Useful for
calling a method with a Scala Array from Python with Py4J.
+ If the param pylist is a 2D array, then a 2D java array will be returned.
+ The returned 2D java array is a square, non-jagged 2D array that is big
+ enough for all elements. The empty slots in the inner Java arrays will
+ be filled with null to make the non-jagged 2D array.
:param pylist:
Python list to convert to a Java Array.
@@ -87,9 +91,21 @@ class JavaWrapper(object):
- bool -> sc._gateway.jvm.java.lang.Boolean
"""
sc = SparkContext._active_spark_context
- java_array = sc._gateway.new_array(java_class, len(pylist))
- for i in xrange(len(pylist)):
- java_array[i] = pylist[i]
+ java_array = None
+ if len(pylist) > 0 and isinstance(pylist[0], list):
+ # If pylist is a 2D array, then a 2D java array will be created.
+ # The 2D array is a square, non-jagged 2D array that is big enough for all elements.
+ inner_array_length = 0
+ for i in xrange(len(pylist)):
+ inner_array_length = max(inner_array_length, len(pylist[i]))
+ java_array = sc._gateway.new_array(java_class, len(pylist), inner_array_length)
+ for i in xrange(len(pylist)):
+ for j in xrange(len(pylist[i])):
+ java_array[i][j] = pylist[i][j]
+ else:
+ java_array = sc._gateway.new_array(java_class, len(pylist))
+ for i in xrange(len(pylist)):
+ java_array[i] = pylist[i]
return java_array
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org