You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/09/29 06:44:57 UTC

[1/2] [SPARK-1545] [mllib] Add Random Forests

Repository: spark
Updated Branches:
  refs/heads/master f350cd307 -> 0dc2b6361


http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index b6d49e5..212dce2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -48,7 +48,9 @@ private[tree] class DecisionTreeMetadata(
     val quantileStrategy: QuantileStrategy,
     val maxDepth: Int,
     val minInstancesPerNode: Int,
-    val minInfoGain: Double) extends Serializable {
+    val minInfoGain: Double,
+    val numTrees: Int,
+    val numFeaturesPerNode: Int) extends Serializable {
 
   def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
 
@@ -73,6 +75,11 @@ private[tree] class DecisionTreeMetadata(
     numBins(featureIndex) - 1
   }
 
+  /**
+   * Indicates if feature subsampling is being used.
+   */
+  def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode
+
 }
 
 private[tree] object DecisionTreeMetadata {
@@ -82,7 +89,11 @@ private[tree] object DecisionTreeMetadata {
    * This computes which categorical features will be ordered vs. unordered,
    * as well as the number of splits and bins for each feature.
    */
-  def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
+  def buildMetadata(
+      input: RDD[LabeledPoint],
+      strategy: Strategy,
+      numTrees: Int,
+      featureSubsetStrategy: String): DecisionTreeMetadata = {
 
     val numFeatures = input.take(1)(0).features.size
     val numExamples = input.count()
@@ -128,13 +139,43 @@ private[tree] object DecisionTreeMetadata {
       }
     }
 
+    // Set number of features to use per node (for random forests).
+    val _featureSubsetStrategy = featureSubsetStrategy match {
+      case "auto" =>
+        if (numTrees == 1) {
+          "all"
+        } else {
+          if (strategy.algo == Classification) {
+            "sqrt"
+          } else {
+            "onethird"
+          }
+        }
+      case _ => featureSubsetStrategy
+    }
+    val numFeaturesPerNode: Int = _featureSubsetStrategy match {
+      case "all" => numFeatures
+      case "sqrt" => math.sqrt(numFeatures).ceil.toInt
+      case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
+      case "onethird" => (numFeatures / 3.0).ceil.toInt
+    }
+
     new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
       strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
       strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
-      strategy.minInstancesPerNode, strategy.minInfoGain)
+      strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
   }
 
   /**
+   * Version of [[buildMetadata()]] for DecisionTree.
+   */
+  def buildMetadata(
+      input: RDD[LabeledPoint],
+      strategy: Strategy): DecisionTreeMetadata = {
+    buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
+  }
+
+    /**
    * Given the arity of a categorical feature (arity = number of categories),
    * return the number of bins for the feature if it is to be treated as an unordered feature.
    * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 1c8afc2..0e02345 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -89,12 +89,12 @@ private[tree] class EntropyAggregator(numClasses: Int)
    * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
     if (label >= statsSize) {
       throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
         s" but requires label < numClasses (= $statsSize).")
     }
-    allStats(offset + label.toInt) += 1
+    allStats(offset + label.toInt) += instanceWeight
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 5cfdf34..7c83cd4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -85,12 +85,12 @@ private[tree] class GiniAggregator(numClasses: Int)
    * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
     if (label >= statsSize) {
       throw new IllegalArgumentException(s"GiniAggregator given label $label" +
         s" but requires label < numClasses (= $statsSize).")
     }
-    allStats(offset + label.toInt) += 1
+    allStats(offset + label.toInt) += instanceWeight
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 5a047d6..60e2ab2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -78,7 +78,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri
    * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double): Unit
+  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit
 
   /**
    * Get an [[ImpurityCalculator]] for a (node, feature, bin).

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index e9ccecb..df9eafa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -75,10 +75,10 @@ private[tree] class VarianceAggregator()
    * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
-    allStats(offset) += 1
-    allStats(offset + 1) += label
-    allStats(offset + 2) += label * label
+  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
+    allStats(offset) += instanceWeight
+    allStats(offset + 1) += instanceWeight * label
+    allStats(offset + 2) += instanceWeight * label * label
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 5f0095d..56c3e25 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -41,12 +41,12 @@ import org.apache.spark.mllib.linalg.Vector
 @DeveloperApi
 class Node (
     val id: Int,
-    val predict: Double,
-    val isLeaf: Boolean,
-    val split: Option[Split],
+    var predict: Double,
+    var isLeaf: Boolean,
+    var split: Option[Split],
     var leftNode: Option[Node],
     var rightNode: Option[Node],
-    val stats: Option[InformationGainStats]) extends Serializable with Logging {
+    var stats: Option[InformationGainStats]) extends Serializable with Logging {
 
   override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
     "split = " + split + ", stats = " + stats
@@ -168,6 +168,11 @@ class Node (
 private[tree] object Node {
 
   /**
+   * Return a node with the given node id (but nothing else set).
+   */
+  def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None)
+
+  /**
    * Return the index of the left child of this node.
    */
   def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
new file mode 100644
index 0000000..538c0e2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: Experimental ::
+ * Random forest model for classification or regression.
+ * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make
+ * aggregate predictions.
+ * @param trees Trees which make up this forest.  This cannot be empty.
+ * @param algo algorithm type -- classification or regression
+ */
+@Experimental
+class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable {
+
+  require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
+
+  /**
+   * Predict values for a single data point.
+   *
+   * @param features array representing a single data point
+   * @return Double prediction from the trained model
+   */
+  def predict(features: Vector): Double = {
+    algo match {
+      case Classification =>
+        val predictionToCount = new mutable.HashMap[Int, Int]()
+        trees.foreach { tree =>
+          val prediction = tree.predict(features).toInt
+          predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
+        }
+        predictionToCount.maxBy(_._2)._1
+      case Regression =>
+        trees.map(_.predict(features)).sum / trees.size
+    }
+  }
+
+  /**
+   * Predict values for the given data set.
+   *
+   * @param features RDD representing data points to be predicted
+   * @return RDD[Double] where each entry contains the corresponding prediction
+   */
+  def predict(features: RDD[Vector]): RDD[Double] = {
+    features.map(x => predict(x))
+  }
+
+  /**
+   * Get number of trees in forest.
+   */
+  def numTrees: Int = trees.size
+
+  /**
+   * Print full model.
+   */
+  override def toString: String = {
+    val header = algo match {
+      case Classification =>
+        s"RandomForestModel classifier with $numTrees trees\n"
+      case Regression =>
+        s"RandomForestModel regressor with $numTrees trees\n"
+      case _ => throw new IllegalArgumentException(
+        s"RandomForestModel given unknown algo parameter: $algo.")
+    }
+    header + trees.zipWithIndex.map { case (tree, treeIndex) =>
+      s"  Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
+    }.fold("")(_ + _)
+  }
+
+}
+
+private[tree] object RandomForestModel {
+
+  def build(trees: Array[DecisionTreeModel]): RandomForestModel = {
+    require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
+    val algo: Algo = trees(0).algo
+    require(trees.forall(_.algo == algo),
+      "RandomForestModel cannot combine trees which have different output types" +
+      " (classification/regression).")
+    new RandomForestModel(trees, algo)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 2b2e579..a48ed71 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.mllib.tree
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 import org.scalatest.FunSuite
 
@@ -26,39 +27,13 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
 import org.apache.spark.mllib.util.LocalSparkContext
 
 class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
-  def validateClassifier(
-      model: DecisionTreeModel,
-      input: Seq[LabeledPoint],
-      requiredAccuracy: Double) {
-    val predictions = input.map(x => model.predict(x.features))
-    val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
-      prediction != expected.label
-    }
-    val accuracy = (input.length - numOffPredictions).toDouble / input.length
-    assert(accuracy >= requiredAccuracy,
-      s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
-  }
-
-  def validateRegressor(
-      model: DecisionTreeModel,
-      input: Seq[LabeledPoint],
-      requiredMSE: Double) {
-    val predictions = input.map(x => model.predict(x.features))
-    val squaredError = predictions.zip(input).map { case (prediction, expected) =>
-      val err = prediction - expected.label
-      err * err
-    }.sum
-    val mse = squaredError / input.length
-    assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
-  }
-
   test("Binary classification with continuous features: split and bin calculation") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
@@ -233,7 +208,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 100,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
-    // 2^10 - 1 > 100, so categorical features will be ordered
+    // 2^(10-1) - 1 > 100, so categorical features will be ordered
 
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
     assert(!metadata.isUnordered(featureIndex = 0))
@@ -269,9 +244,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 0)
     assert(bins(0).length === 0)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode: Node, doneTraining: Boolean) =
-      DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
     assert(split.categories === List(1.0))
@@ -299,10 +272,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
     assert(split.categories.length === 1)
@@ -331,7 +301,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(!metadata.isUnordered(featureIndex = 1))
 
     val model = DecisionTree.train(rdd, strategy)
-    validateRegressor(model, arr, 0.0)
+    DecisionTreeSuite.validateRegressor(model, arr, 0.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
   }
@@ -352,12 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins.length === 2)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
-
-    val split = rootNode.split.get
-    assert(split.feature === 0)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val stats = rootNode.stats.get
     assert(stats.gain === 0)
@@ -381,12 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins.length === 2)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
-
-    val split = rootNode.split.get
-    assert(split.feature === 0)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val stats = rootNode.stats.get
     assert(stats.gain === 0)
@@ -411,12 +371,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins.length === 2)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
-
-    val split = rootNode.split.get
-    assert(split.feature === 0)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val stats = rootNode.stats.get
     assert(stats.gain === 0)
@@ -441,12 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins.length === 2)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
-
-    val split = rootNode.split.get
-    assert(split.feature === 0)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val stats = rootNode.stats.get
     assert(stats.gain === 0)
@@ -471,25 +421,39 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
       numClassesForClassification = 2, maxBins = 100)
     val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
-    val rootNodeCopy1 = modelOneNode.topNode.deepCopy()
-    val rootNodeCopy2 = modelOneNode.topNode.deepCopy()
+    val rootNode1 = modelOneNode.topNode.deepCopy()
+    val rootNode2 = modelOneNode.topNode.deepCopy()
+    assert(rootNode1.leftNode.nonEmpty)
+    assert(rootNode1.rightNode.nonEmpty)
 
