You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/09/23 05:27:36 UTC

spark git commit: [SPARK-16719][ML] Random Forests should communicate fewer trees on each iteration

Repository: spark
Updated Branches:
  refs/heads/master a4aeb7677 -> 947b8c6e3


[SPARK-16719][ML] Random Forests should communicate fewer trees on each iteration

## What changes were proposed in this pull request?

RandomForest currently sends the entire forest to each worker on each iteration. This is because (a) the node queue is FIFO and (b) the closure references the entire array of trees (topNodes). (a) causes RFs to handle splits in many trees, especially early on in learning. (b) sends all trees explicitly.

This PR:
(a) Change the RF node queue to be FILO (a stack), so that RFs tend to focus on 1 or a few trees before focusing on others.
(b) Change topNodes to pass only the trees required on that iteration.

## How was this patch tested?

Unit tests:
* Existing tests for correctness of tree learning
* Manually modifying code and running tests to verify that a small number of trees are communicated on each iteration
  * This last item is hard to test via unit tests given the current APIs.

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #14359 from jkbradley/rfs-fewer-trees.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/947b8c6e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/947b8c6e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/947b8c6e

Branch: refs/heads/master
Commit: 947b8c6e3acd671d501f0ed6c077aac8e51ccede
Parents: a4aeb76
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Thu Sep 22 22:27:28 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Sep 22 22:27:28 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/tree/impl/RandomForest.scala       | 54 ++++++++++++--------
 .../spark/ml/tree/impl/RandomForestSuite.scala  | 26 +++++-----
 2 files changed, 46 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/947b8c6e/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 71c8c42..0b7ad92 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -51,7 +51,7 @@ import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
  * findSplits() method during initialization, after which each continuous feature becomes
  * an ordered discretized feature with at most maxBins possible values.
  *
- * The main loop in the algorithm operates on a queue of nodes (nodeQueue).  These nodes
+ * The main loop in the algorithm operates on a queue of nodes (nodeStack).  These nodes
  * lie at the periphery of the tree being trained.  If multiple trees are being trained at once,
  * then this queue contains nodes from all of them.  Each iteration works roughly as follows:
  *   On the master node:
@@ -161,31 +161,42 @@ private[spark] object RandomForest extends Logging {
       None
     }
 
-    // FIFO queue of nodes to train: (treeIndex, node)
-    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+    /*
+      Stack of nodes to train: (treeIndex, node)
+      The reason this is a stack is that we train many trees at once, but we want to focus on
+      completing trees, rather than training all simultaneously.  If we are splitting nodes from
+      1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
+      training the same tree in the next iteration.  This focus allows us to send fewer trees to
+      workers on each iteration; see topNodesForGroup below.
+     */
+    val nodeStack = new mutable.Stack[(Int, LearningNode)]
 
     val rng = new Random()
     rng.setSeed(seed)
 
     // Allocate and queue root nodes.
     val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
-    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
+    Range(0, numTrees).foreach(treeIndex => nodeStack.push((treeIndex, topNodes(treeIndex))))
 
     timer.stop("init")
 
-    while (nodeQueue.nonEmpty) {
+    while (nodeStack.nonEmpty) {
       // Collect some nodes to split, and choose features for each node (if subsampling).
       // Each group of nodes may come from one or multiple trees, and at multiple levels.
       val (nodesForGroup, treeToNodeToIndexInfo) =
-        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+        RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
       // Sanity check (should never occur):
       assert(nodesForGroup.nonEmpty,
         s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")
 
+      // Only send trees to worker if they contain nodes being split this iteration.
+      val topNodesForGroup: Map[Int, LearningNode] =
+        nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap
+
       // Choose node splits, and enqueue new nodes as needed.
       timer.start("findBestSplits")
-      RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
-        treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
+      RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup,
+        treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache)
       timer.stop("findBestSplits")
     }
 
@@ -334,13 +345,14 @@ private[spark] object RandomForest extends Logging {
    *
    * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]]
    * @param metadata Learning and dataset metadata
-   * @param topNodes Root node for each tree.  Used for matching instances with nodes.
+   * @param topNodesForGroup For each tree in group, tree index -> root node.
+   *                         Used for matching instances with nodes.
    * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
    * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
    *                              where nodeIndexInfo stores the index in the group and the
    *                              feature subsets (if using feature subsets).
    * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
-   * @param nodeQueue  Queue of nodes to split, with values (treeIndex, node).
+   * @param nodeStack  Queue of nodes to split, with values (treeIndex, node).
    *                   Updated with new non-leaf nodes which are created.
    * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
    *                    each value in the array is the data point's node Id
