You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/10/03 12:26:25 UTC

git commit: [SPARK-3366][MLLIB]Compute best splits distributively in decision tree

Repository: spark
Updated Branches:
  refs/heads/master 1c90347a4 -> 2e4eae3a5


[SPARK-3366][MLLIB]Compute best splits distributively in decision tree

Currently, all best splits are computed on the driver, which makes the driver a bottleneck for both communication and computation. This PR fix this problem by computed best splits on executors.
Instead of send all aggregate stats to the driver node, we can send aggregate stats for a node to a particular executor, using `reduceByKey` operation, then we can compute best split for this node there.

Implementation details:

Each node now has a nodeStatsAggregator, which save aggregate stats for all features and bins.
First use mapPartition to compute node aggregate stats for all nodes in each partition.
Then transform node aggregate stats to (nodeIndex, nodeStatsAggregator) pairs and use to `reduceByKey` operation to combine nodeStatsAggregator for the same node.
After all stats have been combined, best splits can be computed for each node based on the node aggregate stats. Best split result is collected to driver to construct the decision tree.

CC: mengxr manishamde jkbradley, please help me review this, thanks.

Author: qiping.lqp <qi...@alibaba-inc.com>
Author: chouqin <li...@gmail.com>

Closes #2595 from chouqin/dt-dist-agg and squashes the following commits:

db0d24a [chouqin] fix a minor bug and adjust code
a0d9de3 [chouqin] adjust code based on comments
9f201a6 [chouqin] fix bug: statsSize -> allStatsSize
a8a7ed0 [chouqin] Merge branch 'master' of https://github.com/apache/spark into dt-dist-agg
f13b346 [chouqin] adjust randomforest comments
c32636e [chouqin] adjust code based on comments
ac6a505 [chouqin] adjust code based on comments
7bbb787 [chouqin] add comments
bdd2a63 [qiping.lqp] fix test suite
a75df27 [qiping.lqp] fix test suite
b5b0bc2 [qiping.lqp] fix style
e76414f [qiping.lqp] fix testsuite
748bd45 [qiping.lqp] fix type-mismatch bug
24eacd8 [qiping.lqp] fix type-mismatch bug
5f63d6c [qiping.lqp] add multiclassification using One-Vs-All strategy
4f56496 [qiping.lqp] fix bug
f00fc22 [qiping.lqp] fix bug
532993a [qiping.lqp] Compute best splits distributively in decision tree


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2e4eae3a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2e4eae3a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2e4eae3a

Branch: refs/heads/master
Commit: 2e4eae3a52e3d04895b00447d1ac56ae3c1b98ae
Parents: 1c90347
Author: qiping.lqp <qi...@alibaba-inc.com>
Authored: Fri Oct 3 03:26:17 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri Oct 3 03:26:17 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  | 140 +++++----
 .../apache/spark/mllib/tree/RandomForest.scala  |   5 +-
 .../mllib/tree/impl/DTStatsAggregator.scala     | 292 +++++--------------
 .../mllib/tree/model/InformationGainStats.scala |  11 +
 .../spark/mllib/tree/RandomForestSuite.scala    |   1 +
 5 files changed, 182 insertions(+), 267 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2e4eae3a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b7dc373..b311d10 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -23,7 +23,6 @@ import scala.collection.mutable
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.Logging
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
 import org.apache.spark.mllib.tree.configuration.Strategy