-    // Single group second level tree construction.
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
-      rootNodeCopy1, splits, bins, 10)
-    assert(rootNode.leftNode.nonEmpty)
-    assert(rootNode.rightNode.nonEmpty)
+    val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+    // Single group second level tree construction.
+    val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
+    val treeToNodeToIndexInfo = Map((0, Map(
+      (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
+      (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
+    val nodeQueue = new mutable.Queue[(Int, Node)]()
+    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
+      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
     val children1 = new Array[Node](2)
-    children1(0) = rootNode.leftNode.get
-    children1(1) = rootNode.rightNode.get
-
-    // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
-    // level tree construction.
-    val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
-      rootNodeCopy2, splits, bins, 0)
-    assert(rootNode2.leftNode.nonEmpty)
-    assert(rootNode2.rightNode.nonEmpty)
+    children1(0) = rootNode1.leftNode.get
+    children1(1) = rootNode1.rightNode.get
+
+    // Train one second-level node at a time.
+    val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
+    val treeToNodeToIndexInfoA = Map((0, Map(
+      (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+    nodeQueue.clear()
+    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+      nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
+    val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
+    val treeToNodeToIndexInfoB = Map((0, Map(
+      (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+    nodeQueue.clear()
+    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+      nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
     val children2 = new Array[Node](2)
     children2(0) = rootNode2.leftNode.get
     children2(1) = rootNode2.rightNode.get
@@ -521,10 +485,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(metadata.isUnordered(featureIndex = 0))
     assert(metadata.isUnordered(featureIndex = 1))
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
     assert(split.feature === 0)
@@ -544,7 +505,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 2)
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 1.0)
+    DecisionTreeSuite.validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
   }
@@ -561,7 +522,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 2)
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 1.0)
+    DecisionTreeSuite.validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
     assert(model.topNode.split.get.feature === 1)
@@ -581,14 +542,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(metadata.isUnordered(featureIndex = 1))
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 1.0)
+    DecisionTreeSuite.validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = model.topNode
 
     val split = rootNode.split.get
     assert(split.feature === 0)
@@ -610,12 +568,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 0.9)
+    DecisionTreeSuite.validateClassifier(model, arr, 0.9)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = model.topNode
 
     val split = rootNode.split.get
     assert(split.feature === 1)
@@ -635,12 +590,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(metadata.isUnordered(featureIndex = 0))
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 0.9)
+    DecisionTreeSuite.validateClassifier(model, arr, 0.9)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = model.topNode
 
     val split = rootNode.split.get
     assert(split.feature === 1)
@@ -660,10 +612,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
     assert(split.feature === 0)
@@ -682,7 +631,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(strategy.isMulticlassClassification)
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 0.6)
+    DecisionTreeSuite.validateClassifier(model, arr, 0.6)
   }
 
   test("split must satisfy min instances per node requirements") {
@@ -691,24 +640,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
     arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
 
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini,
       maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2)
 
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     assert(model.topNode.isLeaf)
     assert(model.topNode.predict == 0.0)
-    val predicts = input.map(p => model.predict(p.features)).collect()
+    val predicts = rdd.map(p => model.predict(p.features)).collect()
     predicts.foreach { predict =>
       assert(predict == 0.0)
     }
 
-    // test for findBestSplits when no valid split can be found
-    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    // test when no valid split can be found
+    val rootNode = model.topNode
 
     val gain = rootNode.stats.get
     assert(gain == InformationGainStats.invalidInformationGainStats)
@@ -723,15 +668,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
     arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
 
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini,
       maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
       numClassesForClassification = 2, minInstancesPerNode = 2)
-    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
     val gain = rootNode.stats.get
@@ -757,12 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       assert(predict == 0.0)
     }
 
-    // test for findBestSplits when no valid split can be found
-    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    // test when no valid split can be found
+    val rootNode = model.topNode
 
     val gain = rootNode.stats.get
     assert(gain == InformationGainStats.invalidInformationGainStats)
@@ -771,6 +709,32 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
 object DecisionTreeSuite {
 
+  def validateClassifier(
+      model: DecisionTreeModel,
+      input: Seq[LabeledPoint],
+      requiredAccuracy: Double) {
+    val predictions = input.map(x => model.predict(x.features))
+    val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+      prediction != expected.label
+    }
+    val accuracy = (input.length - numOffPredictions).toDouble / input.length
+    assert(accuracy >= requiredAccuracy,
+      s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+  }
+
+  def validateRegressor(
+      model: DecisionTreeModel,
+      input: Seq[LabeledPoint],
+      requiredMSE: Double) {
+    val predictions = input.map(x => model.predict(x.features))
+    val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+      val err = prediction - expected.label
+      err * err
+    }.sum
+    val mse = squaredError / input.length
+    assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+  }
+
   def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
     val arr = new Array[LabeledPoint](1000)
     for (i <- 0 until 1000) {

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
new file mode 100644
index 0000000..30669fc
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
+import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
+import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.util.StatCounter
+
+/**
+ * Test suite for [[RandomForest]].
+ */
+class RandomForestSuite extends FunSuite with LocalSparkContext {
+
+  test("BaggedPoint RDD: without subsampling") {
+    val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
+    val rdd = sc.parallelize(arr)
+    val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd)
+    baggedRDD.collect().foreach { baggedPoint =>
+      assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
+    }
+  }
+
+  test("BaggedPoint RDD: with subsampling") {
+    val numSubsamples = 100
+    val (expectedMean, expectedStddev) = (1.0, 1.0)
+
+    val seeds = Array(123, 5354, 230, 349867, 23987)
+    val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
+    val rdd = sc.parallelize(arr)
+    seeds.foreach { seed =>
+      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed)
+      val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+      RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+        expectedStddev, epsilon = 0.01)
+    }
+  }
+
+  test("Binary classification with continuous features:" +
+      " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+
+    val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
+    val rdd = sc.parallelize(arr)
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val numTrees = 1
+
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+    val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
+      featureSubsetStrategy = "auto", seed = 123)
+    assert(rf.trees.size === 1)
+    val rfTree = rf.trees(0)
+
+    val dt = DecisionTree.train(rdd, strategy)
+
+    RandomForestSuite.validateClassifier(rf, arr, 0.9)
+    DecisionTreeSuite.validateClassifier(dt, arr, 0.9)
+
+    // Make sure trees are the same.
+    assert(rfTree.toString == dt.toString)
+  }
+
+  test("Regression with continuous features:" +
+    " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+
+    val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
+    val rdd = sc.parallelize(arr)
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val numTrees = 1
+
+    val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+    val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
+      featureSubsetStrategy = "auto", seed = 123)
+    assert(rf.trees.size === 1)
+    val rfTree = rf.trees(0)
+
+    val dt = DecisionTree.train(rdd, strategy)
+
+    RandomForestSuite.validateRegressor(rf, arr, 0.01)
+    DecisionTreeSuite.validateRegressor(dt, arr, 0.01)
+
+    // Make sure trees are the same.
+    assert(rfTree.toString == dt.toString)
+  }
+
+  test("Binary classification with continuous features: subsampling features") {
+    val numFeatures = 50
+    val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures)
+    val rdd = sc.parallelize(arr)
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+    // Select feature subset for top nodes.  Return true if OK.
+    def checkFeatureSubsetStrategy(
+        numTrees: Int,
+        featureSubsetStrategy: String,
+        numFeaturesPerNode: Int): Unit = {
+      val seeds = Array(123, 5354, 230, 349867, 23987)
+      val maxMemoryUsage: Long = 128 * 1024L * 1024L
+      val metadata =
+        DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy)
+      seeds.foreach { seed =>
+        val failString = s"Failed on test with:" +
+          s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
+          s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
+        val nodeQueue = new mutable.Queue[(Int, Node)]()
+        val topNodes: Array[Node] = new Array[Node](numTrees)
+        Range(0, numTrees).foreach { treeIndex =>
+          topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1)
+          nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
+        }
+        val rng = new scala.util.Random(seed = seed)
+        val (nodesForGroup: Map[Int, Array[Node]],
+            treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
+          RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+
+        assert(nodesForGroup.size === numTrees, failString)
+        assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree
+        if (numFeaturesPerNode == numFeatures) {
+          // featureSubset values should all be None
+          assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
+            failString)
+        } else {
+          // Check number of features.
+          assert(treeToNodeToIndexInfo.values.forall(_.values.forall(
+            _.featureSubset.get.size === numFeaturesPerNode)), failString)
+        }
+      }
+    }
+
+    checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures)
+    checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures)
+    checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 1, "log2",
+      (math.log(numFeatures) / math.log(2)).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
+
+    checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
+    checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 2, "log2",
+      (math.log(numFeatures) / math.log(2)).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
+  }
+
+}
+
+object RandomForestSuite {
+
+  /**
+   * Aggregates all values in data, and tests whether the empirical mean and stddev are within
+   * epsilon of the expected values.
+   * @param data  Every element of the data should be an i.i.d. sample from some distribution.
+   */
+  def testRandomArrays(
+      data: Array[Array[Double]],
+      numCols: Int,
+      expectedMean: Double,
+      expectedStddev: Double,
+      epsilon: Double) {
+    val values = new mutable.ArrayBuffer[Double]()
+    data.foreach { row =>
+      assert(row.size == numCols)
+      values ++= row
+    }
+    val stats = new StatCounter(values)
+    assert(math.abs(stats.mean - expectedMean) < epsilon)
+    assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+  }
+
+  def validateClassifier(
+      model: RandomForestModel,
+      input: Seq[LabeledPoint],
+      requiredAccuracy: Double) {
+    val predictions = input.map(x => model.predict(x.features))
+    val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+      prediction != expected.label
+    }
+    val accuracy = (input.length - numOffPredictions).toDouble / input.length
+    assert(accuracy >= requiredAccuracy,
+      s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+  }
+
+  def validateRegressor(
+      model: RandomForestModel,
+      input: Seq[LabeledPoint],
+      requiredMSE: Double) {
+    val predictions = input.map(x => model.predict(x.features))
+    val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+      val err = prediction - expected.label
+      err * err
+    }.sum
+    val mse = squaredError / input.length
+    assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+  }
+
+  def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = {
+    val numInstances = 1000
+    val arr = new Array[LabeledPoint](numInstances)
+    for (i <- 0 until numInstances) {
+      val label = if (i < numInstances / 10) {
+        0.0
+      } else if (i < numInstances / 2) {
+        1.0
+      } else if (i < numInstances * 0.9) {
+        0.0
+      } else {
+        1.0
+      }
+      val features = Array.fill[Double](numFeatures)(i.toDouble)
+      arr(i) = new LabeledPoint(label, Vectors.dense(features))
+    }
+    arr
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/2] git commit: [SPARK-1545] [mllib] Add Random Forests

