You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2017/08/23 00:41:04 UTC

spark git commit: [SPARK-10931][ML][PYSPARK] PySpark Models Copy Param Values from Estimator

Repository: spark
Updated Branches:
  refs/heads/master d56c26210 -> 41bb1ddc6


[SPARK-10931][ML][PYSPARK] PySpark Models Copy Param Values from Estimator

## What changes were proposed in this pull request?

Added call to copy values of Params from Estimator to Model after fit in PySpark ML.  This will copy values for any params that are also defined in the Model.  Since currently most Models do not define the same params from the Estimator, also added method to create new Params from looking at the Java object if they do not exist in the Python object.  This is a temporary fix that can be removed once the PySpark models properly define the params themselves.

## How was this patch tested?

Refactored the `check_params` test to optionally check if the model params for Python and Java match and added this check to an existing fitted model that shares params between Estimator and Model.

Author: Bryan Cutler <cu...@gmail.com>

Closes #17849 from BryanCutler/pyspark-models-own-params-SPARK-10931.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/41bb1ddc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/41bb1ddc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/41bb1ddc

Branch: refs/heads/master
Commit: 41bb1ddc63298c004bb6a6bb6fff9fd4f6e44792
Parents: d56c262
Author: Bryan Cutler <cu...@gmail.com>
Authored: Tue Aug 22 17:40:50 2017 -0700
Committer: Holden Karau <ho...@us.ibm.com>
Committed: Tue Aug 22 17:40:50 2017 -0700

----------------------------------------------------------------------
 python/pyspark/ml/classification.py |  2 +-
 python/pyspark/ml/clustering.py     |  8 ++-
 python/pyspark/ml/tests.py          | 87 +++++++++++++++++++-------------
 python/pyspark/ml/wrapper.py        | 32 +++++++++++-
 4 files changed, 91 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/41bb1ddc/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index bccf8e7..235cee4 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -1434,7 +1434,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
         super(MultilayerPerceptronClassifier, self).__init__()
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
-        self._setDefault(maxIter=100, tol=1E-4, blockSize=128, stepSize=0.03, solver="l-bfgs")
+        self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/41bb1ddc/python/pyspark/ml/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 88ac7e2..66fb005 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -745,7 +745,13 @@ class DistributedLDAModel(LDAModel, JavaMLReadable, JavaMLWritable):
 
         WARNING: This involves collecting a large :py:func:`topicsMatrix` to the driver.
         """
-        return LocalLDAModel(self._call_java("toLocal"))
+        model = LocalLDAModel(self._call_java("toLocal"))
+
+        # SPARK-10931: Temporary fix to be removed once LDAModel defines Params
+        model._create_params_from_java()
+        model._transfer_params_from_java()
+
+        return model
 
     @since("2.0.0")
     def trainingLogLikelihood(self):

http://git-wip-us.apache.org/repos/asf/spark/blob/41bb1ddc/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 0495973..6076b3c 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -455,6 +455,54 @@ class ParamTests(PySparkTestCase):
             LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
         )
 
+    @staticmethod
+    def check_params(test_self, py_stage, check_params_exist=True):
+        """
+        Checks common requirements for Params.params:
+          - set of params exist in Java and Python and are ordered by names
+          - param parent has the same UID as the object's UID
+          - default param value from Java matches value in Python
+          - optionally check if all params from Java also exist in Python
+        """
+        py_stage_str = "%s %s" % (type(py_stage), py_stage)
+        if not hasattr(py_stage, "_to_java"):
+            return
+        java_stage = py_stage._to_java()
+        if java_stage is None:
+            return
+        test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str)
+        if check_params_exist:
+            param_names = [p.name for p in py_stage.params]
+            java_params = list(java_stage.params())
+            java_param_names = [jp.name() for jp in java_params]
+            test_self.assertEqual(
+                param_names, sorted(java_param_names),
+                "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s"
+                % (py_stage_str, java_param_names, param_names))
+        for p in py_stage.params:
+            test_self.assertEqual(p.parent, py_stage.uid)
+            java_param = java_stage.getParam(p.name)
+            py_has_default = py_stage.hasDefault(p)
+            java_has_default = java_stage.hasDefault(java_param)
+            test_self.assertEqual(py_has_default, java_has_default,
+                                  "Default value mismatch of param %s for Params %s"
+                                  % (p.name, str(py_stage)))
+            if py_has_default:
+                if p.name == "seed":
+                    continue  # Random seeds between Spark and PySpark are different
+                java_default = _java2py(test_self.sc,
+                                        java_stage.clear(java_param).getOrDefault(java_param))
+                py_stage._clear(p)
+                py_default = py_stage.getOrDefault(p)
+                # equality test for NaN is always False
+                if isinstance(java_default, float) and np.isnan(java_default):
+                    java_default = "NaN"
+                    py_default = "NaN" if np.isnan(py_default) else "not NaN"
+                test_self.assertEqual(
+                    java_default, py_default,
+                    "Java default %s != python default %s of param %s for Params %s"
+                    % (str(java_default), str(py_default), p.name, str(py_stage)))
+
 
 class EvaluatorTests(SparkSessionTestCase):
 
@@ -511,6 +559,8 @@ class FeatureTests(SparkSessionTestCase):
                          "Model should inherit the UID from its parent estimator.")
         output = idf0m.transform(dataset)
         self.assertIsNotNone(output.head().idf)
+        # Test that parameters transferred to Python Model
+        ParamTests.check_params(self, idf0m)
 
     def test_ngram(self):
         dataset = self.spark.createDataFrame([
@@ -1656,40 +1706,6 @@ class DefaultValuesTests(PySparkTestCase):
     those in their Scala counterparts.
     """
 
