You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/10/06 21:43:32 UTC
spark git commit: [SPARK-10688] [ML] [PYSPARK] Python API for
AFTSurvivalRegression
Repository: spark
Updated Branches:
refs/heads/master e97836015 -> 5952bdb7d
[SPARK-10688] [ML] [PYSPARK] Python API for AFTSurvivalRegression
Implement Python API for AFTSurvivalRegression
Author: vectorijk <ji...@gmail.com>
Closes #8926 from vectorijk/spark-10688.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5952bdb7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5952bdb7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5952bdb7
Branch: refs/heads/master
Commit: 5952bdb7df20d007d59f82261095faca3822c6f6
Parents: e978360
Author: vectorijk <ji...@gmail.com>
Authored: Tue Oct 6 12:43:28 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Oct 6 12:43:28 2015 -0700
----------------------------------------------------------------------
python/pyspark/ml/regression.py | 171 ++++++++++++++++++++++++++++++++++-
1 file changed, 169 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5952bdb7/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 21d454f..a0f7f54 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -22,8 +22,10 @@ from pyspark.ml.param.shared import *
from pyspark.mllib.common import inherit_doc
-__all__ = ['DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor',
- 'GBTRegressionModel', 'LinearRegression', 'LinearRegressionModel',
+__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
+ 'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
+ 'GBTRegressor', 'GBTRegressionModel',
+ 'LinearRegression', 'LinearRegressionModel',
'RandomForestRegressor', 'RandomForestRegressionModel']
@@ -609,6 +611,171 @@ class GBTRegressionModel(TreeEnsembleModels):
"""
+@inherit_doc
+class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
+ HasFitIntercept, HasMaxIter, HasTol):
+ """
+ Accelerated Failure Time (AFT) Model Survival Regression
+
+ Fit a parametric AFT survival regression model based on the Weibull distribution
+ of the survival time.
+
+ .. seealso:: `AFT Model <https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_
+
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> df = sqlContext.createDataFrame([
+ ... (1.0, Vectors.dense(1.0), 1.0),
+ ... (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])
+ >>> aftsr = AFTSurvivalRegression()
+ >>> model = aftsr.fit(df)
+ >>> model.predict(Vectors.dense(6.3))
+ 1.0
+ >>> model.predictQuantiles(Vectors.dense(6.3))
+ DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052])
+ >>> model.transform(df).show()
+ +-----+---------+------+----------+
+ |label| features|censor|prediction|
+ +-----+---------+------+----------+
+ | 1.0| [1.0]| 1.0| 1.0|
+ | 0.0|(1,[],[])| 0.0| 1.0|
+ +-----+---------+------+----------+
+ ...
+
+ .. versionadded:: 1.6.0
+ """
+
+ # a placeholder to make it appear in the generated doc
+ censorCol = Param(Params._dummy(), "censorCol",
+ "censor column name. The value of this column could be 0 or 1. " +
+ "If the value is 1, it means the event has occurred i.e. " +
+ "uncensored; otherwise censored.")
+ quantileProbabilities = \
+ Param(Params._dummy(), "quantileProbabilities",
+ "quantile probabilities array. Values of the quantile probabilities array " +
+ "should be in the range (0, 1) and the array should be non-empty.")
+ quantilesCol = Param(Params._dummy(), "quantilesCol",
+ "quantiles column name. This column will output quantiles of " +
+ "corresponding quantileProbabilities if it is set.")
+
+ @keyword_only
+ def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
+ quantileProbabilities=None, quantilesCol=None):
+ """
+ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
+ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
+ quantilesCol=None):
+ """
+ super(AFTSurvivalRegression, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid)
+ #: Param for censor column name
+ self.censorCol = Param(self, "censorCol",
+ "censor column name. The value of this column could be 0 or 1. " +
+ "If the value is 1, it means the event has occurred i.e. " +
+ "uncensored; otherwise censored.")
+ #: Param for quantile probabilities array
+ self.quantileProbabilities = \
+ Param(self, "quantileProbabilities",
+ "quantile probabilities array. Values of the quantile probabilities array " +
+ "should be in the range (0, 1) and the array should be non-empty.")
+ #: Param for quantiles column name
+ self.quantilesCol = Param(self, "quantilesCol",
+ "quantiles column name. This column will output quantiles of " +
+ "corresponding quantileProbabilities if it is set.")
+ self._setDefault(censorCol="censor",
+ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ @since("1.6.0")
+ def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
+ quantileProbabilities=None, quantilesCol=None):
+ """
+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
+ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
+ quantilesCol=None):
+ """
+ kwargs = self.setParams._input_kwargs
+ if quantileProbabilities is None:
+ return self._set(**kwargs).setQuantileProbabilities([0.01, 0.05, 0.1, 0.25, 0.5,
+ 0.75, 0.9, 0.95, 0.99])
+ else:
+ return self._set(**kwargs)
+
+ def _create_model(self, java_model):
+ return AFTSurvivalRegressionModel(java_model)
+
+ @since("1.6.0")
+ def setCensorCol(self, value):
+ """
+ Sets the value of :py:attr:`censorCol`.
+ """
+ self._paramMap[self.censorCol] = value
+ return self
+
+ @since("1.6.0")
+ def getCensorCol(self):
+ """
+ Gets the value of censorCol or its default value.
+ """
+ return self.getOrDefault(self.censorCol)
+
+ @since("1.6.0")
+ def setQuantileProbabilities(self, value):
+ """
+ Sets the value of :py:attr:`quantileProbabilities`.
+ """
+ self._paramMap[self.quantileProbabilities] = value
+ return self
+
+ @since("1.6.0")
+ def getQuantileProbabilities(self):
+ """
+ Gets the value of quantileProbabilities or its default value.
+ """
+ return self.getOrDefault(self.quantileProbabilities)
+
+ @since("1.6.0")
+ def setQuantilesCol(self, value):
+ """
+ Sets the value of :py:attr:`quantilesCol`.
+ """
+ self._paramMap[self.quantilesCol] = value
+ return self
+
+ @since("1.6.0")
+ def getQuantilesCol(self):
+ """
+ Gets the value of quantilesCol or its default value.
+ """
+ return self.getOrDefault(self.quantilesCol)
+
+
+class AFTSurvivalRegressionModel(JavaModel):
+ """
+ Model fitted by AFTSurvivalRegression.
+
+ .. versionadded:: 1.6.0
+ """
+
+ def predictQuantiles(self, features):
+ """
+ Predicted Quantiles
+ """
+ return self._call_java("predictQuantiles", features)
+
+ def predict(self, features):
+ """
+ Predicted value
+ """
+ return self._call_java("predict", features)
+
+
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org