Posted by me...@apache.org.
[SPARK-1545] [mllib] Add Random Forests

This PR adds RandomForest to MLlib.  The implementation is basic, and future performance optimizations will be important.  (Note: RFs = Random Forests.)

# Overview

## RandomForest
* trains multiple trees at once to reduce the number of passes over the data
* allows feature subsets at each node
* uses a queue of nodes instead of fixed groups for each level

This implementation is based an implementation by manishamde and the [Alpine Labs Sequoia Forest](https://github.com/AlpineNow/SparkML2) by codedeft (in particular, the TreePoint, BaggedPoint, and node queue implementations).  Thank you for your inputs!

## Testing

Correctness: This has been tested for correctness with the test suites and with DecisionTreeRunner on example datasets.

Performance: This has been performance tested using [this branch of spark-perf](https://github.com/jkbradley/spark-perf/tree/rfs).  Results below.

### Regression tests for DecisionTree

Summary: For training 1 tree, there are small regressions, especially from feature subsampling.

In the table below, each row is a single (random) dataset.  The 2 different sets of result columns are for 2 different RF implementations:
* (numTrees): This is from an earlier commit, after implementing RandomForest to train multiple trees at once.  It does not include any code for feature subsampling.
* (feature subsets): This is from this current PR's code, after implementing feature subsampling.
These tests were to identify regressions in DecisionTree, so they are training 1 tree with all of the features (i.e., no feature subsampling).

These were run on an EC2 cluster with 15 workers, training 1 tree with maxDepth = 5 (= 6 levels).  Speedup values < 1 indicate slowdowns from the old DecisionTree implementation.

numInstances | numFeatures | runtime (sec) | speedup | runtime (sec) | speedup
---- | ---- | ---- | ---- | ---- | ----
 | | (numTrees) | (numTrees) | (feature subsets) | (feature subsets)
20000 | 100 | 4.051 | 1.044433473 | 4.478 | 0.9448414471
20000 | 500 | 8.472 | 1.104461756 | 9.315 | 1.004508857
20000 | 1500 | 19.354 | 1.05854087 | 20.863 | 0.9819776638
20000 | 3500 | 43.674 | 1.072033704 | 45.887 | 1.020332556
200000 | 100 | 4.196 | 1.171830315 | 4.848 | 1.014232673
200000 | 500 | 8.926 | 1.082791844 | 9.771 | 0.989151571
200000 | 1500 | 20.58 | 1.068415938 | 22.134 | 0.9934038131
200000 | 3500 | 48.043 | 1.075203464 | 52.249 | 0.9886505005
2000000 | 100 | 4.944 | 1.01355178 | 5.796 | 0.8645617667
2000000 | 500 | 11.11 | 1.016831683 | 12.482 | 0.9050632911
2000000 | 1500 | 31.144 | 1.017852556 | 35.274 | 0.8986789136
2000000 | 3500 | 79.981 | 1.085382778 | 101.105 | 0.8586123337
20000000 | 100 | 8.304 | 0.9270231214 | 9.073 | 0.8484514494
20000000 | 500 | 28.174 | 1.083268262 | 34.236 | 0.8914592826
20000000 | 1500 | 143.97 | 0.9579634646 | 159.275 | 0.8659111599

### Tests for forests

I have run other tests with numTrees=10 and with sqrt(numFeatures), and those indicate that multi-model training and feature subsets can speed up training for forests, especially when training deeper trees.

# Details on specific classes

## Changes to DecisionTree
* Main train() method is now in RandomForest.
* findBestSplits() is no longer needed.  (It split levels into groups, but we now use a queue of nodes.)
* Many small changes to support RFs.  (Note: These methods should be moved to RandomForest.scala in a later PR, but are in DecisionTree.scala to make code comparison easier.)

## RandomForest
* Main train() method is from old DecisionTree.
* selectNodesToSplit: Note that it selects nodes and feature subsets jointly to track memory usage.

## RandomForestModel
* Stores an Array[DecisionTreeModel]
* Prediction:
 * For classification, most common label.  For regression, mean.
 * We could support other methods later.

## examples/.../DecisionTreeRunner
* This now takes numTrees and featureSubsetStrategy, to support RFs.

## DTStatsAggregator
* 2 types of functionality (w/ and w/o subsampling features): These require different indexing methods.  (We could treat both as subsampling, but this is less efficient
  DTStatsAggregator is now abstract, and 2 child classes implement these 2 types of functionality.

## impurities
* These now take instance weights.

## Node
* Some vals changed to vars.
 * This is unfortunately a public API change (DeveloperApi).  This could be avoided by creating a LearningNode struct, but would be awkward.

## RandomForestSuite
Please let me know if there are missing tests!

## BaggedPoint
This wraps TreePoint and holds bootstrap weights/counts.

# Design decisions

* BaggedPoint: BaggedPoint is separate from TreePoint since it may be useful for other bagging algorithms later on.

* RandomForest public API: What options should be easily supported by the train* methods?  Should ALL options be in the Java-friendly constructors?  Should there be a constructor taking Strategy?

* Feature subsampling options: What options should be supported?  scikit-learn supports the same options, except for "onethird."  One option would be to allow users to specific fractions ("0.1"): the current options could be supported, and any unrecognized values would be parsed as Doubles in [0,1].

* Splits and bins are computed before bootstrapping, so all trees use the same discretization.

* One queue, instead of one queue per tree.

CC: mengxr manishamde codedeft chouqin  Please let me know if you have suggestions---thanks!

Author: Joseph K. Bradley <jo...@gmail.com>
Author: qiping.lqp <qi...@alibaba-inc.com>
Author: chouqin <li...@gmail.com>

Closes #2435 from jkbradley/rfs-new and squashes the following commits:

c694174 [Joseph K. Bradley] Fixed typo
cc59d78 [Joseph K. Bradley] fixed imports
e25909f [Joseph K. Bradley] Simplified node group maps.  Specifically, created NodeIndexInfo to store node index in agg and feature subsets, and no longer create extra maps in findBestSplits
fbe9a1e [Joseph K. Bradley] Changed default featureSubsetStrategy to be sqrt for classification, onethird for regression.  Updated docs with references.
ef7c293 [Joseph K. Bradley] Updates based on code review.  Most substantial changes: * Simplified DTStatsAggregator * Made RandomForestModel.trees public * Added test for regression to RandomForestSuite
593b13c [Joseph K. Bradley] Fixed bug in metadata for computing log2(num features).  Now it checks >= 1.
a1a08df [Joseph K. Bradley] Removed old comments
866e766 [Joseph K. Bradley] Changed RandomForestSuite randomized tests to use multiple fixed random seeds.
ff8bb96 [Joseph K. Bradley] removed usage of null from RandomForest and replaced with Option
bf1a4c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new
6b79c07 [Joseph K. Bradley] Added RandomForestSuite, and fixed small bugs, style issues.
d7753d4 [Joseph K. Bradley] Added numTrees and featureSubsetStrategy to DecisionTreeRunner (to support RandomForest).  Fixed bugs so that RandomForest now runs.
746d43c [Joseph K. Bradley] Implemented feature subsampling.  Tested DecisionTree but not RandomForest.
6309d1d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new.  Added RandomForestModel.toString
b7ae594 [Joseph K. Bradley] Updated docs.  Small fix for bug which does not cause errors: No longer allocate unused child nodes for leaf nodes.
121c74e [Joseph K. Bradley] Basic random forests are implemented.  Random features per node not yet implemented.  Test suite not implemented.
325d18a [Joseph K. Bradley] Merge branch 'chouqin-dt-preprune' into rfs-new
4ef9bf1 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new
61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy.
a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
6da8571 [Joseph K. Bradley] RFs partly implemented, not done yet
eddd1eb [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new
5c4ac33 [Joseph K. Bradley] Added check in Strategy to make sure minInstancesPerNode >= 1
0dd4d87 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160
95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes
e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune
f1d11d1 [chouqin] fix typo
c7ebaf1 [chouqin] fix typo
39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test
c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py
306120f [Joseph K. Bradley] Fixed typo in DecisionTreeModel.scala doc
eaa1dcf [Joseph K. Bradley] Added topNode doc in DecisionTree and scalastyle fix
d4d7864 [Joseph K. Bradley] Marked Node.build as deprecated
d4dbb99 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160
1a8f0ad [Joseph K. Bradley] Eliminated pre-allocated nodes array in main train() method. * Nodes are constructed and added to the tree structure as needed during training.
0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree
d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1
2ab763b [Joseph K. Bradley] Simplifications to DecisionTree code:
efcc736 [qiping.lqp] fix bug
10b8012 [qiping.lqp] fix style
6728fad [qiping.lqp] minor fix: remove empty lines
bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune
cadd569 [qiping.lqp] add api docs
46b891f [qiping.lqp] fix bug
e72c7e4 [qiping.lqp] add comments
845c6fa [qiping.lqp] fix style
f195e83 [qiping.lqp] fix style
987cbf4 [qiping.lqp] fix bug
ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain
ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0dc2b636
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0dc2b636
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0dc2b636

Branch: refs/heads/master
Commit: 0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe
Parents: f350cd3
Author: Joseph K. Bradley <jo...@gmail.com>
Authored: Sun Sep 28 21:44:50 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sun Sep 28 21:44:50 2014 -0700

----------------------------------------------------------------------
 .../examples/mllib/DecisionTreeRunner.scala     |  76 ++-
 .../apache/spark/mllib/tree/DecisionTree.scala  | 457 +++++++------------
 .../apache/spark/mllib/tree/RandomForest.scala  | 451 ++++++++++++++++++
 .../spark/mllib/tree/impl/BaggedPoint.scala     |  80 ++++
 .../mllib/tree/impl/DTStatsAggregator.scala     | 219 ++++++---
 .../mllib/tree/impl/DecisionTreeMetadata.scala  |  47 +-
 .../spark/mllib/tree/impurity/Entropy.scala     |   4 +-
 .../apache/spark/mllib/tree/impurity/Gini.scala |   4 +-
 .../spark/mllib/tree/impurity/Impurity.scala    |   2 +-
 .../spark/mllib/tree/impurity/Variance.scala    |   8 +-
 .../apache/spark/mllib/tree/model/Node.scala    |  13 +-
 .../mllib/tree/model/RandomForestModel.scala    | 105 +++++
 .../spark/mllib/tree/DecisionTreeSuite.scala    | 210 ++++-----
 .../spark/mllib/tree/RandomForestSuite.scala    | 245 ++++++++++
 14 files changed, 1410 insertions(+), 511 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 4683e6e..96fb068 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -21,16 +21,18 @@ import scopt.OptionParser
 
 import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{DecisionTree, impurity}
+import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
 import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
 import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.DecisionTreeModel
+import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
 
 /**
- * An example runner for decision tree. Run with
+ * An example runner for decision trees and random forests. Run with
  * {{{
  * ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options]
  * }}}
@@ -57,6 +59,8 @@ object DecisionTreeRunner {
       maxBins: Int = 32,
       minInstancesPerNode: Int = 1,
       minInfoGain: Double = 0.0,
+      numTrees: Int = 1,
+      featureSubsetStrategy: String = "auto",
       fracTest: Double = 0.2)
 
   def main(args: Array[String]) {
@@ -79,11 +83,20 @@ object DecisionTreeRunner {
         .action((x, c) => c.copy(maxBins = x))
       opt[Int]("minInstancesPerNode")
         .text(s"min number of instances required at child nodes to create the parent split," +
-        s" default: ${defaultParams.minInstancesPerNode}")
+          s" default: ${defaultParams.minInstancesPerNode}")
         .action((x, c) => c.copy(minInstancesPerNode = x))
       opt[Double]("minInfoGain")
         .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
         .action((x, c) => c.copy(minInfoGain = x))
+      opt[Int]("numTrees")
+        .text(s"number of trees (1 = decision tree, 2+ = random forest)," +
+          s" default: ${defaultParams.numTrees}")
+        .action((x, c) => c.copy(numTrees = x))
+      opt[String]("featureSubsetStrategy")
+        .text(s"feature subset sampling strategy" +
+          s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}}), " +
+          s"default: ${defaultParams.featureSubsetStrategy}")
+        .action((x, c) => c.copy(featureSubsetStrategy = x))
       opt[Double]("fracTest")
         .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
         .action((x, c) => c.copy(fracTest = x))
@@ -191,18 +204,35 @@ object DecisionTreeRunner {
           numClassesForClassification = numClasses,
           minInstancesPerNode = params.minInstancesPerNode,
           minInfoGain = params.minInfoGain)
-    val model = DecisionTree.train(training, strategy)
-
-    println(model)
-
-    if (params.algo == Classification) {
-      val accuracy = accuracyScore(model, test)
-      println(s"Test accuracy = $accuracy")
-    }
-
-    if (params.algo == Regression) {
-      val mse = meanSquaredError(model, test)
-      println(s"Test mean squared error = $mse")
+    if (params.numTrees == 1) {
+      val model = DecisionTree.train(training, strategy)
+      println(model)
+      if (params.algo == Classification) {
+        val accuracy =
+          new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+        println(s"Test accuracy = $accuracy")
+      }
+      if (params.algo == Regression) {
+        val mse = meanSquaredError(model, test)
+        println(s"Test mean squared error = $mse")
+      }
+    } else {
+      val randomSeed = Utils.random.nextInt()
+      if (params.algo == Classification) {
+        val model = RandomForest.trainClassifier(training, strategy, params.numTrees,
+          params.featureSubsetStrategy, randomSeed)
+        println(model)
+        val accuracy =
+          new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+        println(s"Test accuracy = $accuracy")
+      }
+      if (params.algo == Regression) {
+        val model = RandomForest.trainRegressor(training, strategy, params.numTrees,
+          params.featureSubsetStrategy, randomSeed)
+        println(model)
+        val mse = meanSquaredError(model, test)
+        println(s"Test mean squared error = $mse")
+      }
     }
 
     sc.stop()
@@ -211,9 +241,7 @@ object DecisionTreeRunner {
   /**
    * Calculates the classifier accuracy.
    */
-  private def accuracyScore(
-      model: DecisionTreeModel,
-      data: RDD[LabeledPoint]): Double = {
+  private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
     val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
     val count = data.count()
     correctCount.toDouble / count
@@ -228,4 +256,14 @@ object DecisionTreeRunner {
       err * err
     }.mean()
   }
+
+  /**
+   * Calculates the mean squared error for regression.
+   */
+  private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = {
+    data.map { y =>
+      val err = tree.predict(y.features) - y.label
+      err * err
+    }.mean()
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index c7f2576..b7dc373 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -18,12 +18,14 @@
 package org.apache.spark.mllib.tree
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.Logging
 import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
 import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
@@ -33,7 +35,6 @@ import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
 import org.apache.spark.mllib.tree.impurity._
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.random.XORShiftRandom
 
 
@@ -56,99 +57,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
    * @return DecisionTreeModel that can be used for prediction
    */
   def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
-
-    val timer = new TimeTracker()
-
-    timer.start("total")
-
-    timer.start("init")
-
-    val retaggedInput = input.retag(classOf[LabeledPoint])
-    val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy)
-    logDebug("algo = " + strategy.algo)
-    logDebug("maxBins = " + metadata.maxBins)
-
-    // Find the splits and the corresponding bins (interval between the splits) using a sample
-    // of the input data.
-    timer.start("findSplitsBins")
-    val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
-    timer.stop("findSplitsBins")
-    logDebug("numBins: feature: number of bins")
-    logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
-        s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
-      }.mkString("\n"))
-
-    // Bin feature values (TreePoint representation).
-    // Cache input RDD for speedup during multiple passes.
-    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
-      .persist(StorageLevel.MEMORY_AND_DISK)
-
-    // depth of the decision tree
-    val maxDepth = strategy.maxDepth
-    require(maxDepth <= 30,
-      s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
-
-    // Calculate level for single group construction
-
-    // Max memory usage for aggregates
-    val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L
-    logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
-    // TODO: Calculate memory usage more precisely.
-    val numElementsPerNode = DecisionTree.getElementsPerNode(metadata)
-
-    logDebug("numElementsPerNode = " + numElementsPerNode)
-    val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
-    val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1)
-    logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
-    // nodes at a level is 2^level. level is zero indexed.
-    val maxLevelForSingleGroup = math.max(
-      (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0)
-    logDebug("max level for single group = " + maxLevelForSingleGroup)
-
-    timer.stop("init")
-
-    /*
-     * The main idea here is to perform level-wise training of the decision tree nodes thus
-     * reducing the passes over the data from l to log2(l) where l is the total number of nodes.
-     * Each data sample is handled by a particular node at that level (or it reaches a leaf
-     * beforehand and is not used in later levels.
-     */
-
-    var topNode: Node = null // set on first iteration
-    var level = 0
-    var break = false
-    while (level <= maxDepth && !break) {
-      logDebug("#####################################")
-      logDebug("level = " + level)
-      logDebug("#####################################")
-
-      // Find best split for all nodes at a level.
-      timer.start("findBestSplits")
-      val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput,
-        metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer)
-      timer.stop("findBestSplits")
-
-      if (level == 0) {
-        topNode = tmpTopNode
-      }
-      if (doneTraining) {
-        break = true
-        logDebug("done training")
-      }
-
-      level += 1
-    }
-
-    logDebug("#####################################")
-    logDebug("Extracting tree model")
-    logDebug("#####################################")
-
-    timer.stop("total")
-
-    logInfo("Internal timing for DecisionTree:")
-    logInfo(s"$timer")
-
-    new DecisionTreeModel(topNode, strategy.algo)
+    // Note: random seed will not be used since numTrees = 1.
+    val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
+    val rfModel = rf.train(input)
+    rfModel.trees(0)
   }
 
 }
