You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/10/14 02:44:31 UTC
spark git commit: [SPARK-15957][FOLLOW-UP][ML][PYSPARK] Add Python
API for RFormula forceIndexLabel.
Repository: spark
Updated Branches:
refs/heads/master 9dc0ca060 -> 44cbb61b3
[SPARK-15957][FOLLOW-UP][ML][PYSPARK] Add Python API for RFormula forceIndexLabel.
## What changes were proposed in this pull request?
Follow-up work of #13675, add Python API for ```RFormula forceIndexLabel```.
## How was this patch tested?
Unit test.
Author: Yanbo Liang <yb...@gmail.com>
Closes #15430 from yanboliang/spark-15957-python.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/44cbb61b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/44cbb61b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/44cbb61b
Branch: refs/heads/master
Commit: 44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5
Parents: 9dc0ca0
Author: Yanbo Liang <yb...@gmail.com>
Authored: Thu Oct 13 19:44:24 2016 -0700
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Thu Oct 13 19:44:24 2016 -0700
----------------------------------------------------------------------
python/pyspark/ml/feature.py | 31 +++++++++++++++++++++++++++----
python/pyspark/ml/tests.py | 16 ++++++++++++++++
2 files changed, 43 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/44cbb61b/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 64b21ca..a33c3e7 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2494,21 +2494,30 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
formula = Param(Params._dummy(), "formula", "R model formula",
typeConverter=TypeConverters.toString)
+ forceIndexLabel = Param(Params._dummy(), "forceIndexLabel",
+ "Force to index label whether it is numeric or string",
+ typeConverter=TypeConverters.toBoolean)
+
@keyword_only
- def __init__(self, formula=None, featuresCol="features", labelCol="label"):
+ def __init__(self, formula=None, featuresCol="features", labelCol="label",
+ forceIndexLabel=False):
"""
- __init__(self, formula=None, featuresCol="features", labelCol="label")
+ __init__(self, formula=None, featuresCol="features", labelCol="label", \
+ forceIndexLabel=False)
"""
super(RFormula, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
+ self._setDefault(forceIndexLabel=False)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.5.0")
- def setParams(self, formula=None, featuresCol="features", labelCol="label"):
+ def setParams(self, formula=None, featuresCol="features", labelCol="label",
+ forceIndexLabel=False):
"""
- setParams(self, formula=None, featuresCol="features", labelCol="label")
+ setParams(self, formula=None, featuresCol="features", labelCol="label", \
+ forceIndexLabel=False)
Sets params for RFormula.
"""
kwargs = self.setParams._input_kwargs
@@ -2528,6 +2537,20 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
"""
return self.getOrDefault(self.formula)
+ @since("2.1.0")
+ def setForceIndexLabel(self, value):
+ """
+ Sets the value of :py:attr:`forceIndexLabel`.
+ """
+ return self._set(forceIndexLabel=value)
+
+ @since("2.1.0")
+ def getForceIndexLabel(self):
+ """
+ Gets the value of :py:attr:`forceIndexLabel`.
+ """
+ return self.getOrDefault(self.forceIndexLabel)
+
def _create_model(self, java_model):
return RFormulaModel(java_model)
http://git-wip-us.apache.org/repos/asf/spark/blob/44cbb61b/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index e233549..9d46cc3 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -477,6 +477,22 @@ class FeatureTests(SparkSessionTestCase):
feature, expected = r
self.assertEqual(feature, expected)
+ def test_rformula_force_index_label(self):
+ df = self.spark.createDataFrame([
+ (1.0, 1.0, "a"),
+ (0.0, 2.0, "b"),
+ (1.0, 0.0, "a")], ["y", "x", "s"])
+ # Does not index label by default since it's numeric type.
+ rf = RFormula(formula="y ~ x + s")
+ model = rf.fit(df)
+ transformedDF = model.transform(df)
+ self.assertEqual(transformedDF.head().label, 1.0)
+ # Force to index label.
+ rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True)
+ model2 = rf2.fit(df)
+ transformedDF2 = model2.transform(df)
+ self.assertEqual(transformedDF2.head().label, 0.0)
+
class HasInducedError(Params):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org