You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/04/29 01:20:04 UTC

spark git commit: [SPARK-14862][ML] Updated Classifiers to not require labelCol metadata

Repository: spark
Updated Branches:
  refs/heads/master dae538a4d -> 4f4721a21


[SPARK-14862][ML] Updated Classifiers to not require labelCol metadata

## What changes were proposed in this pull request?

Updated Classifier, DecisionTreeClassifier, RandomForestClassifier, GBTClassifier to not require input column metadata.
* They first check for metadata.
* If numClasses is not specified in metadata, they identify the largest label value (up to a limit).

This functionality is implemented in a new Classifier.getNumClasses method.

Also
* Updated Classifier.extractLabeledPoints to (a) check label values and (b) include a second version which takes a numClasses value for validity checking.

## How was this patch tested?

* Unit tests in ClassifierSuite for helper methods
* Unit tests for DecisionTreeClassifier, RandomForestClassifier, GBTClassifier with toy datasets lacking label metadata

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #12663 from jkbradley/trees-no-metadata.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4f4721a2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4f4721a2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4f4721a2

Branch: refs/heads/master
Commit: 4f4721a21cc9acc2b6f685bbfc8757d29563a775
Parents: dae538a
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Thu Apr 28 16:20:00 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Apr 28 16:20:00 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/Classifier.scala    |  70 +++++++++++-
 .../classification/DecisionTreeClassifier.scala |  10 +-
 .../spark/ml/classification/GBTClassifier.scala |  25 +++--
 .../classification/RandomForestClassifier.scala |  10 +-
 .../ml/classification/ClassifierSuite.scala     | 108 +++++++++++++++++++
 .../DecisionTreeClassifierSuite.scala           |   6 ++
 .../ml/classification/GBTClassifierSuite.scala  |  40 ++++++-
 .../RandomForestClassifierSuite.scala           |   7 ++
 8 files changed, 245 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4f4721a2/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 473e801..bc5fe35 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -17,14 +17,17 @@
 
 package org.apache.spark.ml.classification
 
+import org.apache.spark.SparkException
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
 import org.apache.spark.ml.param.shared.HasRawPredictionCol
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils}
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 
 /**
  * (private[spark]) Params for classification.
@@ -62,6 +65,67 @@ abstract class Classifier[
   def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
 
   // TODO: defaultEvaluator (follow-up PR)
+
+  /**
+   * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
+   * and put it in an RDD with strong types.
+   *
+   * @param dataset  DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
+   *                 and features ([[Vector]]). Labels are cast to [[DoubleType]].
+   * @param numClasses  Number of classes label can take.  Labels must be integers in the range
+   *                    [0, numClasses).
+   * @throws SparkException  if any label is not an integer >= 0
+   */
+  protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
+    require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
+      s" $numClasses, but requires numClasses > 0.")
+    dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+      case Row(label: Double, features: Vector) =>
+        require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
+          s" dataset with invalid label $label.  Labels must be integers in range" +
+          s" [0, 1, ..., $numClasses), where numClasses=$numClasses.")
+        LabeledPoint(label, features)
+    }
+  }
+
+  /**
+   * Get the number of classes.  This looks in column metadata first, and if that is missing,
+   * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
+   * by finding the maximum label value.
+   *
+   * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
+   * such as in [[extractLabeledPoints()]].
+   *
+   * @param dataset  Dataset which contains a column [[labelCol]]
+   * @param maxNumClasses  Maximum number of classes allowed when inferred from data.  If numClasses
+   *                       is specified in the metadata, then maxNumClasses is ignored.
+   * @return  number of classes
+   * @throws IllegalArgumentException  if metadata does not specify numClasses, and the
+   *                                   actual numClasses exceeds maxNumClasses
+   */
+  protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
+    MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
+      case Some(n: Int) => n
+      case None =>
+        // Get number of classes from dataset itself.
+        val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1)
+        if (maxLabelRow.isEmpty) {
+          throw new SparkException("ML algorithm was given empty dataset.")
+        }
+        val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
+        require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
+          s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
+        val numClasses = maxDoubleLabel.toInt + 1
+        require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
+          s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
+          s" to be inferred from values.  To avoid this error for labels with > $maxNumClasses" +
+          s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
+          s" StringIndexer to the label column.")
+        logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" +
+          s" labelCol=$labelCol since numClasses was not specified in the column metadata.")
+        numClasses
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/4f4721a2/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index ecb218e..2b2e13d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -85,14 +85,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
   override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