@@ -353,57 +265,9 @@ object DecisionTree extends Serializable with Logging {
   }
 
   /**
-   * Returns an array of optimal splits for all nodes at a given level. Splits the task into
-   * multiple groups if the level-wise training task could lead to memory overflow.
-   *
-   * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
-   * @param metadata Learning and dataset metadata
-   * @param level Level of the tree
-   * @param topNode Root node of the tree (or invalid node when training first level).
-   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
-   * @param bins possible bins for all features, indexed (numFeatures)(numBins)
-   * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
-   * @return  (root, doneTraining) where:
-   *          root = Root node (which is newly created on the first iteration),
-   *          doneTraining = true if no more internal nodes were created.
-   */
-  private[tree] def findBestSplits(
-      input: RDD[TreePoint],
-      metadata: DecisionTreeMetadata,
-      level: Int,
-      topNode: Node,
-      splits: Array[Array[Split]],
-      bins: Array[Array[Bin]],
-      maxLevelForSingleGroup: Int,
-      timer: TimeTracker = new TimeTracker): (Node, Boolean) = {
-
-    // split into groups to avoid memory overflow during aggregation
-    if (level > maxLevelForSingleGroup) {
-      // When information for all nodes at a given level cannot be stored in memory,
-      // the nodes are divided into multiple groups at each level with the number of groups
-      // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
-      // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
-      val numGroups = 1 << level - maxLevelForSingleGroup
-      logDebug("numGroups = " + numGroups)
-      // Iterate over each group of nodes at a level.
-      var groupIndex = 0
-      var doneTraining = true
-      while (groupIndex < numGroups) {
-        val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
-          topNode, splits, bins, timer, numGroups, groupIndex)
-        doneTraining = doneTraining && doneTrainingGroup
-        groupIndex += 1
-      }
-      (topNode, doneTraining) // Not first iteration, so topNode was already set.
-    } else {
-      findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer)
-    }
-  }
-
-  /**
    * Get the node index corresponding to this data point.
-   * This function mimics prediction, passing an example from the root node down to a node
-   * at the current level being trained; that node's index is returned.
+   * This function mimics prediction, passing an example from the root node down to a leaf
+   * or unsplit node; that node's index is returned.
    *
    * @param node  Node in tree from which to classify the given data point.
    * @param binnedFeatures  Binned feature vector for data point.
@@ -413,14 +277,15 @@ object DecisionTree extends Serializable with Logging {
    *          Otherwise, last node reachable in tree matching this example.
    *          Note: This is the global node index, i.e., the index used in the tree.
    *                This index is different from the index used during training a particular
-   *                set of nodes in a (level, group).
+   *                group of nodes on one call to [[findBestSplits()]].
    */
   private def predictNodeIndex(
       node: Node,
       binnedFeatures: Array[Int],
       bins: Array[Array[Bin]],
       unorderedFeatures: Set[Int]): Int = {
-    if (node.isLeaf) {
+    if (node.isLeaf || node.split.isEmpty) {
+      // Node is either leaf, or has not yet been split.
       node.id
     } else {
       val featureIndex = node.split.get.feature
@@ -465,43 +330,60 @@ object DecisionTree extends Serializable with Logging {
    * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
    *             each (node, feature, bin).
    * @param treePoint  Data point being aggregated.
-   * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
+   * @param nodeIndex  Node corresponding to treePoint.  agg is indexed in [0, numNodes).
    * @param bins possible bins for all features, indexed (numFeatures)(numBins)
    * @param unorderedFeatures  Set of indices of unordered features.
+   * @param instanceWeight  Weight (importance) of instance in dataset.
    */
   private def mixedBinSeqOp(
       agg: DTStatsAggregator,
       treePoint: TreePoint,
       nodeIndex: Int,
       bins: Array[Array[Bin]],
-      unorderedFeatures: Set[Int]): Unit = {
-    // Iterate over all features.
-    val numFeatures = treePoint.binnedFeatures.size
+      unorderedFeatures: Set[Int],
+      instanceWeight: Double,
+      featuresForNode: Option[Array[Int]]): Unit = {
+    val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
+      // Use subsampled features
+      featuresForNode.get.size
+    } else {
+      // Use all features
+      agg.metadata.numFeatures
+    }
     val nodeOffset = agg.getNodeOffset(nodeIndex)
-    var featureIndex = 0
-    while (featureIndex < numFeatures) {
+    // Iterate over features.
+    var featureIndexIdx = 0
+    while (featureIndexIdx < numFeaturesPerNode) {
+      val featureIndex = if (featuresForNode.nonEmpty) {
+        featuresForNode.get.apply(featureIndexIdx)
+      } else {
+        featureIndexIdx
+      }
       if (unorderedFeatures.contains(featureIndex)) {
         // Unordered feature
         val featureValue = treePoint.binnedFeatures(featureIndex)
         val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
-          agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+          agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
         // Update the left or right bin for each split.
-        val numSplits = agg.numSplits(featureIndex)
+        val numSplits = agg.metadata.numSplits(featureIndex)
         var splitIndex = 0
         while (splitIndex < numSplits) {
           if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
-            agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label)
+            agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
+              instanceWeight)
           } else {
-            agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label)
+            agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
+              instanceWeight)
           }
           splitIndex += 1
         }
       } else {
         // Ordered feature
         val binIndex = treePoint.binnedFeatures(featureIndex)
-        agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label)
+        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label,
+          instanceWeight)
       }
