You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/11/07 12:07:28 UTC

spark git commit: [SPARK-18291][SPARKR][ML] SparkR glm predict should output original label when family = binomial.

Repository: spark
Updated Branches:
  refs/heads/master a814eeac6 -> daa975f4b


[SPARK-18291][SPARKR][ML] SparkR glm predict should output original label when family = binomial.

## What changes were proposed in this pull request?
SparkR ```spark.glm``` predict should output original label when family = "binomial".

## How was this patch tested?
Add unit test.
You can also run the following code to test:
```R
training <- suppressWarnings(createDataFrame(iris))
training <- training[training$Species %in% c("versicolor", "virginica"), ]
model <- spark.glm(training, Species ~ Sepal_Length + Sepal_Width,family = binomial(link = "logit"))
showDF(predict(model, training))
```
Before this change:
```
+------------+-----------+------------+-----------+----------+-----+-------------------+
|Sepal_Length|Sepal_Width|Petal_Length|Petal_Width|   Species|label|         prediction|
+------------+-----------+------------+-----------+----------+-----+-------------------+
|         7.0|        3.2|         4.7|        1.4|versicolor|  0.0| 0.8271421517601544|
|         6.4|        3.2|         4.5|        1.5|versicolor|  0.0| 0.6044595910413112|
|         6.9|        3.1|         4.9|        1.5|versicolor|  0.0| 0.7916340858281998|
|         5.5|        2.3|         4.0|        1.3|versicolor|  0.0|0.16080518180591158|
|         6.5|        2.8|         4.6|        1.5|versicolor|  0.0| 0.6112229217050189|
|         5.7|        2.8|         4.5|        1.3|versicolor|  0.0| 0.2555087295500885|
|         6.3|        3.3|         4.7|        1.6|versicolor|  0.0| 0.5681507664364834|
|         4.9|        2.4|         3.3|        1.0|versicolor|  0.0|0.05990570219972002|
|         6.6|        2.9|         4.6|        1.3|versicolor|  0.0| 0.6644434078306246|
|         5.2|        2.7|         3.9|        1.4|versicolor|  0.0|0.11293577405862379|
|         5.0|        2.0|         3.5|        1.0|versicolor|  0.0|0.06152372321585971|
|         5.9|        3.0|         4.2|        1.5|versicolor|  0.0|0.35250697207602555|
|         6.0|        2.2|         4.0|        1.0|versicolor|  0.0|0.32267018290814303|
|         6.1|        2.9|         4.7|        1.4|versicolor|  0.0|  0.433391153814592|
|         5.6|        2.9|         3.6|        1.3|versicolor|  0.0| 0.2280744262436993|
|         6.7|        3.1|         4.4|        1.4|versicolor|  0.0| 0.7219848389339459|
|         5.6|        3.0|         4.5|        1.5|versicolor|  0.0|0.23527698971404695|
|         5.8|        2.7|         4.1|        1.0|versicolor|  0.0|  0.285024533520016|
|         6.2|        2.2|         4.5|        1.5|versicolor|  0.0| 0.4107047877447493|
|         5.6|        2.5|         3.9|        1.1|versicolor|  0.0|0.20083561961645083|
+------------+-----------+------------+-----------+----------+-----+-------------------+
```
After this change:
```
+------------+-----------+------------+-----------+----------+-----+----------+
|Sepal_Length|Sepal_Width|Petal_Length|Petal_Width|   Species|label|prediction|
+------------+-----------+------------+-----------+----------+-----+----------+
|         7.0|        3.2|         4.7|        1.4|versicolor|  0.0| virginica|
|         6.4|        3.2|         4.5|        1.5|versicolor|  0.0| virginica|
|         6.9|        3.1|         4.9|        1.5|versicolor|  0.0| virginica|
|         5.5|        2.3|         4.0|        1.3|versicolor|  0.0|versicolor|
|         6.5|        2.8|         4.6|        1.5|versicolor|  0.0| virginica|
|         5.7|        2.8|         4.5|        1.3|versicolor|  0.0|versicolor|
|         6.3|        3.3|         4.7|        1.6|versicolor|  0.0| virginica|
|         4.9|        2.4|         3.3|        1.0|versicolor|  0.0|versicolor|
|         6.6|        2.9|         4.6|        1.3|versicolor|  0.0| virginica|
|         5.2|        2.7|         3.9|        1.4|versicolor|  0.0|versicolor|
|         5.0|        2.0|         3.5|        1.0|versicolor|  0.0|versicolor|
|         5.9|        3.0|         4.2|        1.5|versicolor|  0.0|versicolor|
|         6.0|        2.2|         4.0|        1.0|versicolor|  0.0|versicolor|
|         6.1|        2.9|         4.7|        1.4|versicolor|  0.0|versicolor|
|         5.6|        2.9|         3.6|        1.3|versicolor|  0.0|versicolor|
|         6.7|        3.1|         4.4|        1.4|versicolor|  0.0| virginica|
|         5.6|        3.0|         4.5|        1.5|versicolor|  0.0|versicolor|
|         5.8|        2.7|         4.1|        1.0|versicolor|  0.0|versicolor|
|         6.2|        2.2|         4.5|        1.5|versicolor|  0.0|versicolor|
|         5.6|        2.5|         3.9|        1.1|versicolor|  0.0|versicolor|
+------------+-----------+------------+-----------+----------+-----+----------+
```

