You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/07/31 01:15:46 UTC

spark git commit: [SPARK-9463] [ML] Expose model coefficients with names in SparkR RFormula

Repository: spark
Updated Branches:
  refs/heads/master be7be6d4c -> e7905a939


[SPARK-9463] [ML] Expose model coefficients with names in SparkR RFormula

Preview:

```
> summary(m)
            features coefficients
1        (Intercept)    1.6765001
2       Sepal_Length    0.3498801
3 Species.versicolor   -0.9833885
4  Species.virginica   -1.0075104

```

Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit

cc mengxr

Author: Eric Liang <ek...@databricks.com>

Closes #7771 from ericl/summary and squashes the following commits:

ccd54c3 [Eric Liang] second pass
a5ca93b [Eric Liang] comments
2772111 [Eric Liang] clean up
70483ef [Eric Liang] fix test
7c247d4 [Eric Liang] Merge branch 'master' into summary
3c55024 [Eric Liang] working
8c539aa [Eric Liang] first pass


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e7905a93
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e7905a93
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e7905a93

Branch: refs/heads/master
Commit: e7905a9395c1a002f50bab29e16a729e14d4ed6f
Parents: be7be6d
Author: Eric Liang <ek...@databricks.com>
Authored: Thu Jul 30 16:15:43 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Jul 30 16:15:43 2015 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |  3 ++-
 R/pkg/R/mllib.R                                 | 26 +++++++++++++++++++
 R/pkg/inst/tests/test_mllib.R                   | 11 ++++++++
 .../apache/spark/ml/feature/OneHotEncoder.scala | 12 ++++-----
 .../org/apache/spark/ml/feature/RFormula.scala  | 12 ++++++++-
 .../org/apache/spark/ml/r/SparkRWrappers.scala  | 27 ++++++++++++++++++--
 .../spark/ml/regression/LinearRegression.scala  |  8 ++++--
 .../spark/ml/feature/OneHotEncoderSuite.scala   |  8 +++---
 .../apache/spark/ml/feature/RFormulaSuite.scala | 18 +++++++++++++
 9 files changed, 108 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e7905a93/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 7f7a8a2..a329e14 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -12,7 +12,8 @@ export("print.jobj")
 
 # MLlib integration
 exportMethods("glm",
-              "predict")
+              "predict",
+              "summary")
 
 # Job group lifecycle management methods
 export("setJobGroup",

http://git-wip-us.apache.org/repos/asf/spark/blob/e7905a93/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 6a8baca..efddcc1 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"),
           function(object, newData) {
             return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
           })
+
+#' Get the summary of a model
+#'
+#' Returns the summary of a model produced by glm(), similarly to R's summary().
+#'
+#' @param model A fitted MLlib model
+#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
+#'         summary.glm for more information.
+#' @rdname glm
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- glm(y ~ x, trainingData)
+#' summary(model)
+#'}
+setMethod("summary", signature(object = "PipelineModel"),
+          function(object) {
+            features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+                                   "getModelFeatures", object@model)
+            weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+                                   "getModelWeights", object@model)
+            coefficients <- as.matrix(unlist(weights))
+            colnames(coefficients) <- c("Estimate")
+            rownames(coefficients) <- unlist(features)
+            return(list(coefficients = coefficients))
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/e7905a93/R/pkg/inst/tests/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 3bef693..f272de7 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", {
   rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
   expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
 })