-      featureIndex += 1
+      featureIndexIdx += 1
     }
   }
 
@@ -513,66 +395,77 @@ object DecisionTree extends Serializable with Logging {
    * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
    *             each (node, feature, bin).
    * @param treePoint  Data point being aggregated.
-   * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
-   * @return agg
+   * @param nodeIndex  Node corresponding to treePoint.  agg is indexed in [0, numNodes).
+   * @param instanceWeight  Weight (importance) of instance in dataset.
    */
   private def orderedBinSeqOp(
       agg: DTStatsAggregator,
       treePoint: TreePoint,
-      nodeIndex: Int): Unit = {
+      nodeIndex: Int,
+      instanceWeight: Double,
+      featuresForNode: Option[Array[Int]]): Unit = {
     val label = treePoint.label
     val nodeOffset = agg.getNodeOffset(nodeIndex)
-    // Iterate over all features.
-    val numFeatures = agg.numFeatures
-    var featureIndex = 0
-    while (featureIndex < numFeatures) {
-      val binIndex = treePoint.binnedFeatures(featureIndex)
-      agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label)
-      featureIndex += 1
+    // Iterate over features.
+    if (featuresForNode.nonEmpty) {
+      // Use subsampled features
+      var featureIndexIdx = 0
+      while (featureIndexIdx < featuresForNode.get.size) {
+        val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
+        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight)
+        featureIndexIdx += 1
+      }
+    } else {
+      // Use all features
+      val numFeatures = agg.metadata.numFeatures
+      var featureIndex = 0
+      while (featureIndex < numFeatures) {
+        val binIndex = treePoint.binnedFeatures(featureIndex)
+        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight)
+        featureIndex += 1
+      }
     }
   }
 
   /**
-   * Returns an array of optimal splits for a group of nodes at a given level
+   * Given a group of nodes, this finds the best split for each node.
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
    * @param metadata Learning and dataset metadata
-   * @param level Level of the tree
-   * @param topNode Root node of the tree (or invalid node when training first level).
+   * @param topNodes Root node for each tree.  Used for matching instances with nodes.
+   * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
+   * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
+   *                              where nodeIndexInfo stores the index in the group and the
+   *                              feature subsets (if using feature subsets).
    * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
    * @param bins possible bins for all features, indexed (numFeatures)(numBins)
-   * @param numGroups total number of node groups at the current level. Default value is set to 1.
-   * @param groupIndex index of the node group being processed. Default value is set to 0.
-   * @return  (root, doneTraining) where:
-   *          root = Root node (which is newly created on the first iteration),
-   *          doneTraining = true if no more internal nodes were created.
+   * @param nodeQueue  Queue of nodes to split, with values (treeIndex, node).
+   *                   Updated with new non-leaf nodes which are created.
    */
-  private def findBestSplitsPerGroup(
-      input: RDD[TreePoint],
+  private[tree] def findBestSplits(
+      input: RDD[BaggedPoint[TreePoint]],
       metadata: DecisionTreeMetadata,
-      level: Int,
-      topNode: Node,
+      topNodes: Array[Node],
+      nodesForGroup: Map[Int, Array[Node]],
+      treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
-      timer: TimeTracker,
-      numGroups: Int = 1,
-      groupIndex: Int = 0): (Node, Boolean) = {
+      nodeQueue: mutable.Queue[(Int, Node)],
+      timer: TimeTracker = new TimeTracker): Unit = {
 
     /*
      * The high-level descriptions of the best split optimizations are noted here.
      *
-     * *Level-wise training*
-     * We perform bin calculations for all nodes at the given level to avoid making multiple
-     * passes over the data. Thus, for a slightly increased computation and storage cost we save
-     * several iterations over the data especially at higher levels of the decision tree.
+     * *Group-wise training*
+     * We perform bin calculations for groups of nodes to reduce the number of
+     * passes over the data.  Each iteration requires more computation and storage,
+     * but saves several iterations over the data.
      *
      * *Bin-wise computation*
      * We use a bin-wise best split computation strategy instead of a straightforward best split
      * computation strategy. Instead of analyzing each sample for contribution to the left/right
      * child node impurity of every split, we first categorize each feature of a sample into a
-     * bin. Each bin is an interval between a low and high split. Since each split, and thus bin,
-     * is ordered (read ordering for categorical variables in the findSplitsBins method),
-     * we exploit this structure to calculate aggregates for bins and then use these aggregates
+     * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
      * to calculate information gain for each split.
      *
      * *Aggregation over partitions*
@@ -582,42 +475,15 @@ object DecisionTree extends Serializable with Logging {
      * drastically reduce the communication overhead.
      */
 
-    // Common calculations for multiple nested methods:
-
-    // numNodes:  Number of nodes in this (level of tree, group),
-    //            where nodes at deeper (larger) levels may be divided into groups.
-    val numNodes = Node.maxNodesInLevel(level) / numGroups
+    // numNodes:  Number of nodes in this group
+    val numNodes = nodesForGroup.values.map(_.size).sum
     logDebug("numNodes = " + numNodes)
-
     logDebug("numFeatures = " + metadata.numFeatures)
     logDebug("numClasses = " + metadata.numClasses)
     logDebug("isMulticlass = " + metadata.isMulticlass)
     logDebug("isMulticlassWithCategoricalFeatures = " +
       metadata.isMulticlassWithCategoricalFeatures)
 
-    // shift when more than one group is used at deep tree level
-    val groupShift = numNodes * groupIndex
-
-    // Used for treePointToNodeIndex to get an index for this (level, group).
-    // - Node.startIndexInLevel(level) gives the global index offset for nodes at this level.
-    // - groupShift corrects for groups in this level before the current group.
-    val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift
-
-    /**
-     * Find the node index for the given example.
-     * Nodes are indexed from 0 at the start of this (level, group).
-     * If the example does not reach this level, returns a value < 0.
-     */
-    def treePointToNodeIndex(treePoint: TreePoint): Int = {
-      if (level == 0) {
-        0
-      } else {
-        val globalNodeIndex =
-          predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
-        globalNodeIndex - globalNodeIndexOffset
-      }
-    }
-
     /**
      * Performs a sequential aggregation over a partition.
      *
@@ -626,21 +492,27 @@ object DecisionTree extends Serializable with Logging {
      *
      * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
      *             each (node, feature, bin).
-     * @param treePoint   Data point being aggregated.
+     * @param baggedPoint   Data point being aggregated.
      * @return  agg
      */
     def binSeqOp(
         agg: DTStatsAggregator,
-        treePoint: TreePoint): DTStatsAggregator = {
-      val nodeIndex = treePointToNodeIndex(treePoint)
-      // If the example does not reach this level, then nodeIndex < 0.
-      // If the example reaches this level but is handled in a different group,
-      //  then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group).
-      if (nodeIndex >= 0 && nodeIndex < numNodes) {
-        if (metadata.unorderedFeatures.isEmpty) {
-          orderedBinSeqOp(agg, treePoint, nodeIndex)
-        } else {
-          mixedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures)
+        baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = {
+      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+        val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
+          bins, metadata.unorderedFeatures)
+        val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null)
+        // If the example does not reach a node in this group, then nodeIndex = null.
+        if (nodeInfo != null) {
+          val aggNodeIndex = nodeInfo.nodeIndexInGroup
+          val featuresForNode = nodeInfo.featureSubset
+          val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+          if (metadata.unorderedFeatures.isEmpty) {
+            orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode)
+          } else {
+            mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures,
+              instanceWeight, featuresForNode)
+          }
         }
       }
       agg