Author: Yanbo Liang <yb...@gmail.com>

Closes #15788 from yanboliang/spark-18291.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/daa975f4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/daa975f4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/daa975f4

Branch: refs/heads/master
Commit: daa975f4bfa4f904697bf3365a4be9987032e490
Parents: a814eea
Author: Yanbo Liang <yb...@gmail.com>
Authored: Mon Nov 7 04:07:19 2016 -0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Mon Nov 7 04:07:19 2016 -0800

----------------------------------------------------------------------
 R/pkg/inst/tests/testthat/test_mllib.R          | 20 +++--
 .../r/GeneralizedLinearRegressionWrapper.scala  | 77 ++++++++++++++++++--
 2 files changed, 84 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/daa975f4/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index e48df03..5f742d9 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -64,6 +64,16 @@ test_that("spark.glm and predict", {
   rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
   expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
 
+  # binomial family
+  binomialTraining <- training[training$Species %in% c("versicolor", "virginica"), ]
+  model <- spark.glm(binomialTraining, Species ~ Sepal_Length + Sepal_Width,
+    family = binomial(link = "logit"))
+  prediction <- predict(model, binomialTraining)
+  expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character")
+  expected <- c("virginica", "virginica", "virginica", "versicolor", "virginica",
+    "versicolor", "virginica", "versicolor", "virginica", "versicolor")
+  expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected)
+
   # poisson family
   model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species,
   family = poisson(link = identity))