+
+test_that("summary coefficients match with native glm", {
+  training <- createDataFrame(sqlContext, iris)
+  stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
+  coefs <- as.vector(stats$coefficients)
+  rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
+  expect_true(all(abs(rCoefs - coefs) < 1e-6))
+  expect_true(all(
+    as.character(stats$features) ==
+    c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
+})

http://git-wip-us.apache.org/repos/asf/spark/blob/e7905a93/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 3825942..9c60d40 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   override def transformSchema(schema: StructType): StructType = {
-    val is = "_is_"
     val inputColName = $(inputCol)
     val outputColName = $(outputCol)
 
@@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
     val outputAttrNames: Option[Array[String]] = inputAttr match {
       case nominal: NominalAttribute =>
         if (nominal.values.isDefined) {
-          nominal.values.map(_.map(v => inputColName + is + v))
+          nominal.values
         } else if (nominal.numValues.isDefined) {
-          nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
+          nominal.numValues.map(n => Array.tabulate(n)(_.toString))
         } else {
           None
         }
       case binary: BinaryAttribute =>
         if (binary.values.isDefined) {
-          binary.values.map(_.map(v => inputColName + is + v))
+          binary.values
         } else {
-          Some(Array.tabulate(2)(i => inputColName + is + i))
+          Some(Array.tabulate(2)(_.toString))
         }
       case _: NumericAttribute =>
         throw new RuntimeException(
@@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
 
   override def transform(dataset: DataFrame): DataFrame = {
     // schema transformation
-    val is = "_is_"
     val inputColName = $(inputCol)
     val outputColName = $(outputCol)
     val shouldDropLast = $(dropLast)
@@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer
             math.max(m0, m1)
           }
         ).toInt + 1
-      val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
+      val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
       val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
       val outputAttrs: Array[Attribute] =
         filtered.map(name => BinaryAttribute.defaultAttr.withName(name))

http://git-wip-us.apache.org/repos/asf/spark/blob/e7905a93/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 0b428d2..d172691 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.ml.feature
 
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.util.parsing.combinator.RegexParsers
 
@@ -91,11 +92,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
     // TODO(ekl) add support for feature interactions
     val encoderStages = ArrayBuffer[PipelineStage]()
     val tempColumns = ArrayBuffer[String]()
+    val takenNames = mutable.Set(dataset.columns: _*)
     val encodedTerms = resolvedFormula.terms.map { term =>
       dataset.schema(term) match {
         case column if column.dataType == StringType =>
           val indexCol = term + "_idx_" + uid
-          val encodedCol = term + "_onehot_" + uid
+          val encodedCol = {
+            var tmp = term
+            while (takenNames.contains(tmp)) {
+              tmp += "_"
+            }
+            tmp
+          }
+          takenNames.add(indexCol)
+          takenNames.add(encodedCol)
           encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
           encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
           tempColumns += indexCol

http://git-wip-us.apache.org/repos/asf/spark/blob/e7905a93/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 9f70592..f5a022c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.ml.api.r
 
+import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.feature.RFormula
-import org.apache.spark.ml.classification.LogisticRegression
-import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
 import org.apache.spark.ml.{Pipeline, PipelineModel}
 import org.apache.spark.sql.DataFrame
 
@@ -44,4 +45,26 @@ private[r] object SparkRWrappers {
     val pipeline = new Pipeline().setStages(Array(formula, estimator))
     pipeline.fit(df)
   }
+
+  def getModelWeights(model: PipelineModel): Array[Double] = {
+    model.stages.last match {
+      case m: LinearRegressionModel =>
+        Array(m.intercept) ++ m.weights.toArray
+      case _: LogisticRegressionModel =>
+        throw new UnsupportedOperationException(
+          "No weights available for LogisticRegressionModel")  // SPARK-9492
+    }
+  }
+
+  def getModelFeatures(model: PipelineModel): Array[String] = {
+    model.stages.last match {
+      case m: LinearRegressionModel =>
+        val attrs = AttributeGroup.fromStructField(
+          m.summary.predictions.schema(m.summary.featuresCol))
+        Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+      case _: LogisticRegressionModel =>
+        throw new UnsupportedOperationException(
+          "No features names available for LogisticRegressionModel")  // SPARK-9492
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e7905a93/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 89718e0..3b85ba0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.StructField
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.StatCounter
 
@@ -146,9 +147,10 @@ class LinearRegression(override val uid: String)
 
       val model = new LinearRegressionModel(uid, weights, intercept)
       val trainingSummary = new LinearRegressionTrainingSummary(
-        model.transform(dataset).select($(predictionCol), $(labelCol)),
+        model.transform(dataset),
         $(predictionCol),
         $(labelCol),
+        $(featuresCol),
         Array(0D))
       return copyValues(model.setSummary(trainingSummary))
     }
@@ -221,9 +223,10 @@ class LinearRegression(override val uid: String)
 
     val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
     val trainingSummary = new LinearRegressionTrainingSummary(
-      model.transform(dataset).select($(predictionCol), $(labelCol)),
+      model.transform(dataset),
       $(predictionCol),
       $(labelCol),
+      $(featuresCol),
       objectiveHistory)
     model.setSummary(trainingSummary)
   }
@@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] (
     predictions: DataFrame,
     predictionCol: String,
     labelCol: String,
+    val featuresCol: String,
     val objectiveHistory: Array[Double])
   extends LinearRegressionSummary(predictions, predictionCol, labelCol) {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e7905a93/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 65846a8..321eeb8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
     val output = encoder.transform(df)
     val group = AttributeGroup.fromStructField(output.schema("encoded"))
     assert(group.size === 2)
-    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
-    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
+    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
+    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
   }
 
   test("input column without ML attribute") {
@@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
     val output = encoder.transform(df)
     val group = AttributeGroup.fromStructField(output.schema("encoded"))
     assert(group.size === 2)
-    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
-    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
+    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
+    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e7905a93/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 8148c55..6aed324 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.feature
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(result.schema.toString == resultSchema.toString)
     assert(result.collect() === expected.collect())
   }
+
+  test("attribute generation") {
+    val formula = new RFormula().setFormula("id ~ a + b")
+    val original = sqlContext.createDataFrame(
+      Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
+    ).toDF("id", "a", "b")
+    val model = formula.fit(original)
+    val result = model.transform(original)
+    val attrs = AttributeGroup.fromStructField(result.schema("features"))
+    val expectedAttrs = new AttributeGroup(
+      "features",
+      Array(
+        new BinaryAttribute(Some("a__bar"), Some(1)),
+        new BinaryAttribute(Some("a__foo"), Some(2)),
+        new NumericAttribute(Some("b"), Some(3))))
+    assert(attrs === expectedAttrs)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org