@@ -649,71 +521,62 @@ object DecisionTree extends Serializable with Logging {
     // Calculate bin aggregates.
     timer.start("aggregation")
     val binAggregates: DTStatsAggregator = {
-      val initAgg = new DTStatsAggregator(metadata, numNodes)
+      val initAgg = if (metadata.subsamplingFeatures) {
+        new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo)
+      } else {
+        new DTStatsAggregatorFixedFeatures(metadata, numNodes)
+      }
       input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
     }
     timer.stop("aggregation")
 
-    // Calculate best splits for all nodes at a given level
+    // Calculate best splits for all nodes in the group
     timer.start("chooseSplits")
-    // On the first iteration, we need to get and return the newly created root node.
-    var newTopNode: Node = topNode
-
-    // Iterate over all nodes at this level
-    var nodeIndex = 0
-    var internalNodeCount = 0
-    while (nodeIndex < numNodes) {
-      val (split: Split, stats: InformationGainStats, predict: Predict) =
-        binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
-      logDebug("best split = " + split)
-
-      val globalNodeIndex = globalNodeIndexOffset + nodeIndex
 
-      // Extract info for this node at the current level.
-      val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth)
-      val node =
-        new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats))
-      logDebug("Node = " + node)
-
-      if (!isLeaf) {
-        internalNodeCount += 1
-      }
-      if (level == 0) {
-        newTopNode = node
-      } else {
-        // Set parent.
-        val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode)
-        if (Node.isLeftChild(globalNodeIndex)) {
-          parentNode.leftNode = Some(node)
-        } else {
-          parentNode.rightNode = Some(node)
+    // Iterate over all nodes in this group.
+    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+      nodesForTree.foreach { node =>
+        val nodeIndex = node.id
+        val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
+        val aggNodeIndex = nodeInfo.nodeIndexInGroup
+        val featuresForNode = nodeInfo.featureSubset
+        val (split: Split, stats: InformationGainStats, predict: Predict) =
+          binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode)
+        logDebug("best split = " + split)
+
+        // Extract info for this node.  Create children if not leaf.
+        val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
+        assert(node.id == nodeIndex)
+        node.predict = predict.predict
+        node.isLeaf = isLeaf
+        node.stats = Some(stats)
+        logDebug("Node = " + node)
+
+        if (!isLeaf) {
+          node.split = Some(split)
+          node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex)))
+          node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex)))
+          nodeQueue.enqueue((treeIndex, node.leftNode.get))
+          nodeQueue.enqueue((treeIndex, node.rightNode.get))
+          logDebug("leftChildIndex = " + node.leftNode.get.id +
+            ", impurity = " + stats.leftImpurity)
+          logDebug("rightChildIndex = " + node.rightNode.get.id +
+            ", impurity = " + stats.rightImpurity)
         }
       }
-      if (level < metadata.maxDepth) {
-        logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) +
-          ", impurity = " + stats.leftImpurity)
-        logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) +
-          ", impurity = " + stats.rightImpurity)
-      }
-
-      nodeIndex += 1
     }
     timer.stop("chooseSplits")
