You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2020/07/11 15:40:44 UTC

[spark] branch branch-3.0 updated: [SPARK-32232][ML][PYSPARK] Make sure ML has the same default solver values between Scala and Python

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 8a6580c  [SPARK-32232][ML][PYSPARK] Make sure ML has the same default solver values between Scala and Python
8a6580c is described below

commit 8a6580cedf509b5f62175e4ed7b2e0882dd89976
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Sat Jul 11 10:37:26 2020 -0500

    [SPARK-32232][ML][PYSPARK] Make sure ML has the same default solver values between Scala and Python
    
    # What changes were proposed in this pull request?
    current problems:
    ```
            mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123)
            model = mlp.fit(df)
            path = tempfile.mkdtemp()
            model_path = path + "/mlp"
            model.save(model_path)
            model2 = MultilayerPerceptronClassificationModel.load(model_path)
            self.assertEqual(model2.getSolver(), "l-bfgs")    # this fails because model2.getSolver() returns 'auto'
            model2.transform(df)
            # this fails with Exception pyspark.sql.utils.IllegalArgumentException: MultilayerPerceptronClassifier_dec859ed24ec parameter solver given invalid value auto.
    ```
    FMClassifier/Regression and GeneralizedLinearRegression have the same problems.
    
    Here are the root cause of the problems:
    1. In HasSolver, both Scala and Python default solver to 'auto'
    
    2. On Scala side, mlp overrides the default of solver to 'l-bfgs', FMClassifier/Regression overrides the default of solver to 'adamW', and glr overrides the default of solver to 'irls'
    
    3. On Scala side, mlp overrides the default of solver in MultilayerPerceptronClassificationParams, so both MultilayerPerceptronClassification and MultilayerPerceptronClassificationModel have 'l-bfgs' as default
    
    4. On Python side, mlp overrides the default of solver in MultilayerPerceptronClassification, so it has default as 'l-bfgs', but MultilayerPerceptronClassificationModel doesn't override the default so it gets the default from HasSolver which is 'auto'. In theory, we don't care about the solver value or any other params values for MultilayerPerceptronClassificationModel, because we have the fitted model already. That's why on Python side, we never set default values for any of the XXXModel.
    
    5. when calling getSolver on the loaded mlp model, it calls this line of code underneath:
    ```
        def _transfer_params_from_java(self):
            """
            Transforms the embedded params from the companion Java object.
            """
            ......
                    # SPARK-14931: Only check set params back to avoid default params mismatch.
                    if self._java_obj.isSet(java_param):
                        value = _java2py(sc, self._java_obj.getOrDefault(java_param))
                        self._set(**{param.name: value})
            ......
    ```
    that's why model2.getSolver() returns 'auto'. The code doesn't get the default Scala value (in this case 'l-bfgs') to set to Python param, so it takes the default value (in this case 'auto') on Python side.
    
    6. when calling model2.transform(df), it calls this underneath:
    ```
        def _transfer_params_to_java(self):
            """
            Transforms the embedded params to the companion Java object.
            """
            ......
                if self.hasDefault(param):
                    pair = self._make_java_param_pair(param, self._defaultParamMap[param])
                    pair_defaults.append(pair)
            ......
    
    ```
    Again, it gets the Python default solver which is 'auto', and this caused the Exception
    
    7. Currently, on Scala side, for some of the algorithms, we set default values in the XXXParam, so both estimator and transformer get the default value. However, for some of the algorithms, we only set default in estimators, and the XXXModel doesn't get the default value. On Python side, we never set defaults for the XXXModel. This causes the default value inconsistency.
    
    8. My proposed solution: set default params in XXXParam for both Scala and Python, so both the estimator and transformer have the same default value for both Scala and Python. I currently only changed solver in this PR. If everyone is OK with the fix, I will change all the other params as well.
    
    I hope my explanation makes sense to your folks :)
    
    ### Why are the changes needed?
    Fix bug
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing and new tests
    
    Closes #29060 from huaxingao/solver_parity.
    
    Authored-by: Huaxin Gao <hu...@us.ibm.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
    (cherry picked from commit 99b4b062555329d5da968ad5dbd9e2b22a193a55)
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../apache/spark/ml/regression/FMRegressor.scala   | 14 ++---
 .../regression/GeneralizedLinearRegression.scala   |  9 +--
 python/pyspark/ml/classification.py                |  8 +--
 python/pyspark/ml/regression.py                    | 16 ++++--
 python/pyspark/ml/tests/test_persistence.py        | 65 +++++++++++++++++++++-
 5 files changed, 84 insertions(+), 28 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
