You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/10/18 00:03:30 UTC

git commit: [SPARK-3934] [SPARK-3918] [mllib] Bug fixes for RandomForest, DecisionTree

Repository: spark
Updated Branches:
  refs/heads/master 23f6171d6 -> 477c6481c


[SPARK-3934] [SPARK-3918] [mllib]  Bug fixes for RandomForest, DecisionTree

SPARK-3934: When run with a mix of unordered categorical and continuous features, on multiclass classification, RandomForest fails. The bug is in the sanity checks in getFeatureOffset and getLeftRightFeatureOffsets, which use the wrong indices for checking whether features are unordered.
Fix: Remove the sanity checks since they are not really needed, and since they would require DTStatsAggregator to keep track of an extra set of indices (for the feature subset).

Added test to RandomForestSuite which failed with old version but now works.

SPARK-3918: Added baggedInput.unpersist at end of training.

Also:
* I removed DTStatsAggregator.isUnordered since it is no longer used.
* DecisionTreeMetadata: Added logWarning when maxBins is automatically reduced.
* Updated DecisionTreeRunner to explicitly fix the test data to have the same number of features as the training data.  This is a temporary fix which should eventually be replaced by pre-indexing both datasets.
* RandomForestModel: Updated toString to print total number of nodes in forest.
* Changed Predict class to be public DeveloperApi.  This was necessary to allow users to create their own trees by hand (for testing).

CC: mengxr  manishamde chouqin codedeft  Just notifying you of these small bug fixes.

Author: Joseph K. Bradley <jo...@gmail.com>

Closes #2785 from jkbradley/dtrunner-update and squashes the following commits:

9132321 [Joseph K. Bradley] merged with master, fixed imports
9dbd000 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update
e116473 [Joseph K. Bradley] Changed Predict class to be public DeveloperApi.
f502e65 [Joseph K. Bradley] bug fix for SPARK-3934
7f3d60f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update
ba567ab [Joseph K. Bradley] Changed DTRunner to load test data using same number of features as in training data.
4e88c1f [Joseph K. Bradley] changed RF toString to print total number of nodes


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/477c6481
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/477c6481
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/477c6481

Branch: refs/heads/master
Commit: 477c6481cca94b15c9c8b43e674f220a1cda1dd1
Parents: 23f6171
Author: Joseph K. Bradley <jo...@gmail.com>
Authored: Fri Oct 17 15:02:57 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri Oct 17 15:02:57 2014 -0700

----------------------------------------------------------------------
 .../spark/examples/mllib/DecisionTreeRunner.scala   |  3 ++-
 .../spark/mllib/tree/impl/DTStatsAggregator.scala   | 16 +---------------
 .../mllib/tree/impl/DecisionTreeMetadata.scala      |  7 ++++++-
 .../org/apache/spark/mllib/tree/model/Predict.scala |  5 ++++-
 .../spark/mllib/tree/model/RandomForestModel.scala  |  4 ++--
 .../apache/spark/mllib/tree/RandomForestSuite.scala | 16 ++++++++++++++++
 6 files changed, 31 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/477c6481/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 837d059..0890e62 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -189,9 +189,10 @@ object DecisionTreeRunner {
     // Create training, test sets.
     val splits = if (params.testInput != "") {
       // Load testInput.
+      val numFeatures = examples.take(1)(0).features.size
       val origTestExamples = params.dataFormat match {
         case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
-        case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput)
+        case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures)
       }
       params.algo match {
         case Classification => {

http://git-wip-us.apache.org/repos/asf/spark/blob/477c6481/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 55f422d..ce8825c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -65,12 +65,6 @@ private[tree] class DTStatsAggregator(
   }
 
   /**
-   * Indicator for each feature of whether that feature is an unordered feature.
-   * TODO: Is Array[Boolean] any faster?
-   */
-  def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
-
-  /**
    * Total number of elements stored in this aggregator
    */
   private val allStatsSize: Int = featureOffsets.last
@@ -128,21 +122,13 @@ private[tree] class DTStatsAggregator(
    * Pre-compute feature offset for use with [[featureUpdate]].
    * For ordered features only.
    */
-  def getFeatureOffset(featureIndex: Int): Int = {
-    require(!isUnordered(featureIndex),
-      s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" +
-        s" for unordered feature $featureIndex.")
-    featureOffsets(featureIndex)
-  }
+  def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
 
   /**
    * Pre-compute feature offset for use with [[featureUpdate]].
    * For unordered features only.
    */
   def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
-    require(isUnordered(featureIndex),
-      s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," +
-        s" but was called for ordered feature $featureIndex.")
     val baseOffset = featureOffsets(featureIndex)
     (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/477c6481/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 212dce2..772c026 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.impl
 
 import scala.collection.mutable
 
+import org.apache.spark.Logging
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
@@ -82,7 +83,7 @@ private[tree] class DecisionTreeMetadata(
 
 }
 
-private[tree] object DecisionTreeMetadata {
+private[tree] object DecisionTreeMetadata extends Logging {
 
   /**
    * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
@@ -103,6 +104,10 @@ private[tree] object DecisionTreeMetadata {
     }
 
     val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
+    if (maxPossibleBins < strategy.maxBins) {
+      logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
+        s" (= number of training instances)")
+    }
 
     // We check the number of bins here against maxPossibleBins.
     // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified

http://git-wip-us.apache.org/repos/asf/spark/blob/477c6481/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index d8476b5..004838e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.mllib.tree.model
 
+import org.apache.spark.annotation.DeveloperApi
+
 /**
  * Predicted value for a node
  * @param predict predicted value
  * @param prob probability of the label (classification only)
  */
-private[tree] class Predict(
+@DeveloperApi
+class Predict(
     val predict: Double,
     val prob: Double = 0.0) extends Serializable {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/477c6481/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
index 4d66d6d..6a22e2a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
@@ -82,9 +82,9 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext
    */
   override def toString: String = algo match {
     case Classification =>
-      s"RandomForestModel classifier with $numTrees trees"
+      s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes"
     case Regression =>
-      s"RandomForestModel regressor with $numTrees trees"
+      s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes"
     case _ => throw new IllegalArgumentException(
       s"RandomForestModel given unknown algo parameter: $algo.")
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/477c6481/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 20d372d..fb44ceb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -173,6 +173,22 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
     checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
   }
 
+  test("alternating categorical and continuous features with multiclass labels to test indexing") {
+    val arr = new Array[LabeledPoint](4)
+    arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))
+    arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0))
+    arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0))
+    arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0))
+    val categoricalFeaturesInfo = Map(0 -> 3, 2 -> 2, 4 -> 4)
+    val input = sc.parallelize(arr)
+
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+      numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
+    val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
+      featureSubsetStrategy = "sqrt", seed = 12345)
+    RandomForestSuite.validateClassifier(model, arr, 1.0)
+  }
+
 }
 
 object RandomForestSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org