You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2020/05/06 14:12:08 UTC
[spark] branch master updated: [SPARK-31609][ML][PYSPARK] Add
VarianceThresholdSelector to PySpark
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 09ece50 [SPARK-31609][ML][PYSPARK] Add VarianceThresholdSelector to PySpark
09ece50 is described below
commit 09ece50799222d577009a2bbd480304d1ae1e14e
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Wed May 6 09:11:03 2020 -0500
[SPARK-31609][ML][PYSPARK] Add VarianceThresholdSelector to PySpark
### What changes were proposed in this pull request?
Add VarianceThresholdSelector to PySpark
### Why are the changes needed?
parity between Scala and Python
### Does this PR introduce any user-facing change?
Yes.
VarianceThresholdSelector is added to PySpark
### How was this patch tested?
new doctest
Closes #28409 from huaxingao/variance_py.
Authored-by: Huaxin Gao <hu...@us.ibm.com>
Signed-off-by: Sean Owen <sr...@gmail.com>
---
python/pyspark/ml/feature.py | 142 +++++++++++++++++++++++++++++++++++++++++++
1 file changed, 142 insertions(+)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 6df2f74..7acf8ce 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -57,6 +57,7 @@ __all__ = ['Binarizer',
'StopWordsRemover',
'StringIndexer', 'StringIndexerModel',
'Tokenizer',
+ 'VarianceThresholdSelector', 'VarianceThresholdSelectorModel',
'VectorAssembler',
'VectorIndexer', 'VectorIndexerModel',
'VectorSizeHint',
@@ -5381,6 +5382,147 @@ class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReada
return self._set(handleInvalid=value)
+class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol):
+ """
+ Params for :py:class:`VarianceThresholdSelector` and
+ :py:class:`VarianceThresholdSelectorrModel`.
+
+ .. versionadded:: 3.1.0
+ """
+
+ varianceThreshold = Param(Params._dummy(), "varianceThreshold",
+ "Param for variance threshold. Features with a variance not " +
+ "greater than this threshold will be removed. The default value " +
+ "is 0.0.", typeConverter=TypeConverters.toFloat)
+
+ @since("3.1.0")
+ def getVarianceThreshold(self):
+ """
+ Gets the value of varianceThreshold or its default value.
+ """
+ return self.getOrDefault(self.varianceThreshold)
+
+
+@inherit_doc
+class VarianceThresholdSelector(JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable,
+ JavaMLWritable):
+ """
+ Feature selector that removes all low-variance features. Features with a
+ variance not greater than the threshold will be removed. The default is to keep
+ all features with non-zero variance, i.e. remove the features that have the
+ same value in all samples.
+
+ >>> from pyspark.ml.linalg import Vectors
+ >>> df = spark.createDataFrame(
+ ... [(Vectors.dense([6.0, 7.0, 0.0, 7.0, 6.0, 0.0]),),
+ ... (Vectors.dense([0.0, 9.0, 6.0, 0.0, 5.0, 9.0]),),
+ ... (Vectors.dense([0.0, 9.0, 3.0, 0.0, 5.0, 5.0]),),
+ ... (Vectors.dense([0.0, 9.0, 8.0, 5.0, 6.0, 4.0]),),
+ ... (Vectors.dense([8.0, 9.0, 6.0, 5.0, 4.0, 4.0]),),
+ ... (Vectors.dense([8.0, 9.0, 6.0, 0.0, 0.0, 0.0]),)],
+ ... ["features"])
+ >>> selector = VarianceThresholdSelector(varianceThreshold=8.2, outputCol="selectedFeatures")
+ >>> model = selector.fit(df)
+ >>> model.getFeaturesCol()
+ 'features'
+ >>> model.setFeaturesCol("features")
+ VarianceThresholdSelectorModel...
+ >>> model.transform(df).head().selectedFeatures
+ DenseVector([6.0, 7.0, 0.0])
+ >>> model.selectedFeatures
+ [0, 3, 5]
+ >>> varianceThresholdSelectorPath = temp_path + "/variance-threshold-selector"
+ >>> selector.save(varianceThresholdSelectorPath)
+ >>> loadedSelector = VarianceThresholdSelector.load(varianceThresholdSelectorPath)
+ >>> loadedSelector.getVarianceThreshold() == selector.getVarianceThreshold()
+ True
+ >>> modelPath = temp_path + "/variance-threshold-selector-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = VarianceThresholdSelectorModel.load(modelPath)
+ >>> loadedModel.selectedFeatures == model.selectedFeatures
+ True
+
+ .. versionadded:: 3.1.0
+ """
+
+ @keyword_only
+ def __init__(self, featuresCol="features", outputCol=None, varianceThreshold=0.0):
+ """
+ __init__(self, featuresCol="features", outputCol=None, varianceThreshold=0.0)
+ """
+ super(VarianceThresholdSelector, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.feature.VarianceThresholdSelector", self.uid)
+ self._setDefault(varianceThreshold=0.0)
+ kwargs = self._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ @since("3.1.0")
+ def setParams(self, featuresCol="features", outputCol=None, varianceThreshold=0.0):
+ """
+ setParams(self, featuresCol="features", outputCol=None, varianceThreshold=0.0)
+ Sets params for this VarianceThresholdSelector.
+ """
+ kwargs = self._input_kwargs
+ return self._set(**kwargs)
+
+ @since("3.1.0")
+ def setVarianceThreshold(self, value):
+ """
+ Sets the value of :py:attr:`varianceThreshold`.
+ """
+ return self._set(varianceThreshold=value)
+
+ @since("3.1.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("3.1.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ def _create_model(self, java_model):
+ return VarianceThresholdSelectorModel(java_model)
+
+
+class VarianceThresholdSelectorModel(JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable,
+ JavaMLWritable):
+ """
+ Model fitted by :py:class:`VarianceThresholdSelector`.
+
+ .. versionadded:: 3.1.0
+ """
+
+ @since("3.1.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("3.1.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ @property
+ @since("3.1.0")
+ def selectedFeatures(self):
+ """
+ List of indices to select (filter).
+ """
+ return self._call_java("selectedFeatures")
+
+
if __name__ == "__main__":
import doctest
import tempfile
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org