You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/02/12 01:42:48 UTC
spark git commit: [SPARK-13047][PYSPARK][ML] Pyspark Params.hasParam
should not throw an error
Repository: spark
Updated Branches:
refs/heads/master 30e009556 -> b35467388
[SPARK-13047][PYSPARK][ML] Pyspark Params.hasParam should not throw an error
Pyspark Params class has a method `hasParam(paramName)` which returns `True` if the class has a parameter by that name, but throws an `AttributeError` otherwise. There is not currently a way of getting a Boolean to indicate if a class has a parameter. With Spark 2.0 we could modify the existing behavior of `hasParam` or add an additional method with this functionality.
In Python:
```python
from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()
print nb.hasParam("smoothing")
print nb.hasParam("notAParam")
```
produces:
> True
> AttributeError: 'NaiveBayes' object has no attribute 'notAParam'
However, in Scala:
```scala
import org.apache.spark.ml.classification.NaiveBayes
val nb = new NaiveBayes()
nb.hasParam("smoothing")
nb.hasParam("notAParam")
```
produces:
> true
> false
cc holdenk
Author: sethah <se...@gmail.com>
Closes #10962 from sethah/SPARK-13047.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b3546738
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b3546738
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b3546738
Branch: refs/heads/master
Commit: b35467388612167f0bc3d17142c21a406f6c620d
Parents: 30e0095
Author: sethah <se...@gmail.com>
Authored: Thu Feb 11 16:42:44 2016 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Feb 11 16:42:44 2016 -0800
----------------------------------------------------------------------
python/pyspark/ml/param/__init__.py | 7 +++++--
python/pyspark/ml/tests.py | 9 +++++++--
2 files changed, 12 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/b3546738/python/pyspark/ml/param/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index ea86d6a..bbf83f0 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -179,8 +179,11 @@ class Params(Identifiable):
Tests whether this instance contains a param with a given
(string) name.
"""
- param = self._resolveParam(paramName)
- return param in self.params
+ if isinstance(paramName, str):
+ p = getattr(self, paramName, None)
+ return isinstance(p, Param)
+ else:
+ raise TypeError("hasParam(): paramName must be a string")
@since("1.4.0")
def getOrDefault(self, param):
http://git-wip-us.apache.org/repos/asf/spark/blob/b3546738/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index e93a4e1..5fcfa9e 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -209,6 +209,11 @@ class ParamTests(PySparkTestCase):
self.assertEqual(maxIter.doc, "max number of iterations (>= 0).")
self.assertTrue(maxIter.parent == testParams.uid)
+ def test_hasparam(self):
+ testParams = TestParams()
+ self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
+ self.assertFalse(testParams.hasParam("notAParameter"))
+
def test_params(self):
testParams = TestParams()
maxIter = testParams.maxIter
@@ -218,7 +223,7 @@ class ParamTests(PySparkTestCase):
params = testParams.params
self.assertEqual(params, [inputCol, maxIter, seed])
- self.assertTrue(testParams.hasParam(maxIter))
+ self.assertTrue(testParams.hasParam(maxIter.name))
self.assertTrue(testParams.hasDefault(maxIter))
self.assertFalse(testParams.isSet(maxIter))
self.assertTrue(testParams.isDefined(maxIter))
@@ -227,7 +232,7 @@ class ParamTests(PySparkTestCase):
self.assertTrue(testParams.isSet(maxIter))
self.assertEqual(testParams.getMaxIter(), 100)
- self.assertTrue(testParams.hasParam(inputCol))
+ self.assertTrue(testParams.hasParam(inputCol.name))
self.assertFalse(testParams.hasDefault(inputCol))
self.assertFalse(testParams.isSet(inputCol))
self.assertFalse(testParams.isDefined(inputCol))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org