index b017a1a..ce98817 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
@@ -113,6 +113,10 @@ private[ml] trait FactorizationMachinesParams extends PredictorParams
     "The solver algorithm for optimization. Supported options: " +
       s"${supportedSolvers.mkString(", ")}. (Default adamW)",
     ParamValidators.inArray[String](supportedSolvers))
+
+  setDefault(factorSize -> 8, fitIntercept -> true, fitLinear -> true, regParam -> 0.0,
+    miniBatchFraction -> 1.0, initStd -> 0.01, maxIter -> 100, stepSize -> 1.0, tol -> 1E-6,
+    solver -> AdamW)
 }
 
 private[ml] trait FactorizationMachines extends FactorizationMachinesParams {
@@ -309,7 +313,6 @@ class FMRegressor @Since("3.0.0") (
    */
   @Since("3.0.0")
   def setFactorSize(value: Int): this.type = set(factorSize, value)
-  setDefault(factorSize -> 8)
 
   /**
    * Set whether to fit intercept term.
@@ -319,7 +322,6 @@ class FMRegressor @Since("3.0.0") (
    */
   @Since("3.0.0")
   def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
-  setDefault(fitIntercept -> true)
 
   /**
    * Set whether to fit linear term.
@@ -329,7 +331,6 @@ class FMRegressor @Since("3.0.0") (
    */
   @Since("3.0.0")
   def setFitLinear(value: Boolean): this.type = set(fitLinear, value)
-  setDefault(fitLinear -> true)
 
   /**
    * Set the L2 regularization parameter.
@@ -339,7 +340,6 @@ class FMRegressor @Since("3.0.0") (
    */
   @Since("3.0.0")
   def setRegParam(value: Double): this.type = set(regParam, value)
-  setDefault(regParam -> 0.0)
 
   /**
    * Set the mini-batch fraction parameter.
@@ -349,7 +349,6 @@ class FMRegressor @Since("3.0.0") (
    */
   @Since("3.0.0")
   def setMiniBatchFraction(value: Double): this.type = set(miniBatchFraction, value)
-  setDefault(miniBatchFraction -> 1.0)
 
   /**
    * Set the standard deviation of initial coefficients.
@@ -359,7 +358,6 @@ class FMRegressor @Since("3.0.0") (
    */
   @Since("3.0.0")
   def setInitStd(value: Double): this.type = set(initStd, value)
-  setDefault(initStd -> 0.01)
 
   /**
    * Set the maximum number of iterations.
@@ -369,7 +367,6 @@ class FMRegressor @Since("3.0.0") (
    */
   @Since("3.0.0")
   def setMaxIter(value: Int): this.type = set(maxIter, value)
-  setDefault(maxIter -> 100)
 
   /**
    * Set the initial step size for the first step (like learning rate).
@@ -379,7 +376,6 @@ class FMRegressor @Since("3.0.0") (
    */
   @Since("3.0.0")
   def setStepSize(value: Double): this.type = set(stepSize, value)
-  setDefault(stepSize -> 1.0)
 
   /**
    * Set the convergence tolerance of iterations.
@@ -389,7 +385,6 @@ class FMRegressor @Since("3.0.0") (
    */
   @Since("3.0.0")
   def setTol(value: Double): this.type = set(tol, value)
-  setDefault(tol -> 1E-6)
 
   /**
    * Set the solver algorithm used for optimization.
@@ -400,7 +395,6 @@ class FMRegressor @Since("3.0.0") (
    */
   @Since("3.0.0")
   def setSolver(value: String): this.type = set(solver, value)
-  setDefault(solver -> AdamW)
 
   /**
    * Set the random seed for weight initialization.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index fa41a98..8fda44c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -180,6 +180,9 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
       s"${supportedSolvers.mkString(", ")}. (Default irls)",
     ParamValidators.inArray[String](supportedSolvers))
 
+  setDefault(family -> Gaussian.name, variancePower -> 0.0, maxIter -> 25, tol -> 1E-6,
+    regParam -> 0.0, solver -> IRLS)
+
   @Since("2.0.0")
   override def validateAndTransformSchema(
       schema: StructType,
@@ -256,7 +259,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
    */
   @Since("2.0.0")
   def setFamily(value: String): this.type = set(family, value)
-  setDefault(family -> Gaussian.name)
 
   /**
    * Sets the value of param [[variancePower]].
@@ -267,7 +269,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
    */
   @Since("2.2.0")
   def setVariancePower(value: Double): this.type = set(variancePower, value)
-  setDefault(variancePower -> 0.0)
 
   /**
    * Sets the value of param [[linkPower]].
@@ -304,7 +305,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
    */
   @Since("2.0.0")
   def setMaxIter(value: Int): this.type = set(maxIter, value)
-  setDefault(maxIter -> 25)
 
   /**
    * Sets the convergence tolerance of iterations.
@@ -315,7 +315,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
    */
   @Since("2.0.0")
   def setTol(value: Double): this.type = set(tol, value)
-  setDefault(tol -> 1E-6)
 
   /**
    * Sets the regularization parameter for L2 regularization.
@@ -331,7 +330,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
    */
   @Since("2.0.0")
   def setRegParam(value: Double): this.type = set(regParam, value)
-  setDefault(regParam -> 0.0)
 
   /**
    * Sets the value of param [[weightCol]].
@@ -363,7 +361,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
    */
   @Since("2.0.0")
   def setSolver(value: String): this.type = set(solver, value)
-  setDefault(solver -> IRLS)
 
   /**
    * Sets the link prediction (linear predictor) column name.
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 424e16a..369761e 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -2172,6 +2172,10 @@ class _MultilayerPerceptronParams(_JavaProbabilisticClassifierParams, HasSeed, H
     initialWeights = Param(Params._dummy(), "initialWeights", "The initial weights of the model.",
                            typeConverter=TypeConverters.toVector)
 
+    def __init__(self):
+        super(_MultilayerPerceptronParams, self).__init__()
+        self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
+
     @since("1.6.0")
     def getLayers(self):
         """
@@ -2275,7 +2279,6 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer
         super(MultilayerPerceptronClassifier, self).__init__()
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
-        self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
@@ -2871,9 +2874,6 @@ class FMClassifier(JavaProbabilisticClassifier, _FactorizationMachinesParams, Ja
         super(FMClassifier, self).__init__()
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.classification.FMClassifier", self.uid)
-        self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
-                         miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
-                         tol=1e-6, solver="adamW")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index f367bb8..9c3c1e6 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1847,6 +1847,11 @@ class _GeneralizedLinearRegressionParams(_JavaPredictorParams, HasFitIntercept,
                       "or empty, we treat all instance offsets as 0.0",
                       typeConverter=TypeConverters.toString)
 
+    def __init__(self):
+        super(_GeneralizedLinearRegressionParams, self).__init__()
+        self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
+                         variancePower=0.0, aggregationDepth=2)
+
     @since("2.0.0")
     def getFamily(self):
         """
@@ -1979,8 +1984,6 @@ class GeneralizedLinearRegression(JavaRegressor, _GeneralizedLinearRegressionPar
         super(GeneralizedLinearRegression, self).__init__()
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
-        self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
-                         variancePower=0.0, aggregationDepth=2)
         kwargs = self._input_kwargs
 
         self.setParams(**kwargs)
@@ -2354,6 +2357,12 @@ class _FactorizationMachinesParams(_JavaPredictorParams, HasMaxIter, HasStepSize
     solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
                    "options: gd, adamW. (Default adamW)", typeConverter=TypeConverters.toString)
 
+    def __init__(self):
+        super(_FactorizationMachinesParams, self).__init__()
+        self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
+                         miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
+                         tol=1e-6, solver="adamW")
+
     @since("3.0.0")
     def getFactorSize(self):
         """
@@ -2445,9 +2454,6 @@ class FMRegressor(JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, J
         super(FMRegressor, self).__init__()
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.regression.FMRegressor", self.uid)
-        self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
-                         miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
-                         tol=1e-6, solver="adamW")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py
index d4edcc2..2f6d451 100644
--- a/python/pyspark/ml/tests/test_persistence.py
+++ b/python/pyspark/ml/tests/test_persistence.py
@@ -21,19 +21,78 @@ import tempfile
 import unittest
 
 from pyspark.ml import Transformer
-from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \
-    OneVsRestModel
+from pyspark.ml.classification import DecisionTreeClassifier, FMClassifier, \
+    FMClassificationModel, LogisticRegression, MultilayerPerceptronClassifier, \
+    MultilayerPerceptronClassificationModel, OneVsRest, OneVsRestModel
 from pyspark.ml.clustering import KMeans
 from pyspark.ml.feature import Binarizer, HashingTF, PCA
 from pyspark.ml.linalg import Vectors
 from pyspark.ml.param import Params
 from pyspark.ml.pipeline import Pipeline, PipelineModel
-from pyspark.ml.regression import DecisionTreeRegressor, LinearRegression
+from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \
+    GeneralizedLinearRegressionModel, \
+    LinearRegression
 from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWriter
 from pyspark.ml.wrapper import JavaParams
 from pyspark.testing.mlutils import MockUnaryTransformer, SparkSessionTestCase
 
 
+class TestDefaultSolver(SparkSessionTestCase):
+
+    def test_multilayer_load(self):
+        df = self.spark.createDataFrame([(0.0, Vectors.dense([0.0, 0.0])),
+                                         (1.0, Vectors.dense([0.0, 1.0])),
+                                         (1.0, Vectors.dense([1.0, 0.0])),
+                                         (0.0, Vectors.dense([1.0, 1.0]))],
+                                        ["label",  "features"])
+
+        mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123)
+        model = mlp.fit(df)
+        self.assertEqual(model.getSolver(), "l-bfgs")
+        transformed1 = model.transform(df)
+        path = tempfile.mkdtemp()
+        model_path = path + "/mlp"
+        model.save(model_path)
+        model2 = MultilayerPerceptronClassificationModel.load(model_path)
+        self.assertEqual(model2.getSolver(), "l-bfgs")
+        transformed2 = model2.transform(df)
+        self.assertEqual(transformed1.take(4), transformed2.take(4))
+
+    def test_fm_load(self):
+        df = self.spark.createDataFrame([(1.0, Vectors.dense(1.0)),
+                                         (0.0, Vectors.sparse(1, [], []))],
+                                        ["label",  "features"])
+        fm = FMClassifier(factorSize=2, maxIter=50, stepSize=2.0)
+        model = fm.fit(df)
+        self.assertEqual(model.getSolver(), "adamW")
+        transformed1 = model.transform(df)
+        path = tempfile.mkdtemp()
+        model_path = path + "/fm"
+        model.save(model_path)
+        model2 = FMClassificationModel.load(model_path)
+        self.assertEqual(model2.getSolver(), "adamW")
+        transformed2 = model2.transform(df)
+        self.assertEqual(transformed1.take(2), transformed2.take(2))
+
+    def test_glr_load(self):
+        df = self.spark.createDataFrame([(1.0, Vectors.dense(0.0, 0.0)),
+                                         (1.0, Vectors.dense(1.0, 2.0)),
+                                         (2.0, Vectors.dense(0.0, 0.0)),
+                                         (2.0, Vectors.dense(1.0, 1.0))],
+                                        ["label",  "features"])
+        glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
+        model = glr.fit(df)
+        self.assertEqual(model.getSolver(), "irls")
+        transformed1 = model.transform(df)
+        path = tempfile.mkdtemp()
+        model_path = path + "/glr"
+        model.save(model_path)
+        model2 = GeneralizedLinearRegressionModel.load(model_path)
+        self.assertEqual(model2.getSolver(), "irls")
+        transformed2 = model2.transform(df)
+        self.assertEqual(transformed1.take(4), transformed2.take(4))
+
+
 class PersistenceTest(SparkSessionTestCase):
 
     def test_linear_regression(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org