You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/09/18 06:37:15 UTC

spark git commit: [SPARK-8518] [ML] Log-linear models for survival analysis

Repository: spark
Updated Branches:
  refs/heads/master 0f5ef6dfa -> 98f1ea67d


[SPARK-8518] [ML] Log-linear models for survival analysis

[Accelerated Failure Time (AFT) model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) is the most commonly used and easy to parallel method of survival analysis for censored survival data. It is the log-linear model based on the Weibull distribution of the survival time.
Users can refer to the R function [```survreg```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) to compare the model and [```predict```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/predict.survreg.html) to compare the prediction. There are different kinds of model prediction, I have just select the type ```response``` which is default used for R.

Author: Yanbo Liang <yb...@gmail.com>

Closes #8611 from yanboliang/spark-8518.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/98f1ea67
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/98f1ea67
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/98f1ea67

Branch: refs/heads/master
Commit: 98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b
Parents: 0f5ef6d
Author: Yanbo Liang <yb...@gmail.com>
Authored: Thu Sep 17 21:37:10 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Sep 17 21:37:10 2015 -0700

----------------------------------------------------------------------
 .../ml/regression/AFTSurvivalRegression.scala   | 449 +++++++++++++++++++
 .../regression/AFTSurvivalRegressionSuite.scala | 311 +++++++++++++
 2 files changed, 760 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/98f1ea67/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
