You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/09/24 00:26:03 UTC

spark git commit: [SPARK-10686] [ML] Add quantilesCol to AFTSurvivalRegression

Repository: spark
Updated Branches:
  refs/heads/master 098be27ad -> ce2b056d3


[SPARK-10686] [ML] Add quantilesCol to AFTSurvivalRegression

By default ```quantilesCol``` should be empty. If ```quantileProbabilities``` is set, we should append quantiles as a new column (of type Vector).

Author: Yanbo Liang <yb...@gmail.com>

Closes #8836 from yanboliang/spark-10686.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ce2b056d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ce2b056d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ce2b056d

Branch: refs/heads/master
Commit: ce2b056d35c0c75d5c162b93680ee2d84152e911
Parents: 098be27
Author: Yanbo Liang <yb...@gmail.com>
Authored: Wed Sep 23 15:26:02 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Sep 23 15:26:02 2015 -0700

----------------------------------------------------------------------
 .../ml/regression/AFTSurvivalRegression.scala   | 51 +++++++++++---
 .../regression/AFTSurvivalRegressionSuite.scala | 74 +++++++++++++-------
 2 files changed, 91 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ce2b056d/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 5b25db6..717caac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -41,7 +41,7 @@ import org.apache.spark.storage.StorageLevel
  */
 private[regression] trait AFTSurvivalRegressionParams extends Params
   with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
-  with HasTol with HasFitIntercept {
+  with HasTol with HasFitIntercept with Logging {
 
   /**
    * Param for censor column name.
@@ -59,21 +59,35 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
 
   /**
    * Param for quantile probabilities array.
-   * Values of the quantile probabilities array should be in the range [0, 1].
+   * Values of the quantile probabilities array should be in the range [0, 1]
+   * and the array should be non-empty.
    * @group param
    */
   @Since("1.6.0")
   final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this,
     "quantileProbabilities", "quantile probabilities array",
-    (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1)))
+    (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1)) && t.length > 0)
 
   /** @group getParam */
   @Since("1.6.0")
   def getQuantileProbabilities: Array[Double] = $(quantileProbabilities)
+  setDefault(quantileProbabilities -> Array(0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99))
 
-  /** Checks whether the input has quantile probabilities array. */
-  protected[regression] def hasQuantileProbabilities: Boolean = {
-    isDefined(quantileProbabilities) && $(quantileProbabilities).size != 0
+  /**
+   * Param for quantiles column name.
+   * This column will output quantiles of corresponding quantileProbabilities if it is set.
+   * @group param
+   */
+  @Since("1.6.0")
+  final val quantilesCol: Param[String] = new Param(this, "quantilesCol", "quantiles column name")
+
+  /** @group getParam */
+  @Since("1.6.0")
+  def getQuantilesCol: String = $(quantilesCol)
+
+  /** Checks whether the input has quantiles column name. */
+  protected[regression] def hasQuantilesCol: Boolean = {
+    isDefined(quantilesCol) && $(quantilesCol) != ""
   }
 
   /**
@@ -90,6 +104,9 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
       SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
       SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
     }
+    if (hasQuantilesCol) {
+      SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
+    }
     SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
   }
 }
@@ -124,6 +141,14 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
   @Since("1.6.0")
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
+  /** @group setParam */
+  @Since("1.6.0")
+  def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value)
+
+  /** @group setParam */
+  @Since("1.6.0")
+  def setQuantilesCol(value: String): this.type = set(quantilesCol, value)
+
   /**
    * Set if we should fit the intercept
    * Default is true.
@@ -243,10 +268,12 @@ class AFTSurvivalRegressionModel private[ml] (
   @Since("1.6.0")
   def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value)
 
+  /** @group setParam */
+  @Since("1.6.0")
+  def setQuantilesCol(value: String): this.type = set(quantilesCol, value)
+
   @Since("1.6.0")
   def predictQuantiles(features: Vector): Vector = {
-    require(hasQuantileProbabilities,
-      "AFTSurvivalRegressionModel predictQuantiles must set quantile probabilities array")
     // scale parameter for the Weibull distribution of lifetime
     val lambda = math.exp(BLAS.dot(coefficients, features) + intercept)
     // shape parameter for the Weibull distribution of lifetime
@@ -266,7 +293,13 @@ class AFTSurvivalRegressionModel private[ml] (
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema)
     val predictUDF = udf { features: Vector => predict(features) }
-    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+    val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
+    if (hasQuantilesCol) {
+      dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+        .withColumn($(quantilesCol), predictQuantilesUDF(col($(featuresCol))))
+    } else {
+      dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+    }
   }
 
   @Since("1.6.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/ce2b056d/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index ca7140a..359f310 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -22,8 +22,7 @@ import scala.util.Random
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
-import org.apache.spark.mllib.linalg.BLAS
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -59,16 +58,20 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
     assert(aftr.getFitIntercept)
     assert(aftr.getMaxIter === 100)
     assert(aftr.getTol === 1E-6)
-    val model = aftr.fit(datasetUnivariate)
+    val model = aftr.setQuantileProbabilities(Array(0.1, 0.8))
+      .setQuantilesCol("quantiles")
+      .fit(datasetUnivariate)
 
     // copied model must have the same parent.
     MLTestingUtils.checkCopy(model)
 
     model.transform(datasetUnivariate)
-      .select("label", "prediction")
+      .select("label", "prediction", "quantiles")
       .collect()
     assert(model.getFeaturesCol === "features")
     assert(model.getPredictionCol === "prediction")
+    assert(model.getQuantileProbabilities === Array(0.1, 0.8))
+    assert(model.getQuantilesCol === "quantiles")
     assert(model.intercept !== 0.0)
     assert(model.hasParent)
   }