@@ -351,11 +363,11 @@ private[spark] object RandomForest extends Logging {
   private[tree] def findBestSplits(
       input: RDD[BaggedPoint[TreePoint]],
       metadata: DecisionTreeMetadata,
-      topNodes: Array[LearningNode],
+      topNodesForGroup: Map[Int, LearningNode],
       nodesForGroup: Map[Int, Array[LearningNode]],
       treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
       splits: Array[Array[Split]],
-      nodeQueue: mutable.Queue[(Int, LearningNode)],
+      nodeStack: mutable.Stack[(Int, LearningNode)],
       timer: TimeTracker = new TimeTracker,
       nodeIdCache: Option[NodeIdCache] = None): Unit = {
 
@@ -437,7 +449,8 @@ private[spark] object RandomForest extends Logging {
         agg: Array[DTStatsAggregator],
         baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
       treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
-        val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
+        val nodeIndex =
+          topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
         nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
       }
       agg
@@ -593,10 +606,10 @@ private[spark] object RandomForest extends Logging {
 
           // enqueue left child and right child if they are not leaves
           if (!leftChildIsLeaf) {
-            nodeQueue.enqueue((treeIndex, node.leftChild.get))
+            nodeStack.push((treeIndex, node.leftChild.get))
           }
           if (!rightChildIsLeaf) {
-            nodeQueue.enqueue((treeIndex, node.rightChild.get))
+            nodeStack.push((treeIndex, node.rightChild.get))
           }
 
           logDebug("leftChildIndex = " + node.leftChild.get.id +
@@ -1029,7 +1042,7 @@ private[spark] object RandomForest extends Logging {
    * will be needed; this allows an adaptive number of nodes since different nodes may require
    * different amounts of memory (if featureSubsetStrategy is not "all").
    *
-   * @param nodeQueue  Queue of nodes to split.
+   * @param nodeStack  Queue of nodes to split.
    * @param maxMemoryUsage  Bound on size of aggregate statistics.
    * @return  (nodesForGroup, treeToNodeToIndexInfo).
    *          nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
@@ -1041,7 +1054,7 @@ private[spark] object RandomForest extends Logging {
    *          The feature indices are None if not subsampling features.
    */
   private[tree] def selectNodesToSplit(
-      nodeQueue: mutable.Queue[(Int, LearningNode)],
+      nodeStack: mutable.Stack[(Int, LearningNode)],
       maxMemoryUsage: Long,
       metadata: DecisionTreeMetadata,
       rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
@@ -1054,8 +1067,8 @@ private[spark] object RandomForest extends Logging {
     var numNodesInGroup = 0
     // If maxMemoryInMB is set very small, we want to still try to split 1 node,
     // so we allow one iteration if memUsage == 0.
-    while (nodeQueue.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
-      val (treeIndex, node) = nodeQueue.head
+    while (nodeStack.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
+      val (treeIndex, node) = nodeStack.top
       // Choose subset of features for node (if subsampling).
       val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
         Some(SamplingUtils.reservoirSampleAndCount(Range(0,
@@ -1066,7 +1079,7 @@ private[spark] object RandomForest extends Logging {
       // Check if enough memory remains to add this node to the group.
       val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
       if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
-        nodeQueue.dequeue()
+        nodeStack.pop()
         mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
           node
         mutableTreeToNodeToIndexInfo
@@ -1109,5 +1122,4 @@ private[spark] object RandomForest extends Logging {
       3 * totalBins
     }
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/947b8c6e/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index dcc2f30..79b19ea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy,
+  Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.util.collection.OpenHashMap
@@ -239,12 +240,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
     val treeToNodeToIndexInfo = Map((0, Map(
       (topNode.id, new RandomForest.NodeIndexInfo(0, None))
     )))
-    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
-    RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
-      nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
+    val nodeStack = new mutable.Stack[(Int, LearningNode)]
+    RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
+      nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
 
     // don't enqueue leaf nodes into node queue
-    assert(nodeQueue.isEmpty)
+    assert(nodeStack.isEmpty)
 
     // set impurity and predict for topNode
     assert(topNode.stats !== null)
@@ -281,12 +282,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
     val treeToNodeToIndexInfo = Map((0, Map(
       (topNode.id, new RandomForest.NodeIndexInfo(0, None))
     )))
-    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
-    RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
-      nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
+    val nodeStack = new mutable.Stack[(Int, LearningNode)]
+    RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
+      nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
 
     // don't enqueue a node into node queue if its impurity is 0.0
-    assert(nodeQueue.isEmpty)
+    assert(nodeStack.isEmpty)
 
     // set impurity and predict for topNode
     assert(topNode.stats !== null)
@@ -393,16 +394,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
         val failString = s"Failed on test with:" +
           s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
           s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
-        val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+        val nodeStack = new mutable.Stack[(Int, LearningNode)]
         val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees)
         Range(0, numTrees).foreach { treeIndex =>
           topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1)
-          nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
+          nodeStack.push((treeIndex, topNodes(treeIndex)))
         }
         val rng = new scala.util.Random(seed = seed)
         val (nodesForGroup: Map[Int, Array[LearningNode]],
         treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
-          RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+          RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
 
         assert(nodesForGroup.size === numTrees, failString)
         assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree
@@ -546,7 +547,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
     val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
     assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
   }
-
 }
 
 private object RandomForestSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org