You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/10/28 16:54:24 UTC

spark git commit: [SPARK-11367][ML][PYSPARK] Python LinearRegression should support setting solver

Repository: spark
Updated Branches:
  refs/heads/master fba9e9545 -> f92b7b98e


[SPARK-11367][ML][PYSPARK] Python LinearRegression should support setting solver

[SPARK-10668](https://issues.apache.org/jira/browse/SPARK-10668) has provided ```WeightedLeastSquares``` solver("normal") in ```LinearRegression``` with L2 regularization in Scala and R, Python ML ```LinearRegression``` should also support setting solver("auto", "normal", "l-bfgs")

Author: Yanbo Liang <yb...@gmail.com>

Closes #9328 from yanboliang/spark-11367.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f92b7b98
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f92b7b98
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f92b7b98

Branch: refs/heads/master
Commit: f92b7b98e9998a6069996cc66ca26cbfa695fce5
Parents: fba9e95
Author: Yanbo Liang <yb...@gmail.com>
Authored: Wed Oct 28 08:54:20 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Oct 28 08:54:20 2015 -0700

----------------------------------------------------------------------
 .../pyspark/ml/param/_shared_params_code_gen.py |  4 ++-
 python/pyspark/ml/param/shared.py               | 28 ++++++++++++++++++++
 python/pyspark/ml/regression.py                 | 27 +++++--------------
 3 files changed, 37 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f92b7b98/python/pyspark/ml/param/_shared_params_code_gen.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 7143d56..070c5db 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -135,7 +135,9 @@ if __name__ == "__main__":
          "values >= 0. The class with largest value p/t is predicted, where p is the original " +
          "probability of that class and t is the class' threshold.", None),
         ("weightCol", "weight column name. If this is not set or empty, we treat " +
-         "all instance weights as 1.0.", None)]
+         "all instance weights as 1.0.", None),
+        ("solver", "the solver algorithm for optimization. If this is not set or empty, " +
+         "default value is 'auto'.", "'auto'")]
 
     code = []
     for name, doc, defaultValueStr in shared:

http://git-wip-us.apache.org/repos/asf/spark/blob/f92b7b98/python/pyspark/ml/param/shared.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 3a58ac8..4bdf2a8 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -597,6 +597,34 @@ class HasWeightCol(Params):
         return self.getOrDefault(self.weightCol)
 
 
+class HasSolver(Params):
+    """
+    Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.
+    """
+
+    # a placeholder to make it appear in the generated doc
+    solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.")
+
+    def __init__(self):
+        super(HasSolver, self).__init__()
+        #: param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.
+        self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.")
+        self._setDefault(solver='auto')
+
+    def setSolver(self, value):
+        """
+        Sets the value of :py:attr:`solver`.
+        """
+        self._paramMap[self.solver] = value
+        return self
+
+    def getSolver(self):
+        """
+        Gets the value of solver or its default value.
+        """
+        return self.getOrDefault(self.solver)
+
+
 class DecisionTreeParams(Params):
     """
     Mixin for Decision Tree parameters.

http://git-wip-us.apache.org/repos/asf/spark/blob/f92b7b98/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index eeb18b3..dc68815 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -33,7 +33,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
 @inherit_doc
 class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
                        HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
-                       HasStandardization):
+                       HasStandardization, HasSolver):
     """
     Linear regression.
 
@@ -50,7 +50,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
     >>> df = sqlContext.createDataFrame([
     ...     (1.0, Vectors.dense(1.0)),
     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
-    >>> lr = LinearRegression(maxIter=5, regParam=0.0)
+    >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal")
     >>> model = lr.fit(df)
     >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
@@ -73,11 +73,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
     @keyword_only
     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                  maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
-                 standardization=True):
+                 standardization=True, solver="auto"):
         """
         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
-                 standardization=True)
+                 standardization=True, solver="auto")
         """
         super(LinearRegression, self).__init__()
         self._java_obj = self._new_java_obj(
@@ -90,11 +90,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
     @since("1.4.0")
     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                   maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
-                  standardization=True):
+                  standardization=True, solver="auto"):
         """
         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
-                  standardization=True)
+                  standardization=True, solver="auto")
         Sets params for linear regression.
         """
         kwargs = self.setParams._input_kwargs
@@ -103,21 +103,6 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
     def _create_model(self, java_model):
         return LinearRegressionModel(java_model)
 
-    @since("1.4.0")
-    def setElasticNetParam(self, value):
-        """
-        Sets the value of :py:attr:`elasticNetParam`.
-        """
-        self._paramMap[self.elasticNetParam] = value
-        return self
-
-    @since("1.4.0")
-    def getElasticNetParam(self):
-        """
-        Gets the value of elasticNetParam or its default value.
-        """
-        return self.getOrDefault(self.elasticNetParam)
-
 
 class LinearRegressionModel(JavaModel):
     """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org