You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/02 00:58:35 UTC

git commit: [SPARK-3161][MLLIB] Adding a node Id caching mechanism for training deci...

Repository: spark
Updated Branches:
  refs/heads/master d8176b1c2 -> 56f2c61cd


[SPARK-3161][MLLIB] Adding a node Id caching mechanism for training deci...

...sion trees. jkbradley mengxr chouqin Please review this.

Author: Sung Chung <sc...@alpinenow.com>

Closes #2868 from codedeft/SPARK-3161 and squashes the following commits:

5f5a156 [Sung Chung] [SPARK-3161][MLLIB] Adding a node Id caching mechanism for training decision trees.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56f2c61c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56f2c61c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56f2c61c

Branch: refs/heads/master
Commit: 56f2c61cde3f5d906c2a58e9af1a661222f2c679
Parents: d8176b1
Author: Sung Chung <sc...@alpinenow.com>
Authored: Sat Nov 1 16:58:26 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sat Nov 1 16:58:26 2014 -0700

----------------------------------------------------------------------
 .../examples/mllib/DecisionTreeRunner.scala     |  25 ++-
 .../apache/spark/mllib/tree/DecisionTree.scala  | 114 +++++++++--
 .../apache/spark/mllib/tree/RandomForest.scala  |  22 +-
 .../mllib/tree/configuration/Strategy.scala     |  12 +-
 .../spark/mllib/tree/impl/NodeIdCache.scala     | 204 +++++++++++++++++++
 .../spark/mllib/tree/RandomForestSuite.scala    |  69 +++++--
 6 files changed, 405 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/56f2c61c/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index f987303..49751a3 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -62,7 +62,10 @@ object DecisionTreeRunner {
       minInfoGain: Double = 0.0,
       numTrees: Int = 1,
       featureSubsetStrategy: String = "auto",
-      fracTest: Double = 0.2) extends AbstractParams[Params]
+      fracTest: Double = 0.2,
+      useNodeIdCache: Boolean = false,
+      checkpointDir: Option[String] = None,
+      checkpointInterval: Int = 10) extends AbstractParams[Params]
 
   def main(args: Array[String]) {
     val defaultParams = Params()
@@ -102,6 +105,21 @@ object DecisionTreeRunner {
         .text(s"fraction of data to hold out for testing.  If given option testInput, " +
           s"this option is ignored. default: ${defaultParams.fracTest}")
         .action((x, c) => c.copy(fracTest = x))
+      opt[Boolean]("useNodeIdCache")
+        .text(s"whether to use node Id cache during training, " +
+          s"default: ${defaultParams.useNodeIdCache}")
+        .action((x, c) => c.copy(useNodeIdCache = x))
+      opt[String]("checkpointDir")
+        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+         s"default: ${defaultParams.checkpointDir match {
+           case Some(strVal) => strVal
+           case None => "None"
+         }}")
+        .action((x, c) => c.copy(checkpointDir = Some(x)))
+      opt[Int]("checkpointInterval")
+        .text(s"how often to checkpoint the node Id cache, " +
+         s"default: ${defaultParams.checkpointInterval}")
+        .action((x, c) => c.copy(checkpointInterval = x))
       opt[String]("testInput")
         .text(s"input path to test dataset.  If given, option fracTest is ignored." +
           s" default: ${defaultParams.testInput}")
