You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/11/14 04:25:22 UTC

spark git commit: [SPARK-18412][SPARKR][ML] Fix exception for some SparkR ML algorithms training on libsvm data

Repository: spark
Updated Branches:
  refs/heads/master b91a51bb2 -> 07be232ea


[SPARK-18412][SPARKR][ML] Fix exception for some SparkR ML algorithms training on libsvm data

## What changes were proposed in this pull request?
* Fix the following exceptions which throws when ```spark.randomForest```(classification), ```spark.gbt```(classification), ```spark.naiveBayes``` and ```spark.glm```(binomial family) were fitted on libsvm data.
```
java.lang.IllegalArgumentException: requirement failed: If label column already exists, forceIndexLabel can not be set with true.
```
See [SPARK-18412](https://issues.apache.org/jira/browse/SPARK-18412) for more detail about how to reproduce this bug.
* Refactor out ```getFeaturesAndLabels``` to RWrapperUtils, since lots of ML algorithm wrappers use this function.
* Drop some unwanted columns when making prediction.

## How was this patch tested?
Add unit test.

Author: Yanbo Liang <yb...@gmail.com>

Closes #15851 from yanboliang/spark-18412.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/07be232e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/07be232e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/07be232e

Branch: refs/heads/master
Commit: 07be232ea12dfc8dc3701ca948814be7dbebf4ee
Parents: b91a51b
Author: Yanbo Liang <yb...@gmail.com>
Authored: Sun Nov 13 20:25:12 2016 -0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Sun Nov 13 20:25:12 2016 -0800

----------------------------------------------------------------------
 R/pkg/inst/tests/testthat/test_mllib.R          | 18 ++++++++--
 .../spark/ml/r/GBTClassificationWrapper.scala   | 18 ++++------
 .../r/GeneralizedLinearRegressionWrapper.scala  |  5 ++-
 .../apache/spark/ml/r/NaiveBayesWrapper.scala   | 14 +++-----
 .../org/apache/spark/ml/r/RWrapperUtils.scala   | 36 +++++++++++++++++---
 .../r/RandomForestClassificationWrapper.scala   | 18 ++++------
 6 files changed, 68 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/07be232e/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index b76f75d..07df4b6 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -881,7 +881,8 @@ test_that("spark.kstest", {
   expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:")
 })
 
-test_that("spark.randomForest Regression", {
+test_that("spark.randomForest", {
+  # regression
   data <- suppressWarnings(createDataFrame(longley))
   model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
                               numTrees = 1)
@@ -923,9 +924,8 @@ test_that("spark.randomForest Regression", {
   expect_equal(stats$treeWeights, stats2$treeWeights)
 
   unlink(modelPath)
-})
 
-test_that("spark.randomForest Classification", {
+  # classification
   data <- suppressWarnings(createDataFrame(iris))
   model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification",
                               maxDepth = 5, maxBins = 16)
@@ -971,6 +971,12 @@ test_that("spark.randomForest Classification", {
   predictions <- collect(predict(model, data))$prediction
   expect_equal(length(grep("1.0", predictions)), 50)
   expect_equal(length(grep("2.0", predictions)), 50)
+
+  # spark.randomForest classification can work on libsvm data
+  data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
+                source = "libsvm")
+  model <- spark.randomForest(data, label ~ features, "classification")
+  expect_equal(summary(model)$numFeatures, 4)
 })
 
 test_that("spark.gbt", {
@@ -1039,6 +1045,12 @@ test_that("spark.gbt", {
   expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction))
   expect_equal(s$numFeatures, 5)
   expect_equal(s$numTrees, 20)
+
+  # spark.gbt classification can work on libsvm data
+  data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"),
+                source = "libsvm")
+  model <- spark.gbt(data, label ~ features, "classification")
+  expect_equal(summary(model)$numFeatures, 692)
 })
 
 sparkR.session.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/07be232e/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
index 8946025..aacb41e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
@@ -23,10 +23,10 @@ import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.ml.{Pipeline, PipelineModel}
-import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
 import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
 import org.apache.spark.ml.feature.{IndexToString, RFormula}
 import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.r.RWrapperUtils._
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
 
@@ -51,6 +51,7 @@ private[r] class GBTClassifierWrapper private (
     pipeline.transform(dataset)
       .drop(PREDICTED_LABEL_INDEX_COL)
       .drop(gbtcModel.getFeaturesCol)
+      .drop(gbtcModel.getLabelCol)
   }
 
   override def write: MLWriter = new
@@ -81,19 +82,11 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
-    RWrapperUtils.checkDataColumns(rFormula, data)
+    checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
 
-    // get feature names from output schema
-    val schema = rFormulaModel.transform(data).schema
-    val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
-      .attributes.get
-    val features = featureAttrs.map(_.name.get)
-
-    // get label names from output schema
-    val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
-      .asInstanceOf[NominalAttribute]
-    val labels = labelAttr.values.get
+    // get labels and feature names from output schema
+    val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
 
     // assemble and fit the pipeline
     val rfc = new GBTClassifier()
@@ -109,6 +102,7 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
       .setMaxMemoryInMB(maxMemoryInMB)
       .setCacheNodeIds(cacheNodeIds)
       .setFeaturesCol(rFormula.getFeaturesCol)
+      .setLabelCol(rFormula.getLabelCol)
       .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
     if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/07be232e/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 995b1ef..add4d49 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -29,6 +29,7 @@ import org.apache.spark.ml.regression._
 import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.r.RWrapperUtils._
 import org.apache.spark.ml.util._
 import org.apache.spark.sql._
 import org.apache.spark.sql.functions._
@@ -64,6 +65,7 @@ private[r] class GeneralizedLinearRegressionWrapper private (
         .drop(PREDICTED_LABEL_PROB_COL)
         .drop(PREDICTED_LABEL_INDEX_COL)
         .drop(glm.getFeaturesCol)
+        .drop(glm.getLabelCol)
     } else {
       pipeline.transform(dataset)
         .drop(glm.getFeaturesCol)
@@ -92,7 +94,7 @@ private[r] object GeneralizedLinearRegressionWrapper
       regParam: Double): GeneralizedLinearRegressionWrapper = {
     val rFormula = new RFormula().setFormula(formula)
     if (family == "binomial") rFormula.setForceIndexLabel(true)
-    RWrapperUtils.checkDataColumns(rFormula, data)
+    checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
     // get labels and feature names from output schema
     val schema = rFormulaModel.transform(data).schema
@@ -109,6 +111,7 @@ private[r] object GeneralizedLinearRegressionWrapper
       .setWeightCol(weightCol)
       .setRegParam(regParam)
       .setFeaturesCol(rFormula.getFeaturesCol)
+      .setLabelCol(rFormula.getLabelCol)
     val pipeline = if (family == "binomial") {
       // Convert prediction from probability to label index.
       val probToPred = new ProbabilityToPrediction()

http://git-wip-us.apache.org/repos/asf/spark/blob/07be232e/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index 4fdab2d..0afea4b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -23,9 +23,9 @@ import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.ml.{Pipeline, PipelineModel}
-import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
 import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
 import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.ml.r.RWrapperUtils._
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
 
@@ -46,6 +46,7 @@ private[r] class NaiveBayesWrapper private (
     pipeline.transform(dataset)
       .drop(PREDICTED_LABEL_INDEX_COL)
       .drop(naiveBayesModel.getFeaturesCol)
+      .drop(naiveBayesModel.getLabelCol)
   }
 
   override def write: MLWriter = new NaiveBayesWrapper.NaiveBayesWrapperWriter(this)
@@ -60,21 +61,16 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
-    RWrapperUtils.checkDataColumns(rFormula, data)
+    checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
     // get labels and feature names from output schema
-    val schema = rFormulaModel.transform(data).schema
-    val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
-      .asInstanceOf[NominalAttribute]
-    val labels = labelAttr.values.get
-    val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
-      .attributes.get
-    val features = featureAttrs.map(_.name.get)
+    val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
     // assemble and fit the pipeline
     val naiveBayes = new NaiveBayes()
       .setSmoothing(smoothing)
       .setModelType("bernoulli")
       .setFeaturesCol(rFormula.getFeaturesCol)
+      .setLabelCol(rFormula.getLabelCol)
       .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
     val idxToStr = new IndexToString()
       .setInputCol(PREDICTED_LABEL_INDEX_COL)

http://git-wip-us.apache.org/repos/asf/spark/blob/07be232e/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
index 379007c..665e50a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
@@ -18,11 +18,12 @@
 package org.apache.spark.ml.r
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
+import org.apache.spark.ml.feature.{RFormula, RFormulaModel}
 import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.sql.Dataset
 
-object RWrapperUtils extends Logging {
+private[r] object RWrapperUtils extends Logging {
 
   /**
    * DataFrame column check.
@@ -32,14 +33,41 @@ object RWrapperUtils extends Logging {
    *
    * @param rFormula RFormula instance
    * @param data Input dataset
-   * @return Unit
    */
   def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = {
     if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) {
       val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}"
-      logWarning(s"data containing ${rFormula.getFeaturesCol} column, " +
+      logInfo(s"data containing ${rFormula.getFeaturesCol} column, " +
         s"using new name $newFeaturesName instead")
       rFormula.setFeaturesCol(newFeaturesName)
     }
+
+    if (rFormula.getForceIndexLabel && data.schema.fieldNames.contains(rFormula.getLabelCol)) {
+      val newLabelName = s"${Identifiable.randomUID(rFormula.getLabelCol)}"
+      logInfo(s"data containing ${rFormula.getLabelCol} column and we force to index label, " +
+        s"using new name $newLabelName instead")
+      rFormula.setLabelCol(newLabelName)
+    }
+  }
+
+  /**
+   * Get the feature names and original labels from the schema
+   * of DataFrame transformed by RFormulaModel.
+   *
+   * @param rFormulaModel The RFormulaModel instance.
+   * @param data Input dataset.
+   * @return The feature names and original labels.
+   */
+  def getFeaturesAndLabels(
+      rFormulaModel: RFormulaModel,
+      data: Dataset[_]): (Array[String], Array[String]) = {
+    val schema = rFormulaModel.transform(data).schema
+    val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+      .attributes.get
+    val features = featureAttrs.map(_.name.get)
+    val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
+      .asInstanceOf[NominalAttribute]
+    val labels = labelAttr.values.get
+    (features, labels)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/07be232e/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 31f846d..0b860e5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -23,10 +23,10 @@ import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.ml.{Pipeline, PipelineModel}
-import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
 import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
 import org.apache.spark.ml.feature.{IndexToString, RFormula}
 import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.r.RWrapperUtils._
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
 
@@ -51,6 +51,7 @@ private[r] class RandomForestClassifierWrapper private (
     pipeline.transform(dataset)
       .drop(PREDICTED_LABEL_INDEX_COL)
       .drop(rfcModel.getFeaturesCol)
+      .drop(rfcModel.getLabelCol)
   }
 
   override def write: MLWriter = new
@@ -82,19 +83,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
-    RWrapperUtils.checkDataColumns(rFormula, data)
+    checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
 
-    // get feature names from output schema
-    val schema = rFormulaModel.transform(data).schema
-    val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
-      .attributes.get
-    val features = featureAttrs.map(_.name.get)
-
-    // get label names from output schema
-    val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
-      .asInstanceOf[NominalAttribute]
-    val labels = labelAttr.values.get
+    // get labels and feature names from output schema
+    val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
 
     // assemble and fit the pipeline
     val rfc = new RandomForestClassifier()
@@ -111,6 +104,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
       .setCacheNodeIds(cacheNodeIds)
       .setProbabilityCol(probabilityCol)
       .setFeaturesCol(rFormula.getFeaturesCol)
+      .setLabelCol(rFormula.getLabelCol)
       .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
     if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org