@@ -108,7 +111,10 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
   }
 
   test("aft survival regression with univariate") {
-    val trainer = new AFTSurvivalRegression
+    val quantileProbabilities = Array(0.1, 0.5, 0.9)
+    val trainer = new AFTSurvivalRegression()
+      .setQuantileProbabilities(quantileProbabilities)
+      .setQuantilesCol("quantiles")
     val model = trainer.fit(datasetUnivariate)
 
     /*
@@ -159,23 +165,25 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
        [1]  0.1879174  2.6801195 14.5779394
      */
     val features = Vectors.dense(6.559282795753792)
-    val quantileProbabilities = Array(0.1, 0.5, 0.9)
     val responsePredictR = 4.494763
     val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394)
 
     assert(model.predict(features) ~== responsePredictR relTol 1E-3)
-    model.setQuantileProbabilities(quantileProbabilities)
     assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
 
-    model.transform(datasetUnivariate).select("features", "prediction").collect().foreach {
-      case Row(features: DenseVector, prediction1: Double) =>
-        val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
-        assert(prediction1 ~== prediction2 relTol 1E-5)
+    model.transform(datasetUnivariate).select("features", "prediction", "quantiles")
+      .collect().foreach {
+        case Row(features: Vector, prediction: Double, quantiles: Vector) =>
+          assert(prediction ~== model.predict(features) relTol 1E-5)
+          assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
     }
   }
 
   test("aft survival regression with multivariate") {
-    val trainer = new AFTSurvivalRegression
+    val quantileProbabilities = Array(0.1, 0.5, 0.9)
+    val trainer = new AFTSurvivalRegression()
+      .setQuantileProbabilities(quantileProbabilities)
+      .setQuantilesCol("quantiles")
     val model = trainer.fit(datasetMultivariate)
 
     /*
@@ -227,23 +235,26 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
        [1]  0.5287044  3.3285858 10.7517072
      */
     val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
-    val quantileProbabilities = Array(0.1, 0.5, 0.9)
     val responsePredictR = 4.761219
     val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072)
 
     assert(model.predict(features) ~== responsePredictR relTol 1E-3)
-    model.setQuantileProbabilities(quantileProbabilities)
     assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
 
-    model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
-      case Row(features: DenseVector, prediction1: Double) =>
-        val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
-        assert(prediction1 ~== prediction2 relTol 1E-5)
+    model.transform(datasetMultivariate).select("features", "prediction", "quantiles")
+      .collect().foreach {
+        case Row(features: Vector, prediction: Double, quantiles: Vector) =>
+          assert(prediction ~== model.predict(features) relTol 1E-5)
+          assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
     }
   }
 
   test("aft survival regression w/o intercept") {
-    val trainer = new AFTSurvivalRegression().setFitIntercept(false)
+    val quantileProbabilities = Array(0.1, 0.5, 0.9)
+    val trainer = new AFTSurvivalRegression()
+      .setQuantileProbabilities(quantileProbabilities)
+      .setQuantilesCol("quantiles")
+      .setFitIntercept(false)
     val model = trainer.fit(datasetMultivariate)
 
     /*
@@ -294,18 +305,31 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
        [1]   1.452103  25.506077 158.428600
      */
     val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
-    val quantileProbabilities = Array(0.1, 0.5, 0.9)
     val responsePredictR = 44.54465
     val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600)
 
     assert(model.predict(features) ~== responsePredictR relTol 1E-3)
-    model.setQuantileProbabilities(quantileProbabilities)
     assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
 
-    model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
-      case Row(features: DenseVector, prediction1: Double) =>
-        val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
-        assert(prediction1 ~== prediction2 relTol 1E-5)
+    model.transform(datasetMultivariate).select("features", "prediction", "quantiles")
+      .collect().foreach {
+        case Row(features: Vector, prediction: Double, quantiles: Vector) =>
+          assert(prediction ~== model.predict(features) relTol 1E-5)
+          assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
+    }
+  }
+
+  test("aft survival regression w/o quantiles column") {
+    val trainer = new AFTSurvivalRegression
+    val model = trainer.fit(datasetUnivariate)
+    val outputDf = model.transform(datasetUnivariate)
+
+    assert(outputDf.schema.fieldNames.contains("quantiles") === false)
+
+    outputDf.select("features", "prediction")
+      .collect().foreach {
+        case Row(features: Vector, prediction: Double) =>
+          assert(prediction ~== model.predict(features) relTol 1E-5)
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org