-      case Some(n: Int) => n
-      case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
-        s" with invalid label column ${$(labelCol)}, without the number of classes" +
-        " specified. See StringIndexer.")
-        // TODO: Automatically index labels: SPARK-7126
-    }
-    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+    val numClasses: Int = getNumClasses(dataset)
+    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
     val strategy = getOldStrategy(categoricalFeatures, numClasses)
     val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
       seed = $(seed), parentUID = Some(uid))

http://git-wip-us.apache.org/repos/asf/spark/blob/4f4721a2/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index e736f01..acc0458 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -35,8 +35,9 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
 
 /**
  * :: Experimental ::
@@ -126,16 +127,16 @@ class GBTClassifier @Since("1.4.0") (
   override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
-      case Some(n: Int) => n
-      case None => throw new IllegalArgumentException("GBTClassifier was given input" +
-        s" with invalid label column ${$(labelCol)}, without the number of classes" +
-        " specified. See StringIndexer.")
-      // TODO: Automatically index labels: SPARK-7126
-    }
-    require(numClasses == 2,
-      s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
-    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+    // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
+    // 2 classes now.  This lets us provide a more precise error message.
+    val oldDataset: RDD[LabeledPoint] =
+      dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+        case Row(label: Double, features: Vector) =>
+          require(label == 0 || label == 1, s"GBTClassifier was given" +
+            s" dataset with invalid label $label.  Labels must be in {0,1}; note that" +
+            s" GBTClassifier currently only supports binary classification.")
+          LabeledPoint(label, features)
+      }
     val numFeatures = oldDataset.first().features.size
     val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
     val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
@@ -165,6 +166,7 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
  * model for classification.
  * It supports binary labels, as well as both continuous and categorical features.
  * Note: Multiclass labels are not currently supported.
+ *
  * @param _trees  Decision trees in the ensemble.
  * @param _treeWeights  Weights for the decision trees in the ensemble.
  */