@@ -236,7 +254,10 @@ object DecisionTreeRunner {
           maxBins = params.maxBins,
           numClassesForClassification = numClasses,
           minInstancesPerNode = params.minInstancesPerNode,
-          minInfoGain = params.minInfoGain)
+          minInfoGain = params.minInfoGain,
+          useNodeIdCache = params.useNodeIdCache,
+          checkpointDir = params.checkpointDir,
+          checkpointInterval = params.checkpointInterval)
     if (params.numTrees == 1) {
       val startTime = System.nanoTime()
       val model = DecisionTree.train(training, strategy)

http://git-wip-us.apache.org/repos/asf/spark/blob/56f2c61c/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 752ed59..78acc17 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -437,6 +437,11 @@ object DecisionTree extends Serializable with Logging {
    * @param bins possible bins for all features, indexed (numFeatures)(numBins)
    * @param nodeQueue  Queue of nodes to split, with values (treeIndex, node).
    *                   Updated with new non-leaf nodes which are created.
+   * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
+   *                    each value in the array is the data point's node Id
+   *                    for a corresponding tree. This is used to prevent the need
+   *                    to pass the entire tree to the executors during
+   *                    the node stat aggregation phase.
    */
   private[tree] def findBestSplits(
       input: RDD[BaggedPoint[TreePoint]],
@@ -447,7 +452,8 @@ object DecisionTree extends Serializable with Logging {
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
       nodeQueue: mutable.Queue[(Int, Node)],
-      timer: TimeTracker = new TimeTracker): Unit = {
+      timer: TimeTracker = new TimeTracker,
+      nodeIdCache: Option[NodeIdCache] = None): Unit = {
 
     /*
      * The high-level descriptions of the best split optimizations are noted here.
@@ -479,6 +485,37 @@ object DecisionTree extends Serializable with Logging {
     logDebug("isMulticlass = " + metadata.isMulticlass)
     logDebug("isMulticlassWithCategoricalFeatures = " +
       metadata.isMulticlassWithCategoricalFeatures)
+    logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
+
+    /**
+     * Performs a sequential aggregation over a partition for a particular tree and node.
+     *
+     * For each feature, the aggregate sufficient statistics are updated for the relevant
+     * bins.
+     *
+     * @param treeIndex Index of the tree that we want to perform aggregation for.
+     * @param nodeInfo The node info for the tree node.
+     * @param agg Array storing aggregate calculation, with a set of sufficient statistics
+     *            for each (node, feature, bin).
+     * @param baggedPoint Data point being aggregated.
+     */
+    def nodeBinSeqOp(
+        treeIndex: Int,
+        nodeInfo: RandomForest.NodeIndexInfo,
+        agg: Array[DTStatsAggregator],
+        baggedPoint: BaggedPoint[TreePoint]): Unit = {
+      if (nodeInfo != null) {
+        val aggNodeIndex = nodeInfo.nodeIndexInGroup
+        val featuresForNode = nodeInfo.featureSubset
+        val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+        if (metadata.unorderedFeatures.isEmpty) {
+          orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
+        } else {
+          mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
+            instanceWeight, featuresForNode)
+        }
+      }
+    }
 
     /**
      * Performs a sequential aggregation over a partition.
@@ -497,20 +534,25 @@ object DecisionTree extends Serializable with Logging {
       treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
         val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
           bins, metadata.unorderedFeatures)
-        val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null)
-        // If the example does not reach a node in this group, then nodeIndex = null.
-        if (nodeInfo != null) {
-          val aggNodeIndex = nodeInfo.nodeIndexInGroup
-          val featuresForNode = nodeInfo.featureSubset
-          val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
-          if (metadata.unorderedFeatures.isEmpty) {
-            orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
-          } else {
-            mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
-              instanceWeight, featuresForNode)
-          }
-        }
+        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
+      }
+
+      agg
+    }
+
+    /**
+     * Do the same thing as binSeqOp, but with nodeIdCache.
+     */
+    def binSeqOpWithNodeIdCache(
+        agg: Array[DTStatsAggregator],
+        dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
+      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+        val baggedPoint = dataPoint._1
+        val nodeIdCache = dataPoint._2
+        val nodeIndex = nodeIdCache(treeIndex)
+        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
       }
+
       agg
     }
 
@@ -553,7 +595,26 @@ object DecisionTree extends Serializable with Logging {
     // Finally, only best Splits for nodes are collected to driver to construct decision tree.
     val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
     val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
-    val nodeToBestSplits =
+
+    val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
+      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
+        // Construct a nodeStatsAggregators array to hold node aggregate stats,
+        // each node will have a nodeStatsAggregator
+        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+            Some(nodeToFeatures(nodeIndex))
+          }
+          new DTStatsAggregator(metadata, featuresForNode)
+        }
+
+        // iterator all instances in current partition and update aggregate stats
+        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
+
+        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+        // which can be combined with other partition using `reduceByKey`
+        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+      }
+    } else {
       input.mapPartitions { points =>
         // Construct a nodeStatsAggregators array to hold node aggregate stats,
         // each node will have a nodeStatsAggregator
@@ -570,7 +631,10 @@ object DecisionTree extends Serializable with Logging {
         // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
         // which can be combined with other partition using `reduceByKey`
         nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
-      }.reduceByKey((a, b) => a.merge(b))
+      }
+    }
+
+    val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
         .map { case (nodeIndex, aggStats) =>
           val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
             Some(nodeToFeatures(nodeIndex))
@@ -584,6 +648,13 @@ object DecisionTree extends Serializable with Logging {
 
     timer.stop("chooseSplits")
 
+    val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
+      Array.fill[mutable.Map[Int, NodeIndexUpdater]](
+        metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
+    } else {
+      null
+    }
+
     // Iterate over all nodes in this group.
     nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
       nodesForTree.foreach { node =>
@@ -613,6 +684,13 @@ object DecisionTree extends Serializable with Logging {
           node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
             stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
 
+          if (nodeIdCache.nonEmpty) {
+            val nodeIndexUpdater = NodeIndexUpdater(
+              split = split,
+              nodeIndex = nodeIndex)
+            nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
+          }
+
           // enqueue left child and right child if they are not leaves
           if (!leftChildIsLeaf) {
             nodeQueue.enqueue((treeIndex, node.leftNode.get))
@@ -629,6 +707,10 @@ object DecisionTree extends Serializable with Logging {
       }
     }
 
+    if (nodeIdCache.nonEmpty) {
+      // Update the cache if needed.
+      nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins)
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/56f2c61c/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 1dcaf91..9683916 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
 import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
 import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
 import org.apache.spark.mllib.tree.impurity.Impurities
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
@@ -160,6 +160,19 @@ private class RandomForest (
      * in lower levels).
      */
 
+    // Create an RDD of node Id cache.
+    // At first, all the rows belong to the root nodes (node Id == 1).
+    val nodeIdCache = if (strategy.useNodeIdCache) {
+      Some(NodeIdCache.init(
+        data = baggedInput,
+        numTrees = numTrees,
+        checkpointDir = strategy.checkpointDir,
+        checkpointInterval = strategy.checkpointInterval,
+        initVal = 1))
+    } else {
+      None
+    }
+
     // FIFO queue of nodes to train: (treeIndex, node)
     val nodeQueue = new mutable.Queue[(Int, Node)]()
 
@@ -182,7 +195,7 @@ private class RandomForest (
       // Choose node splits, and enqueue new nodes as needed.
       timer.start("findBestSplits")
       DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
-        treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
+        treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
       timer.stop("findBestSplits")
     }
 
@@ -193,6 +206,11 @@ private class RandomForest (
     logInfo("Internal timing for DecisionTree:")
     logInfo(s"$timer")
 
+    // Delete any remaining checkpoints used for node Id cache.
+    if (nodeIdCache.nonEmpty) {
+      nodeIdCache.get.deleteAllCheckpoints()
+    }
+
     val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
     val treeWeights = Array.fill[Double](numTrees)(1.0)
     new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)

http://git-wip-us.apache.org/repos/asf/spark/blob/56f2c61c/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 2ed63cf..d09295c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -60,6 +60,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
  * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
  *                      256 MB.
  * @param subsamplingRate Fraction of the training data used for learning decision tree.
+ * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
+ *                      maintain a separate RDD of node Id cache for each row.
+ * @param checkpointDir If the node Id cache is used, it will help to checkpoint
+ *                      the node Id cache periodically. This is the checkpoint directory
+ *                      to be used for the node Id cache.
+ * @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
+ *                           E.g. 10 means that the cache will get checkpointed every 10 updates.
  */
 @Experimental
 class Strategy (
@@ -73,7 +80,10 @@ class Strategy (
     @BeanProperty var minInstancesPerNode: Int = 1,
     @BeanProperty var minInfoGain: Double = 0.0,
     @BeanProperty var maxMemoryInMB: Int = 256,
-    @BeanProperty var subsamplingRate: Double = 1) extends Serializable {
+    @BeanProperty var subsamplingRate: Double = 1,
+    @BeanProperty var useNodeIdCache: Boolean = false,
+    @BeanProperty var checkpointDir: Option[String] = None,
+    @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
 
   if (algo == Classification) {
     require(numClassesForClassification >= 2)

http://git-wip-us.apache.org/repos/asf/spark/blob/56f2c61c/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
new file mode 100644
index 0000000..83011b4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
+
+/**
+ * :: DeveloperApi ::
+ * This is used by the node id cache to find the child id that a data point would belong to.
+ * @param split Split information.
+ * @param nodeIndex The current node index of a data point that this will update.
+ */
+@DeveloperApi
+private[tree] case class NodeIndexUpdater(
+    split: Split,
+    nodeIndex: Int) {
+  /**
+   * Determine a child node index based on the feature value and the split.
+   * @param binnedFeatures Binned feature values.
+   * @param bins Bin information to convert the bin indices to approximate feature values.
+   * @return Child node index to update to.
+   */
+  def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
+    if (split.featureType == Continuous) {
+      val featureIndex = split.feature
+      val binIndex = binnedFeatures(featureIndex)
+      val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
+      if (featureValueUpperBound <= split.threshold) {
+        Node.leftChildIndex(nodeIndex)
+      } else {
+        Node.rightChildIndex(nodeIndex)
+      }
+    } else {
+      if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
+        Node.leftChildIndex(nodeIndex)
+      } else {
+        Node.rightChildIndex(nodeIndex)
+      }
+    }
+  }
+}
+
+/**
+ * :: DeveloperApi ::
+ * A given TreePoint would belong to a particular node per tree.
+ * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
+ * in each tree. Initially, values should all be 1 for root node.
+ * The nodeIdsForInstances RDD needs to be updated at each iteration.
+ * @param nodeIdsForInstances The initial values in the cache
+ *            (should be an Array of all 1's (meaning the root nodes)).
+ * @param checkpointDir The checkpoint directory where
+ *                      the checkpointed files will be stored.
+ * @param checkpointInterval The checkpointing interval
+ *                           (how often should the cache be checkpointed.).
+ */
+@DeveloperApi
+private[tree] class NodeIdCache(
+  var nodeIdsForInstances: RDD[Array[Int]],
+  val checkpointDir: Option[String],
+  val checkpointInterval: Int) {
+
+  // Keep a reference to a previous node Ids for instances.
+  // Because we will keep on re-persisting updated node Ids,
+  // we want to unpersist the previous RDD.
+  private var prevNodeIdsForInstances: RDD[Array[Int]] = null
+
+  // To keep track of the past checkpointed RDDs.
+  private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
+  private var rddUpdateCount = 0
+
+  // If a checkpoint directory is given, and there's no prior checkpoint directory,
+  // then set the checkpoint directory with the given one.
+  if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
+    nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
+  }
+
+  /**
+   * Update the node index values in the cache.
+   * This updates the RDD and its lineage.
+   * TODO: Passing bin information to executors seems unnecessary and costly.
+   * @param data The RDD of training rows.
+   * @param nodeIdUpdaters A map of node index updaters.
+   *                       The key is the indices of nodes that we want to update.
+   * @param bins Bin information needed to find child node indices.
+   */
+  def updateNodeIndices(
+      data: RDD[BaggedPoint[TreePoint]],
+      nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
+      bins: Array[Array[Bin]]): Unit = {
+    if (prevNodeIdsForInstances != null) {
+      // Unpersist the previous one if one exists.
+      prevNodeIdsForInstances.unpersist()
+    }
+
+    prevNodeIdsForInstances = nodeIdsForInstances
+    nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
+      dataPoint => {
+        var treeId = 0
+        while (treeId < nodeIdUpdaters.length) {
+          val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null)
+          if (nodeIdUpdater != null) {
+            val newNodeIndex = nodeIdUpdater.updateNodeIndex(
+              binnedFeatures = dataPoint._1.datum.binnedFeatures,
+              bins = bins)
+            dataPoint._2(treeId) = newNodeIndex
+          }
+
+          treeId += 1
+        }
+
+        dataPoint._2
+      }
+    }
+
+    // Keep on persisting new ones.
+    nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
+    rddUpdateCount += 1
+
+    // Handle checkpointing if the directory is not None.
+    if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
+      (rddUpdateCount % checkpointInterval) == 0) {
+      // Let's see if we can delete previous checkpoints.
+      var canDelete = true
+      while (checkpointQueue.size > 1 && canDelete) {
+        // We can delete the oldest checkpoint iff
+        // the next checkpoint actually exists in the file system.
+        if (checkpointQueue.get(1).get.getCheckpointFile != None) {
+          val old = checkpointQueue.dequeue()
+
+          // Since the old checkpoint is not deleted by Spark,
+          // we'll manually delete it here.
+          val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
+          fs.delete(new Path(old.getCheckpointFile.get), true)
+        } else {
+          canDelete = false
+        }
+      }
+
+      nodeIdsForInstances.checkpoint()
+      checkpointQueue.enqueue(nodeIdsForInstances)
+    }
+  }
+
+  /**
+   * Call this after training is finished to delete any remaining checkpoints.
+   */
+  def deleteAllCheckpoints(): Unit = {
+    while (checkpointQueue.size > 0) {
+      val old = checkpointQueue.dequeue()
+      if (old.getCheckpointFile != None) {
+        val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
+        fs.delete(new Path(old.getCheckpointFile.get), true)
+      }
+    }
+  }
+}
+
+@DeveloperApi
+private[tree] object NodeIdCache {
+  /**
+   * Initialize the node Id cache with initial node Id values.
+   * @param data The RDD of training rows.
+   * @param numTrees The number of trees that we want to create cache for.
+   * @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
+   * @param checkpointInterval The checkpointing interval
+   *                           (how often should the cache be checkpointed.).
+   * @param initVal The initial values in the cache.
+   * @return A node Id cache containing an RDD of initial root node Indices.
+   */
+  def init(
+      data: RDD[BaggedPoint[TreePoint]],
+      numTrees: Int,
+      checkpointDir: Option[String],
+      checkpointInterval: Int,
+      initVal: Int = 1): NodeIdCache = {
+    new NodeIdCache(
+      data.map(_ => Array.fill[Int](numTrees)(initVal)),
+      checkpointDir,
+      checkpointInterval)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/56f2c61c/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 10c046e..73c4393 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -34,18 +34,11 @@ import org.apache.spark.mllib.util.LocalSparkContext
  * Test suite for [[RandomForest]].
  */
 class RandomForestSuite extends FunSuite with LocalSparkContext {
-
-  test("Binary classification with continuous features:" +
-      " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
-
+  def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) {
     val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
     val rdd = sc.parallelize(arr)
-    val categoricalFeaturesInfo = Map.empty[Int, Int]
     val numTrees = 1
 
-    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
-      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
-
     val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
       featureSubsetStrategy = "auto", seed = 123)
     assert(rf.weakHypotheses.size === 1)
@@ -60,18 +53,27 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
     assert(rfTree.toString == dt.toString)
   }
 
-  test("Regression with continuous features:" +
+  test("Binary classification with continuous features:" +
     " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+    binaryClassificationTestWithContinuousFeatures(strategy)
+  }
 
+  test("Binary classification with continuous features and node Id cache :" +
+    " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+    binaryClassificationTestWithContinuousFeatures(strategy)
+  }
+
+  def regressionTestWithContinuousFeatures(strategy: Strategy) {
     val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
     val rdd = sc.parallelize(arr)
-    val categoricalFeaturesInfo = Map.empty[Int, Int]
     val numTrees = 1
 
-    val strategy = new Strategy(algo = Regression, impurity = Variance,
-      maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
-      categoricalFeaturesInfo = categoricalFeaturesInfo)
-
     val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
       featureSubsetStrategy = "auto", seed = 123)
     assert(rf.weakHypotheses.size === 1)
@@ -86,14 +88,28 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
     assert(rfTree.toString == dt.toString)
   }
 
-  test("Binary classification with continuous features: subsampling features") {
+  test("Regression with continuous features:" +
+    " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val strategy = new Strategy(algo = Regression, impurity = Variance,
+      maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+      categoricalFeaturesInfo = categoricalFeaturesInfo)
+    regressionTestWithContinuousFeatures(strategy)
+  }
+
+  test("Regression with continuous features and node Id cache :" +
+    " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val strategy = new Strategy(algo = Regression, impurity = Variance,
+      maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+      categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+    regressionTestWithContinuousFeatures(strategy)
+  }
+
+  def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) {
     val numFeatures = 50
     val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
     val rdd = sc.parallelize(arr)
-    val categoricalFeaturesInfo = Map.empty[Int, Int]
-
-    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
-      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
 
     // Select feature subset for top nodes.  Return true if OK.
     def checkFeatureSubsetStrategy(
@@ -149,6 +165,20 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
     checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
   }
 
+  test("Binary classification with continuous features: subsampling features") {
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+    binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+  }
+
+  test("Binary classification with continuous features and node Id cache: subsampling features") {
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+    binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+  }
+
   test("alternating categorical and continuous features with multiclass labels to test indexing") {
     val arr = new Array[LabeledPoint](4)
     arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))
@@ -164,7 +194,6 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
       featureSubsetStrategy = "sqrt", seed = 12345)
     EnsembleTestHelper.validateClassifier(model, arr, 1.0)
   }
-
 }
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org