@@ -36,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity._
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.SparkContext._
 
 
 /**
@@ -328,9 +328,8 @@ object DecisionTree extends Serializable with Logging {
    * for each subset is updated.
    *
    * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
-   *             each (node, feature, bin).
+   *             each (feature, bin).
    * @param treePoint  Data point being aggregated.
-   * @param nodeIndex  Node corresponding to treePoint.  agg is indexed in [0, numNodes).
    * @param bins possible bins for all features, indexed (numFeatures)(numBins)
    * @param unorderedFeatures  Set of indices of unordered features.
    * @param instanceWeight  Weight (importance) of instance in dataset.
@@ -338,7 +337,6 @@ object DecisionTree extends Serializable with Logging {
   private def mixedBinSeqOp(
       agg: DTStatsAggregator,
       treePoint: TreePoint,
-      nodeIndex: Int,
       bins: Array[Array[Bin]],
       unorderedFeatures: Set[Int],
       instanceWeight: Double,
@@ -350,7 +348,6 @@ object DecisionTree extends Serializable with Logging {
       // Use all features
       agg.metadata.numFeatures
     }
-    val nodeOffset = agg.getNodeOffset(nodeIndex)
     // Iterate over features.
     var featureIndexIdx = 0
     while (featureIndexIdx < numFeaturesPerNode) {
@@ -363,16 +360,16 @@ object DecisionTree extends Serializable with Logging {
         // Unordered feature
         val featureValue = treePoint.binnedFeatures(featureIndex)
         val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
-          agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
+          agg.getLeftRightFeatureOffsets(featureIndexIdx)
         // Update the left or right bin for each split.
         val numSplits = agg.metadata.numSplits(featureIndex)
         var splitIndex = 0
         while (splitIndex < numSplits) {
           if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
-            agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
+            agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
               instanceWeight)
           } else {
-            agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
+            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
               instanceWeight)
           }
           splitIndex += 1
@@ -380,8 +377,7 @@ object DecisionTree extends Serializable with Logging {
       } else {
         // Ordered feature
         val binIndex = treePoint.binnedFeatures(featureIndex)
-        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label,
-          instanceWeight)
+        agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
       }
       featureIndexIdx += 1
     }
@@ -393,26 +389,24 @@ object DecisionTree extends Serializable with Logging {
    * For each feature, the sufficient statistics of one bin are updated.
    *
    * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
-   *             each (node, feature, bin).
+   *             each (feature, bin).
    * @param treePoint  Data point being aggregated.
-   * @param nodeIndex  Node corresponding to treePoint.  agg is indexed in [0, numNodes).
    * @param instanceWeight  Weight (importance) of instance in dataset.
    */
   private def orderedBinSeqOp(
       agg: DTStatsAggregator,
       treePoint: TreePoint,
-      nodeIndex: Int,
       instanceWeight: Double,
       featuresForNode: Option[Array[Int]]): Unit = {
     val label = treePoint.label
-    val nodeOffset = agg.getNodeOffset(nodeIndex)
+
     // Iterate over features.
     if (featuresForNode.nonEmpty) {
       // Use subsampled features
       var featureIndexIdx = 0
       while (featureIndexIdx < featuresForNode.get.size) {
         val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
-        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight)
+        agg.update(featureIndexIdx, binIndex, label, instanceWeight)
         featureIndexIdx += 1
       }
     } else {
@@ -421,7 +415,7 @@ object DecisionTree extends Serializable with Logging {
       var featureIndex = 0
       while (featureIndex < numFeatures) {
         val binIndex = treePoint.binnedFeatures(featureIndex)
-        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight)
+        agg.update(featureIndex, binIndex, label, instanceWeight)
         featureIndex += 1
       }
     }
@@ -496,8 +490,8 @@ object DecisionTree extends Serializable with Logging {
      * @return  agg
      */
     def binSeqOp(
-        agg: DTStatsAggregator,
-        baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = {
+        agg: Array[DTStatsAggregator],
+        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
       treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
         val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
           bins, metadata.unorderedFeatures)
@@ -508,9 +502,9 @@ object DecisionTree extends Serializable with Logging {
           val featuresForNode = nodeInfo.featureSubset
           val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
           if (metadata.unorderedFeatures.isEmpty) {
-            orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode)
+            orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
           } else {
-            mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures,
+            mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
               instanceWeight, featuresForNode)
           }
         }
@@ -518,30 +512,76 @@ object DecisionTree extends Serializable with Logging {
       agg
     }
 