-    def check_params(self, py_stage):
-        import pyspark.ml.feature
-        if not hasattr(py_stage, "_to_java"):
-            return
-        java_stage = py_stage._to_java()
-        if java_stage is None:
-            return
-        for p in py_stage.params:
-            java_param = java_stage.getParam(p.name)
-            py_has_default = py_stage.hasDefault(p)
-            java_has_default = java_stage.hasDefault(java_param)
-            self.assertEqual(py_has_default, java_has_default,
-                             "Default value mismatch of param %s for Params %s"
-                             % (p.name, str(py_stage)))
-            if py_has_default:
-                if p.name == "seed":
-                    return  # Random seeds between Spark and PySpark are different
-                java_default =\
-                    _java2py(self.sc, java_stage.clear(java_param).getOrDefault(java_param))
-                py_stage._clear(p)
-                py_default = py_stage.getOrDefault(p)
-                if isinstance(py_stage, pyspark.ml.feature.Imputer) and p.name == "missingValue":
-                    # SPARK-15040 - default value for Imputer param 'missingValue' is NaN,
-                    # and NaN != NaN, so handle it specially here
-                    import math
-                    self.assertTrue(math.isnan(java_default) and math.isnan(py_default),
-                                    "Java default %s and python default %s are not both NaN for "
-                                    "param %s for Params %s"
-                                    % (str(java_default), str(py_default), p.name, str(py_stage)))
-                    return
-                self.assertEqual(java_default, py_default,
-                                 "Java default %s != python default %s of param %s for Params %s"
-                                 % (str(java_default), str(py_default), p.name, str(py_stage)))
-
     def test_java_params(self):
         import pyspark.ml.feature
         import pyspark.ml.classification
@@ -1703,7 +1719,8 @@ class DefaultValuesTests(PySparkTestCase):
             for name, cls in inspect.getmembers(module, inspect.isclass):
                 if not name.endswith('Model') and issubclass(cls, JavaParams)\
                         and not inspect.isabstract(cls):
-                    self.check_params(cls())
+                    # NOTE: disable check_params_exist until there is parity with Scala API
+                    ParamTests.check_params(self, cls(), check_params_exist=False)
 
 
 def _squared_distance(a, b):

http://git-wip-us.apache.org/repos/asf/spark/blob/41bb1ddc/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index ee6301e..0f846fb 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -135,6 +135,20 @@ class JavaParams(JavaWrapper, Params):
                 paramMap.put([pair])
         return paramMap
 
+    def _create_params_from_java(self):
+        """
+        SPARK-10931: Temporary fix to create params that are defined in the Java obj but not here
+        """
+        java_params = list(self._java_obj.params())
+        from pyspark.ml.param import Param
+        for java_param in java_params:
+            java_param_name = java_param.name()
+            if not hasattr(self, java_param_name):
+                param = Param(self, java_param_name, java_param.doc())
+                setattr(param, "created_from_java_param", True)
+                setattr(self, java_param_name, param)
+                self._params = None  # need to reset so self.params will discover new params
+
     def _transfer_params_from_java(self):
         """
         Transforms the embedded params from the companion Java object.
@@ -147,6 +161,10 @@ class JavaParams(JavaWrapper, Params):
                 if self._java_obj.isSet(java_param):
                     value = _java2py(sc, self._java_obj.getOrDefault(java_param))
                     self._set(**{param.name: value})
+                # SPARK-10931: Temporary fix for params that have a default in Java
+                if self._java_obj.hasDefault(java_param) and not self.isDefined(param):
+                    value = _java2py(sc, self._java_obj.getDefault(java_param)).get()
+                    self._setDefault(**{param.name: value})
 
     def _transfer_param_map_from_java(self, javaParamMap):
         """
@@ -204,6 +222,11 @@ class JavaParams(JavaWrapper, Params):
             # Load information from java_stage to the instance.
             py_stage = py_type()
             py_stage._java_obj = java_stage
+
+            # SPARK-10931: Temporary fix so that persisted models would own params from Estimator
+            if issubclass(py_type, JavaModel):
+                py_stage._create_params_from_java()
+
             py_stage._resetUid(java_stage.uid())
             py_stage._transfer_params_from_java()
         elif hasattr(py_type, "_from_java"):
@@ -263,7 +286,8 @@ class JavaEstimator(JavaParams, Estimator):
 
     def _fit(self, dataset):
         java_model = self._fit_java(dataset)
-        return self._create_model(java_model)
+        model = self._create_model(java_model)
+        return self._copyValues(model)
 
 
 @inherit_doc
@@ -307,4 +331,10 @@ class JavaModel(JavaTransformer, Model):
         """
         super(JavaModel, self).__init__(java_model)
         if java_model is not None:
+
+            # SPARK-10931: This is a temporary fix to allow models to own params
+            # from estimators. Eventually, these params should be in models through
+            # using common base classes between estimators and models.
+            self._create_params_from_java()
+
             self._resetUid(java_model.uid())


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org