new file mode 100644
index 0000000..5b25db6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -0,0 +1,449 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import scala.collection.mutable
+
+import breeze.linalg.{DenseVector => BDV}
+import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS}
+
+import org.apache.spark.{SparkException, Logging}
+import org.apache.spark.annotation.{Since, Experimental}
+import org.apache.spark.ml.{Model, Estimator}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.linalg.BLAS
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, DataFrame}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, StructType}
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Params for accelerated failure time (AFT) regression.
+ */
+private[regression] trait AFTSurvivalRegressionParams extends Params
+  with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
+  with HasTol with HasFitIntercept {
+
+  /**
+   * Param for censor column name.
+   * The value of this column could be 0 or 1.
+   * If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored.
+   * @group param
+   */
+  @Since("1.6.0")
+  final val censorCol: Param[String] = new Param(this, "censorCol", "censor column name")
+
+  /** @group getParam */
+  @Since("1.6.0")
+  def getCensorCol: String = $(censorCol)
+  setDefault(censorCol -> "censor")
+
+  /**
+   * Param for quantile probabilities array.
+   * Values of the quantile probabilities array should be in the range [0, 1].
+   * @group param
+   */
+  @Since("1.6.0")
+  final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this,
+    "quantileProbabilities", "quantile probabilities array",
+    (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1)))
+
+  /** @group getParam */
+  @Since("1.6.0")
+  def getQuantileProbabilities: Array[Double] = $(quantileProbabilities)
+
+  /** Checks whether the input has quantile probabilities array. */
+  protected[regression] def hasQuantileProbabilities: Boolean = {
+    isDefined(quantileProbabilities) && $(quantileProbabilities).size != 0
+  }
+
+  /**
+   * Validates and transforms the input schema with the provided param map.
+   * @param schema input schema
+   * @param fitting whether this is in fitting or prediction
+   * @return output schema
+   */
+  protected def validateAndTransformSchema(
+      schema: StructType,
+      fitting: Boolean): StructType = {
+    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+    if (fitting) {
+      SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
+      SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+    }
+    SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
+  }
+}
+
+/**
+ * :: Experimental ::
+ * Fit a parametric survival regression model named accelerated failure time (AFT) model
+ * ([[https://en.wikipedia.org/wiki/Accelerated_failure_time_model]])
+ * based on the Weibull distribution of the survival time.
+ */
+@Experimental
+@Since("1.6.0")
+class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String)
+  extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging {
+
+  @Since("1.6.0")
+  def this() = this(Identifiable.randomUID("aftSurvReg"))
+
+  /** @group setParam */
+  @Since("1.6.0")
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group setParam */
+  @Since("1.6.0")
+  def setLabelCol(value: String): this.type = set(labelCol, value)
+
+  /** @group setParam */
+  @Since("1.6.0")
+  def setCensorCol(value: String): this.type = set(censorCol, value)
+
+  /** @group setParam */
+  @Since("1.6.0")
+  def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+  /**
+   * Set if we should fit the intercept
+   * Default is true.
+   * @group setParam
+   */
+  @Since("1.6.0")
+  def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+  setDefault(fitIntercept -> true)
+
+  /**
+   * Set the maximum number of iterations.
+   * Default is 100.
+   * @group setParam
+   */
+  @Since("1.6.0")
+  def setMaxIter(value: Int): this.type = set(maxIter, value)
+  setDefault(maxIter -> 100)
+
+  /**
+   * Set the convergence tolerance of iterations.
+   * Smaller value will lead to higher accuracy with the cost of more iterations.
+   * Default is 1E-6.
+   * @group setParam
+   */
+  @Since("1.6.0")
+  def setTol(value: Double): this.type = set(tol, value)
+  setDefault(tol -> 1E-6)
+
+  /**
+   * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
+   * and put it in an RDD with strong types.
+   */
+  protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = {
+    dataset.select($(featuresCol), $(labelCol), $(censorCol)).map {
+      case Row(features: Vector, label: Double, censor: Double) =>
+        AFTPoint(features, label, censor)
+    }
+  }
+
+  @Since("1.6.0")
+  override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = {
+    validateAndTransformSchema(dataset.schema, fitting = true)
+    val instances = extractAFTPoints(dataset)
+    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+    if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
+
+    val costFun = new AFTCostFun(instances, $(fitIntercept))
+    val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
+
+    val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size
+    /*
+       The weights vector has three parts:
+       the first element: Double, log(sigma), the log of scale parameter
+       the second element: Double, intercept of the beta parameter
+       the third to the end elements: Doubles, regression coefficients vector of the beta parameter
+     */
+    val initialWeights = Vectors.zeros(numFeatures + 2)
+
+    val states = optimizer.iterations(new CachedDiffFunction(costFun),
+      initialWeights.toBreeze.toDenseVector)
+
+    val weights = {
+      val arrayBuilder = mutable.ArrayBuilder.make[Double]
+      var state: optimizer.State = null
+      while (states.hasNext) {
+        state = states.next()
+        arrayBuilder += state.adjustedValue
+      }
+      if (state == null) {
+        val msg = s"${optimizer.getClass.getName} failed."
+        throw new SparkException(msg)
+      }
+
+      state.x.toArray.clone()
+    }
+
+    if (handlePersistence) instances.unpersist()
+
+    val coefficients = Vectors.dense(weights.slice(2, weights.length))
+    val intercept = weights(1)
+    val scale = math.exp(weights(0))
+    val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
+    copyValues(model.setParent(this))
+  }
+
+  @Since("1.6.0")
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema, fitting = true)
+  }
+
+  @Since("1.6.0")
+  override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra)
+}
+
+/**
+ * :: Experimental ::
+ * Model produced by [[AFTSurvivalRegression]].
+ */
+@Experimental
+@Since("1.6.0")
+class AFTSurvivalRegressionModel private[ml] (
+    @Since("1.6.0") override val uid: String,
+    @Since("1.6.0") val coefficients: Vector,
+    @Since("1.6.0") val intercept: Double,
+    @Since("1.6.0") val scale: Double)
+  extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams {
+
+  /** @group setParam */
+  @Since("1.6.0")
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group setParam */
+  @Since("1.6.0")
+  def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+  /** @group setParam */
+  @Since("1.6.0")
+  def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value)
+
+  @Since("1.6.0")
+  def predictQuantiles(features: Vector): Vector = {
+    require(hasQuantileProbabilities,
+      "AFTSurvivalRegressionModel predictQuantiles must set quantile probabilities array")
+    // scale parameter for the Weibull distribution of lifetime
+    val lambda = math.exp(BLAS.dot(coefficients, features) + intercept)
+    // shape parameter for the Weibull distribution of lifetime
+    val k = 1 / scale
+    val quantiles = $(quantileProbabilities).map {
+      q => lambda * math.exp(math.log(-math.log(1 - q)) / k)
+    }
+    Vectors.dense(quantiles)
+  }
+
+  @Since("1.6.0")
+  def predict(features: Vector): Double = {
+    math.exp(BLAS.dot(coefficients, features) + intercept)
+  }
+
+  @Since("1.6.0")
+  override def transform(dataset: DataFrame): DataFrame = {
+    transformSchema(dataset.schema)
+    val predictUDF = udf { features: Vector => predict(features) }
+    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+  }
+
+  @Since("1.6.0")
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema, fitting = false)
+  }
+
+  @Since("1.6.0")
+  override def copy(extra: ParamMap): AFTSurvivalRegressionModel = {
+    copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra)
+      .setParent(parent)
+  }
+}
+
+/**
+ * AFTAggregator computes the gradient and loss for a AFT loss function,
+ * as used in AFT survival regression for samples in sparse or dense vector in a online fashion.
+ *
+ * The loss function and likelihood function under the AFT model based on:
+ * Lawless, J. F., Statistical Models and Methods for Lifetime Data,
+ * New York: John Wiley & Sons, Inc. 2003.
+ *
+ * Two AFTAggregator can be merged together to have a summary of loss and gradient of
+ * the corresponding joint dataset.
+ *
+ * Given the values of the covariates x^{'}, for random lifetime t_{i} of subjects i = 1, ..., n,
+ * with possible right-censoring, the likelihood function under the AFT model is given as
+ * {{{
+ *   L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}
+ *   (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}
+ *   (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}}
+ * }}}
+ * Where \delta_{i} is the indicator of the event has occurred i.e. uncensored or not.
+ * Using \epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}, the log-likelihood function
+ * assumes the form
+ * {{{
+ *   \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+
+ *   \delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}]
+ * }}}
+ * Where S_{0}(\epsilon_{i}) is the baseline survivor function,
+ * and f_{0}(\epsilon_{i}) is corresponding density function.
+ *
+ * The most commonly used log-linear survival regression method is based on the Weibull
+ * distribution of the survival time. The Weibull distribution for lifetime corresponding
+ * to extreme value distribution for log of the lifetime,
+ * and the S_{0}(\epsilon) function is
+ * {{{
+ *   S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}})
+ * }}}
+ * the f_{0}(\epsilon_{i}) function is
+ * {{{
+ *   f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}})
+ * }}}
+ * The log-likelihood function for Weibull distribution of lifetime is
+ * {{{
+ *   \iota(\beta,\sigma)=
+ *   -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}]
+ * }}}
+ * Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability,
+ * the loss function we use to optimize is -\iota(\beta,\sigma).
+ * The gradient functions for \beta and \log\sigma respectively are
+ * {{{
+ *   \frac{\partial (-\iota)}{\partial \beta}=
+ *   \sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma}
+ * }}}
+ * {{{
+ *   \frac{\partial (-\iota)}{\partial (\log\sigma)}=
+ *   \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}]
+ * }}}
+ * @param weights The log of scale parameter, the intercept and
+ *                regression coefficients corresponding to the features.
+ * @param fitIntercept Whether to fit an intercept term.
+ */
+private class AFTAggregator(weights: BDV[Double], fitIntercept: Boolean)
+  extends Serializable {
+
+  // beta is the intercept and regression coefficients to the covariates
+  private val beta = weights.slice(1, weights.length)
+  // sigma is the scale parameter of the AFT model
+  private val sigma = math.exp(weights(0))
+
+  private var totalCnt: Long = 0L
+  private var lossSum = 0.0
+  private var gradientBetaSum = BDV.zeros[Double](beta.length)
+  private var gradientLogSigmaSum = 0.0
+
+  def count: Long = totalCnt
+
+  def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
+
+  // Here we optimize loss function over beta and log(sigma)
+  def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)),
+    gradientBetaSum/totalCnt.toDouble)
+
+  /**
+   * Add a new training data to this AFTAggregator, and update the loss and gradient
+   * of the objective function.
+   *
+   * @param data The AFTPoint representation for one data point to be added into this aggregator.
+   * @return This AFTAggregator object.
+   */
+  def add(data: AFTPoint): this.type = {
+
+    // TODO: Don't create a new xi vector each time.
+    val xi = if (fitIntercept) {
+      Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze
+    } else {
+      Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze
+    }
+    val ti = data.label
+    val delta = data.censor
+    val epsilon = (math.log(ti) - beta.dot(xi)) / sigma
+
+    lossSum += math.log(sigma) * delta
+    lossSum += (math.exp(epsilon) - delta * epsilon)
+
+    // Sanity check (should never occur):
+    assert(!lossSum.isInfinity,
+      s"AFTAggregator loss sum is infinity. Error for unknown reason.")
+
+    gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma
+    gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon
+
+    totalCnt += 1
+    this
+  }
+
+  /**
+   * Merge another AFTAggregator, and update the loss and gradient
+   * of the objective function.
+   * (Note that it's in place merging; as a result, `this` object will be modified.)
+   *
+   * @param other The other AFTAggregator to be merged.
+   * @return This AFTAggregator object.
+   */
+  def merge(other: AFTAggregator): this.type = {
+    if (totalCnt != 0) {
+      totalCnt += other.totalCnt
+      lossSum += other.lossSum
+
+      gradientBetaSum += other.gradientBetaSum
+      gradientLogSigmaSum += other.gradientLogSigmaSum
+    }
+    this
+  }
+}
+
+/**
+ * AFTCostFun implements Breeze's DiffFunction[T] for AFT cost.
+ * It returns the loss and gradient at a particular point (coefficients).
+ * It's used in Breeze's convex optimization routines.
+ */
+private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean)
+  extends DiffFunction[BDV[Double]] {
+
+  override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
+
+    val aftAggregator = data.treeAggregate(new AFTAggregator(coefficients, fitIntercept))(
+      seqOp = (c, v) => (c, v) match {
+        case (aggregator, instance) => aggregator.add(instance)
+      },
+      combOp = (c1, c2) => (c1, c2) match {
+        case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
+      })
+
+    (aftAggregator.loss, aftAggregator.gradient)
+  }
+}
+
+/**
+ * Class that represents the (features, label, censor) of a data point.
+ *
+ * @param features List of features for this data point.
+ * @param label Label for this data point.
+ * @param censor Indicator of the event has occurred or not. If the value is 1, it means
+ *                 the event has occurred i.e. uncensored; otherwise censored.
+ */
+private[regression] case class AFTPoint(features: Vector, label: Double, censor: Double) {
+  require(censor == 1.0 || censor == 0.0, "censor of class AFTPoint must be 1.0 or 0.0")
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/98f1ea67/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
new file mode 100644
index 0000000..ca7140a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
+import org.apache.spark.mllib.linalg.BLAS
+import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{Row, DataFrame}
+
+class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  @transient var datasetUnivariate: DataFrame = _
+  @transient var datasetMultivariate: DataFrame = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    datasetUnivariate = sqlContext.createDataFrame(
+      sc.parallelize(generateAFTInput(
+        1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)))
+    datasetMultivariate = sqlContext.createDataFrame(
+      sc.parallelize(generateAFTInput(
+        2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0)))
+  }
+
+  test("params") {
+    ParamsSuite.checkParams(new AFTSurvivalRegression)
+    val model = new AFTSurvivalRegressionModel("aftSurvReg", Vectors.dense(0.0), 0.0, 0.0)
+    ParamsSuite.checkParams(model)
+  }
+
+  test("aft survival regression: default params") {
+    val aftr = new AFTSurvivalRegression
+    assert(aftr.getLabelCol === "label")
+    assert(aftr.getFeaturesCol === "features")
+    assert(aftr.getPredictionCol === "prediction")
+    assert(aftr.getCensorCol === "censor")
+    assert(aftr.getFitIntercept)
+    assert(aftr.getMaxIter === 100)
+    assert(aftr.getTol === 1E-6)
+    val model = aftr.fit(datasetUnivariate)
+
+    // copied model must have the same parent.
+    MLTestingUtils.checkCopy(model)
+
+    model.transform(datasetUnivariate)
+      .select("label", "prediction")
+      .collect()
+    assert(model.getFeaturesCol === "features")
+    assert(model.getPredictionCol === "prediction")
+    assert(model.intercept !== 0.0)
+    assert(model.hasParent)
+  }
+
+  def generateAFTInput(
+      numFeatures: Int,
+      xMean: Array[Double],
+      xVariance: Array[Double],
+      nPoints: Int,
+      seed: Int,
+      weibullShape: Double,
+      weibullScale: Double,
+      exponentialMean: Double): Seq[AFTPoint] = {
+
+    def censor(x: Double, y: Double): Double = { if (x <= y) 1.0 else 0.0 }
+
+    val weibull = new WeibullGenerator(weibullShape, weibullScale)
+    weibull.setSeed(seed)
+
+    val exponential = new ExponentialGenerator(exponentialMean)
+    exponential.setSeed(seed)
+
+    val rnd = new Random(seed)
+    val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](numFeatures)(rnd.nextDouble()))
+
+    x.foreach { v =>
+      var i = 0
+      val len = v.length
+      while (i < len) {
+        v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
+        i += 1
+      }
+    }
+    val y = (1 to nPoints).map { i => (weibull.nextValue(), exponential.nextValue()) }
+
+    y.zip(x).map { p => AFTPoint(Vectors.dense(p._2), p._1._1, censor(p._1._1, p._1._2)) }
+  }
+
+  test("aft survival regression with univariate") {
+    val trainer = new AFTSurvivalRegression
+    val model = trainer.fit(datasetUnivariate)
+
+    /*
+       Using the following R code to load the data and train the model using survival package.
+
+       library("survival")
+       data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+       features <- data$V1
+       censor <- data$V2
+       label <- data$V3
+       sr.fit <- survreg(Surv(label, censor) ~ features, dist='weibull')
+       summary(sr.fit)
+
+                    Value Std. Error      z        p
+       (Intercept)  1.759     0.4141  4.247 2.16e-05
+       features    -0.039     0.0735 -0.531 5.96e-01
+       Log(scale)   0.344     0.0379  9.073 1.16e-19
+
+       Scale= 1.41
+
+       Weibull distribution
+       Loglik(model)= -1152.2   Loglik(intercept only)= -1152.3
+           Chisq= 0.28 on 1 degrees of freedom, p= 0.6
+       Number of Newton-Raphson Iterations: 5
+       n= 1000
+     */
+    val coefficientsR = Vectors.dense(-0.039)
+    val interceptR = 1.759
+    val scaleR = 1.41
+
+    assert(model.intercept ~== interceptR relTol 1E-3)
+    assert(model.coefficients ~== coefficientsR relTol 1E-3)
+    assert(model.scale ~== scaleR relTol 1E-3)
+
+    /*
+       Using the following R code to predict.
+
+       testdata <- list(features=6.559282795753792)
+       responsePredict <- predict(sr.fit, newdata=testdata)
+       responsePredict
+
+              1
+       4.494763
+
+       quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9))
+       quantilePredict
+
+       [1]  0.1879174  2.6801195 14.5779394
+     */
+    val features = Vectors.dense(6.559282795753792)
+    val quantileProbabilities = Array(0.1, 0.5, 0.9)
+    val responsePredictR = 4.494763
+    val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394)
+
+    assert(model.predict(features) ~== responsePredictR relTol 1E-3)
+    model.setQuantileProbabilities(quantileProbabilities)
+    assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
+
+    model.transform(datasetUnivariate).select("features", "prediction").collect().foreach {
+      case Row(features: DenseVector, prediction1: Double) =>
+        val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
+        assert(prediction1 ~== prediction2 relTol 1E-5)
+    }
+  }
+
+  test("aft survival regression with multivariate") {
+    val trainer = new AFTSurvivalRegression
+    val model = trainer.fit(datasetMultivariate)
+
+    /*
+       Using the following R code to load the data and train the model using survival package.
+
+       library("survival")
+       data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+       feature1 <- data$V1
+       feature2 <- data$V2
+       censor <- data$V3
+       label <- data$V4
+       sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2, dist='weibull')
+       summary(sr.fit)
+
+                     Value Std. Error      z        p
+       (Intercept)  1.9206     0.1057 18.171 8.78e-74
+       feature1    -0.0844     0.0611 -1.381 1.67e-01
+       feature2     0.0677     0.0468  1.447 1.48e-01
+       Log(scale)  -0.0236     0.0436 -0.542 5.88e-01
+
+       Scale= 0.977
+
+       Weibull distribution
+       Loglik(model)= -1070.7   Loglik(intercept only)= -1072.7
+           Chisq= 3.91 on 2 degrees of freedom, p= 0.14
+       Number of Newton-Raphson Iterations: 5
+       n= 1000
+     */
+    val coefficientsR = Vectors.dense(-0.0844, 0.0677)
+    val interceptR = 1.9206
+    val scaleR = 0.977
+
+    assert(model.intercept ~== interceptR relTol 1E-3)
+    assert(model.coefficients ~== coefficientsR relTol 1E-3)
+    assert(model.scale ~== scaleR relTol 1E-3)
+
+    /*
+       Using the following R code to predict.
+       testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683)
+       responsePredict <- predict(sr.fit, newdata=testdata)
+       responsePredict
+
+              1
+       4.761219
+
+       quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9))
+       quantilePredict
+
+       [1]  0.5287044  3.3285858 10.7517072
+     */
+    val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
+    val quantileProbabilities = Array(0.1, 0.5, 0.9)
+    val responsePredictR = 4.761219
+    val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072)
+
+    assert(model.predict(features) ~== responsePredictR relTol 1E-3)
+    model.setQuantileProbabilities(quantileProbabilities)
+    assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
+
+    model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
+      case Row(features: DenseVector, prediction1: Double) =>
+        val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
+        assert(prediction1 ~== prediction2 relTol 1E-5)
+    }
+  }
+
+  test("aft survival regression w/o intercept") {
+    val trainer = new AFTSurvivalRegression().setFitIntercept(false)
+    val model = trainer.fit(datasetMultivariate)
+
+    /*
+       Using the following R code to load the data and train the model using survival package.
+
+       library("survival")
+       data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+       feature1 <- data$V1
+       feature2 <- data$V2
+       censor <- data$V3
+       label <- data$V4
+       sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2 - 1, dist='weibull')
+       summary(sr.fit)
+
+                   Value Std. Error     z        p
+       feature1    0.896     0.0685  13.1 3.93e-39
+       feature2   -0.709     0.0522 -13.6 5.78e-42
+       Log(scale)  0.420     0.0401  10.5 1.23e-25
+
+       Scale= 1.52
+
+       Weibull distribution
+       Loglik(model)= -1292.4   Loglik(intercept only)= -1072.7
+         Chisq= -439.57 on 1 degrees of freedom, p= 1
+       Number of Newton-Raphson Iterations: 6
+       n= 1000
+     */
+    val coefficientsR = Vectors.dense(0.896, -0.709)
+    val interceptR = 0.0
+    val scaleR = 1.52
+
+    assert(model.intercept === interceptR)
+    assert(model.coefficients ~== coefficientsR relTol 1E-3)
+    assert(model.scale ~== scaleR relTol 1E-3)
+
+    /*
+       Using the following R code to predict.
+       testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683)
+       responsePredict <- predict(sr.fit, newdata=testdata)
+       responsePredict
+
+              1
+       44.54465
+
+       quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9))
+       quantilePredict
+
+       [1]   1.452103  25.506077 158.428600
+     */
+    val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
+    val quantileProbabilities = Array(0.1, 0.5, 0.9)
+    val responsePredictR = 44.54465
+    val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600)
+
+    assert(model.predict(features) ~== responsePredictR relTol 1E-3)
+    model.setQuantileProbabilities(quantileProbabilities)
+    assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
+
+    model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
+      case Row(features: DenseVector, prediction1: Double) =>
+        val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
+        assert(prediction1 ~== prediction2 relTol 1E-5)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org