You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/04/12 20:27:18 UTC

spark git commit: [SPARK-13322][ML] AFTSurvivalRegression supports feature standardization

Repository: spark
Updated Branches:
  refs/heads/master 75e05a5a9 -> 101663f1a


[SPARK-13322][ML] AFTSurvivalRegression supports feature standardization

## What changes were proposed in this pull request?
AFTSurvivalRegression should support feature standardization, it will improve the convergence rate.
Test the convergence rate on the [Ovarian](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/ovarian.html) data which is standard data comes with Survival library in R,
* without standardization(before this PR) -> 74 iterations.
* with standardization(after this PR) -> 38 iterations.

But after this fix, with or without ```standardization``` will converge to the same solution. It means that ```standardization = false``` will run the same code route as ```standardization = true```. Because if the features are not standardized at all, it will result convergency issue when the features have very different scales. This behavior is the same as ML [```LinearRegression``` and ```LogisticRegression```](https://issues.apache.org/jira/browse/SPARK-8522). See more discussion about this topic at #11247.
cc mengxr
## How was this patch tested?
unit test.

Author: Yanbo Liang <yb...@gmail.com>

Closes #11365 from yanboliang/spark-13322.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/101663f1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/101663f1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/101663f1

Branch: refs/heads/master
Commit: 101663f1ae222a919fc40510aa4f2bad22d1be6f
Parents: 75e05a5
Author: Yanbo Liang <yb...@gmail.com>
Authored: Tue Apr 12 11:27:16 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Apr 12 11:27:16 2016 -0700

----------------------------------------------------------------------
 .../ml/regression/AFTSurvivalRegression.scala   | 105 +++++++++++++------
 .../regression/AFTSurvivalRegressionSuite.scala |  22 ++++
 2 files changed, 93 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/101663f1/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index afed1f3..89ba6ab 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -31,6 +31,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
@@ -198,10 +199,20 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
     val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
     if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
 
-    val costFun = new AFTCostFun(instances, $(fitIntercept))
+    val featuresSummarizer = {
+      val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
+      val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => {
+        c1.merge(c2)
+      }
+      instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp)
+    }
+
+    val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+
+    val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd)
     val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
 
-    val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size
+    val numFeatures = featuresStd.size
     /*
        The parameters vector has three parts:
        the first element: Double, log(sigma), the log of scale parameter
@@ -230,7 +241,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
 
     if (handlePersistence) instances.unpersist()
 
-    val coefficients = Vectors.dense(parameters.slice(2, parameters.length))
+    val rawCoefficients = parameters.slice(2, parameters.length)
+    var i = 0
+    while (i < numFeatures) {
+      rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
+      i += 1
+    }
+    val coefficients = Vectors.dense(rawCoefficients)
     val intercept = parameters(1)
     val scale = math.exp(parameters(0))
     val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
@@ -434,29 +451,36 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
  * @param parameters including three part: The log of scale parameter, the intercept and
  *                regression coefficients corresponding to the features.
  * @param fitIntercept Whether to fit an intercept term.
+ * @param featuresStd The standard deviation values of the features.
  */
