You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/02/26 06:09:04 UTC

spark git commit: [SPARK-13033] [ML] [PYSPARK] Add import/export for ml.regression

Repository: spark
Updated Branches:
  refs/heads/master 90d07154c -> f3be369ef


[SPARK-13033] [ML] [PYSPARK] Add import/export for ml.regression

Add export/import for all estimators and transformers(which have Scala implementation) under pyspark/ml/regression.py.

yanboliang Please help to review.
For doctest, I though it's enough to add one since it's common usage. But I can add to all if we want it.

Author: Tommy YU <tu...@163.com>

Closes #11000 from Wenpei/spark-13033-ml.regression-exprot-import and squashes the following commits:

3646b36 [Tommy YU] address review comments
9cddc98 [Tommy YU] change base on review and pr 11197
cc61d9d [Tommy YU] remove default parameter set
19535d4 [Tommy YU] add export/import to regression
44a9dc2 [Tommy YU] add import/export for ml.regression


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f3be369e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f3be369e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f3be369e

Branch: refs/heads/master
Commit: f3be369ef7b78944ce9cafc94b19df52772cea58
Parents: 90d0715
Author: Tommy YU <tu...@163.com>
Authored: Thu Feb 25 21:09:02 2016 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Feb 25 21:09:02 2016 -0800

----------------------------------------------------------------------
 python/pyspark/ml/regression.py | 34 ++++++++++++++++++++++++++++++----
 1 file changed, 30 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f3be369e/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index de4a751..6b994fe 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -154,7 +154,7 @@ class LinearRegressionModel(JavaModel, MLWritable, MLReadable):
 
 @inherit_doc
 class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
-                         HasWeightCol):
+                         HasWeightCol, MLWritable, MLReadable):
     """
     .. note:: Experimental
 
@@ -172,6 +172,18 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
     0.0
     >>> model.boundaries
     DenseVector([0.0, 1.0])
+    >>> ir_path = temp_path + "/ir"
+    >>> ir.save(ir_path)
+    >>> ir2 = IsotonicRegression.load(ir_path)
+    >>> ir2.getIsotonic()
+    True
+    >>> model_path = temp_path + "/ir_model"
+    >>> model.save(model_path)
+    >>> model2 = IsotonicRegressionModel.load(model_path)
+    >>> model.boundaries == model2.boundaries
+    True
+    >>> model.predictions == model2.predictions
+    True
     """
 
     isotonic = \
@@ -237,7 +249,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
         return self.getOrDefault(self.featureIndex)
 
 
-class IsotonicRegressionModel(JavaModel):
+class IsotonicRegressionModel(JavaModel, MLWritable, MLReadable):
     """
     .. note:: Experimental
 
@@ -663,7 +675,7 @@ class GBTRegressionModel(TreeEnsembleModels):
 
 @inherit_doc
 class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
-                            HasFitIntercept, HasMaxIter, HasTol):
+                            HasFitIntercept, HasMaxIter, HasTol, MLWritable, MLReadable):
     """
     Accelerated Failure Time (AFT) Model Survival Regression
 
@@ -690,6 +702,20 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
     |  0.0|(1,[],[])|   0.0|       1.0|
     +-----+---------+------+----------+
     ...
+    >>> aftsr_path = temp_path + "/aftsr"
+    >>> aftsr.save(aftsr_path)
+    >>> aftsr2 = AFTSurvivalRegression.load(aftsr_path)
+    >>> aftsr2.getMaxIter()
+    100
+    >>> model_path = temp_path + "/aftsr_model"
+    >>> model.save(model_path)
+    >>> model2 = AFTSurvivalRegressionModel.load(model_path)
+    >>> model.coefficients == model2.coefficients
+    True
+    >>> model.intercept == model2.intercept
+    True
+    >>> model.scale == model2.scale
+    True
 
     .. versionadded:: 1.6.0
     """
@@ -787,7 +813,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
         return self.getOrDefault(self.quantilesCol)
 
 
-class AFTSurvivalRegressionModel(JavaModel):
+class AFTSurvivalRegressionModel(JavaModel, MLWritable, MLReadable):
     """
     Model fitted by AFTSurvivalRegression.
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org