-    // Calculate bin aggregates.
-    timer.start("aggregation")
-    val binAggregates: DTStatsAggregator = {
-      val initAgg = if (metadata.subsamplingFeatures) {
-        new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo)
-      } else {
-        new DTStatsAggregatorFixedFeatures(metadata, numNodes)
+    /**
+     * Get node index in group --> features indices map,
+     * which is a short cut to find feature indices for a node given node index in group
+     * @param treeToNodeToIndexInfo
+     * @return
+     */
+    def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]])
+      : Option[Map[Int, Array[Int]]] = if (!metadata.subsamplingFeatures) {
+      None
+    } else {
+      val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
+      treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
+        nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
+          assert(nodeIndexInfo.featureSubset.isDefined)
+          mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
+        }
       }
-      input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
+      Some(mutableNodeToFeatures.toMap)
     }
-    timer.stop("aggregation")
 
     // Calculate best splits for all nodes in the group
     timer.start("chooseSplits")
 
+    // In each partition, iterate all instances and compute aggregate stats for each node,
+    // yield an (nodeIndex, nodeAggregateStats) pair for each node.
+    // After a `reduceByKey` operation,
+    // stats of a node will be shuffled to a particular partition and be combined together,
+    // then best splits for nodes are found there.
+    // Finally, only best Splits for nodes are collected to driver to construct decision tree.
+    val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
+    val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
+    val nodeToBestSplits =
+      input.mapPartitions { points =>
+        // Construct a nodeStatsAggregators array to hold node aggregate stats,
+        // each node will have a nodeStatsAggregator
+        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+            Some(nodeToFeatures(nodeIndex))
+          }
+          new DTStatsAggregator(metadata, featuresForNode)
+        }
+
+        // iterator all instances in current partition and update aggregate stats
+        points.foreach(binSeqOp(nodeStatsAggregators, _))
+
+        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+        // which can be combined with other partition using `reduceByKey`
+        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+      }.reduceByKey((a, b) => a.merge(b))
+        .map { case (nodeIndex, aggStats) =>
+          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+            Some(nodeToFeatures(nodeIndex))
+          }
+
+          // find best split for each node
+          val (split: Split, stats: InformationGainStats, predict: Predict) =
+            binsToBestSplit(aggStats, splits, featuresForNode)
+          (nodeIndex, (split, stats, predict))
+        }.collectAsMap()
+
+    timer.stop("chooseSplits")
+
     // Iterate over all nodes in this group.
     nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
       nodesForTree.foreach { node =>
         val nodeIndex = node.id
         val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
         val aggNodeIndex = nodeInfo.nodeIndexInGroup
-        val featuresForNode = nodeInfo.featureSubset
         val (split: Split, stats: InformationGainStats, predict: Predict) =
-          binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode)
+          nodeToBestSplits(aggNodeIndex)
         logDebug("best split = " + split)
 
         // Extract info for this node.  Create children if not leaf.
@@ -565,7 +605,7 @@ object DecisionTree extends Serializable with Logging {
         }
       }
     }
