You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2017/03/02 11:34:00 UTC

spark git commit: [SPARK-19704][ML] AFTSurvivalRegression should support numeric censorCol

Repository: spark
Updated Branches:
  refs/heads/master 625cfe09e -> 50c08e82f


[SPARK-19704][ML] AFTSurvivalRegression should support numeric censorCol

## What changes were proposed in this pull request?
make `AFTSurvivalRegression` support numeric censorCol
## How was this patch tested?
existing tests and added tests

Author: Zheng RuiFeng <ru...@foxmail.com>

Closes #17034 from zhengruifeng/aft_numeric_censor.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/50c08e82
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/50c08e82
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/50c08e82

Branch: refs/heads/master
Commit: 50c08e82f011dd31b4ff7ff2b45fb9fb4c0e3231
Parents: 625cfe0
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Thu Mar 2 13:34:04 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Thu Mar 2 13:34:04 2017 +0200

----------------------------------------------------------------------
 .../ml/regression/AFTSurvivalRegression.scala   |  6 ++--
 .../ml/regression/IsotonicRegression.scala      |  2 +-
 .../regression/AFTSurvivalRegressionSuite.scala | 34 +++++++++++++++++++-
 3 files changed, 37 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/50c08e82/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 2f78dd3..4b36083 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -106,7 +106,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
       fitting: Boolean): StructType = {
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     if (fitting) {
-      SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
+      SchemaUtils.checkNumericType(schema, $(censorCol))
       SchemaUtils.checkNumericType(schema, $(labelCol))
     }
     if (hasQuantilesCol) {
@@ -200,8 +200,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
    * and put it in an RDD with strong types.
    */
   protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = {
-    dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
-      .rdd.map {
+    dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType),
+      col($(censorCol)).cast(DoubleType)).rdd.map {
         case Row(features: Vector, label: Double, censor: Double) =>
           AFTPoint(features, label, censor)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/50c08e82/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index a6c2943..529f66e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -49,7 +49,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
    */
   final val isotonic: BooleanParam =
     new BooleanParam(this, "isotonic",
-      "whether the output sequence should be isotonic/increasing (true) or" +
+      "whether the output sequence should be isotonic/increasing (true) or " +
         "antitonic/decreasing (false)")
 
   /** @group getParam */

http://git-wip-us.apache.org/repos/asf/spark/blob/50c08e82/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 0fdfdf3..3cd4b0a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -27,6 +27,8 @@ import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types._
 
 class AFTSurvivalRegressionSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -352,7 +354,7 @@ class AFTSurvivalRegressionSuite
     }
   }
 
-  test("should support all NumericType labels") {
+  test("should support all NumericType labels, and not support other types") {
     val aft = new AFTSurvivalRegression().setMaxIter(1)
     MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
       aft, spark, isClassification = false) { (expected, actual) =>
@@ -361,6 +363,36 @@ class AFTSurvivalRegressionSuite
       }
   }
 
+  test("should support all NumericType censors, and not support other types") {
+    val df = spark.createDataFrame(Seq(
+      (0, Vectors.dense(0)),
+      (1, Vectors.dense(1)),
+      (2, Vectors.dense(2)),
+      (3, Vectors.dense(3)),
+      (4, Vectors.dense(4))
+    )).toDF("label", "features")
+      .withColumn("censor", lit(0.0))
+    val aft = new AFTSurvivalRegression().setMaxIter(1)
+    val expected = aft.fit(df)
+
+    val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DecimalType(10, 0))
+    types.foreach { t =>
+      val actual = aft.fit(df.select(col("label"), col("features"),
+        col("censor").cast(t)))
+      assert(expected.intercept === actual.intercept)
+      assert(expected.coefficients === actual.coefficients)
+    }
+
+    val dfWithStringCensors = spark.createDataFrame(Seq(
+      (0, Vectors.dense(0, 2, 3), "0")
+    )).toDF("label", "features", "censor")
+    val thrown = intercept[IllegalArgumentException] {
+      aft.fit(dfWithStringCensors)
+    }
+    assert(thrown.getMessage.contains(
+      "Column censor must be of type NumericType but was actually of type StringType"))
+  }
+
   test("numerical stability of standardization") {
     val trainer = new AFTSurvivalRegression()
     val model1 = trainer.fit(datasetUnivariate)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org