-
-    val doneTraining = internalNodeCount == 0
-    (newTopNode, doneTraining)
   }
 
   /**
    * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
    * @param leftImpurityCalculator left node aggregates for this (feature, split)
    * @param rightImpurityCalculator right node aggregate for this (feature, split)
-   * @return information gain and statistics for all splits
+   * @return information gain and statistics for split
    */
   private def calculateGainForSplit(
       leftImpurityCalculator: ImpurityCalculator,
       rightImpurityCalculator: ImpurityCalculator,
-      level: Int,
       metadata: DecisionTreeMetadata): InformationGainStats = {
     val leftCount = leftImpurityCalculator.count
     val rightCount = rightImpurityCalculator.count
@@ -753,7 +616,7 @@ object DecisionTree extends Serializable with Logging {
    * Calculate predict value for current node, given stats of any split.
    * Note that this function is called only once for each node.
    * @param leftImpurityCalculator left node aggregates for a split
-   * @param rightImpurityCalculator right node aggregates for a node
+   * @param rightImpurityCalculator right node aggregates for a split
    * @return predict value for current node
    */
   private def calculatePredict(
@@ -770,27 +633,33 @@ object DecisionTree extends Serializable with Logging {
   /**
    * Find the best split for a node.
    * @param binAggregates Bin statistics.
-   * @param nodeIndex Index for node to split in this (level, group).
-   * @return tuple for best split: (Split, information gain)
+   * @param nodeIndex Index into aggregates for node to split in this group.
+   * @return tuple for best split: (Split, information gain, prediction at node)
    */
   private def binsToBestSplit(
       binAggregates: DTStatsAggregator,
       nodeIndex: Int,
-      level: Int,
-      metadata: DecisionTreeMetadata,
-      splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
+      splits: Array[Array[Split]],
+      featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
+
+    val metadata: DecisionTreeMetadata = binAggregates.metadata
 
     // calculate predict only once
     var predict: Option[Predict] = None
 
     // For each (feature, split), calculate the gain, and select the best (feature, split).
-    val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex =>
+    val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx =>
+      val featureIndex = if (featuresForNode.nonEmpty) {
+        featuresForNode.get.apply(featureIndexIdx)
+      } else {
+        featureIndexIdx
+      }
       val numSplits = metadata.numSplits(featureIndex)
       if (metadata.isContinuous(featureIndex)) {
         // Cumulative sum (scanLeft) of bin statistics.
         // Afterwards, binAggregates for a bin is the sum of aggregates for
         // that bin + all preceding bins.
-        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
+        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
         var splitIndex = 0
         while (splitIndex < numSplits) {
           binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
@@ -803,26 +672,26 @@ object DecisionTree extends Serializable with Logging {
             val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
             (splitIdx, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
       } else if (metadata.isUnordered(featureIndex)) {
         // Unordered categorical feature
         val (leftChildOffset, rightChildOffset) =
-          binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+          binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
         val (bestFeatureSplitIndex, bestFeatureGainStats) =
           Range(0, numSplits).map { splitIndex =>
             val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
             val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
       } else {
         // Ordered categorical feature
-        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
+        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
         val numBins = metadata.numBins(featureIndex)
 
         /* Each bin is one category (feature value).
@@ -887,7 +756,7 @@ object DecisionTree extends Serializable with Logging {
               binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         val categoriesForSplit =
@@ -904,18 +773,6 @@ object DecisionTree extends Serializable with Logging {
   }
 
   /**
-   * Get the number of values to be stored per node in the bin aggregates.
-   */
-  private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = {
-    val totalBins = metadata.numBins.map(_.toLong).sum
-    if (metadata.isClassification) {
-      metadata.numClasses * totalBins
-    } else {
-      3 * totalBins
-    }
-  }
-
-  /**
    * Returns splits and bins for decision tree calculation.
    * Continuous and categorical features are handled differently.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
new file mode 100644
index 0000000..7fa7725
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
+import org.apache.spark.mllib.tree.impurity.Impurities
+import org.apache.spark.mllib.tree.model._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+/**
+ * :: Experimental ::
+ * A class which implements a random forest learning algorithm for classification and regression.
+ * It supports both continuous and categorical features.
+ *
+ * The settings for featureSubsetStrategy are based on the following references:
+ *  - log2: tested in Breiman (2001)
+ *  - sqrt: recommended by Breiman manual for random forests
+ *  - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
+ *    package.
+ * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf  Breiman (2001)]]
+ * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf  Breiman manual for
+ *     random forests]]
+ *
+ * @param strategy The configuration parameters for the random forest algorithm which specify
+ *                 the type of algorithm (classification, regression, etc.), feature type
+ *                 (continuous, categorical), depth of the tree, quantile calculation strategy,
+ *                 etc.
+ * @param numTrees If 1, then no bootstrapping is used.  If > 1, then bootstrapping is done.
+ * @param featureSubsetStrategy Number of features to consider for splits at each node.
+ *                              Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ *                              If "auto" is set, this parameter is set based on numTrees:
+ *                                if numTrees == 1, set to "all";
+ *                                if numTrees > 1 (forest) set to "sqrt" for classification and
+ *                                  to "onethird" for regression.
+ * @param seed  Random seed for bootstrapping and choosing feature subsets.
+ */
+@Experimental
+private class RandomForest (
+    private val strategy: Strategy,
+    private val numTrees: Int,
+    featureSubsetStrategy: String,
+    private val seed: Int)
+  extends Serializable with Logging {
+
+  strategy.assertValid()
+  require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
+  require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
+    s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
+    s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
+
+  /**
+   * Method to train a decision tree model over an RDD
+   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+   * @return RandomForestModel that can be used for prediction
+   */
+  def train(input: RDD[LabeledPoint]): RandomForestModel = {
+
+    val timer = new TimeTracker()
+
+    timer.start("total")
+
+    timer.start("init")
+
+    val retaggedInput = input.retag(classOf[LabeledPoint])
+    val metadata =
+      DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
+    logDebug("algo = " + strategy.algo)
+    logDebug("numTrees = " + numTrees)
+    logDebug("seed = " + seed)
+    logDebug("maxBins = " + metadata.maxBins)
+    logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
+    logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
+
+    // Find the splits and the corresponding bins (interval between the splits) using a sample
+    // of the input data.
+    timer.start("findSplitsBins")
+    val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
+    timer.stop("findSplitsBins")
+    logDebug("numBins: feature: number of bins")
+    logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+        s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+      }.mkString("\n"))
+
+    // Bin feature values (TreePoint representation).
+    // Cache input RDD for speedup during multiple passes.
+    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
+    val baggedInput = if (numTrees > 1) {
+      BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed)
+    } else {
+      BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+    }.persist(StorageLevel.MEMORY_AND_DISK)
+
+    // depth of the decision tree
+    val maxDepth = strategy.maxDepth
+    require(maxDepth <= 30,
+      s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+
+    // Max memory usage for aggregates
+    // TODO: Calculate memory usage more precisely.
+    val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
+    logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
+    val maxMemoryPerNode = {
+      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+        // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
+        Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
+          .take(metadata.numFeaturesPerNode).map(_._2))
+      } else {
+        None
+      }
+      RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
+    }
+    require(maxMemoryPerNode <= maxMemoryUsage,
+      s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
+      " which is too small for the given features." +
+      s"  Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
+
+    timer.stop("init")
+
+    /*
+     * The main idea here is to perform group-wise training of the decision tree nodes thus
+     * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
+     * Each data sample is handled by a particular node (or it reaches a leaf and is not used
+     * in lower levels).
+     */
+
+    // FIFO queue of nodes to train: (treeIndex, node)
+    val nodeQueue = new mutable.Queue[(Int, Node)]()
+
+    val rng = new scala.util.Random()
+    rng.setSeed(seed)
+
+    // Allocate and queue root nodes.
+    val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
+    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
+
+    while (nodeQueue.nonEmpty) {
+      // Collect some nodes to split, and choose features for each node (if subsampling).
+      // Each group of nodes may come from one or multiple trees, and at multiple levels.
+      val (nodesForGroup, treeToNodeToIndexInfo) =
+        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+      // Sanity check (should never occur):
+      assert(nodesForGroup.size > 0,
+        s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")
+
+      // Choose node splits, and enqueue new nodes as needed.
+      timer.start("findBestSplits")
+      DecisionTree.findBestSplits(baggedInput,
+        metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
+      timer.stop("findBestSplits")
+    }
+
+    timer.stop("total")
+
+    logInfo("Internal timing for DecisionTree:")
+    logInfo(s"$timer")
+
+    val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
+    RandomForestModel.build(trees)
+  }
+
+}
+
+object RandomForest extends Serializable with Logging {
+
+  /**
+   * Method to train a decision tree model for binary or multiclass classification.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              Labels should take values {0, 1, ..., numClasses-1}.
+   * @param strategy Parameters for training each tree in the forest.
+   * @param numTrees Number of trees in the random forest.
+   * @param featureSubsetStrategy Number of features to consider for splits at each node.
+   *                              Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+   *                              If "auto" is set, this parameter is set based on numTrees:
+   *                                if numTrees == 1, set to "all";
+   *                                if numTrees > 1 (forest) set to "sqrt" for classification and
+   *                                  to "onethird" for regression.
+   * @param seed  Random seed for bootstrapping and choosing feature subsets.
+   * @return RandomForestModel that can be used for prediction
+   */
+  def trainClassifier(
+      input: RDD[LabeledPoint],
+      strategy: Strategy,
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      seed: Int): RandomForestModel = {
+    require(strategy.algo == Classification,
+      s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
+    val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
+    rf.train(input)
+  }
+
+  /**
+   * Method to train a decision tree model for binary or multiclass classification.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              Labels should take values {0, 1, ..., numClasses-1}.
+   * @param numClassesForClassification number of classes for classification.
+   * @param categoricalFeaturesInfo Map storing arity of categorical features.
+   *                                E.g., an entry (n -> k) indicates that feature n is categorical
+   *                                with k categories indexed from 0: {0, 1, ..., k-1}.
+   * @param numTrees Number of trees in the random forest.
+   * @param featureSubsetStrategy Number of features to consider for splits at each node.
+   *                              Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+   *                              If "auto" is set, this parameter is set based on numTrees:
+   *                                if numTrees == 1, set to "all";
+   *                                if numTrees > 1 (forest) set to "sqrt" for classification and
+   *                                  to "onethird" for regression.
+   * @param impurity Criterion used for information gain calculation.
+   *                 Supported values: "gini" (recommended) or "entropy".
+   * @param maxDepth Maximum depth of the tree.
+   *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+   *                  (suggested value: 4)
+   * @param maxBins maximum number of bins used for splitting features
+   *                 (suggested value: 100)
+   * @param seed  Random seed for bootstrapping and choosing feature subsets.
+   * @return RandomForestModel that can be used for prediction
+   */
+  def trainClassifier(
+      input: RDD[LabeledPoint],
+      numClassesForClassification: Int,
+      categoricalFeaturesInfo: Map[Int, Int],
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int,
+      seed: Int = Utils.random.nextInt()): RandomForestModel = {
+    val impurityType = Impurities.fromString(impurity)
+    val strategy = new Strategy(Classification, impurityType, maxDepth,
+      numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
+    trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed)
+  }
+
+  /**
+   * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainClassifier]]
+   */
+  def trainClassifier(
+      input: JavaRDD[LabeledPoint],
+      numClassesForClassification: Int,
+      categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int,
+      seed: Int): RandomForestModel = {
+    trainClassifier(input.rdd, numClassesForClassification,
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+      numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
+  }
+
+  /**
+   * Method to train a decision tree model for regression.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              Labels are real numbers.
+   * @param strategy Parameters for training each tree in the forest.
+   * @param numTrees Number of trees in the random forest.
+   * @param featureSubsetStrategy Number of features to consider for splits at each node.
+   *                              Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+   *                              If "auto" is set, this parameter is set based on numTrees:
+   *                                if numTrees == 1, set to "all";
+   *                                if numTrees > 1 (forest) set to "sqrt" for classification and
+   *                                  to "onethird" for regression.
+   * @param seed  Random seed for bootstrapping and choosing feature subsets.
+   * @return RandomForestModel that can be used for prediction
+   */
+  def trainRegressor(
+      input: RDD[LabeledPoint],
+      strategy: Strategy,
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      seed: Int): RandomForestModel = {
+    require(strategy.algo == Regression,
+      s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
+    val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
+    rf.train(input)
+  }
+
+  /**
+   * Method to train a decision tree model for regression.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              Labels are real numbers.
+   * @param categoricalFeaturesInfo Map storing arity of categorical features.
+   *                                E.g., an entry (n -> k) indicates that feature n is categorical
+   *                                with k categories indexed from 0: {0, 1, ..., k-1}.
+   * @param numTrees Number of trees in the random forest.
+   * @param featureSubsetStrategy Number of features to consider for splits at each node.
+   *                              Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+   *                              If "auto" is set, this parameter is set based on numTrees:
+   *                                if numTrees == 1, set to "all";
+   *                                if numTrees > 1 (forest) set to "sqrt" for classification and
+   *                                  to "onethird" for regression.
+   * @param impurity Criterion used for information gain calculation.
+   *                 Supported values: "variance".
+   * @param maxDepth Maximum depth of the tree.
+   *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+   *                  (suggested value: 4)
+   * @param maxBins maximum number of bins used for splitting features
+   *                 (suggested value: 100)
+   * @param seed  Random seed for bootstrapping and choosing feature subsets.
+   * @return RandomForestModel that can be used for prediction
+   */
+  def trainRegressor(
+      input: RDD[LabeledPoint],
+      categoricalFeaturesInfo: Map[Int, Int],
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int,
+      seed: Int = Utils.random.nextInt()): RandomForestModel = {
+    val impurityType = Impurities.fromString(impurity)
+    val strategy = new Strategy(Regression, impurityType, maxDepth,
+      0, maxBins, Sort, categoricalFeaturesInfo)
+    trainRegressor(input, strategy, numTrees, featureSubsetStrategy, seed)
+  }
+
+  /**
+   * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainRegressor]]
+   */
+  def trainRegressor(
+      input: JavaRDD[LabeledPoint],
+      categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int,
+      seed: Int): RandomForestModel = {
+    trainRegressor(input.rdd,
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+      numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
+  }
+
+  /**
+   * List of supported feature subset sampling strategies.
+   */
+  val supportedFeatureSubsetStrategies: Array[String] =
+    Array("auto", "all", "sqrt", "log2", "onethird")
+
+  private[tree] class NodeIndexInfo(
+      val nodeIndexInGroup: Int,
+      val featureSubset: Option[Array[Int]]) extends Serializable
+
+  /**
+   * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
+   * This tracks the memory usage for aggregates and stops adding nodes when too much memory
+   * will be needed; this allows an adaptive number of nodes since different nodes may require
+   * different amounts of memory (if featureSubsetStrategy is not "all").
+   *
+   * @param nodeQueue  Queue of nodes to split.
+   * @param maxMemoryUsage  Bound on size of aggregate statistics.
+   * @return  (nodesForGroup, treeToNodeToIndexInfo).
+   *          nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+   *          treeToNodeToIndexInfo holds indices selected features for each node:
+   *            treeIndex --> (global) node index --> (node index in group, feature indices).
+   *          The (global) node index is the index in the tree; the node index in group is the
+   *           index in [0, numNodesInGroup) of the node in this group.
+   *          The feature indices are None if not subsampling features.
+   */
+  private[tree] def selectNodesToSplit(
+      nodeQueue: mutable.Queue[(Int, Node)],
+      maxMemoryUsage: Long,
+      metadata: DecisionTreeMetadata,
+      rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = {
+    // Collect some nodes to split:
+    //  nodesForGroup(treeIndex) = nodes to split
+    val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]()
+    val mutableTreeToNodeToIndexInfo =
+      new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
+    var memUsage: Long = 0L
+    var numNodesInGroup = 0
+    while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
+      val (treeIndex, node) = nodeQueue.head
+      // Choose subset of features for node (if subsampling).
+      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+        // TODO: Use more efficient subsampling?  (use selection-and-rejection or reservoir)
+        Some(rng.shuffle(Range(0, metadata.numFeatures).toList)
+          .take(metadata.numFeaturesPerNode).toArray)
+      } else {
+        None
+      }
+      // Check if enough memory remains to add this node to the group.
+      val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
+      if (memUsage + nodeMemUsage <= maxMemoryUsage) {
+        nodeQueue.dequeue()
+        mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node
+        mutableTreeToNodeToIndexInfo
+          .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
+          = new NodeIndexInfo(numNodesInGroup, featureSubset)
+      }
+      numNodesInGroup += 1
+      memUsage += nodeMemUsage
+    }
+    // Convert mutable maps to immutable ones.
+    val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap
+    val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
+    (nodesForGroup, treeToNodeToIndexInfo)
+  }
+
+  /**
+   * Get the number of values to be stored for this node in the bin aggregates.
+   * @param featureSubset  Indices of features which may be split at this node.
+   *                       If None, then use all features.
+   */
+  private[tree] def aggregateSizeForNode(
+      metadata: DecisionTreeMetadata,
+      featureSubset: Option[Array[Int]]): Long = {
+    val totalBins = if (featureSubset.nonEmpty) {
+      featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
+    } else {
+      metadata.numBins.map(_.toLong).sum
+    }
+    if (metadata.isClassification) {
+      metadata.numClasses * totalBins
+    } else {
+      3 * totalBins
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
new file mode 100644
index 0000000..937c8a2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+/**
+ * Internal representation of a datapoint which belongs to several subsamples of the same dataset,
+ * particularly for bagging (e.g., for random forests).
+ *
+ * This holds one instance, as well as an array of weights which represent the (weighted)
+ * number of times which this instance appears in each subsample.
+ * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that
+ * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
+ *
+ * @param datum  Data instance
+ * @param subsampleWeights  Weight of this instance in each subsampled dataset.
+ *
+ * TODO: This does not currently support (Double) weighted instances.  Once MLlib has weighted
+ *       dataset support, update.  (We store subsampleWeights as Double for this future extension.)
+ */
+private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
+  extends Serializable
+
+private[tree] object BaggedPoint {
+
+  /**
+   * Convert an input dataset into its BaggedPoint representation,
+   * choosing subsample counts for each instance.
+   * Each subsample has the same number of instances as the original dataset,
+   * and is created by subsampling with replacement.
+   * @param input     Input dataset.
+   * @param numSubsamples  Number of subsamples of this RDD to take.
+   * @param seed   Random seed.
+   * @return  BaggedPoint dataset representation
+   */
+  def convertToBaggedRDD[Datum](
+      input: RDD[Datum],
+      numSubsamples: Int,
+      seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
+    input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+      // TODO: Support different sampling rates, and sampling without replacement.
+      // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
+      val poisson = new Poisson(1.0, new DRand(seed + partitionIndex + 1))
+      instances.map { instance =>
+        val subsampleWeights = new Array[Double](numSubsamples)
+        var subsampleIndex = 0
+        while (subsampleIndex < numSubsamples) {
+          subsampleWeights(subsampleIndex) = poisson.nextInt()
+          subsampleIndex += 1
+        }
+        new BaggedPoint(instance, subsampleWeights)
+      }
+    }
+  }
+
+  def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
+    input.map(datum => new BaggedPoint(datum, Array(1.0)))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 61a9424..d49df7a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -17,16 +17,17 @@
 
 package org.apache.spark.mllib.tree.impl
 
+import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
 import org.apache.spark.mllib.tree.impurity._
 
 /**
  * DecisionTree statistics aggregator.
  * This holds a flat array of statistics for a set of (nodes, features, bins)
  * and helps with indexing.
+ * This class is abstract to support learning with and without feature subsampling.
  */
-private[tree] class DTStatsAggregator(
-    val metadata: DecisionTreeMetadata,
-    val numNodes: Int) extends Serializable {
+private[tree] abstract class DTStatsAggregator(
+    val metadata: DecisionTreeMetadata) extends Serializable {
 
   /**
    * [[ImpurityAggregator]] instance specifying the impurity type.
@@ -43,18 +44,6 @@ private[tree] class DTStatsAggregator(
    */
   val statsSize: Int = impurityAggregator.statsSize
 
-  val numFeatures: Int = metadata.numFeatures
-
-  /**
-   * Number of bins for each feature.  This is indexed by the feature index.
-   */
-  val numBins: Array[Int] = metadata.numBins
-
-  /**
-   * Number of splits for the given feature.
-   */
-  def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex)
-
   /**
    * Indicator for each feature of whether that feature is an unordered feature.
    * TODO: Is Array[Boolean] any faster?
@@ -62,30 +51,14 @@ private[tree] class DTStatsAggregator(
   def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
 
   /**
-   * Offset for each feature for calculating indices into the [[allStats]] array.
-   */
-  private val featureOffsets: Array[Int] = {
-    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
-  }
-
-  /**
-   * Number of elements for each node, corresponding to stride between nodes in [[allStats]].
-   */
-  private val nodeStride: Int = featureOffsets.last
-
-  /**
    * Total number of elements stored in this aggregator.
    */
-  val allStatsSize: Int = numNodes * nodeStride
+  def allStatsSize: Int
 
   /**
-   * Flat array of elements.
-   * Index for start of stats for a (node, feature, bin) is:
-   *   index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
-   * Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex))
-   *       and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex))
+   * Get flat array of elements stored in this aggregator.
    */
-  val allStats: Array[Double] = new Array[Double](allStatsSize)
+  protected def allStats: Array[Double]
 
   /**
    * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
@@ -102,36 +75,39 @@ private[tree] class DTStatsAggregator(
   /**
    * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
    */
-  def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
-    val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
-    impurityAggregator.update(allStats, i, label)
+  def update(
+      nodeIndex: Int,
+      featureIndex: Int,
+      binIndex: Int,
+      label: Double,
+      instanceWeight: Double): Unit = {
+    val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize
+    impurityAggregator.update(allStats, i, label, instanceWeight)
   }
 
   /**
    * Pre-compute node offset for use with [[nodeUpdate]].
    */
-  def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
+  def getNodeOffset(nodeIndex: Int): Int
 
   /**
    * Faster version of [[update]].
    * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
    * @param nodeOffset  Pre-computed node offset from [[getNodeOffset]].
    */
-  def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
-    val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
-    impurityAggregator.update(allStats, i, label)
-  }
+  def nodeUpdate(
+      nodeOffset: Int,
+      nodeIndex: Int,
+      featureIndex: Int,
+      binIndex: Int,
+      label: Double,
+      instanceWeight: Double): Unit
 
   /**
    * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
    * For ordered features only.
    */
-  def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
-    require(!isUnordered(featureIndex),
-      s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" +
-      s" for unordered feature $featureIndex.")
-    nodeIndex * nodeStride + featureOffsets(featureIndex)
-  }
+  def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int
 
   /**
    * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
@@ -140,9 +116,9 @@ private[tree] class DTStatsAggregator(
   def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
     require(isUnordered(featureIndex),
       s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
-      s" but was called for ordered feature $featureIndex.")
-    val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
-    (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
+        s" but was called for ordered feature $featureIndex.")
+    val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex)
+    (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize)
   }
 
   /**
@@ -154,8 +130,13 @@ private[tree] class DTStatsAggregator(
    *                           (node, feature, left/right child) offset from
    *                           [[getLeftRightNodeFeatureOffsets]].
    */
-  def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = {
-    impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label)
+  def nodeFeatureUpdate(
+      nodeFeatureOffset: Int,
+      binIndex: Int,
+      label: Double,
+      instanceWeight: Double): Unit = {
+    impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label,
+      instanceWeight)
   }
 
   /**
@@ -189,7 +170,139 @@ private[tree] class DTStatsAggregator(
     }
     this
   }
+}
+
+/**
+ * DecisionTree statistics aggregator.
+ * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * and helps with indexing.
+ *
+ * This instance of [[DTStatsAggregator]] is used when not subsampling features.
+ *
+ * @param numNodes  Number of nodes to collect statistics for.
+ */
+private[tree] class DTStatsAggregatorFixedFeatures(
+    metadata: DecisionTreeMetadata,
+    numNodes: Int) extends DTStatsAggregator(metadata) {
+
+  /**
+   * Offset for each feature for calculating indices into the [[allStats]] array.
+   * Mapping: featureIndex --> offset
+   */
+  private val featureOffsets: Array[Int] = {
+    metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
+  }
+
+  /**
+   * Number of elements for each node, corresponding to stride between nodes in [[allStats]].
+   */
+  private val nodeStride: Int = featureOffsets.last
 
+  override val allStatsSize: Int = numNodes * nodeStride
+
+  /**
+   * Flat array of elements.
+   * Index for start of stats for a (node, feature, bin) is:
+   *   index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
+   * Note: For unordered features, the left child stats precede the right child stats
+   *       in the binIndex order.
+   */
+  override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
+
+  override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
+
+  override def nodeUpdate(
+      nodeOffset: Int,
+      nodeIndex: Int,
+      featureIndex: Int,
+      binIndex: Int,
+      label: Double,
+      instanceWeight: Double): Unit = {
+    val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
+    impurityAggregator.update(allStats, i, label, instanceWeight)
+  }
+
+  override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
+    nodeIndex * nodeStride + featureOffsets(featureIndex)
+  }
+}
+
+/**
+ * DecisionTree statistics aggregator.
+ * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * and helps with indexing.
+ *
+ * This instance of [[DTStatsAggregator]] is used when subsampling features.
+ *
+ * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
+ *                              where nodeIndexInfo stores the index in the group and the
+ *                              feature subsets (if using feature subsets).
+ */
+private[tree] class DTStatsAggregatorSubsampledFeatures(
+    metadata: DecisionTreeMetadata,
+    treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) {
+
+  /**
+   * For each node, offset for each feature for calculating indices into the [[allStats]] array.
+   * Mapping: nodeIndex --> featureIndex --> offset
+   */
+  private val featureOffsets: Array[Array[Int]] = {
+    val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum
+    val offsets = new Array[Array[Int]](numNodes)
+    treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) =>
+      nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) =>
+        offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_))
+          .scanLeft(0)((total, nBins) => total + statsSize * nBins)
+      }
+    }
+    offsets
+  }
+
+  /**
+   * For each node, offset for each feature for calculating indices into the [[allStats]] array.
+   */
+  protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _)
+
+  override val allStatsSize: Int = nodeOffsets.last
+
+  /**
+   * Flat array of elements.
+   * Index for start of stats for a (node, feature, bin) is:
+   *   index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize
+   * Note: For unordered features, the left child stats precede the right child stats
+   *       in the binIndex order.
+   */
+  override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
+
+  override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex)
+
+  /**
+   * Faster version of [[update]].
+   * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+   * @param nodeOffset  Pre-computed node offset from [[getNodeOffset]].
+   * @param featureIndex  Index of feature in featuresForNodes(nodeIndex).
+   *                      Note: This is NOT the original feature index.
+   */
+  override def nodeUpdate(
+      nodeOffset: Int,
+      nodeIndex: Int,
+      featureIndex: Int,
+      binIndex: Int,
+      label: Double,
+      instanceWeight: Double): Unit = {
+    val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize
+    impurityAggregator.update(allStats, i, label, instanceWeight)
+  }
+
+  /**
+   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+   * For ordered features only.
+   * @param featureIndex  Index of feature in featuresForNodes(nodeIndex).
+   *                      Note: This is NOT the original feature index.
+   */
+  override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
+    nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex)
+  }
 }
 
 private[tree] object DTStatsAggregator extends Serializable {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org