You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/10/06 21:43:32 UTC

spark git commit: [SPARK-10688] [ML] [PYSPARK] Python API for AFTSurvivalRegression

Repository: spark
Updated Branches:
  refs/heads/master e97836015 -> 5952bdb7d


[SPARK-10688] [ML] [PYSPARK] Python API for AFTSurvivalRegression

Implement Python API for AFTSurvivalRegression

Author: vectorijk <ji...@gmail.com>

Closes #8926 from vectorijk/spark-10688.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5952bdb7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5952bdb7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5952bdb7

Branch: refs/heads/master
Commit: 5952bdb7df20d007d59f82261095faca3822c6f6
Parents: e978360
Author: vectorijk <ji...@gmail.com>
Authored: Tue Oct 6 12:43:28 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Oct 6 12:43:28 2015 -0700

----------------------------------------------------------------------
 python/pyspark/ml/regression.py | 171 ++++++++++++++++++++++++++++++++++-
 1 file changed, 169 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5952bdb7/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 21d454f..a0f7f54 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -22,8 +22,10 @@ from pyspark.ml.param.shared import *
 from pyspark.mllib.common import inherit_doc
 
 
-__all__ = ['DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor',
-           'GBTRegressionModel', 'LinearRegression', 'LinearRegressionModel',
+__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
+           'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
+           'GBTRegressor', 'GBTRegressionModel',
+           'LinearRegression', 'LinearRegressionModel',
            'RandomForestRegressor', 'RandomForestRegressionModel']
 
 
@@ -609,6 +611,171 @@ class GBTRegressionModel(TreeEnsembleModels):
     """
 
 
+@inherit_doc
+class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
+                            HasFitIntercept, HasMaxIter, HasTol):
+    """
+    Accelerated Failure Time (AFT) Model Survival Regression
+
+    Fit a parametric AFT survival regression model based on the Weibull distribution
+    of the survival time.
+
+    .. seealso:: `AFT Model <https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_
+
+    >>> from pyspark.mllib.linalg import Vectors
+    >>> df = sqlContext.createDataFrame([
+    ...     (1.0, Vectors.dense(1.0), 1.0),
+    ...     (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])
+    >>> aftsr = AFTSurvivalRegression()
+    >>> model = aftsr.fit(df)
+    >>> model.predict(Vectors.dense(6.3))
+    1.0
+    >>> model.predictQuantiles(Vectors.dense(6.3))
+    DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052])
+    >>> model.transform(df).show()
+    +-----+---------+------+----------+
+    |label| features|censor|prediction|
+    +-----+---------+------+----------+
+    |  1.0|    [1.0]|   1.0|       1.0|
+    |  0.0|(1,[],[])|   0.0|       1.0|
+    +-----+---------+------+----------+
+    ...
+
+    .. versionadded:: 1.6.0
+    """
+
+    # a placeholder to make it appear in the generated doc
+    censorCol = Param(Params._dummy(), "censorCol",
+                      "censor column name. The value of this column could be 0 or 1. " +
+                      "If the value is 1, it means the event has occurred i.e. " +
+                      "uncensored; otherwise censored.")
+    quantileProbabilities = \
+        Param(Params._dummy(), "quantileProbabilities",
+              "quantile probabilities array. Values of the quantile probabilities array " +
+              "should be in the range (0, 1) and the array should be non-empty.")
+    quantilesCol = Param(Params._dummy(), "quantilesCol",
+                         "quantiles column name. This column will output quantiles of " +
+                         "corresponding quantileProbabilities if it is set.")
+
+    @keyword_only
+    def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+                 fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
+                 quantileProbabilities=None, quantilesCol=None):
+        """
+        __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+                 fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
+                 quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
+                 quantilesCol=None):
+        """
+        super(AFTSurvivalRegression, self).__init__()
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid)
+        #: Param for censor column name
+        self.censorCol = Param(self,  "censorCol",
+                               "censor column name. The value of this column could be 0 or 1. " +
+                               "If the value is 1, it means the event has occurred i.e. " +
+                               "uncensored; otherwise censored.")
+        #: Param for quantile probabilities array
+        self.quantileProbabilities = \
+            Param(self, "quantileProbabilities",
+                  "quantile probabilities array. Values of the quantile probabilities array " +
+                  "should be in the range (0, 1) and the array should be non-empty.")
+        #: Param for quantiles column name
+        self.quantilesCol = Param(self, "quantilesCol",
+                                  "quantiles column name. This column will output quantiles of " +
+                                  "corresponding quantileProbabilities if it is set.")
+        self._setDefault(censorCol="censor",
+                         quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
+        kwargs = self.__init__._input_kwargs
+        self.setParams(**kwargs)
+
+    @keyword_only
+    @since("1.6.0")
+    def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+                  fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
+                  quantileProbabilities=None, quantilesCol=None):
+        """
+        setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+                  fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
+                  quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
+                  quantilesCol=None):
+        """
+        kwargs = self.setParams._input_kwargs
+        if quantileProbabilities is None:
+            return self._set(**kwargs).setQuantileProbabilities([0.01, 0.05, 0.1, 0.25, 0.5,
+                                                                 0.75, 0.9, 0.95, 0.99])
+        else:
+            return self._set(**kwargs)
+
+    def _create_model(self, java_model):
+        return AFTSurvivalRegressionModel(java_model)
+
+    @since("1.6.0")
+    def setCensorCol(self, value):
+        """
+        Sets the value of :py:attr:`censorCol`.
+        """
+        self._paramMap[self.censorCol] = value
+        return self
+
+    @since("1.6.0")
+    def getCensorCol(self):
+        """
+        Gets the value of censorCol or its default value.
+        """
+        return self.getOrDefault(self.censorCol)
+
+    @since("1.6.0")
+    def setQuantileProbabilities(self, value):
+        """
+        Sets the value of :py:attr:`quantileProbabilities`.
+        """
+        self._paramMap[self.quantileProbabilities] = value
+        return self
+
+    @since("1.6.0")
+    def getQuantileProbabilities(self):
+        """
+        Gets the value of quantileProbabilities or its default value.
+        """
+        return self.getOrDefault(self.quantileProbabilities)
+
+    @since("1.6.0")
+    def setQuantilesCol(self, value):
+        """
+        Sets the value of :py:attr:`quantilesCol`.
+        """
+        self._paramMap[self.quantilesCol] = value
+        return self
+
+    @since("1.6.0")
+    def getQuantilesCol(self):
+        """
+        Gets the value of quantilesCol or its default value.
+        """
+        return self.getOrDefault(self.quantilesCol)
+
+
+class AFTSurvivalRegressionModel(JavaModel):
+    """
+    Model fitted by AFTSurvivalRegression.
+
+    .. versionadded:: 1.6.0
+    """
+
+    def predictQuantiles(self, features):
+        """
+        Predicted Quantiles
+        """
+        return self._call_java("predictQuantiles", features)
+
+    def predict(self, features):
+        """
+        Predicted value
+        """
+        return self._call_java("predict", features)
+
+
 if __name__ == "__main__":
     import doctest
     from pyspark.context import SparkContext


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org