@@ -185,6 +187,7 @@ class GBTClassificationModel private[ml](
 
   /**
    * Construct a GBTClassificationModel
+   *
    * @param _trees  Decision trees in the ensemble.
    * @param _treeWeights  Weights for the decision trees in the ensemble.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/4f4721a2/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 28364c2..fb3418d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -101,14 +101,8 @@ class RandomForestClassifier @Since("1.4.0") (
   override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
-      case Some(n: Int) => n
-      case None => throw new IllegalArgumentException("RandomForestClassifier was given input" +
-        s" with invalid label column ${$(labelCol)}, without the number of classes" +
-        " specified. See StringIndexer.")
-      // TODO: Automatically index labels: SPARK-7126
-    }
-    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+    val numClasses: Int = getNumClasses(dataset)
+    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
     val strategy =
       super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
     val trees =

http://git-wip-us.apache.org/repos/asf/spark/blob/4f4721a2/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
index d0e3fe7..89afb94 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
@@ -17,6 +17,86 @@
 
 package org.apache.spark.ml.classification
 
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.classification.ClassifierSuite.MockClassifier
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("extractLabeledPoints") {
+    def getTestData(labels: Seq[Double]): DataFrame = {
+      val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
+      sqlContext.createDataFrame(data)
+    }
+
+    val c = new MockClassifier
+    // Valid dataset
+    val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
+    c.extractLabeledPoints(df0, 6).count()
+    // Invalid datasets
+    val df1 = getTestData(Seq(0.0, -2.0, 1.0, 5.0))
+    withClue("Classifier should fail if label is negative") {
+      val e: SparkException = intercept[SparkException] {
+        c.extractLabeledPoints(df1, 6).count()
+      }
+      assert(e.getMessage.contains("given dataset with invalid label"))
+    }
+    val df2 = getTestData(Seq(0.0, 2.1, 1.0, 5.0))
+    withClue("Classifier should fail if label is not an integer") {
+      val e: SparkException = intercept[SparkException] {
+        c.extractLabeledPoints(df2, 6).count()
+      }
+      assert(e.getMessage.contains("given dataset with invalid label"))
+    }
+    // extractLabeledPoints with numClasses specified
+    withClue("Classifier should fail if label is >= numClasses") {
+      val e: SparkException = intercept[SparkException] {
+        c.extractLabeledPoints(df0, numClasses = 5).count()
+      }
+      assert(e.getMessage.contains("given dataset with invalid label"))
+    }
+    withClue("Classifier.extractLabeledPoints should fail if numClasses <= 0") {
+      val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+        c.extractLabeledPoints(df0, numClasses = 0).count()
+      }
+      assert(e.getMessage.contains("but requires numClasses > 0"))
+    }
+  }
+
+  test("getNumClasses") {
+    def getTestData(labels: Seq[Double]): DataFrame = {
+      val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
+      sqlContext.createDataFrame(data)
+    }
+
+    val c = new MockClassifier
+    // Valid dataset
+    val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
+    assert(c.getNumClasses(df0) === 6)
+    // Invalid datasets
+    val df1 = getTestData(Seq(0.0, 2.0, 1.0, 5.1))
+    withClue("getNumClasses should fail if label is max label not an integer") {
+      val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+        c.getNumClasses(df1)
+      }
+      assert(e.getMessage.contains("requires integers in range"))
+    }
+    val df2 = getTestData(Seq(0.0, 2.0, 1.0, Int.MaxValue.toDouble))
+    withClue("getNumClasses should fail if label is max label is >= Int.MaxValue") {
+      val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+        c.getNumClasses(df2)
+      }
+      assert(e.getMessage.contains("requires integers in range"))
+    }
+  }
+}
+
 object ClassifierSuite {
 
   /**
@@ -29,4 +109,32 @@ object ClassifierSuite {
     "rawPredictionCol" -> "myRawPrediction"
   )
 
+  class MockClassifier(override val uid: String)
+    extends Classifier[Vector, MockClassifier, MockClassificationModel] {
+
+    def this() = this(Identifiable.randomUID("mockclassifier"))
+
+    override def copy(extra: ParamMap): MockClassifier = throw new NotImplementedError()
+
+    override def train(dataset: Dataset[_]): MockClassificationModel =
+      throw new NotImplementedError()
+
+    // Make methods public
+    override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] =
+      super.extractLabeledPoints(dataset, numClasses)
+    def getNumClasses(dataset: Dataset[_]): Int = super.getNumClasses(dataset)
+  }
+
+  class MockClassificationModel(override val uid: String)
+    extends ClassificationModel[Vector, MockClassificationModel] {
+
+    def this() = this(Identifiable.randomUID("mockclassificationmodel"))
+
+    protected def predictRaw(features: Vector): Vector = throw new NotImplementedError()
+
+    override def copy(extra: ParamMap): MockClassificationModel = throw new NotImplementedError()
+
+    override def numClasses: Int = throw new NotImplementedError()
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f4721a2/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index fe839e1..29845b5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -342,6 +342,12 @@ class DecisionTreeClassifierSuite
       }
   }
 
+  test("Fitting without numClasses in metadata") {
+    val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
+    val dt = new DecisionTreeClassifier().setMaxDepth(1)
+    dt.fit(df)
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/4f4721a2/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 7e6aec6..087e201 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -17,12 +17,13 @@
 
 package org.apache.spark.ml.classification
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree.LeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -128,6 +129,43 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
   }
   */
 
+  test("Fitting without numClasses in metadata") {
+    val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
+    val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
+    gbt.fit(df)
+  }
+
+  test("extractLabeledPoints with bad data") {
+    def getTestData(labels: Seq[Double]): DataFrame = {
+      val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
+      sqlContext.createDataFrame(data)
+    }
+
+    val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
+    // Invalid datasets
+    val df1 = getTestData(Seq(0.0, -1.0, 1.0, 0.0))
+    withClue("Classifier should fail if label is negative") {
+      val e: SparkException = intercept[SparkException] {
+        gbt.fit(df1)
+      }
+      assert(e.getMessage.contains("currently only supports binary classification"))
+    }
+    val df2 = getTestData(Seq(0.0, 0.1, 1.0, 0.0))
+    withClue("Classifier should fail if label is not an integer") {
+      val e: SparkException = intercept[SparkException] {
+        gbt.fit(df2)
+      }
+      assert(e.getMessage.contains("currently only supports binary classification"))
+    }
+    val df3 = getTestData(Seq(0.0, 2.0, 1.0, 0.0))
+    withClue("Classifier should fail if label is >= 2") {
+      val e: SparkException = intercept[SparkException] {
+        gbt.fit(df3)
+      }
+      assert(e.getMessage.contains("currently only supports binary classification"))
+    }
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of feature importance
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/4f4721a2/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index aaaa429..9074435 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -154,9 +154,16 @@ class RandomForestClassifierSuite
     }
   }
 
+  test("Fitting without numClasses in metadata") {
+    val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
+    val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1)
+    rf.fit(df)
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of feature importance
   /////////////////////////////////////////////////////////////////////////////
+
   test("Feature importance with toy data") {
     val numClasses = 2
     val rf = new RandomForestClassifier()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org