-private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
-  extends Serializable {
+private class AFTAggregator(
+    parameters: BDV[Double],
+    fitIntercept: Boolean,
+    featuresStd: Array[Double]) extends Serializable {
 
   // the regression coefficients to the covariates
   private val coefficients = parameters.slice(2, parameters.length)
-  private val intercept = parameters.valueAt(1)
+  private val intercept = parameters(1)
   // sigma is the scale parameter of the AFT model
   private val sigma = math.exp(parameters(0))
 
   private var totalCnt: Long = 0L
   private var lossSum = 0.0
-  private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length)
-  private var gradientInterceptSum = 0.0
-  private var gradientLogSigmaSum = 0.0
+  // Here we optimize loss function over log(sigma), intercept and coefficients
+  private val gradientSumArray = Array.ofDim[Double](parameters.length)
 
   def count: Long = totalCnt
+  def loss: Double = {
+    require(totalCnt > 0.0, s"The number of instances should be " +
+      s"greater than 0.0, but got $totalCnt.")
+    lossSum / totalCnt
+  }
+  def gradient: BDV[Double] = {
+    require(totalCnt > 0.0, s"The number of instances should be " +
+      s"greater than 0.0, but got $totalCnt.")
+    new BDV(gradientSumArray.map(_ / totalCnt.toDouble))
+  }
 
-  def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
-
-  // Here we optimize loss function over coefficients, intercept and log(sigma)
-  def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)),
-    BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble)
 
   /**
    * Add a new training data to this AFTAggregator, and update the loss and gradient
@@ -466,25 +490,32 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
    * @return This AFTAggregator object.
    */
   def add(data: AFTPoint): this.type = {
-
-    val interceptFlag = if (fitIntercept) 1.0 else 0.0
-
-    val xi = data.features.toBreeze
+    val xi = data.features
     val ti = data.label
     val delta = data.censor
-    val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma
 
-    lossSum += math.log(sigma) * delta
-    lossSum += (math.exp(epsilon) - delta * epsilon)
+    val margin = {
+      var sum = 0.0
+      xi.foreachActive { (index, value) =>
+        if (featuresStd(index) != 0.0 && value != 0.0) {
+          sum += coefficients(index) * (value / featuresStd(index))
+        }
+      }
+      sum + intercept
+    }
+    val epsilon = (math.log(ti) - margin) / sigma
+
+    lossSum += delta * math.log(sigma) - delta * epsilon + math.exp(epsilon)
 
-    // Sanity check (should never occur):
-    assert(!lossSum.isInfinity,
-      s"AFTAggregator loss sum is infinity. Error for unknown reason.")
+    val multiplier = (delta - math.exp(epsilon)) / sigma
 
-    val deltaMinusExpEps = delta - math.exp(epsilon)
-    gradientCoefficientSum += xi * deltaMinusExpEps / sigma
-    gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma
-    gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon
+    gradientSumArray(0) += delta + multiplier * sigma * epsilon
+    gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 }
+    xi.foreachActive { (index, value) =>
+      if (featuresStd(index) != 0.0 && value != 0.0) {
+        gradientSumArray(index + 2) += multiplier * (value / featuresStd(index))
+      }
+    }
 
     totalCnt += 1
     this
@@ -503,9 +534,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
       totalCnt += other.totalCnt
       lossSum += other.lossSum
 
-      gradientCoefficientSum += other.gradientCoefficientSum
-      gradientInterceptSum += other.gradientInterceptSum
-      gradientLogSigmaSum += other.gradientLogSigmaSum
+      var i = 0
+      val len = this.gradientSumArray.length
+      while (i < len) {
+        this.gradientSumArray(i) += other.gradientSumArray(i)
+        i += 1
+      }
     }
     this
   }
@@ -516,12 +550,15 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
  * It returns the loss and gradient at a particular point (parameters).
  * It's used in Breeze's convex optimization routines.
  */
-private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean)
-  extends DiffFunction[BDV[Double]] {
+private class AFTCostFun(
+    data: RDD[AFTPoint],
+    fitIntercept: Boolean,
+    featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] {
 
   override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
 
-    val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))(
+    val aftAggregator = data.treeAggregate(
+      new AFTAggregator(parameters, fitIntercept, featuresStd))(
       seqOp = (c, v) => (c, v) match {
         case (aggregator, instance) => aggregator.add(instance)
       },

http://git-wip-us.apache.org/repos/asf/spark/blob/101663f1/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index f4844cc..76891ad 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -33,6 +33,7 @@ class AFTSurvivalRegressionSuite
 
   @transient var datasetUnivariate: DataFrame = _
   @transient var datasetMultivariate: DataFrame = _
+  @transient var datasetUnivariateScaled: DataFrame = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()
@@ -42,6 +43,11 @@ class AFTSurvivalRegressionSuite
     datasetMultivariate = sqlContext.createDataFrame(
       sc.parallelize(generateAFTInput(
         2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0)))
+    datasetUnivariateScaled = sqlContext.createDataFrame(
+      sc.parallelize(generateAFTInput(
+        1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x =>
+          AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor)
+      })
   }
 
   /**
@@ -356,6 +362,22 @@ class AFTSurvivalRegressionSuite
       }
   }
 
+  test("numerical stability of standardization") {
+    val trainer = new AFTSurvivalRegression()
+    val model1 = trainer.fit(datasetUnivariate)
+    val model2 = trainer.fit(datasetUnivariateScaled)
+
+    /**
+     * During training we standardize the dataset first, so no matter how we multiple
+     * a scaling factor into the dataset, the convergence rate should be the same,
+     * and the coefficients should equal to the original coefficients multiple by
+     * the scaling factor. It will have no effect on the intercept and scale.
+     */
+    assert(model1.coefficients(0) ~== model2.coefficients(0) * 1.0E3 absTol 0.01)
+    assert(model1.intercept ~== model2.intercept absTol 0.01)
+    assert(model1.scale ~== model2.scale absTol 0.01)
+  }
+
   test("read/write") {
     def checkModelData(
         model: AFTSurvivalRegressionModel,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org