@@ -128,10 +138,10 @@ test_that("spark.glm summary", {
   expect_equal(stats$aic, rStats$aic)
 
   # Test spark.glm works with weighted dataset
-  a1 <- c(0, 1, 2, 3)
-  a2 <- c(5, 2, 1, 3)
-  w <- c(1, 2, 3, 4)
-  b <- c(1, 0, 1, 0)
+  a1 <- c(0, 1, 2, 3, 4)
+  a2 <- c(5, 2, 1, 3, 2)
+  w <- c(1, 2, 3, 4, 5)
+  b <- c(1, 0, 1, 0, 0)
   data <- as.data.frame(cbind(a1, a2, w, b))
   df <- createDataFrame(data)
 
@@ -158,7 +168,7 @@ test_that("spark.glm summary", {
   data <- as.data.frame(cbind(a1, a2, b))
   df <- suppressWarnings(createDataFrame(data))
   regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0))
-  expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result
+  expect_equal(regStats$aic, 14.00976, tolerance = 1e-4) # 14.00976 is from summary() result
 })
 
 test_that("spark.glm save/load", {

http://git-wip-us.apache.org/repos/asf/spark/blob/daa975f4/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index b1bb577..995b1ef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -23,11 +23,16 @@ import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.ml.{Pipeline, PipelineModel}
-import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
+import org.apache.spark.ml.feature.{IndexToString, RFormula}
 import org.apache.spark.ml.regression._
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
 
 private[r] class GeneralizedLinearRegressionWrapper private (
     val pipeline: PipelineModel,
@@ -42,6 +47,8 @@ private[r] class GeneralizedLinearRegressionWrapper private (
     val rNumIterations: Int,
     val isLoaded: Boolean = false) extends MLWritable {
 
+  import GeneralizedLinearRegressionWrapper._
+
   private val glm: GeneralizedLinearRegressionModel =
     pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
 
@@ -52,7 +59,15 @@ private[r] class GeneralizedLinearRegressionWrapper private (
   def residuals(residualsType: String): DataFrame = glm.summary.residuals(residualsType)
 
   def transform(dataset: Dataset[_]): DataFrame = {
-    pipeline.transform(dataset).drop(glm.getFeaturesCol)
+    if (rFamily == "binomial") {
+      pipeline.transform(dataset)
+        .drop(PREDICTED_LABEL_PROB_COL)
+        .drop(PREDICTED_LABEL_INDEX_COL)
+        .drop(glm.getFeaturesCol)
+    } else {
+      pipeline.transform(dataset)
+        .drop(glm.getFeaturesCol)
+    }
   }
 
   override def write: MLWriter =
@@ -62,6 +77,10 @@ private[r] class GeneralizedLinearRegressionWrapper private (
 private[r] object GeneralizedLinearRegressionWrapper
   extends MLReadable[GeneralizedLinearRegressionWrapper] {
 
+  val PREDICTED_LABEL_PROB_COL = "pred_label_prob"
+  val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
+  val PREDICTED_LABEL_COL = "prediction"
+
   def fit(
       formula: String,
       data: DataFrame,
@@ -71,8 +90,8 @@ private[r] object GeneralizedLinearRegressionWrapper
       maxIter: Int,
       weightCol: String,
       regParam: Double): GeneralizedLinearRegressionWrapper = {
-    val rFormula = new RFormula()
-      .setFormula(formula)
+    val rFormula = new RFormula().setFormula(formula)
+    if (family == "binomial") rFormula.setForceIndexLabel(true)
     RWrapperUtils.checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
     // get labels and feature names from output schema
@@ -90,9 +109,27 @@ private[r] object GeneralizedLinearRegressionWrapper
       .setWeightCol(weightCol)
       .setRegParam(regParam)
       .setFeaturesCol(rFormula.getFeaturesCol)
-    val pipeline = new Pipeline()
-      .setStages(Array(rFormulaModel, glr))
-      .fit(data)
+    val pipeline = if (family == "binomial") {
+      // Convert prediction from probability to label index.
+      val probToPred = new ProbabilityToPrediction()
+        .setInputCol(PREDICTED_LABEL_PROB_COL)
+        .setOutputCol(PREDICTED_LABEL_INDEX_COL)
+      // Convert prediction from label index to original label.
+      val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
+        .asInstanceOf[NominalAttribute]
+      val labels = labelAttr.values.get
+      val idxToStr = new IndexToString()
+        .setInputCol(PREDICTED_LABEL_INDEX_COL)
+        .setOutputCol(PREDICTED_LABEL_COL)
+        .setLabels(labels)
+
+      new Pipeline()
+        .setStages(Array(rFormulaModel, glr.setPredictionCol(PREDICTED_LABEL_PROB_COL),
+          probToPred, idxToStr))
+        .fit(data)
+    } else {
+      new Pipeline().setStages(Array(rFormulaModel, glr)).fit(data)
+    }
 
     val glm: GeneralizedLinearRegressionModel =
       pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
@@ -200,3 +237,27 @@ private[r] object GeneralizedLinearRegressionWrapper
     }
   }
 }
+
+/**
+ * This utility transformer converts the predicted value of GeneralizedLinearRegressionModel
+ * with "binomial" family from probability to prediction according to threshold 0.5.
+ */
+private[r] class ProbabilityToPrediction private[r] (override val uid: String)
+  extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
+
+  def this() = this(Identifiable.randomUID("probToPred"))
+
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  override def transformSchema(schema: StructType): StructType = {
+    StructType(schema.fields :+ StructField($(outputCol), DoubleType))
+  }
+
+  override def transform(dataset: Dataset[_]): DataFrame = {
+    dataset.withColumn($(outputCol), round(col($(inputCol))))
+  }
+
+  override def copy(extra: ParamMap): ProbabilityToPrediction = defaultCopy(extra)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org