-    timer.stop("chooseSplits")
+
   }
 
   /**
@@ -633,36 +673,33 @@ object DecisionTree extends Serializable with Logging {
   /**
    * Find the best split for a node.
    * @param binAggregates Bin statistics.
-   * @param nodeIndex Index into aggregates for node to split in this group.
    * @return tuple for best split: (Split, information gain, prediction at node)
    */
   private def binsToBestSplit(
       binAggregates: DTStatsAggregator,
-      nodeIndex: Int,
       splits: Array[Array[Split]],
       featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
 
-    val metadata: DecisionTreeMetadata = binAggregates.metadata
-
     // calculate predict only once
     var predict: Option[Predict] = None
 
     // For each (feature, split), calculate the gain, and select the best (feature, split).
-    val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx =>
+    val (bestSplit, bestSplitStats) =
+      Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
       val featureIndex = if (featuresForNode.nonEmpty) {
         featuresForNode.get.apply(featureIndexIdx)
       } else {
         featureIndexIdx
       }
-      val numSplits = metadata.numSplits(featureIndex)
-      if (metadata.isContinuous(featureIndex)) {
+      val numSplits = binAggregates.metadata.numSplits(featureIndex)
+      if (binAggregates.metadata.isContinuous(featureIndex)) {
         // Cumulative sum (scanLeft) of bin statistics.
         // Afterwards, binAggregates for a bin is the sum of aggregates for
         // that bin + all preceding bins.
-        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
+        val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
         var splitIndex = 0
         while (splitIndex < numSplits) {
-          binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+          binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
           splitIndex += 1
         }
         // Find best split.
@@ -672,27 +709,29 @@ object DecisionTree extends Serializable with Logging {
             val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats,
+              rightChildStats, binAggregates.metadata)
             (splitIdx, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
-      } else if (metadata.isUnordered(featureIndex)) {
+      } else if (binAggregates.metadata.isUnordered(featureIndex)) {
         // Unordered categorical feature
         val (leftChildOffset, rightChildOffset) =
-          binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
+          binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
         val (bestFeatureSplitIndex, bestFeatureGainStats) =
           Range(0, numSplits).map { splitIndex =>
             val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
             val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats,
+              rightChildStats, binAggregates.metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
       } else {
         // Ordered categorical feature
-        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
-        val numBins = metadata.numBins(featureIndex)
+        val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+        val numBins = binAggregates.metadata.numBins(featureIndex)
 
         /* Each bin is one category (feature value).
          * The bins are ordered based on centroidForCategories, and this ordering determines which
@@ -700,7 +739,7 @@ object DecisionTree extends Serializable with Logging {
          *
          * centroidForCategories is a list: (category, centroid)
          */
-        val centroidForCategories = if (metadata.isMulticlass) {
+        val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
           // For categorical variables in multiclass classification,
           // the bins are ordered by the impurity of their corresponding labels.
           Range(0, numBins).map { case featureValue =>
@@ -741,7 +780,7 @@ object DecisionTree extends Serializable with Logging {
         while (splitIndex < numSplits) {
           val currentCategory = categoriesSortedByCentroid(splitIndex)._1
           val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
-          binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory)
+          binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
           splitIndex += 1
         }
         // lastCategory = index of bin with total aggregates for this (node, feature)
@@ -756,7 +795,8 @@ object DecisionTree extends Serializable with Logging {
               binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats,
+              rightChildStats, binAggregates.metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         val categoriesForSplit =

http://git-wip-us.apache.org/repos/asf/spark/blob/2e4eae3a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 7fa7725..fa7a26f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -171,8 +171,8 @@ private class RandomForest (
 
       // Choose node splits, and enqueue new nodes as needed.
       timer.start("findBestSplits")
-      DecisionTree.findBestSplits(baggedInput,
-        metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
+      DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
+        treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
       timer.stop("findBestSplits")
     }
 
@@ -382,6 +382,7 @@ object RandomForest extends Serializable with Logging {
    * @param maxMemoryUsage  Bound on size of aggregate statistics.
    * @return  (nodesForGroup, treeToNodeToIndexInfo).
    *          nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+   *
    *          treeToNodeToIndexInfo holds indices selected features for each node:
    *            treeIndex --> (global) node index --> (node index in group, feature indices).
    *          The (global) node index is the index in the tree; the node index in group is the

http://git-wip-us.apache.org/repos/asf/spark/blob/2e4eae3a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index d49df7a..55f422d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -17,17 +17,19 @@
 
 package org.apache.spark.mllib.tree.impl
 
-import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
 import org.apache.spark.mllib.tree.impurity._
 
+
+
 /**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * DecisionTree statistics aggregator for a node.
+ * This holds a flat array of statistics for a set of (features, bins)
  * and helps with indexing.
  * This class is abstract to support learning with and without feature subsampling.
  */
-private[tree] abstract class DTStatsAggregator(
-    val metadata: DecisionTreeMetadata) extends Serializable {
+private[tree] class DTStatsAggregator(
+    val metadata: DecisionTreeMetadata,
+    featureSubset: Option[Array[Int]]) extends Serializable {
 
   /**
    * [[ImpurityAggregator]] instance specifying the impurity type.
@@ -42,7 +44,25 @@ private[tree] abstract class DTStatsAggregator(
   /**
    * Number of elements (Double values) used for the sufficient statistics of each bin.
    */
-  val statsSize: Int = impurityAggregator.statsSize
+  private val statsSize: Int = impurityAggregator.statsSize
+
+  /**
+   * Number of bins for each feature.  This is indexed by the feature index.
+   */
+  private val numBins: Array[Int] = {
+    if (featureSubset.isDefined) {
+      featureSubset.get.map(metadata.numBins(_))
+    } else {
+      metadata.numBins
+    }
+  }
+
+  /**
+   * Offset for each feature for calculating indices into the [[allStats]] array.
+   */
+  private val featureOffsets: Array[Int] = {
+    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
+  }
 
   /**
    * Indicator for each feature of whether that feature is an unordered feature.
@@ -51,107 +71,95 @@ private[tree] abstract class DTStatsAggregator(
   def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
 
   /**
-   * Total number of elements stored in this aggregator.
+   * Total number of elements stored in this aggregator
    */
-  def allStatsSize: Int
+  private val allStatsSize: Int = featureOffsets.last
 
   /**
-   * Get flat array of elements stored in this aggregator.
+   * Flat array of elements.
+   * Index for start of stats for a (feature, bin) is:
+   *   index = featureOffsets(featureIndex) + binIndex * statsSize
+   * Note: For unordered features,
+   *       the left child stats have binIndex in [0, numBins(featureIndex) / 2))
+   *       and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
    */
-  protected def allStats: Array[Double]
+  private val allStats: Array[Double] = new Array[Double](allStatsSize)
+
 
   /**
    * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
-   * @param nodeFeatureOffset  For ordered features, this is a pre-computed (node, feature) offset
-   *                           from [[getNodeFeatureOffset]].
+   * @param featureOffset  For ordered features, this is a pre-computed (node, feature) offset
+   *                           from [[getFeatureOffset]].
    *                           For unordered features, this is a pre-computed
    *                           (node, feature, left/right child) offset from
-   *                           [[getLeftRightNodeFeatureOffsets]].
+   *                           [[getLeftRightFeatureOffsets]].
    */
-  def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = {
-    impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize)
+  def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
+    impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
   }
 
   /**
-   * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+   * Update the stats for a given (feature, bin) for ordered features, using the given label.
    */
-  def update(
-      nodeIndex: Int,
-      featureIndex: Int,
-      binIndex: Int,
-      label: Double,
-      instanceWeight: Double): Unit = {
-    val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize
+  def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
+    val i = featureOffsets(featureIndex) + binIndex * statsSize
     impurityAggregator.update(allStats, i, label, instanceWeight)
   }
 
   /**
-   * Pre-compute node offset for use with [[nodeUpdate]].
-   */
-  def getNodeOffset(nodeIndex: Int): Int
-
-  /**
    * Faster version of [[update]].
-   * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
-   * @param nodeOffset  Pre-computed node offset from [[getNodeOffset]].
+   * Update the stats for a given (feature, bin), using the given label.
+   * @param featureOffset  For ordered features, this is a pre-computed feature offset
+   *                           from [[getFeatureOffset]].
+   *                           For unordered features, this is a pre-computed
+   *                           (feature, left/right child) offset from
+   *                           [[getLeftRightFeatureOffsets]].
    */
-  def nodeUpdate(
-      nodeOffset: Int,
-      nodeIndex: Int,
-      featureIndex: Int,
+  def featureUpdate(
+      featureOffset: Int,
       binIndex: Int,
       label: Double,
-      instanceWeight: Double): Unit
+      instanceWeight: Double): Unit = {
+    impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
+      label, instanceWeight)
+  }
 
   /**
-   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+   * Pre-compute feature offset for use with [[featureUpdate]].
    * For ordered features only.
    */
-  def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int
+  def getFeatureOffset(featureIndex: Int): Int = {
+    require(!isUnordered(featureIndex),
+      s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" +
+        s" for unordered feature $featureIndex.")
+    featureOffsets(featureIndex)
+  }
 
   /**
-   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+   * Pre-compute feature offset for use with [[featureUpdate]].
    * For unordered features only.
    */
-  def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
+  def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
     require(isUnordered(featureIndex),
-      s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
+      s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," +
         s" but was called for ordered feature $featureIndex.")
-    val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex)
-    (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize)
-  }
-
-  /**
-   * Faster version of [[update]].
-   * Update the stats for a given (node, feature, bin), using the given label.
-   * @param nodeFeatureOffset  For ordered features, this is a pre-computed (node, feature) offset
-   *                           from [[getNodeFeatureOffset]].
-   *                           For unordered features, this is a pre-computed
-   *                           (node, feature, left/right child) offset from
-   *                           [[getLeftRightNodeFeatureOffsets]].
-   */
-  def nodeFeatureUpdate(
-      nodeFeatureOffset: Int,
-      binIndex: Int,
-      label: Double,
-      instanceWeight: Double): Unit = {
-    impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label,
-      instanceWeight)
+    val baseOffset = featureOffsets(featureIndex)
+    (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
   }
 
   /**
-   * For a given (node, feature), merge the stats for two bins.
-   * @param nodeFeatureOffset  For ordered features, this is a pre-computed (node, feature) offset
-   *                           from [[getNodeFeatureOffset]].
+   * For a given feature, merge the stats for two bins.
+   * @param featureOffset  For ordered features, this is a pre-computed feature offset
+   *                           from [[getFeatureOffset]].
    *                           For unordered features, this is a pre-computed
-   *                           (node, feature, left/right child) offset from
-   *                           [[getLeftRightNodeFeatureOffsets]].
+   *                           (feature, left/right child) offset from
+   *                           [[getLeftRightFeatureOffsets]].
    * @param binIndex  The other bin is merged into this bin.
    * @param otherBinIndex  This bin is not modified.
    */
-  def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
-    impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize,
-      nodeFeatureOffset + otherBinIndex * statsSize)
+  def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
+    impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize,
+      featureOffset + otherBinIndex * statsSize)
   }
 
   /**
@@ -161,7 +169,7 @@ private[tree] abstract class DTStatsAggregator(
   def merge(other: DTStatsAggregator): DTStatsAggregator = {
     require(allStatsSize == other.allStatsSize,
       s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
-      + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
+        + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
     var i = 0
     // TODO: Test BLAS.axpy
     while (i < allStatsSize) {
@@ -171,149 +179,3 @@ private[tree] abstract class DTStatsAggregator(
     this
   }
 }
-
-/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
- * and helps with indexing.
- *
- * This instance of [[DTStatsAggregator]] is used when not subsampling features.
- *
- * @param numNodes  Number of nodes to collect statistics for.
- */
-private[tree] class DTStatsAggregatorFixedFeatures(
-    metadata: DecisionTreeMetadata,
-    numNodes: Int) extends DTStatsAggregator(metadata) {
-
-  /**
-   * Offset for each feature for calculating indices into the [[allStats]] array.
-   * Mapping: featureIndex --> offset
-   */
-  private val featureOffsets: Array[Int] = {
-    metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
-  }
-
-  /**
-   * Number of elements for each node, corresponding to stride between nodes in [[allStats]].
-   */
-  private val nodeStride: Int = featureOffsets.last
-
-  override val allStatsSize: Int = numNodes * nodeStride
-
-  /**
-   * Flat array of elements.
-   * Index for start of stats for a (node, feature, bin) is:
-   *   index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
-   * Note: For unordered features, the left child stats precede the right child stats
-   *       in the binIndex order.
-   */
-  override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
-
-  override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
-
-  override def nodeUpdate(
-      nodeOffset: Int,
-      nodeIndex: Int,
-      featureIndex: Int,
-      binIndex: Int,
-      label: Double,
-      instanceWeight: Double): Unit = {
-    val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
-    impurityAggregator.update(allStats, i, label, instanceWeight)
-  }
-
-  override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
-    nodeIndex * nodeStride + featureOffsets(featureIndex)
-  }
-}
-
-/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
- * and helps with indexing.
- *
- * This instance of [[DTStatsAggregator]] is used when subsampling features.
- *
- * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
- *                              where nodeIndexInfo stores the index in the group and the
- *                              feature subsets (if using feature subsets).
- */
-private[tree] class DTStatsAggregatorSubsampledFeatures(
-    metadata: DecisionTreeMetadata,
-    treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) {
-
-  /**
-   * For each node, offset for each feature for calculating indices into the [[allStats]] array.
-   * Mapping: nodeIndex --> featureIndex --> offset
-   */
-  private val featureOffsets: Array[Array[Int]] = {
-    val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum
-    val offsets = new Array[Array[Int]](numNodes)
-    treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) =>
-      nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) =>
-        offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_))
-          .scanLeft(0)((total, nBins) => total + statsSize * nBins)
-      }
-    }
-    offsets
-  }
-
-  /**
-   * For each node, offset for each feature for calculating indices into the [[allStats]] array.
-   */
-  protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _)
-
-  override val allStatsSize: Int = nodeOffsets.last
-
-  /**
-   * Flat array of elements.
-   * Index for start of stats for a (node, feature, bin) is:
-   *   index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize
-   * Note: For unordered features, the left child stats precede the right child stats
-   *       in the binIndex order.
-   */
-  override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
-
-  override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex)
-
-  /**
-   * Faster version of [[update]].
-   * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
-   * @param nodeOffset  Pre-computed node offset from [[getNodeOffset]].
-   * @param featureIndex  Index of feature in featuresForNodes(nodeIndex).
-   *                      Note: This is NOT the original feature index.
-   */
-  override def nodeUpdate(
-      nodeOffset: Int,
-      nodeIndex: Int,
-      featureIndex: Int,
-      binIndex: Int,
-      label: Double,
-      instanceWeight: Double): Unit = {
-    val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize
-    impurityAggregator.update(allStats, i, label, instanceWeight)
-  }
-
-  /**
-   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
-   * For ordered features only.
-   * @param featureIndex  Index of feature in featuresForNodes(nodeIndex).
-   *                      Note: This is NOT the original feature index.
-   */
-  override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
-    nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex)
-  }
-}
-
-private[tree] object DTStatsAggregator extends Serializable {
-
-  /**
-   * Combines two aggregates (modifying the first) and returns the combination.
-   */
-  def binCombOp(
-      agg1: DTStatsAggregator,
-      agg2: DTStatsAggregator): DTStatsAggregator = {
-    agg1.merge(agg2)
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/2e4eae3a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index f3e2619..a89e71e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -38,6 +38,17 @@ class InformationGainStats(
     "gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
       .format(gain, impurity, leftImpurity, rightImpurity)
   }
+
+  override def equals(o: Any) =
+    o match {
+      case other: InformationGainStats => {
+        gain == other.gain &&
+        impurity == other.impurity &&
+        leftImpurity == other.leftImpurity &&
+        rightImpurity == other.rightImpurity
+      }
+      case _ => false
+    }
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2e4eae3a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 30669fc..20d372d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -145,6 +145,7 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
 
         assert(nodesForGroup.size === numTrees, failString)
         assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree
+
         if (numFeaturesPerNode == numFeatures) {
           // featureSubset values should all be None
           assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org