You are viewing a plain text version of this content. The canonical link for it is here.
Posted to by on 2014/09/08 18:47:18 UTC

[1/2] [SPARK-3086] [SPARK-3043] [SPARK-3156] [mllib] DecisionTree aggregation improvements

Repository: spark
Updated Branches:
  refs/heads/master 0d1cc4ae4 -> 711356b42
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
new file mode 100644
index 0000000..866d85a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -0,0 +1,213 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.tree.impl
+import org.apache.spark.mllib.tree.impurity._
+ * DecisionTree statistics aggregator.
+ * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * and helps with indexing.
+ */
+private[tree] class DTStatsAggregator(
+    val metadata: DecisionTreeMetadata,
+    val numNodes: Int) extends Serializable {
+  /**
+   * [[ImpurityAggregator]] instance specifying the impurity type.
+   */
+  val impurityAggregator: ImpurityAggregator = metadata.impurity match {
+    case Gini => new GiniAggregator(metadata.numClasses)
+    case Entropy => new EntropyAggregator(metadata.numClasses)
+    case Variance => new VarianceAggregator()
+    case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
+  }
+  /**
+   * Number of elements (Double values) used for the sufficient statistics of each bin.
+   */
+  val statsSize: Int = impurityAggregator.statsSize
+  val numFeatures: Int = metadata.numFeatures
+  /**
+   * Number of bins for each feature.  This is indexed by the feature index.
+   */
+  val numBins: Array[Int] = metadata.numBins
+  /**
+   * Number of splits for the given feature.
+   */
+  def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex)
+  /**
+   * Indicator for each feature of whether that feature is an unordered feature.
+   * TODO: Is Array[Boolean] any faster?
+   */
+  def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
+  /**
+   * Offset for each feature for calculating indices into the [[allStats]] array.
+   */
+  private val featureOffsets: Array[Int] = {
+    def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
+      if (isUnordered(featureIndex)) {
+        total + 2 * numBins(featureIndex)
+      } else {
+        total + numBins(featureIndex)
+      }
+    }
+    Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
+  }
+  /**
+   * Number of elements for each node, corresponding to stride between nodes in [[allStats]].
+   */
+  private val nodeStride: Int = featureOffsets.last
+  /**
+   * Total number of elements stored in this aggregator.
+   */
+  val allStatsSize: Int = numNodes * nodeStride
+  /**
+   * Flat array of elements.
+   * Index for start of stats for a (node, feature, bin) is:
+   *   index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
+   * Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex))
+   *       and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex))
+   */
+  val allStats: Array[Double] = new Array[Double](allStatsSize)
+  /**
+   * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
+   * @param nodeFeatureOffset  For ordered features, this is a pre-computed (node, feature) offset
+   *                           from [[getNodeFeatureOffset]].
+   *                           For unordered features, this is a pre-computed
+   *                           (node, feature, left/right child) offset from
+   *                           [[getLeftRightNodeFeatureOffsets]].
+   */
+  def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = {
+    impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize)
+  }
+  /**
+   * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+   */
+  def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
+    val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
+    impurityAggregator.update(allStats, i, label)
+  }
+  /**
+   * Pre-compute node offset for use with [[nodeUpdate]].
+   */
+  def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
+  /**
+   * Faster version of [[update]].
+   * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+   * @param nodeOffset  Pre-computed node offset from [[getNodeOffset]].
+   */
+  def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
+    val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
+    impurityAggregator.update(allStats, i, label)
+  }
+  /**
+   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+   * For ordered features only.
+   */
+  def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
+    require(!isUnordered(featureIndex),
+      s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" +
+      s" for unordered feature $featureIndex.")
+    nodeIndex * nodeStride + featureOffsets(featureIndex)
+  }
+  /**
+   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+   * For unordered features only.
+   */
+  def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
+    require(isUnordered(featureIndex),
+      s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
+      s" but was called for ordered feature $featureIndex.")
+    val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
+    (baseOffset, baseOffset + numBins(featureIndex) * statsSize)
+  }
+  /**
+   * Faster version of [[update]].
+   * Update the stats for a given (node, feature, bin), using the given label.
+   * @param nodeFeatureOffset  For ordered features, this is a pre-computed (node, feature) offset
+   *                           from [[getNodeFeatureOffset]].
+   *                           For unordered features, this is a pre-computed
+   *                           (node, feature, left/right child) offset from
+   *                           [[getLeftRightNodeFeatureOffsets]].
+   */
+  def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = {
+    impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label)
+  }
+  /**
+   * For a given (node, feature), merge the stats for two bins.
+   * @param nodeFeatureOffset  For ordered features, this is a pre-computed (node, feature) offset
+   *                           from [[getNodeFeatureOffset]].
+   *                           For unordered features, this is a pre-computed
+   *                           (node, feature, left/right child) offset from
+   *                           [[getLeftRightNodeFeatureOffsets]].
+   * @param binIndex  The other bin is merged into this bin.
+   * @param otherBinIndex  This bin is not modified.
+   */
+  def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
+    impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize,
+      nodeFeatureOffset + otherBinIndex * statsSize)
+  }
+  /**
+   * Merge this aggregator with another, and returns this aggregator.
+   * This method modifies this aggregator in-place.
+   */
+  def merge(other: DTStatsAggregator): DTStatsAggregator = {
+    require(allStatsSize == other.allStatsSize,
+      s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
+      + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
+    var i = 0
+    // TODO: Test BLAS.axpy
+    while (i < allStatsSize) {
+      allStats(i) += other.allStats(i)
+      i += 1
+    }
+    this
+  }
+private[tree] object DTStatsAggregator extends Serializable {
+  /**
+   * Combines two aggregates (modifying the first) and returns the combination.
+   */
+  def binCombOp(
+      agg1: DTStatsAggregator,
+      agg2: DTStatsAggregator): DTStatsAggregator = {
+    agg1.merge(agg2)
+  }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index d9eda35..e95add7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -26,14 +26,15 @@ import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.impurity.Impurity
 import org.apache.spark.rdd.RDD
  * Learning and dataset metadata for DecisionTree.
  * @param numClasses    For classification: labels can take values {0, ..., numClasses - 1}.
  *                      For regression: fixed at 0 (no meaning).
+ * @param maxBins  Maximum number of bins, for all features.
  * @param featureArity  Map: categorical feature index --> arity.
  *                      I.e., the feature takes values in {0, ..., arity - 1}.
+ * @param numBins  Number of bins for each feature.
 private[tree] class DecisionTreeMetadata(
     val numFeatures: Int,
@@ -42,6 +43,7 @@ private[tree] class DecisionTreeMetadata(
     val maxBins: Int,
     val featureArity: Map[Int, Int],
     val unorderedFeatures: Set[Int],
+    val numBins: Array[Int],
     val impurity: Impurity,
     val quantileStrategy: QuantileStrategy) extends Serializable {
@@ -57,10 +59,26 @@ private[tree] class DecisionTreeMetadata(
   def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
+  /**
+   * Number of splits for the given feature.
+   * For unordered features, there are 2 bins per split.
+   * For ordered features, there is 1 more bin than split.
+   */
+  def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
+    numBins(featureIndex) >> 1
+  } else {
+    numBins(featureIndex) - 1
+  }
 private[tree] object DecisionTreeMetadata {
+  /**
+   * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
+   * This computes which categorical features will be ordered vs. unordered,
+   * as well as the number of splits and bins for each feature.
+   */
   def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
     val numFeatures = input.take(1)(0).features.size
@@ -70,32 +88,55 @@ private[tree] object DecisionTreeMetadata {
       case Regression => 0
-    val maxBins = math.min(strategy.maxBins, numExamples).toInt
-    val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
+    val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
+    // We check the number of bins here against maxPossibleBins.
+    // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
+    // based on the number of training examples.
+    if (strategy.categoricalFeaturesInfo.nonEmpty) {
+      val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
+      require(maxCategoriesPerFeature <= maxPossibleBins,
+        s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
+          s"in categorical features (= $maxCategoriesPerFeature)")
+    }
     val unorderedFeatures = new mutable.HashSet[Int]()
+    val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
     if (numClasses > 2) {
-      strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
-        if (k - 1 < log2MaxBinsp1) {
-          // Note: The above check is equivalent to checking:
-          //       numUnorderedBins = (1 << k - 1) - 1 < maxBins
-          unorderedFeatures.add(f)
+      // Multiclass classification
+      val maxCategoriesForUnorderedFeature =
+        ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
+      strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
+        // Decide if some categorical features should be treated as unordered features,
+        //  which require 2 * ((1 << numCategories - 1) - 1) bins.
+        // We do this check with log values to prevent overflows in case numCategories is large.
+        // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
+        if (numCategories <= maxCategoriesForUnorderedFeature) {
+          unorderedFeatures.add(featureIndex)
+          numBins(featureIndex) = numUnorderedBins(numCategories)
         } else {
-          // TODO: Allow this case, where we simply will know nothing about some categories?
-          require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
-            s"in categorical features (>= $k)")
+          numBins(featureIndex) = numCategories
     } else {
-      strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
-        require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
-          s"in categorical features (>= $k)")
+      // Binary classification or regression
+      strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
+        numBins(featureIndex) = numCategories
-    new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
-      strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
+    new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
+      strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
       strategy.impurity, strategy.quantileCalculationStrategy)
+  /**
+   * Given the arity of a categorical feature (arity = number of categories),
+   * return the number of bins for the feature if it is to be treated as an unordered feature.
+   * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
+   * there are math.pow(2, arity - 1) - 1 such splits.
+   * Each split has 2 corresponding bins.
+   */
+  def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
index 170e43e..35e361a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -48,54 +48,63 @@ private[tree] object TreePoint {
    * binning feature values in preparation for DecisionTree training.
    * @param input     Input dataset.
    * @param bins      Bins for features, of size (numFeatures, numBins).
-   * @param metadata Learning and dataset metadata
+   * @param metadata  Learning and dataset metadata
    * @return  TreePoint dataset representation
   def convertToTreeRDD(
       input: RDD[LabeledPoint],
       bins: Array[Array[Bin]],
       metadata: DecisionTreeMetadata): RDD[TreePoint] = {
+    // Construct arrays for featureArity and isUnordered for efficiency in the inner loop.
+    val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
+    val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures)
+    var featureIndex = 0
+    while (featureIndex < metadata.numFeatures) {
+      featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
+      isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
+      featureIndex += 1
+    } { x =>
-      TreePoint.labeledPointToTreePoint(x, bins, metadata)
+      TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered)
    * Convert one LabeledPoint into its TreePoint representation.
    * @param bins      Bins for features, of size (numFeatures, numBins).
+   * @param featureArity  Array indexed by feature, with value 0 for continuous and numCategories
+   *                      for categorical features.
+   * @param isUnordered  Array index by feature, with value true for unordered categorical features.
   private def labeledPointToTreePoint(
       labeledPoint: LabeledPoint,
       bins: Array[Array[Bin]],
-      metadata: DecisionTreeMetadata): TreePoint = {
+      featureArity: Array[Int],
+      isUnordered: Array[Boolean]): TreePoint = {
     val numFeatures = labeledPoint.features.size
-    val numBins = bins(0).size
     val arr = new Array[Int](numFeatures)
     var featureIndex = 0
     while (featureIndex < numFeatures) {
-      arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
-        metadata.isUnordered(featureIndex), bins, metadata.featureArity)
+      arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
+        isUnordered(featureIndex), bins)
       featureIndex += 1
     new TreePoint(labeledPoint.label, arr)
    * Find bin for one (labeledPoint, feature).
+   * @param featureArity  0 for continuous features; number of categories for categorical features.
    * @param isUnorderedFeature  (only applies if feature is categorical)
    * @param bins   Bins for features, of size (numFeatures, numBins).
-   * @param categoricalFeaturesInfo  Map over categorical features: feature index --> feature arity
   private def findBin(
       featureIndex: Int,
       labeledPoint: LabeledPoint,
-      isFeatureContinuous: Boolean,
+      featureArity: Int,
       isUnorderedFeature: Boolean,
-      bins: Array[Array[Bin]],
-      categoricalFeaturesInfo: Map[Int, Int]): Int = {
+      bins: Array[Array[Bin]]): Int = {
      * Binary search helper method for continuous feature.
@@ -121,44 +130,7 @@ private[tree] object TreePoint {
-    /**
-     * Sequential search helper method to find bin for categorical feature in multiclass
-     * classification. The category is returned since each category can belong to multiple
-     * splits. The actual left/right child allocation per split is performed in the
-     * sequential phase of the bin aggregate operation.
-     */
-    def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
-      labeledPoint.features(featureIndex).toInt
-    }
-    /**
-     * Sequential search helper method to find bin for categorical feature
-     * (for classification and regression).
-     */
-    def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
-      val featureCategories = categoricalFeaturesInfo(featureIndex)
-      val featureValue = labeledPoint.features(featureIndex)
-      var binIndex = 0
-      while (binIndex < featureCategories) {
-        val bin = bins(featureIndex)(binIndex)
-        val categories = bin.highSplit.categories
-        if (categories.contains(featureValue)) {
-          return binIndex
-        }
-        binIndex += 1
-      }
-      if (featureValue < 0 || featureValue >= featureCategories) {
-        throw new IllegalArgumentException(
-          s"DecisionTree given invalid data:" +
-            s" Feature $featureIndex is categorical with values in" +
-            s" {0,...,${featureCategories - 1}," +
-            s" but a data point gives it value $featureValue.\n" +
-            "  Bad data point: " + labeledPoint.toString)
-      }
-      -1
-    }
-    if (isFeatureContinuous) {
+    if (featureArity == 0) {
       // Perform binary search for finding bin for continuous features.
       val binIndex = binarySearchForBins()
       if (binIndex == -1) {
@@ -168,18 +140,17 @@ private[tree] object TreePoint {
     } else {
-      // Perform sequential search to find bin for categorical features.
-      val binIndex = if (isUnorderedFeature) {
-          sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
-        } else {
-          sequentialBinSearchForOrderedCategoricalFeature()
-        }
-      if (binIndex == -1) {
-        throw new RuntimeException("No bin was found for categorical feature." +
-          " This error can occur when given invalid data values (such as NaN)." +
-          s" Feature index: $featureIndex.  Feature value: ${labeledPoint.features(featureIndex)}")
+      // Categorical feature bins are indexed by feature values.
+      val featureValue = labeledPoint.features(featureIndex)
+      if (featureValue < 0 || featureValue >= featureArity) {
+        throw new IllegalArgumentException(
+          s"DecisionTree given invalid data:" +
+            s" Feature $featureIndex is categorical with values in" +
+            s" {0,...,${featureArity - 1}," +
+            s" but a data point gives it value $featureValue.\n" +
+            "  Bad data point: " + labeledPoint.toString)
-      binIndex
+      featureValue.toInt
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 96d2471..1c8afc2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -74,3 +74,87 @@ object Entropy extends Impurity {
   def instance = this
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ * @param numClasses  Number of classes for label.
+ */
+private[tree] class EntropyAggregator(numClasses: Int)
+  extends ImpurityAggregator(numClasses) with Serializable {
+  /**
+   * Update stats for one (node, feature, bin) with the given label.
+   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+    if (label >= statsSize) {
+      throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
+        s" but requires label < numClasses (= $statsSize).")
+    }
+    allStats(offset + label.toInt) += 1
+  }
+  /**
+   * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
+    new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
+  }
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[EntropyAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats  Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+  /**
+   * Make a deep copy of this [[ImpurityCalculator]].
+   */
+  def copy: EntropyCalculator = new EntropyCalculator(stats.clone())
+  /**
+   * Calculate the impurity from the stored sufficient statistics.
+   */
+  def calculate(): Double = Entropy.calculate(stats, stats.sum)
+  /**
+   * Number of data points accounted for in the sufficient statistics.
+   */
+  def count: Long = stats.sum.toLong
+  /**
+   * Prediction which should be made based on the sufficient statistics.
+   */
+  def predict: Double = if (count == 0) {
+    0
+  } else {
+    indexOfLargestArrayElement(stats)
+  }
+  /**
+   * Probability of the label given by [[predict]].
+   */
+  override def prob(label: Double): Double = {
+    val lbl = label.toInt
+    require(lbl < stats.length,
+      s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+    val cnt = count
+    if (cnt == 0) {
+      0
+    } else {
+      stats(lbl) / cnt
+    }
+  }
+  override def toString: String = s"EntropyCalculator(stats = [${stats.mkString(", ")}])"
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index d586f44..5cfdf34 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -70,3 +70,87 @@ object Gini extends Impurity {
   def instance = this
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ * @param numClasses  Number of classes for label.
+ */
+private[tree] class GiniAggregator(numClasses: Int)
+  extends ImpurityAggregator(numClasses) with Serializable {
+  /**
+   * Update stats for one (node, feature, bin) with the given label.
+   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+    if (label >= statsSize) {
+      throw new IllegalArgumentException(s"GiniAggregator given label $label" +
+        s" but requires label < numClasses (= $statsSize).")
+    }
+    allStats(offset + label.toInt) += 1
+  }
+  /**
+   * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
+    new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
+  }
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[GiniAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats  Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+  /**
+   * Make a deep copy of this [[ImpurityCalculator]].
+   */
+  def copy: GiniCalculator = new GiniCalculator(stats.clone())
+  /**
+   * Calculate the impurity from the stored sufficient statistics.
+   */
+  def calculate(): Double = Gini.calculate(stats, stats.sum)
+  /**
+   * Number of data points accounted for in the sufficient statistics.
+   */
+  def count: Long = stats.sum.toLong
+  /**
+   * Prediction which should be made based on the sufficient statistics.
+   */
+  def predict: Double = if (count == 0) {
+    0
+  } else {
+    indexOfLargestArrayElement(stats)
+  }
+  /**
+   * Probability of the label given by [[predict]].
+   */
+  override def prob(label: Double): Double = {
+    val lbl = label.toInt
+    require(lbl < stats.length,
+      s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+    val cnt = count
+    if (cnt == 0) {
+      0
+    } else {
+      stats(lbl) / cnt
+    }
+  }
+  override def toString: String = s"GiniCalculator(stats = [${stats.mkString(", ")}])"
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 92b0c7b..5a047d6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -22,6 +22,9 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
  * :: Experimental ::
  * Trait for calculating information gain.
+ * This trait is used for
+ *  (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]]
+ *  (b) calculating impurity values from sufficient statistics.
 trait Impurity extends Serializable {
@@ -47,3 +50,127 @@ trait Impurity extends Serializable {
   def calculate(count: Double, sum: Double, sumSquares: Double): Double
+ * Interface for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ * @param statsSize  Length of the vector of sufficient statistics for one bin.
+ */
+private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {
+  /**
+   * Merge the stats from one bin into another.
+   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
+   * @param offset    Start index of stats for (node, feature, bin) which is modified by the merge.
+   * @param otherOffset  Start index of stats for (node, feature, other bin) which is not modified.
+   */
+  def merge(allStats: Array[Double], offset: Int, otherOffset: Int): Unit = {
+    var i = 0
+    while (i < statsSize) {
+      allStats(offset + i) += allStats(otherOffset + i)
+      i += 1
+    }
+  }
+  /**
+   * Update stats for one (node, feature, bin) with the given label.
+   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def update(allStats: Array[Double], offset: Int, label: Double): Unit
+  /**
+   * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[ImpurityAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats  Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) {
+  /**
+   * Make a deep copy of this [[ImpurityCalculator]].
+   */
+  def copy: ImpurityCalculator
+  /**
+   * Calculate the impurity from the stored sufficient statistics.
+   */
+  def calculate(): Double
+  /**
+   * Add the stats from another calculator into this one, modifying and returning this calculator.
+   */
+  def add(other: ImpurityCalculator): ImpurityCalculator = {
+    require(stats.size == other.stats.size,
+      s"Two ImpurityCalculator instances cannot be added with different counts sizes." +
+        s"  Sizes are ${stats.size} and ${other.stats.size}.")
+    var i = 0
+    while (i < other.stats.size) {
+      stats(i) += other.stats(i)
+      i += 1
+    }
+    this
+  }
+  /**
+   * Subtract the stats from another calculator from this one, modifying and returning this
+   * calculator.
+   */
+  def subtract(other: ImpurityCalculator): ImpurityCalculator = {
+    require(stats.size == other.stats.size,
+      s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." +
+      s"  Sizes are ${stats.size} and ${other.stats.size}.")
+    var i = 0
+    while (i < other.stats.size) {
+      stats(i) -= other.stats(i)
+      i += 1
+    }
+    this
+  }
+  /**
+   * Number of data points accounted for in the sufficient statistics.
+   */
+  def count: Long
+  /**
+   * Prediction which should be made based on the sufficient statistics.
+   */
+  def predict: Double
+  /**
+   * Probability of the label given by [[predict]], or -1 if no probability is available.
+   */
+  def prob(label: Double): Double = -1
+  /**
+   * Return the index of the largest array element.
+   * Fails if the array is empty.
+   */
+  protected def indexOfLargestArrayElement(array: Array[Double]): Int = {
+    val result = array.foldLeft(-1, Double.MinValue, 0) {
+      case ((maxIndex, maxValue, currentIndex), currentValue) =>
+        if (currentValue > maxValue) {
+          (currentIndex, currentValue, currentIndex + 1)
+        } else {
+          (maxIndex, maxValue, currentIndex + 1)
+        }
+    }
+    if (result._1 < 0) {
+      throw new RuntimeException("ImpurityCalculator internal error:" +
+        " indexOfLargestArrayElement failed")
+    }
+    result._1
+  }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index f7d99a4..e9ccecb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -61,3 +61,75 @@ object Variance extends Impurity {
   def instance = this
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ */
+private[tree] class VarianceAggregator()
+  extends ImpurityAggregator(statsSize = 3) with Serializable {
+  /**
+   * Update stats for one (node, feature, bin) with the given label.
+   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+    allStats(offset) += 1
+    allStats(offset + 1) += label
+    allStats(offset + 2) += label * label
+  }
+  /**
+   * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
+    new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
+  }
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[GiniAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats  Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+  require(stats.size == 3,
+    s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
+    s" but was given array of length ${stats.size}.")
+  /**
+   * Make a deep copy of this [[ImpurityCalculator]].
+   */
+  def copy: VarianceCalculator = new VarianceCalculator(stats.clone())
+  /**
+   * Calculate the impurity from the stored sufficient statistics.
+   */
+  def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
+  /**
+   * Number of data points accounted for in the sufficient statistics.
+   */
+  def count: Long = stats(0).toLong
+  /**
+   * Prediction which should be made based on the sufficient statistics.
+   */
+  def predict: Double = if (count == 0) {
+    0
+  } else {
+    stats(1) / count
+  }
+  override def toString: String = {
+    s"VarianceAggregator(cnt = ${stats(0)}, sum = ${stats(1)}, sum2 = ${stats(2)})"
+  }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
index af35d88..0cad473 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.model
 import org.apache.spark.mllib.tree.configuration.FeatureType._
- * Used for "binning" the features bins for faster best split calculation.
+ * Used for "binning" the feature values for faster best split calculation.
  * For a continuous feature, the bin is determined by a low and a high split,
  *  where an example with featureValue falls into the bin s.t.
@@ -30,13 +30,16 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
  *  bins, splits, and feature values.  The bin is determined by category/feature value.
  *  However, the bins are not necessarily ordered by feature value;
  *  they are ordered using impurity.
+ *
  * For unordered categorical features, there is a 1-1 correspondence between bins, splits,
  *  where bins and splits correspond to subsets of feature values (in highSplit.categories).
+ *  An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all
+ *  partitionings of categories into 2 disjoint, non-empty sets.
  * @param lowSplit signifying the lower threshold for the continuous feature to be
  *                 accepted in the bin
  * @param highSplit signifying the upper threshold for the continuous feature to be
- *                 accepted in the bin
+ *                  accepted in the bin
  * @param featureType type of feature -- categorical or continuous
  * @param category categorical label value accepted in the bin for ordered features
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 0eee626..5b8a4cb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -24,8 +24,13 @@ import org.apache.spark.mllib.linalg.Vector
  * :: DeveloperApi ::
- * Node in a decision tree
- * @param id integer node id
+ * Node in a decision tree.
+ *
+ * About node indexing:
+ *   Nodes are indexed from 1.  Node 1 is the root; nodes 2, 3 are the left, right children.
+ *   Node index 0 is not used.
+ *
+ * @param id integer node id, from 1
  * @param predict predicted value at the node
  * @param isLeaf whether the leaf is a node
  * @param split split to calculate left and right nodes
@@ -51,17 +56,13 @@ class Node (
    * @param nodes array of nodes
   def build(nodes: Array[Node]): Unit = {
-    logDebug("building node " + id + " at level " +
-      (scala.math.log(id + 1)/scala.math.log(2)).toInt )
+    logDebug("building node " + id + " at level " + Node.indexToLevel(id))
     logDebug("id = " + id + ", split = " + split)
     logDebug("stats = " + stats)
     logDebug("predict = " + predict)
     if (!isLeaf) {
-      val leftNodeIndex = id * 2 + 1
-      val rightNodeIndex = id * 2 + 2
-      leftNode = Some(nodes(leftNodeIndex))
-      rightNode = Some(nodes(rightNodeIndex))
+      leftNode = Some(nodes(Node.leftChildIndex(id)))
+      rightNode = Some(nodes(Node.rightChildIndex(id)))
@@ -96,24 +97,20 @@ class Node (
    * Get the number of nodes in tree below this node, including leaf nodes.
    * E.g., if this is a leaf, returns 0.  If both children are leaves, returns 2.
-  private[tree] def numDescendants: Int = {
-    if (isLeaf) {
-      0
-    } else {
-      2 + leftNode.get.numDescendants + rightNode.get.numDescendants
-    }
+  private[tree] def numDescendants: Int = if (isLeaf) {
+    0
+  } else {
+    2 + leftNode.get.numDescendants + rightNode.get.numDescendants
    * Get depth of tree from this node.
    * E.g.: Depth 0 means this is a leaf node.
-  private[tree] def subtreeDepth: Int = {
-    if (isLeaf) {
-      0
-    } else {
-      1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth)
-    }
+  private[tree] def subtreeDepth: Int = if (isLeaf) {
+    0
+  } else {
+    1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth)
@@ -148,3 +145,49 @@ class Node (
+private[tree] object Node {
+  /**
+   * Return the index of the left child of this node.
+   */
+  def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
+  /**
+   * Return the index of the right child of this node.
+   */
+  def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
+  /**
+   * Get the parent index of the given node, or 0 if it is the root.
+   */
+  def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1
+  /**
+   * Return the level of a tree which the given node is in.
+   */
+  def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) {
+    throw new IllegalArgumentException(s"0 is not a valid node index.")
+  } else {
+    java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex))
+  }
+  /**
+   * Returns true if this is a left child.
+   * Note: Returns false for the root.
+   */
+  def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0
+  /**
+   * Return the maximum number of nodes which can be in the given level of the tree.
+   * @param level  Level of tree (0 = root).
+   */
+  def maxNodesInLevel(level: Int): Int = 1 << level
+  /**
+   * Return the index of the first node in the given level.
+   * @param level  Level of tree (0 = root).
+   */
+  def startIndexInLevel(level: Int): Int = 1 << level
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 2f36fd9..8e556c9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -21,15 +21,16 @@ import scala.collection.JavaConverters._
 import org.scalatest.FunSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
-import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.LocalSparkContext
-import org.apache.spark.mllib.regression.LabeledPoint
 class DecisionTreeSuite extends FunSuite with LocalSparkContext {
@@ -59,12 +60,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
-  test("split and bin calculation") {
+  test("Binary classification with continuous features: split and bin calculation") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Gini, 3, 2, 100)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(bins.length === 2)
@@ -72,7 +74,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
-  test("split and bin calculation for categorical variables") {
+  test("Binary classification with binary (ordered) categorical features:" +
+    " split and bin calculation") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -83,77 +86,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     assert(splits.length === 2)
     assert(bins.length === 2)
-    assert(splits(0).length === 99)
-    assert(bins(0).length === 100)
-    // Check splits.
-    assert(splits(0)(0).feature === 0)
-    assert(splits(0)(0).threshold === Double.MinValue)
-    assert(splits(0)(0).featureType === Categorical)
-    assert(splits(0)(0).categories.length === 1)
-    assert(splits(0)(0).categories.contains(1.0))
-    assert(splits(0)(1).feature === 0)
-    assert(splits(0)(1).threshold === Double.MinValue)
-    assert(splits(0)(1).featureType === Categorical)
-    assert(splits(0)(1).categories.length === 2)
-    assert(splits(0)(1).categories.contains(1.0))
-    assert(splits(0)(1).categories.contains(0.0))
-    assert(splits(0)(2) === null)
-    assert(splits(1)(0).feature === 1)
-    assert(splits(1)(0).threshold === Double.MinValue)
-    assert(splits(1)(0).featureType === Categorical)
-    assert(splits(1)(0).categories.length === 1)
-    assert(splits(1)(0).categories.contains(0.0))
-    assert(splits(1)(1).feature === 1)
-    assert(splits(1)(1).threshold === Double.MinValue)
-    assert(splits(1)(1).featureType === Categorical)
-    assert(splits(1)(1).categories.length === 2)
-    assert(splits(1)(1).categories.contains(1.0))
-    assert(splits(1)(1).categories.contains(0.0))
-    assert(splits(1)(2) === null)
-    // Check bins.
-    assert(bins(0)(0).category === 1.0)
-    assert(bins(0)(0).lowSplit.categories.length === 0)
-    assert(bins(0)(0).highSplit.categories.length === 1)
-    assert(bins(0)(0).highSplit.categories.contains(1.0))
-    assert(bins(0)(1).category === 0.0)
-    assert(bins(0)(1).lowSplit.categories.length === 1)
-    assert(bins(0)(1).lowSplit.categories.contains(1.0))
-    assert(bins(0)(1).highSplit.categories.length === 2)
-    assert(bins(0)(1).highSplit.categories.contains(1.0))
-    assert(bins(0)(1).highSplit.categories.contains(0.0))
-    assert(bins(0)(2) === null)
-    assert(bins(1)(0).category === 0.0)
-    assert(bins(1)(0).lowSplit.categories.length === 0)
-    assert(bins(1)(0).highSplit.categories.length === 1)
-    assert(bins(1)(0).highSplit.categories.contains(0.0))
-    assert(bins(1)(1).category === 1.0)
-    assert(bins(1)(1).lowSplit.categories.length === 1)
-    assert(bins(1)(1).lowSplit.categories.contains(0.0))
-    assert(bins(1)(1).highSplit.categories.length === 2)
-    assert(bins(1)(1).highSplit.categories.contains(0.0))
-    assert(bins(1)(1).highSplit.categories.contains(1.0))
-    assert(bins(1)(2) === null)
+    // no bins or splits pre-computed for ordered categorical features
+    assert(splits(0).length === 0)
+    assert(bins(0).length === 0)
-  test("split and bin calculations for categorical variables with no sample for one category") {
+  test("Binary classification with 3-ary (ordered) categorical features," +
+    " with no samples for one category") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -164,104 +110,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    // Check splits.
-    assert(splits(0)(0).feature === 0)
-    assert(splits(0)(0).threshold === Double.MinValue)
-    assert(splits(0)(0).featureType === Categorical)
-    assert(splits(0)(0).categories.length === 1)
-    assert(splits(0)(0).categories.contains(1.0))
-    assert(splits(0)(1).feature === 0)
-    assert(splits(0)(1).threshold === Double.MinValue)
-    assert(splits(0)(1).featureType === Categorical)
-    assert(splits(0)(1).categories.length === 2)
-    assert(splits(0)(1).categories.contains(1.0))
-    assert(splits(0)(1).categories.contains(0.0))
-    assert(splits(0)(2).feature === 0)
-    assert(splits(0)(2).threshold === Double.MinValue)
-    assert(splits(0)(2).featureType === Categorical)
-    assert(splits(0)(2).categories.length === 3)
-    assert(splits(0)(2).categories.contains(1.0))
-    assert(splits(0)(2).categories.contains(0.0))
-    assert(splits(0)(2).categories.contains(2.0))
-    assert(splits(0)(3) === null)
-    assert(splits(1)(0).feature === 1)
-    assert(splits(1)(0).threshold === Double.MinValue)
-    assert(splits(1)(0).featureType === Categorical)
-    assert(splits(1)(0).categories.length === 1)
-    assert(splits(1)(0).categories.contains(0.0))
-    assert(splits(1)(1).feature === 1)
-    assert(splits(1)(1).threshold === Double.MinValue)
-    assert(splits(1)(1).featureType === Categorical)
-    assert(splits(1)(1).categories.length === 2)
-    assert(splits(1)(1).categories.contains(1.0))
-    assert(splits(1)(1).categories.contains(0.0))
-    assert(splits(1)(2).feature === 1)
-    assert(splits(1)(2).threshold === Double.MinValue)
-    assert(splits(1)(2).featureType === Categorical)
-    assert(splits(1)(2).categories.length === 3)
-    assert(splits(1)(2).categories.contains(1.0))
-    assert(splits(1)(2).categories.contains(0.0))
-    assert(splits(1)(2).categories.contains(2.0))
-    assert(splits(1)(3) === null)
-    // Check bins.
-    assert(bins(0)(0).category === 1.0)
-    assert(bins(0)(0).lowSplit.categories.length === 0)
-    assert(bins(0)(0).highSplit.categories.length === 1)
-    assert(bins(0)(0).highSplit.categories.contains(1.0))
-    assert(bins(0)(1).category === 0.0)
-    assert(bins(0)(1).lowSplit.categories.length === 1)
-    assert(bins(0)(1).lowSplit.categories.contains(1.0))
-    assert(bins(0)(1).highSplit.categories.length === 2)
-    assert(bins(0)(1).highSplit.categories.contains(1.0))
-    assert(bins(0)(1).highSplit.categories.contains(0.0))
-    assert(bins(0)(2).category === 2.0)
-    assert(bins(0)(2).lowSplit.categories.length === 2)
-    assert(bins(0)(2).lowSplit.categories.contains(1.0))
-    assert(bins(0)(2).lowSplit.categories.contains(0.0))
-    assert(bins(0)(2).highSplit.categories.length === 3)
-    assert(bins(0)(2).highSplit.categories.contains(1.0))
-    assert(bins(0)(2).highSplit.categories.contains(0.0))
-    assert(bins(0)(2).highSplit.categories.contains(2.0))
-    assert(bins(0)(3) === null)
-    assert(bins(1)(0).category === 0.0)
-    assert(bins(1)(0).lowSplit.categories.length === 0)
-    assert(bins(1)(0).highSplit.categories.length === 1)
-    assert(bins(1)(0).highSplit.categories.contains(0.0))
-    assert(bins(1)(1).category === 1.0)
-    assert(bins(1)(1).lowSplit.categories.length === 1)
-    assert(bins(1)(1).lowSplit.categories.contains(0.0))
-    assert(bins(1)(1).highSplit.categories.length === 2)
-    assert(bins(1)(1).highSplit.categories.contains(0.0))
-    assert(bins(1)(1).highSplit.categories.contains(1.0))
-    assert(bins(1)(2).category === 2.0)
-    assert(bins(1)(2).lowSplit.categories.length === 2)
-    assert(bins(1)(2).lowSplit.categories.contains(0.0))
-    assert(bins(1)(2).lowSplit.categories.contains(1.0))
-    assert(bins(1)(2).highSplit.categories.length === 3)
-    assert(bins(1)(2).highSplit.categories.contains(0.0))
-    assert(bins(1)(2).highSplit.categories.contains(1.0))
-    assert(bins(1)(2).highSplit.categories.contains(2.0))
-    assert(bins(1)(3) === null)
+    assert(splits.length === 2)
+    assert(bins.length === 2)
+    // no bins or splits pre-computed for ordered categorical features
+    assert(splits(0).length === 0)
+    assert(bins(0).length === 0)
   test("extract categories from a number for multiclass classification") {
@@ -270,8 +128,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
-  test("split and bin calculations for unordered categorical variables with multiclass " +
-    "classification") {
+  test("Multiclass classification with unordered categorical features:" +
+      " split and bin calculations") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -282,8 +140,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 100,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(metadata.isUnordered(featureIndex = 0))
+    assert(metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    assert(splits.length === 2)
+    assert(bins.length === 2)
+    assert(splits(0).length === 3)
+    assert(bins(0).length === 6)
     // Expecting 2^2 - 1 = 3 bins/splits
     assert(splits(0)(0).feature === 0)
@@ -321,10 +186,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
-    assert(splits(0)(3) === null)
-    assert(splits(1)(3) === null)
     // Check bins.
     assert(bins(0)(0).category === Double.MinValue)
@@ -360,13 +221,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
-    assert(bins(0)(3) === null)
-    assert(bins(1)(3) === null)
-  test("split and bin calculations for ordered categorical variables with multiclass " +
-    "classification") {
+  test("Multiclass classification with ordered categorical features: split and bin calculations") {
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
     assert(arr.length === 3000)
     val rdd = sc.parallelize(arr)
@@ -377,52 +234,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 100,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
+    // 2^10 - 1 > 100, so categorical features will be ordered
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    // 2^10 - 1 > 100, so categorical variables will be ordered
-    assert(splits(0)(0).feature === 0)
-    assert(splits(0)(0).threshold === Double.MinValue)
-    assert(splits(0)(0).featureType === Categorical)
-    assert(splits(0)(0).categories.length === 1)
-    assert(splits(0)(0).categories.contains(1.0))
-    assert(splits(0)(1).feature === 0)
-    assert(splits(0)(1).threshold === Double.MinValue)
-    assert(splits(0)(1).featureType === Categorical)
-    assert(splits(0)(1).categories.length === 2)
-    assert(splits(0)(1).categories.contains(2.0))
-    assert(splits(0)(2).feature === 0)
-    assert(splits(0)(2).threshold === Double.MinValue)
-    assert(splits(0)(2).featureType === Categorical)
-    assert(splits(0)(2).categories.length === 3)
-    assert(splits(0)(2).categories.contains(2.0))
-    assert(splits(0)(2).categories.contains(1.0))
-    assert(splits(0)(10) === null)
-    assert(splits(1)(10) === null)
-    // Check bins.
-    assert(bins(0)(0).category === 1.0)
-    assert(bins(0)(0).lowSplit.categories.length === 0)
-    assert(bins(0)(0).highSplit.categories.length === 1)
-    assert(bins(0)(0).highSplit.categories.contains(1.0))
-    assert(bins(0)(1).category === 2.0)
-    assert(bins(0)(1).lowSplit.categories.length === 1)
-    assert(bins(0)(1).highSplit.categories.length === 2)
-    assert(bins(0)(1).highSplit.categories.contains(1.0))
-    assert(bins(0)(1).highSplit.categories.contains(2.0))
-    assert(bins(0)(10) === null)
+    assert(splits.length === 2)
+    assert(bins.length === 2)
+    // no bins or splits pre-computed for ordered categorical features
+    assert(splits(0).length === 0)
+    assert(bins(0).length === 0)
-  test("classification stump with all categorical variables") {
+  test("Binary classification stump with ordered categorical features") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -433,15 +259,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       maxDepth = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    assert(splits.length === 2)
+    assert(bins.length === 2)
+    // no bins or splits pre-computed for ordered categorical features
+    assert(splits(0).length === 0)
+    assert(bins(0).length === 0)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
       new Array[Node](0), splits, bins, 10)
     val split = bestSplits(0)._1
-    assert(split.categories.length === 1)
-    assert(split.categories.contains(1.0))
+    assert(split.categories === List(1.0))
     assert(split.featureType === Categorical)
     assert(split.threshold === Double.MinValue)
@@ -452,7 +286,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(stats.impurity > 0.2)
-  test("regression stump with all categorical variables") {
+  test("Regression stump with 3-ary (ordered) categorical features") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -462,10 +296,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       maxDepth = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
       new Array[Node](0), splits, bins, 10)
     val split = bestSplits(0)._1
@@ -480,7 +318,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(stats.impurity > 0.2)
-  test("regression stump with categorical variables of arity 2") {
+  test("Regression stump with binary (ordered) categorical features") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -490,6 +328,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       maxDepth = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val model = DecisionTree.train(rdd, strategy)
     validateRegressor(model, arr, 0.0)
@@ -497,12 +338,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(model.depth === 1)
-  test("stump with fixed label 0 for Gini") {
+  test("Binary classification stump with fixed label 0 for Gini") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
+    val strategy = new Strategy(Classification, Gini, maxDepth = 3,
+      numClassesForClassification = 2, maxBins = 100)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -512,7 +357,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -521,12 +366,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.rightImpurity === 0)
-  test("stump with fixed label 1 for Gini") {
+  test("Binary classification stump with fixed label 1 for Gini") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
+    val strategy = new Strategy(Classification, Gini, maxDepth = 3,
+      numClassesForClassification = 2, maxBins = 100)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -536,7 +385,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -546,12 +395,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.predict === 1)
-  test("stump with fixed label 0 for Entropy") {
+  test("Binary classification stump with fixed label 0 for Entropy") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+    val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
+      numClassesForClassification = 2, maxBins = 100)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -561,7 +414,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -571,12 +424,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.predict === 0)
-  test("stump with fixed label 1 for Entropy") {
+  test("Binary classification stump with fixed label 1 for Entropy") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+    val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
+      numClassesForClassification = 2, maxBins = 100)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -586,7 +443,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -596,7 +453,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.predict === 1)
-  test("second level node building with/without groups") {
+  test("Second level node building with vs. without groups") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -613,12 +470,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     // Train a 1-node model
     val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
     val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
-    val nodes: Array[Node] = new Array[Node](7)
-    nodes(0) = modelOneNode.topNode
-    nodes(0).leftNode = None
-    nodes(0).rightNode = None
+    val nodes: Array[Node] = new Array[Node](8)
+    nodes(1) = modelOneNode.topNode
+    nodes(1).leftNode = None
+    nodes(1).rightNode = None
-    val parentImpurities = Array(0.5, 0.5, 0.5)
+    val parentImpurities = Array(0, 0.5, 0.5, 0.5)
     // Single group second level tree construction.
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
@@ -648,16 +505,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
-  test("stump with categorical variables for multiclass classification") {
+  test("Multiclass classification stump with 3-ary (unordered) categorical features") {
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(metadata.isUnordered(featureIndex = 0))
+    assert(metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
@@ -668,7 +528,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplit.featureType === Categorical)
-  test("stump with 1 continuous variable for binary classification, to check off-by-1 error") {
+  test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
     val arr = new Array[LabeledPoint](4)
     arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0))
     arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
@@ -684,26 +544,27 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(model.depth === 1)
-  test("stump with 2 continuous variables for binary classification") {
+  test("Binary classification stump with 2 continuous features") {
     val arr = new Array[LabeledPoint](4)
     arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
     arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
     arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
     arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
       numClassesForClassification = 2)
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
     assert(model.topNode.split.get.feature === 1)
-  test("stump with categorical variables for multiclass classification, with just enough bins") {
-    val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features
+  test("Multiclass classification stump with unordered categorical features," +
+    " with just enough bins") {
+    val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -711,6 +572,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(metadata.isUnordered(featureIndex = 0))
+    assert(metadata.isUnordered(featureIndex = 1))
     val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 1.0)
@@ -719,7 +582,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
@@ -733,7 +596,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(gain.rightImpurity === 0)
-  test("stump with continuous variables for multiclass classification") {
+  test("Multiclass classification stump with continuous features") {
     val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -746,7 +609,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
@@ -759,20 +622,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
-  test("stump with continuous + categorical variables for multiclass classification") {
+  test("Multiclass classification stump with continuous + unordered categorical features") {
     val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(metadata.isUnordered(featureIndex = 0))
     val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 0.9)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
@@ -784,17 +648,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplit.threshold < 2020)
-  test("stump with categorical variables for ordered multiclass classification") {
+  test("Multiclass classification stump with 10-ary (ordered) categorical features") {
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
@@ -805,6 +671,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplit.featureType === Categorical)
+  test("Multiclass classification tree with 10-ary (ordered) categorical features," +
+      " with just enough bins") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+    val rdd = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
+      numClassesForClassification = 3, maxBins = 10,
+      categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+    assert(strategy.isMulticlassClassification)
+    val model = DecisionTree.train(rdd, strategy)
+    validateClassifier(model, arr, 0.6)
+  }
@@ -899,5 +777,4 @@ object DecisionTreeSuite {

To unsubscribe, e-mail:
For additional commands, e-mail:

[2/2] git commit: [SPARK-3086] [SPARK-3043] [SPARK-3156] [mllib] DecisionTree aggregation improvements

Posted by
[SPARK-3086] [SPARK-3043] [SPARK-3156] [mllib]  DecisionTree aggregation improvements

1. Variable numBins for each feature [SPARK-3043]
2. Reduced data reshaping in aggregation [SPARK-3043]
3. Choose ordering for ordered categorical features adaptively [SPARK-3156]
4. Changed nodes to use 1-indexing [SPARK-3086]
5. Small clean-ups

Note: This PR looks bigger than it is since I moved several functions from inside findBestSplitsPerGroup to outside of it (to make it clear what was being serialized in the aggregation).

Speedups: This update helps most when many features use few bins but a few features use many bins.  Some example results on speedups with 2M examples, 3.5K features (15-worker EC2 cluster):
* Example where old code was reasonably efficient (1/2 continuous, 1/4 binary, 1/4 20-category): 164.813 --> 116.491 sec
* Example where old code wasted many bins (1/10 continuous, 81/100 binary, 9/100 20-category): 128.701 --> 39.334 sec


(1) Variable numBins for each feature [SPARK-3043]

DecisionTreeMetadata now computes a variable numBins for each feature.  It also tracks numSplits.

(2) Reduced data reshaping in aggregation [SPARK-3043]

Added DTStatsAggregator, a wrapper around the aggregate statistics array for easy but efficient indexing.
* Added ImpurityAggregator and ImpurityCalculator classes, to make DecisionTree code more oblivious to the type of impurity.
* Design note: I originally tried creating Impurity classes which stored data and storing the aggregates in an Array[Array[Array[Impurity]]].  However, this led to significant slowdowns, perhaps because of overhead in creating so many objects.

The aggregate statistics are never reshaped, and cumulative sums are computed in-place.

Updated the layout of aggregation functions.  The update simplifies things by (1) dividing features into ordered/unordered (instead of ordered/unordered/continuous) and (2) making use of the DTStatsAggregator for indexing.
For this update, the following functions were refactored:
* updateBinForOrderedFeature
* updateBinForUnorderedFeature
* binaryOrNotCategoricalBinSeqOp
* multiclassWithCategoricalBinSeqOp
* regressionBinSeqOp
The above 5 functions were replaced with:
* orderedBinSeqOp
* someUnorderedBinSeqOp

Other changes:
* calculateGainForSplit now treats all feature types the same way.
* Eliminated extractLeftRightNodeAggregates.

(3) Choose ordering for ordered categorical features adaptively [SPARK-3156]

Updated binsToBestSplit():
* This now computes cumulative sums of stats for ordered features.
* For ordered categorical features, it chooses an ordering for categories. (This uses to be done by findSplitsBins.)
* Uses iterators to shorten code and avoid building an Array[Array[InformationGainStats]].

Side effects:
* In findSplitsBins: A sample of the data is only taken for data with continuous features.  It is not needed for data with only categorical features.
* In findSplitsBins: splits and bins are no longer pre-computed for ordered categorical features since they are not needed.
* TreePoint binning is simpler for categorical features.

(4) Changed nodes to use 1-indexing [SPARK-3086]

Nodes used to be indexed from 0.  Now they are indexed from 1.
Node indexing functions are now collected in object Node (Node.scala).

(5) Small clean-ups

Eliminated functions extractNodeInfo() and extractInfoForLowerLevels() to reduce duplicate code.
Eliminated InvalidBinIndex since it is no longer used.

CC: mengxr  manishamde  Please let me know if you have thoughts on this—thanks!

Author: Joseph K. Bradley <>

Closes #2125 from jkbradley/dt-opt3alt and squashes the following commits:

42c192a [Joseph K. Bradley] Merge branch 'rfs' into dt-opt3alt
d3cc46b [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt
00e4404 [Joseph K. Bradley] optimization for TreePoint construction (pre-computing featureArity and isUnordered as arrays)
425716c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs
a2acea5 [Joseph K. Bradley] Small optimizations based on profiling
aa4e4df [Joseph K. Bradley] Updated DTStatsAggregator with bug fix (nodeString should not be multiplied by statsSize)
4651154 [Joseph K. Bradley] Changed numBins semantics for unordered features. * Before: numBins = numSplits = (1 << k - 1) - 1 * Now: numBins = 2 * numSplits = 2 * [(1 << k - 1) - 1] * This also involved changing the semantics of: ** DecisionTreeMetadata.numUnorderedBins()
1e3b1c7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt
1485fcc [Joseph K. Bradley] Made some DecisionTree methods private.
92f934f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt
e676da1 [Joseph K. Bradley] Updated documentation for DecisionTree
37ca845 [Joseph K. Bradley] Fixed problem with how DecisionTree handles ordered categorical	features.
105f8ab [Joseph K. Bradley] Removed commented-out getEmptyBinAggregates from DecisionTree
062c31d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt
6d32ccd [Joseph K. Bradley] In DecisionTree.binsToBestSplit, changed loops to iterators to shorten code.
807cd00 [Joseph K. Bradley] Finished DTStatsAggregator, a wrapper around the aggregate statistics for easy but hopefully efficient indexing.  Modified old ImpurityAggregator classes and renamed them ImpurityCalculator; added ImpurityAggregator classes which work with DTStatsAggregator but do not store data.  Unit tests all succeed.
f2166fd [Joseph K. Bradley] still working on DTStatsAggregator
92f7118 [Joseph K. Bradley] Added partly written DTStatsAggregator
fd8df30 [Joseph K. Bradley] Moved some aggregation helpers outside of findBestSplitsPerGroup
d7c53ee [Joseph K. Bradley] Added more doc for ImpurityAggregator
a40f8f1 [Joseph K. Bradley] Changed nodes to be indexed from 1.  Tests work.
95cad7c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3
5f94342 [Joseph K. Bradley] Added treeAggregate since not yet merged from master.  Moved node indexing functions to Node.
61c4509 [Joseph K. Bradley] Fixed bugs from merge: missing DT timer call, and numBins setting.  Cleaned up DT Suite some.
3ba7166 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3
b314659 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3
9c83363 [Joseph K. Bradley] partial merge but not done yet
45f7ea7 [Joseph K. Bradley] partial merge, not yet done
5fce635 [Joseph K. Bradley] Merge branch 'dt-opt2' into dt-opt3
26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used.  Removed debugging println calls in DecisionTree.scala.
356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2
430d782 [Joseph K. Bradley] Added more debug info on binning error.  Added some docs.
d036089 [Joseph K. Bradley] Print timing info to logDebug.
e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private
8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up.  Removed debugging println calls from DecisionTree.  Made TreePoint extend Serialiable
a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1
dd4d3aa [Joseph K. Bradley] Mid-process in bug fix: bug for binary classification with categorical features * Bug: Categorical features were all treated as ordered for binary classification.  This is possible but would require the bin ordering to be determined on-the-fly after the aggregation.  Currently, the ordering is determined a priori and fixed for all splits. * (Temp) Fix: Treat low-arity categorical features as unordered for binary classification. * Related change: I removed most tests for isMulticlass in the code.  I instead test metadata for whether there are unordered features. * Status: The bug may be fixed, but more testing needs to be done.
438a660 [Joseph K. Bradley] removed subsampling for mnist8m from DT
86e217f [Joseph K. Bradley] added cache to DT input
e3c84cc [Joseph K. Bradley] Added stuff fro mnist8m to D T Runner
51ef781 [Joseph K. Bradley] Fixed bug introduced by last commit: Variance impurity calculation was incorrect since counts were swapped accidentally
fd65372 [Joseph K. Bradley] Major changes: * Created ImpurityAggregator classes, rather than old aggregates. * Feature split/bin semantics are based on ordered vs. unordered ** E.g.: numSplits = numBins for all unordered features, and numSplits = numBins - 1 for all ordered features. * numBins can differ for each feature
c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification
b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes
b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt
0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree
3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging)
f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
a95bc22 [Joseph K. Bradley] timing for DecisionTree internals


Branch: refs/heads/master
Commit: 711356b422c66e2a80377a9f43fce97282460520
Parents: 0d1cc4a
Author: Joseph K. Bradley <>
Authored: Mon Sep 8 09:47:13 2014 -0700
Committer: Xiangrui Meng <>
Committed: Mon Sep 8 09:47:13 2014 -0700

 .../apache/spark/mllib/tree/DecisionTree.scala  | 1341 ++++++------------
 .../mllib/tree/impl/DTStatsAggregator.scala     |  213 +++
 .../mllib/tree/impl/DecisionTreeMetadata.scala  |   73 +-
 .../spark/mllib/tree/impl/TreePoint.scala       |   93 +-
 .../spark/mllib/tree/impurity/Entropy.scala     |   84 ++
 .../apache/spark/mllib/tree/impurity/Gini.scala |   84 ++
 .../spark/mllib/tree/impurity/Impurity.scala    |  127 ++
 .../spark/mllib/tree/impurity/Variance.scala    |   72 +
 .../org/apache/spark/mllib/tree/model/Bin.scala |    7 +-
 .../apache/spark/mllib/tree/model/Node.scala    |   85 +-
 .../spark/mllib/tree/DecisionTreeSuite.scala    |  391 ++---
 11 files changed, 1322 insertions(+), 1248 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 5cdd258..dd766c1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -28,8 +28,9 @@ import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint}
+import org.apache.spark.mllib.tree.impl._
 import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
+import org.apache.spark.mllib.tree.impurity._
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
@@ -65,36 +66,41 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     val retaggedInput = input.retag(classOf[LabeledPoint])
     val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy)
     logDebug("algo = " + strategy.algo)
+    logDebug("maxBins = " + metadata.maxBins)
     // Find the splits and the corresponding bins (interval between the splits) using a sample
     // of the input data.
     val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
-    val numBins = bins(0).length
-    logDebug("numBins = " + numBins)
+    logDebug("numBins: feature: number of bins")
+    logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+        s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+      }.mkString("\n"))
     // Bin feature values (TreePoint representation).
     // Cache input RDD for speedup during multiple passes.
     val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
-    val numFeatures = metadata.numFeatures
     // depth of the decision tree
     val maxDepth = strategy.maxDepth
-    // the max number of nodes possible given the depth of the tree
-    val maxNumNodes = (2 << maxDepth) - 1
+    require(maxDepth <= 30,
+      s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+    // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1
+    val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1)
     // Initialize an array to hold parent impurity calculations for each node.
-    val parentImpurities = new Array[Double](maxNumNodes)
+    val parentImpurities = new Array[Double](maxNumNodesPlus1)
     // dummy value for top node (updated during first split calculation)
-    val nodes = new Array[Node](maxNumNodes)
+    val nodes = new Array[Node](maxNumNodesPlus1)
     // Calculate level for single group construction
     // Max memory usage for aggregates
     val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
     logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
-    val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins)
+    // TODO: Calculate memory usage more precisely.
+    val numElementsPerNode = DecisionTree.getElementsPerNode(metadata)
     logDebug("numElementsPerNode = " + numElementsPerNode)
     val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
@@ -124,26 +130,29 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       // Find best split for all nodes at a level.
-      val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
-        metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
+      val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
+        DecisionTree.findBestSplits(treeInput, parentImpurities,
+          metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
-      val levelNodeIndexOffset = (1 << level) - 1
+      val levelNodeIndexOffset = Node.startIndexInLevel(level)
       for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
         val nodeIndex = levelNodeIndexOffset + index
-        val isLeftChild = level != 0 && nodeIndex % 2 == 1
-        val parentNodeIndex = if (isLeftChild) { // -1 for root node
-            (nodeIndex - 1) / 2
-          } else {
-            (nodeIndex - 2) / 2
-          }
         // Extract info for this node (index) at the current level.
-        extractNodeInfo(nodeSplitStats, level, index, nodes)
+        val split = nodeSplitStats._1
+        val stats = nodeSplitStats._2
+        val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
+        val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
+        logDebug("Node = " + node)
+        nodes(nodeIndex) = node
         if (level != 0) {
           // Set parent.
-          if (isLeftChild) {
+          val parentNodeIndex = Node.parentIndex(nodeIndex)
+          if (Node.isLeftChild(nodeIndex)) {
             nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex))
           } else {
             nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
@@ -151,11 +160,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
         // Extract info for nodes at the next lower level.
-        extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities)
+        if (level < maxDepth) {
+          val leftChildIndex = Node.leftChildIndex(nodeIndex)
+          val leftImpurity = stats.leftImpurity
+          logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity)
+          parentImpurities(leftChildIndex) = leftImpurity
+          val rightChildIndex = Node.rightChildIndex(nodeIndex)
+          val rightImpurity = stats.rightImpurity
+          logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity)
+          parentImpurities(rightChildIndex) = rightImpurity
+        }
-        logDebug("final best split = " + nodeSplitStats._1)
+        logDebug("final best split = " + split)
-      require((1 << level) == splitsStatsForLevel.length)
+      require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length)
       // Check whether all the nodes at the current level at leaves.
       val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
       logDebug("all leaf = " + allLeaf)
@@ -171,7 +190,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     // Initialize the top or root node of the tree.
-    val topNode = nodes(0)
+    val topNode = nodes(1)
     // Build the full tree using the node info calculated in the level-wise best split calculations.
@@ -183,47 +202,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     new DecisionTreeModel(topNode, strategy.algo)
-  /**
-   * Extract the decision tree node information for the given tree level and node index
-   */
-  private def extractNodeInfo(
-      nodeSplitStats: (Split, InformationGainStats),
-      level: Int,
-      index: Int,
-      nodes: Array[Node]): Unit = {
-    val split = nodeSplitStats._1
-    val stats = nodeSplitStats._2
-    val nodeIndex = (1 << level) - 1 + index
-    val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
-    val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
-    logDebug("Node = " + node)
-    nodes(nodeIndex) = node
-  }
-  /**
-   *  Extract the decision tree node information for the children of the node
-   */
-  private def extractInfoForLowerLevels(
-      level: Int,
-      index: Int,
-      maxDepth: Int,
-      nodeSplitStats: (Split, InformationGainStats),
-      parentImpurities: Array[Double]): Unit = {
-    if (level >= maxDepth) {
-      return
-    }
-    val leftNodeIndex = (2 << level) - 1 + 2 * index
-    val leftImpurity = nodeSplitStats._2.leftImpurity
-    logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity)
-    parentImpurities(leftNodeIndex) = leftImpurity
-    val rightNodeIndex = leftNodeIndex + 1
-    val rightImpurity = nodeSplitStats._2.rightImpurity
-    logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity)
-    parentImpurities(rightNodeIndex) = rightImpurity
-  }
 object DecisionTree extends Serializable with Logging {
@@ -425,9 +403,6 @@ object DecisionTree extends Serializable with Logging {
       impurity, maxDepth, maxBins)
-  private val InvalidBinIndex = -1
    * Returns an array of optimal splits for all nodes at a given level. Splits the task into
    * multiple groups if the level-wise training task could lead to memory overflow.
@@ -436,12 +411,12 @@ object DecisionTree extends Serializable with Logging {
    * @param parentImpurities Impurities for all parent nodes for the current level
    * @param metadata Learning and dataset metadata
    * @param level Level of the tree
-   * @param splits possible splits for all features
-   * @param bins possible bins for all features
+   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+   * @param bins possible bins for all features, indexed (numFeatures)(numBins)
    * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
    * @return array (over nodes) of splits with best split for each node at a given level.
-  protected[tree] def findBestSplits(
+  private[tree] def findBestSplits(
       input: RDD[TreePoint],
       parentImpurities: Array[Double],
       metadata: DecisionTreeMetadata,
@@ -475,14 +450,147 @@ object DecisionTree extends Serializable with Logging {
+   * Get the node index corresponding to this data point.
+   * This function mimics prediction, passing an example from the root node down to a node
+   * at the current level being trained; that node's index is returned.
+   *
+   * @param node  Node in tree from which to classify the given data point.
+   * @param binnedFeatures  Binned feature vector for data point.
+   * @param bins possible bins for all features, indexed (numFeatures)(numBins)
+   * @param unorderedFeatures  Set of indices of unordered features.
+   * @return  Leaf index if the data point reaches a leaf.
+   *          Otherwise, last node reachable in tree matching this example.
+   *          Note: This is the global node index, i.e., the index used in the tree.
+   *                This index is different from the index used during training a particular
+   *                set of nodes in a (level, group).
+   */
+  private def predictNodeIndex(
+      node: Node,
+      binnedFeatures: Array[Int],
+      bins: Array[Array[Bin]],
+      unorderedFeatures: Set[Int]): Int = {
+    if (node.isLeaf) {
+    } else {
+      val featureIndex = node.split.get.feature
+      val splitLeft = node.split.get.featureType match {
+        case Continuous => {
+          val binIndex = binnedFeatures(featureIndex)
+          val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
+          // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
+          // We do not need to check lowSplit since bins are separated by splits.
+          featureValueUpperBound <= node.split.get.threshold
+        }
+        case Categorical => {
+          val featureValue = binnedFeatures(featureIndex)
+          node.split.get.categories.contains(featureValue)
+        }
+        case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
+      }
+      if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
+        // Return index from next layer of nodes to train
+        if (splitLeft) {
+          Node.leftChildIndex(
+        } else {
+          Node.rightChildIndex(
+        }
+      } else {
+        if (splitLeft) {
+          predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures)
+        } else {
+          predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures)
+        }
+      }
+    }
+  }
+  /**
+   * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
+   *
+   * For ordered features, a single bin is updated.
+   * For unordered features, bins correspond to subsets of categories; either the left or right bin
+   * for each subset is updated.
+   *
+   * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
+   *             each (node, feature, bin).
+   * @param treePoint  Data point being aggregated.
+   * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
+   * @param bins possible bins for all features, indexed (numFeatures)(numBins)
+   * @param unorderedFeatures  Set of indices of unordered features.
+   */
+  private def mixedBinSeqOp(
+      agg: DTStatsAggregator,
+      treePoint: TreePoint,
+      nodeIndex: Int,
+      bins: Array[Array[Bin]],
+      unorderedFeatures: Set[Int]): Unit = {
+    // Iterate over all features.
+    val numFeatures = treePoint.binnedFeatures.size
+    val nodeOffset = agg.getNodeOffset(nodeIndex)
+    var featureIndex = 0
+    while (featureIndex < numFeatures) {
+      if (unorderedFeatures.contains(featureIndex)) {
+        // Unordered feature
+        val featureValue = treePoint.binnedFeatures(featureIndex)
+        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
+          agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+        // Update the left or right bin for each split.
+        val numSplits = agg.numSplits(featureIndex)
+        var splitIndex = 0
+        while (splitIndex < numSplits) {
+          if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
+            agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label)
+          } else {
+            agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label)
+          }
+          splitIndex += 1
+        }
+      } else {
+        // Ordered feature
+        val binIndex = treePoint.binnedFeatures(featureIndex)
+        agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label)
+      }
+      featureIndex += 1
+    }
+  }
+  /**
+   * Helper for binSeqOp, for regression and for classification with only ordered features.
+   *
+   * For each feature, the sufficient statistics of one bin are updated.
+   *
+   * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
+   *             each (node, feature, bin).
+   * @param treePoint  Data point being aggregated.
+   * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
+   * @return agg
+   */
+  private def orderedBinSeqOp(
+      agg: DTStatsAggregator,
+      treePoint: TreePoint,
+      nodeIndex: Int): Unit = {
+    val label = treePoint.label
+    val nodeOffset = agg.getNodeOffset(nodeIndex)
+    // Iterate over all features.
+    val numFeatures = agg.numFeatures
+    var featureIndex = 0
+    while (featureIndex < numFeatures) {
+      val binIndex = treePoint.binnedFeatures(featureIndex)
+      agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label)
+      featureIndex += 1
+    }
+  }
+  /**
    * Returns an array of optimal splits for a group of nodes at a given level
    * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
    * @param parentImpurities Impurities for all parent nodes for the current level
    * @param metadata Learning and dataset metadata
    * @param level Level of the tree
-   * @param splits possible splits for all features
-   * @param bins possible bins for all features, indexed as (numFeatures)(numBins)
+   * @param nodes Array of all nodes in the tree.  Used for matching data points to nodes.
+   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+   * @param bins possible bins for all features, indexed (numFeatures)(numBins)
    * @param numGroups total number of node groups at the current level. Default value is set to 1.
    * @param groupIndex index of the node group being processed. Default value is set to 0.
    * @return array of splits with best splits for all nodes at a given level.
@@ -527,88 +635,22 @@ object DecisionTree extends Serializable with Logging {
     // numNodes:  Number of nodes in this (level of tree, group),
     //            where nodes at deeper (larger) levels may be divided into groups.
-    val numNodes = (1 << level) / numGroups
+    val numNodes = Node.maxNodesInLevel(level) / numGroups
     logDebug("numNodes = " + numNodes)
-    // Find the number of features by looking at the first sample.
-    val numFeatures = metadata.numFeatures
-    logDebug("numFeatures = " + numFeatures)
-    // numBins:  Number of bins = 1 + number of possible splits
-    val numBins = bins(0).length
-    logDebug("numBins = " + numBins)
-    val numClasses = metadata.numClasses
-    logDebug("numClasses = " + numClasses)
-    val isMulticlass = metadata.isMulticlass
-    logDebug("isMulticlass = " + isMulticlass)
-    val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures
-    logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures)
+    logDebug("numFeatures = " + metadata.numFeatures)
+    logDebug("numClasses = " + metadata.numClasses)
+    logDebug("isMulticlass = " + metadata.isMulticlass)
+    logDebug("isMulticlassWithCategoricalFeatures = " +
+      metadata.isMulticlassWithCategoricalFeatures)
     // shift when more than one group is used at deep tree level
     val groupShift = numNodes * groupIndex
-    /**
-     * Get the node index corresponding to this data point.
-     * This function mimics prediction, passing an example from the root node down to a node
-     * at the current level being trained; that node's index is returned.
-     *
-     * @return  Leaf index if the data point reaches a leaf.
-     *          Otherwise, last node reachable in tree matching this example.
-     */
-    def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = {
-      if (node.isLeaf) {
-      } else {
-        val featureIndex = node.split.get.feature
-        val splitLeft = node.split.get.featureType match {
-          case Continuous => {
-            val binIndex = binnedFeatures(featureIndex)
-            val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
-            // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
-            // We do not need to check lowSplit since bins are separated by splits.
-            featureValueUpperBound <= node.split.get.threshold
-          }
-          case Categorical => {
-            val featureValue = if (metadata.isUnordered(featureIndex)) {
-                binnedFeatures(featureIndex)
-              } else {
-                val binIndex = binnedFeatures(featureIndex)
-                bins(featureIndex)(binIndex).category
-              }
-            node.split.get.categories.contains(featureValue)
-          }
-          case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
-        }
-        if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
-          // Return index from next layer of nodes to train
-          if (splitLeft) {
-   * 2 + 1 // left
-          } else {
-   * 2 + 2 // right
-          }
-        } else {
-          if (splitLeft) {
-            predictNodeIndex(node.leftNode.get, binnedFeatures)
-          } else {
-            predictNodeIndex(node.rightNode.get, binnedFeatures)
-          }
-        }
-      }
-    }
-    def nodeIndexToLevel(idx: Int): Int = {
-      if (idx == 0) {
-        0
-      } else {
-        math.floor(math.log(idx) / math.log(2)).toInt
-      }
-    }
-    // Used for treePointToNodeIndex
-    val levelOffset = (1 << level) - 1
+    // Used for treePointToNodeIndex to get an index for this (level, group).
+    // - Node.startIndexInLevel(level) gives the global index offset for nodes at this level.
+    // - groupShift corrects for groups in this level before the current group.
+    val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift
      * Find the node index for the given example.
@@ -619,661 +661,254 @@ object DecisionTree extends Serializable with Logging {
       if (level == 0) {
       } else {
-        val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures)
-        // Get index for this (level, group).
-        globalNodeIndex - levelOffset - groupShift
-      }
-    }
-    /**
-     * Increment aggregate in location for (node, feature, bin, label).
-     *
-     * @param treePoint  Data point being aggregated.
-     * @param agg  Array storing aggregate calculation, of size:
-     *             numClasses * numBins * numFeatures * numNodes.
-     *             Indexed by (node, feature, bin, label) where label is the least significant bit.
-     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
-     */
-    def updateBinForOrderedFeature(
-        treePoint: TreePoint,
-        agg: Array[Double],
-        nodeIndex: Int,
-        featureIndex: Int): Unit = {
-      // Update the left or right count for one bin.
-      val aggIndex =
-        numClasses * numBins * numFeatures * nodeIndex +
-        numClasses * numBins * featureIndex +
-        numClasses * treePoint.binnedFeatures(featureIndex) +
-        treePoint.label.toInt
-      agg(aggIndex) += 1
-    }
-    /**
-     * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label),
-     * where [bins] ranges over all bins.
-     * Updates left or right side of aggregate depending on split.
-     *
-     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
-     * @param treePoint  Data point being aggregated.
-     * @param agg  Indexed by (left/right, node, feature, bin, label)
-     *             where label is the least significant bit.
-     *             The left/right specifier is a 0/1 index indicating left/right child info.
-     * @param rightChildShift Offset for right side of agg.
-     */
-    def updateBinForUnorderedFeature(
-        nodeIndex: Int,
-        featureIndex: Int,
-        treePoint: TreePoint,
-        agg: Array[Double],
-        rightChildShift: Int): Unit = {
-      val featureValue = treePoint.binnedFeatures(featureIndex)
-      // Update the left or right count for one bin.
-      val aggShift =
-        numClasses * numBins * numFeatures * nodeIndex +
-        numClasses * numBins * featureIndex +
-        treePoint.label.toInt
-      // Find all matching bins and increment their values
-      val featureCategories = metadata.featureArity(featureIndex)
-      val numCategoricalBins = (1 << featureCategories - 1) - 1
-      var binIndex = 0
-      while (binIndex < numCategoricalBins) {
-        val aggIndex = aggShift + binIndex * numClasses
-        if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) {
-          agg(aggIndex) += 1
-        } else {
-          agg(rightChildShift + aggIndex) += 1
-        }
-        binIndex += 1
-      }
-    }
-    /**
-     * Helper for binSeqOp.
-     *
-     * @param agg  Array storing aggregate calculation, of size:
-     *             numClasses * numBins * numFeatures * numNodes.
-     *             Indexed by (node, feature, bin, label) where label is the least significant bit.
-     * @param treePoint  Data point being aggregated.
-     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
-     */
-    def binaryOrNotCategoricalBinSeqOp(
-        agg: Array[Double],
-        treePoint: TreePoint,
-        nodeIndex: Int): Unit = {
-      // Iterate over all features.
-      var featureIndex = 0
-      while (featureIndex < numFeatures) {
-        updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
-        featureIndex += 1
-      }
-    }
-    val rightChildShift = numClasses * numBins * numFeatures * numNodes
-    /**
-     * Helper for binSeqOp.
-     *
-     * @param agg  Array storing aggregate calculation.
-     *             For ordered features, this is of size:
-     *               numClasses * numBins * numFeatures * numNodes.
-     *             For unordered features, this is of size:
-     *               2 * numClasses * numBins * numFeatures * numNodes.
-     * @param treePoint   Data point being aggregated.
-     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
-     */
-    def multiclassWithCategoricalBinSeqOp(
-        agg: Array[Double],
-        treePoint: TreePoint,
-        nodeIndex: Int): Unit = {
-      val label = treePoint.label
-      // Iterate over all features.
-      var featureIndex = 0
-      while (featureIndex < numFeatures) {
-        if (metadata.isUnordered(featureIndex)) {
-          updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift)
-        } else {
-          updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
-        }
-        featureIndex += 1
-      }
-    }
-    /**
-     * Performs a sequential aggregation over a partition for regression.
-     * For l nodes, k features,
-     * the count, sum, sum of squares of one of the p bins is incremented.
-     *
-     * @param agg Array storing aggregate calculation, updated by this function.
-     *            Size: 3 * numBins * numFeatures * numNodes
-     * @param treePoint   Data point being aggregated.
-     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
-     * @return agg
-     */
-    def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = {
-      val label = treePoint.label
-      // Iterate over all features.
-      var featureIndex = 0
-      while (featureIndex < numFeatures) {
-        // Update count, sum, and sum^2 for one bin.
-        val binIndex = treePoint.binnedFeatures(featureIndex)
-        val aggIndex =
-          3 * numBins * numFeatures * nodeIndex +
-          3 * numBins * featureIndex +
-          3 * binIndex
-        agg(aggIndex) += 1
-        agg(aggIndex + 1) += label
-        agg(aggIndex + 2) += label * label
-        featureIndex += 1
+        val globalNodeIndex =
+          predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
+        globalNodeIndex - globalNodeIndexOffset
      * Performs a sequential aggregation over a partition.
-     * For l nodes, k features,
-     *   For classification:
-     *     Either the left count or the right count of one of the bins is
-     *     incremented based upon whether the feature is classified as 0 or 1.
-     *   For regression:
-     *     The count, sum, sum of squares of one of the bins is incremented.
-     * @param agg Array storing aggregate calculation, updated by this function.
-     *            Size for classification:
-     *              numClasses * numBins * numFeatures * numNodes for ordered features, or
-     *              2 * numClasses * numBins * numFeatures * numNodes for unordered features.
-     *            Size for regression:
-     *              3 * numBins * numFeatures * numNodes.
+     * Each data point contributes to one node. For each feature,
+     * the aggregate sufficient statistics are updated for the relevant bins.
+     *
+     * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
+     *             each (node, feature, bin).
      * @param treePoint   Data point being aggregated.
      * @return  agg
-    def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = {
+    def binSeqOp(
+        agg: DTStatsAggregator,
+        treePoint: TreePoint): DTStatsAggregator = {
       val nodeIndex = treePointToNodeIndex(treePoint)
       // If the example does not reach this level, then nodeIndex < 0.
       // If the example reaches this level but is handled in a different group,
       //  then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group).
       if (nodeIndex >= 0 && nodeIndex < numNodes) {
-        if (metadata.isClassification) {
-          if (isMulticlassWithCategoricalFeatures) {
-            multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex)
-          } else {
-            binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex)
-          }
+        if (metadata.unorderedFeatures.isEmpty) {
+          orderedBinSeqOp(agg, treePoint, nodeIndex)
         } else {
-          regressionBinSeqOp(agg, treePoint, nodeIndex)
+          mixedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures)
-    // Calculate bin aggregate length for classification or regression.
-    val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins)
-    logDebug("binAggregateLength = " + binAggregateLength)
-    /**
-     * Combines the aggregates from partitions.
-     * @param agg1 Array containing aggregates from one or more partitions
-     * @param agg2 Array containing aggregates from one or more partitions
-     * @return Combined aggregate from agg1 and agg2
-     */
-    def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = {
-      var index = 0
-      val combinedAggregate = new Array[Double](binAggregateLength)
-      while (index < binAggregateLength) {
-        combinedAggregate(index) = agg1(index) + agg2(index)
-        index += 1
-      }
-      combinedAggregate
-    }
     // Calculate bin aggregates.
-    val binAggregates = {
-      input.treeAggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
+    val binAggregates: DTStatsAggregator = {
+      val initAgg = new DTStatsAggregator(metadata, numNodes)
+      input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
-    logDebug("binAggregates.length = " + binAggregates.length)
-    /**
-     * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
-     * @param leftNodeAgg left node aggregates for this (feature, split)
-     * @param rightNodeAgg right node aggregate for this (feature, split)
-     * @param topImpurity impurity of the parent node
-     * @return information gain and statistics for all splits
-     */
-    def calculateGainForSplit(
-        leftNodeAgg: Array[Double],
-        rightNodeAgg: Array[Double],
-        topImpurity: Double): InformationGainStats = {
-      if (metadata.isClassification) {
-        val leftTotalCount = leftNodeAgg.sum
-        val rightTotalCount = rightNodeAgg.sum
-        val impurity = {
-          if (level > 0) {
-            topImpurity
-          } else {
-            // Calculate impurity for root node.
-            val rootNodeCounts = new Array[Double](numClasses)
-            var classIndex = 0
-            while (classIndex < numClasses) {
-              rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex)
-              classIndex += 1
-            }
-            metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
-          }
-        }
-        val totalCount = leftTotalCount + rightTotalCount
-        if (totalCount == 0) {
-          // Return arbitrary prediction.
-          return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
-        }
-        // Sum of count for each label
-        val leftrightNodeAgg: Array[Double] =
- { case (leftCount, rightCount) =>
-            leftCount + rightCount
-          }
-        def indexOfLargestArrayElement(array: Array[Double]): Int = {
-          val result = array.foldLeft(-1, Double.MinValue, 0) {
-            case ((maxIndex, maxValue, currentIndex), currentValue) =>
-              if (currentValue > maxValue) {
-                (currentIndex, currentValue, currentIndex + 1)
-              } else {
-                (maxIndex, maxValue, currentIndex + 1)
-              }
-          }
-          if (result._1 < 0) {
-            throw new RuntimeException("DecisionTree internal error:" +
-              " calculateGainForSplit failed in indexOfLargestArrayElement")
-          }
-          result._1
-        }
-        val predict = indexOfLargestArrayElement(leftrightNodeAgg)
-        val prob = leftrightNodeAgg(predict) / totalCount
-        val leftImpurity = if (leftTotalCount == 0) {
-          topImpurity
-        } else {
-          metadata.impurity.calculate(leftNodeAgg, leftTotalCount)
-        }
-        val rightImpurity = if (rightTotalCount == 0) {
-          topImpurity
-        } else {
-          metadata.impurity.calculate(rightNodeAgg, rightTotalCount)
-        }
-        val leftWeight = leftTotalCount / totalCount
-        val rightWeight = rightTotalCount / totalCount
-        val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
-        new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
-      } else {
-        // Regression
-        val leftCount = leftNodeAgg(0)
-        val leftSum = leftNodeAgg(1)
-        val leftSumSquares = leftNodeAgg(2)
+    // Calculate best splits for all nodes at a given level
+    timer.start("chooseSplits")
+    val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
+    // Iterating over all nodes at this level
+    var nodeIndex = 0
+    while (nodeIndex < numNodes) {
+      val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex)
+      logDebug("node impurity = " + nodeImpurity)
+      bestSplits(nodeIndex) =
+        binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits)
+      logDebug("best split = " + bestSplits(nodeIndex)._1)
+      nodeIndex += 1
+    }
+    timer.stop("chooseSplits")
-        val rightCount = rightNodeAgg(0)
-        val rightSum = rightNodeAgg(1)
-        val rightSumSquares = rightNodeAgg(2)
+    bestSplits
+  }
-        val impurity = {
-          if (level > 0) {
-            topImpurity
-          } else {
-            // Calculate impurity for root node.
-            val count = leftCount + rightCount
-            val sum = leftSum + rightSum
-            val sumSquares = leftSumSquares + rightSumSquares
-            metadata.impurity.calculate(count, sum, sumSquares)
-          }
-        }
+  /**
+   * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
+   * @param leftImpurityCalculator left node aggregates for this (feature, split)
+   * @param rightImpurityCalculator right node aggregate for this (feature, split)
+   * @param topImpurity impurity of the parent node
+   * @return information gain and statistics for all splits
+   */
+  private def calculateGainForSplit(
+      leftImpurityCalculator: ImpurityCalculator,
+      rightImpurityCalculator: ImpurityCalculator,
+      topImpurity: Double,
+      level: Int,
+      metadata: DecisionTreeMetadata): InformationGainStats = {
-        if (leftCount == 0) {
-          return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
-            rightSum / rightCount)
-        }
-        if (rightCount == 0) {
-          return new InformationGainStats(0, topImpurity, topImpurity,
-            Double.MinValue, leftSum / leftCount)
-        }
+    val leftCount = leftImpurityCalculator.count
+    val rightCount = rightImpurityCalculator.count
-        val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares)
-        val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares)
+    val totalCount = leftCount + rightCount
+    if (totalCount == 0) {
+      // Return arbitrary prediction.
+      return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+    }
-        val leftWeight = leftCount.toDouble / (leftCount + rightCount)
-        val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+    val parentNodeAgg = leftImpurityCalculator.copy
+    parentNodeAgg.add(rightImpurityCalculator)
+    // impurity of parent node
+    val impurity = if (level > 0) {
+      topImpurity
+    } else {
+      parentNodeAgg.calculate()
+    }
-        val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+    val predict = parentNodeAgg.predict
+    val prob = parentNodeAgg.prob(predict)
-        val predict = (leftSum + rightSum) / (leftCount + rightCount)
-        new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
-      }
-    }
+    val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
+    val rightImpurity = rightImpurityCalculator.calculate()
-    /**
-     * Extracts left and right split aggregates.
-     * @param binData Aggregate array slice from getBinDataForNode.
-     *                For classification:
-     *                  For unordered features, this is leftChildData ++ rightChildData,
-     *                    each of which is indexed by (feature, split/bin, class),
-     *                    with class being the least significant bit.
-     *                  For ordered features, this is of size numClasses * numBins * numFeatures.
-     *                For regression:
-     *                  This is of size 2 * numFeatures * numBins.
-     * @return (leftNodeAgg, rightNodeAgg) pair of arrays.
-     *         For classification, each array is of size (numFeatures, (numBins - 1), numClasses).
-     *         For regression, each array is of size (numFeatures, (numBins - 1), 3).
-     *
-     */
-    def extractLeftRightNodeAggregates(
-        binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
-      /**
-       * The input binData is indexed as (feature, bin, class).
-       * This computes cumulative sums over splits.
-       * Each (feature, class) pair is handled separately.
-       * Note: numSplits = numBins - 1.
-       * @param leftNodeAgg  Each (feature, class) slice is an array over splits.
-       *                     Element i (i = 0, ..., numSplits - 2) is set to be
-       *                     the cumulative sum (from left) over binData for bins 0, ..., i.
-       * @param rightNodeAgg Each (feature, class) slice is an array over splits.
-       *                     Element i (i = 1, ..., numSplits - 1) is set to be
-       *                     the cumulative sum (from right) over binData for bins
-       *                     numBins - 1, ..., numBins - 1 - i.
-       */
-      def findAggForOrderedFeatureClassification(
-          leftNodeAgg: Array[Array[Array[Double]]],
-          rightNodeAgg: Array[Array[Array[Double]]],
-          featureIndex: Int) {
-        // shift for this featureIndex
-        val shift = numClasses * featureIndex * numBins
-        var classIndex = 0
-        while (classIndex < numClasses) {
-          // left node aggregate for the lowest split
-          leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex)
-          // right node aggregate for the highest split
-          rightNodeAgg(featureIndex)(numBins - 2)(classIndex)
-            = binData(shift + (numClasses * (numBins - 1)) + classIndex)
-          classIndex += 1
-        }
+    val leftWeight = leftCount / totalCount.toDouble
+    val rightWeight = rightCount / totalCount.toDouble
-        // Iterate over all splits.
-        var splitIndex = 1
-        while (splitIndex < numBins - 1) {
-          // calculating left node aggregate for a split as a sum of left node aggregate of a
-          // lower split and the left bin aggregate of a bin where the split is a high split
-          var innerClassIndex = 0
-          while (innerClassIndex < numClasses) {
-            leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex)
-              = binData(shift + numClasses * splitIndex + innerClassIndex) +
-                leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex)
-            rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) =
-              binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) +
-                rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex)
-            innerClassIndex += 1
-          }
-          splitIndex += 1
-        }
-      }
+    val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
-      /**
-       * Reshape binData for this feature.
-       * Indexes binData as (feature, split, class) with class as the least significant bit.
-       * @param leftNodeAgg   leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value
-       */
-      def findAggForUnorderedFeatureClassification(
-          leftNodeAgg: Array[Array[Array[Double]]],
-          rightNodeAgg: Array[Array[Array[Double]]],
-          featureIndex: Int) {
-        val rightChildShift = numClasses * numBins * numFeatures
-        var splitIndex = 0
-        while (splitIndex < numBins - 1) {
-          var classIndex = 0
-          while (classIndex < numClasses) {
-            // shift for this featureIndex
-            val shift = numClasses * featureIndex * numBins + splitIndex * numClasses
-            val leftBinValue = binData(shift + classIndex)
-            val rightBinValue = binData(rightChildShift + shift + classIndex)
-            leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue
-            rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue
-            classIndex += 1
-          }
-          splitIndex += 1
-        }
-      }
+    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+  }
-      def findAggForRegression(
-          leftNodeAgg: Array[Array[Array[Double]]],
-          rightNodeAgg: Array[Array[Array[Double]]],
-          featureIndex: Int) {
-        // shift for this featureIndex
-        val shift = 3 * featureIndex * numBins
-        // left node aggregate for the lowest split
-        leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
-        leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)
-        leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2)
-        // right node aggregate for the highest split
-        rightNodeAgg(featureIndex)(numBins - 2)(0) =
-          binData(shift + (3 * (numBins - 1)))
-        rightNodeAgg(featureIndex)(numBins - 2)(1) =
-          binData(shift + (3 * (numBins - 1)) + 1)
-        rightNodeAgg(featureIndex)(numBins - 2)(2) =
-          binData(shift + (3 * (numBins - 1)) + 2)
-        // Iterate over all splits.
-        var splitIndex = 1
-        while (splitIndex < numBins - 1) {
-          var i = 0 // index for regression histograms
-          while (i < 3) { // count, sum, sum^2
-            // calculating left node aggregate for a split as a sum of left node aggregate of a
-            // lower split and the left bin aggregate of a bin where the split is a high split
-            leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) +
-              leftNodeAgg(featureIndex)(splitIndex - 1)(i)
-            // calculating right node aggregate for a split as a sum of right node aggregate of a
-            // higher split and the right bin aggregate of a bin where the split is a low split
-            rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) =
-              binData(shift + (3 * (numBins - 1 - splitIndex) + i)) +
-                rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i)
-            i += 1
-          }
-          splitIndex += 1
-        }
-      }
+  /**
+   * Find the best split for a node.
+   * @param binAggregates Bin statistics.
+   * @param nodeIndex Index for node to split in this (level, group).
+   * @param nodeImpurity Impurity of the node (nodeIndex).
+   * @return tuple for best split: (Split, information gain)
+   */
+  private def binsToBestSplit(
+      binAggregates: DTStatsAggregator,
+      nodeIndex: Int,
+      nodeImpurity: Double,
+      level: Int,
+      metadata: DecisionTreeMetadata,
+      splits: Array[Array[Split]]): (Split, InformationGainStats) = {
-      if (metadata.isClassification) {
-        // Initialize left and right split aggregates.
-        val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
-        val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
-        var featureIndex = 0
-        while (featureIndex < numFeatures) {
-          if (metadata.isUnordered(featureIndex)) {
-            findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
-          } else {
-            findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
-          }
-          featureIndex += 1
-        }
-        (leftNodeAgg, rightNodeAgg)
-      } else {
-        // Regression
-        // Initialize left and right split aggregates.
-        val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
-        val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
-        // Iterate over all features.
-        var featureIndex = 0
-        while (featureIndex < numFeatures) {
-          findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
-          featureIndex += 1
-        }
-        (leftNodeAgg, rightNodeAgg)
-      }
-    }
+    logDebug("node impurity = " + nodeImpurity)
-    /**
-     * Calculates information gain for all nodes splits.
-     */
-    def calculateGainsForAllNodeSplits(
-        leftNodeAgg: Array[Array[Array[Double]]],
-        rightNodeAgg: Array[Array[Array[Double]]],
-        nodeImpurity: Double): Array[Array[InformationGainStats]] = {
-      val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
-      var featureIndex = 0
-      while (featureIndex < numFeatures) {
-        val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
+    // For each (feature, split), calculate the gain, and select the best (feature, split).
+    Range(0, metadata.numFeatures).map { featureIndex =>
+      val numSplits = metadata.numSplits(featureIndex)
+      if (metadata.isContinuous(featureIndex)) {
+        // Cumulative sum (scanLeft) of bin statistics.
+        // Afterwards, binAggregates for a bin is the sum of aggregates for
+        // that bin + all preceding bins.
+        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
         var splitIndex = 0
-        while (splitIndex < numSplitsForFeature) {
-          gains(featureIndex)(splitIndex) =
-            calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex),
-              rightNodeAgg(featureIndex)(splitIndex), nodeImpurity)
+        while (splitIndex < numSplits) {
+          binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
           splitIndex += 1
-        featureIndex += 1
-      }
-      gains
-    }
-    /**
-     * Get the number of splits for a feature.
-     */
-    def getNumSplitsForFeature(featureIndex: Int): Int = {
-      if (metadata.isContinuous(featureIndex)) {
-        numBins - 1
+        // Find best split.
+        val (bestFeatureSplitIndex, bestFeatureGainStats) =
+          Range(0, numSplits).map { case splitIdx =>
+            val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+            val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
+            rightChildStats.subtract(leftChildStats)
+            val gainStats =
+              calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+            (splitIdx, gainStats)
+          }.maxBy(_._2.gain)
+        (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+      } else if (metadata.isUnordered(featureIndex)) {
+        // Unordered categorical feature
+        val (leftChildOffset, rightChildOffset) =
+          binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+        val (bestFeatureSplitIndex, bestFeatureGainStats) =
+          Range(0, numSplits).map { splitIndex =>
+            val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
+            val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+            val gainStats =
+              calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+            (splitIndex, gainStats)
+          }.maxBy(_._2.gain)
+        (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
       } else {
-        // Categorical feature
-        val featureCategories = metadata.featureArity(featureIndex)
-        if (metadata.isUnordered(featureIndex)) {
-          (1 << featureCategories - 1) - 1
-        } else {
-          featureCategories
-        }
-      }
-    }
-    /**
-     * Find the best split for a node.
-     * @param binData Bin data slice for this node, given by getBinDataForNode.
-     * @param nodeImpurity impurity of the top node
-     * @return tuple of split and information gain
-     */
-    def binsToBestSplit(
-        binData: Array[Double],
-        nodeImpurity: Double): (Split, InformationGainStats) = {
-      logDebug("node impurity = " + nodeImpurity)
-      // Extract left right node aggregates.
-      val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
-      // Calculate gains for all splits.
-      val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
-      val (bestFeatureIndex, bestSplitIndex, gainStats) = {
-        // Initialize with infeasible values.
-        var bestFeatureIndex = Int.MinValue
-        var bestSplitIndex = Int.MinValue
-        var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0)
-        // Iterate over features.
-        var featureIndex = 0
-        while (featureIndex < numFeatures) {
-          // Iterate over all splits.
-          var splitIndex = 0
-          val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
-          while (splitIndex < numSplitsForFeature) {
-            val gainStats = gains(featureIndex)(splitIndex)
-            if (gainStats.gain > bestGainStats.gain) {
-              bestGainStats = gainStats
-              bestFeatureIndex = featureIndex
-              bestSplitIndex = splitIndex
+        // Ordered categorical feature
+        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
+        val numBins = metadata.numBins(featureIndex)
+        /* Each bin is one category (feature value).
+         * The bins are ordered based on centroidForCategories, and this ordering determines which
+         * splits are considered.  (With K categories, we consider K - 1 possible splits.)
+         *
+         * centroidForCategories is a list: (category, centroid)
+         */
+        val centroidForCategories = if (metadata.isMulticlass) {
+          // For categorical variables in multiclass classification,
+          // the bins are ordered by the impurity of their corresponding labels.
+          Range(0, numBins).map { case featureValue =>
+            val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+            val centroid = if (categoryStats.count != 0) {
+              categoryStats.calculate()
+            } else {
+              Double.MaxValue
-            splitIndex += 1
+            (featureValue, centroid)
+          }
+        } else { // regression or binary classification
+          // For categorical variables in regression and binary classification,
+          // the bins are ordered by the centroid of their corresponding labels.
+          Range(0, numBins).map { case featureValue =>
+            val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+            val centroid = if (categoryStats.count != 0) {
+              categoryStats.predict
+            } else {
+              Double.MaxValue
+            }
+            (featureValue, centroid)
-          featureIndex += 1
-        (bestFeatureIndex, bestSplitIndex, bestGainStats)
-      }
-      logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex))
-      logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))
+        logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
-      (splits(bestFeatureIndex)(bestSplitIndex), gainStats)
-    }
+        // bins sorted by centroids
+        val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
-    /**
-     * Get bin data for one node.
-     */
-    def getBinDataForNode(node: Int): Array[Double] = {
-      if (metadata.isClassification) {
-        if (isMulticlassWithCategoricalFeatures) {
-          val shift = numClasses * node * numBins * numFeatures
-          val rightChildShift = numClasses * numBins * numFeatures * numNodes
-          val binsForNode = {
-            val leftChildData
-            = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
-            val rightChildData
-            = binAggregates.slice(rightChildShift + shift,
-              rightChildShift + shift + numClasses * numBins * numFeatures)
-            leftChildData ++ rightChildData
-          }
-          binsForNode
-        } else {
-          val shift = numClasses * node * numBins * numFeatures
-          val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
-          binsForNode
+        logDebug("Sorted centroids for categorical variable = " +
+          categoriesSortedByCentroid.mkString(","))
+        // Cumulative sum (scanLeft) of bin statistics.
+        // Afterwards, binAggregates for a bin is the sum of aggregates for
+        // that bin + all preceding bins.
+        var splitIndex = 0
+        while (splitIndex < numSplits) {
+          val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+          val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+          binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory)
+          splitIndex += 1
-      } else {
-        // Regression
-        val shift = 3 * node * numBins * numFeatures
-        val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
-        binsForNode
+        // lastCategory = index of bin with total aggregates for this (node, feature)
+        val lastCategory = categoriesSortedByCentroid.last._1
+        // Find best split.
+        val (bestFeatureSplitIndex, bestFeatureGainStats) =
+          Range(0, numSplits).map { splitIndex =>
+            val featureValue = categoriesSortedByCentroid(splitIndex)._1
+            val leftChildStats =
+              binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+            val rightChildStats =
+              binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
+            rightChildStats.subtract(leftChildStats)
+            val gainStats =
+              calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+            (splitIndex, gainStats)
+          }.maxBy(_._2.gain)
+        val categoriesForSplit =
+, bestFeatureSplitIndex + 1)
+        val bestFeatureSplit =
+          new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
+        (bestFeatureSplit, bestFeatureGainStats)
-    }
-    // Calculate best splits for all nodes at a given level
-    timer.start("chooseSplits")
-    val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
-    // Iterating over all nodes at this level
-    var node = 0
-    while (node < numNodes) {
-      val nodeImpurityIndex = (1 << level) - 1 + node + groupShift
-      val binsForNode: Array[Double] = getBinDataForNode(node)
-      logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
-      val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
-      logDebug("parent node impurity = " + parentNodeImpurity)
-      bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
-      node += 1
-    }
-    timer.stop("chooseSplits")
-    bestSplits
+    }.maxBy(_._2.gain)
    * Get the number of values to be stored per node in the bin aggregates.
-   *
-   * @param numBins  Number of bins = 1 + number of possible splits.
-  private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = {
+  private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = {
+    val totalBins = metadata.numBins.sum
     if (metadata.isClassification) {
-      if (metadata.isMulticlassWithCategoricalFeatures) {
-        2 * metadata.numClasses * numBins * metadata.numFeatures
-      } else {
-        metadata.numClasses * numBins * metadata.numFeatures
-      }
+      metadata.numClasses * totalBins
     } else {
-      3 * numBins * metadata.numFeatures
+      3 * totalBins
@@ -1284,6 +919,7 @@ object DecisionTree extends Serializable with Logging {
    * Continuous features:
    *   For each feature, there are numBins - 1 possible splits representing the possible binary
    *   decisions at each node in the tree.
+   *   This finds locations (feature values) for splits using a subsample of the data.
    * Categorical features:
    *   For each feature, there is 1 bin per split.
@@ -1292,7 +928,6 @@ object DecisionTree extends Serializable with Logging {
    *       For multiclass classification with a low-arity feature
    *       (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
    *       the feature is split based on subsets of categories.
-   *       There are (1 << maxFeatureValue - 1) - 1 splits.
    *   (b) "ordered features"
    *       For regression and binary classification,
    *       and for multiclass classification with a high-arity feature,
@@ -1302,7 +937,7 @@ object DecisionTree extends Serializable with Logging {
    * @param metadata Learning and dataset metadata
    * @return A tuple of (splits, bins).
    *         Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
-   *          of size (numFeatures, numBins - 1).
+   *          of size (numFeatures, numSplits).
    *         Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
    *          of size (numFeatures, numBins).
@@ -1310,84 +945,80 @@ object DecisionTree extends Serializable with Logging {
       input: RDD[LabeledPoint],
       metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
-    val count = input.count()
+    logDebug("isMulticlass = " + metadata.isMulticlass)
-    // Find the number of features by looking at the first sample
-    val numFeatures = input.take(1)(0).features.size
-    val maxBins = metadata.maxBins
-    val numBins = if (maxBins <= count) maxBins else count.toInt
-    logDebug("numBins = " + numBins)
-    val isMulticlass = metadata.isMulticlass
-    logDebug("isMulticlass = " + isMulticlass)
-    /*
-     * Ensure numBins is always greater than the categories. For multiclass classification,
-     * numBins should be greater than 2^(maxCategories - 1) - 1.
-     * It's a limitation of the current implementation but a reasonable trade-off since features
-     * with large number of categories get favored over continuous features.
-     *
-     * This needs to be checked here instead of in Strategy since numBins can be determined
-     * by the number of training examples.
-     * TODO: Allow this case, where we simply will know nothing about some categories.
-     */
-    if (metadata.featureArity.size > 0) {
-      val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2
-      require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
-        "in categorical features")
-    }
-    // Calculate the number of sample for approximate quantile calculation.
-    val requiredSamples = numBins*numBins
-    val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
-    logDebug("fraction of data used for calculating quantiles = " + fraction)
+    val numFeatures = metadata.numFeatures
-    // sampled input for RDD calculation
-    val sampledInput =
+    // Sample the input only if there are continuous features.
+    val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
+    val sampledInput = if (hasContinuousFeatures) {
+      // Calculate the number of samples for approximate quantile calculation.
+      val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
+      val fraction = if (requiredSamples < metadata.numExamples) {
+        requiredSamples.toDouble / metadata.numExamples
+      } else {
+        1.0
+      }
+      logDebug("fraction of data used for calculating quantiles = " + fraction)
       input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
-    val numSamples = sampledInput.length
-    val stride: Double = numSamples.toDouble / numBins
-    logDebug("stride = " + stride)
+    } else {
+      new Array[LabeledPoint](0)
+    }
     metadata.quantileStrategy match {
       case Sort =>
-        val splits = Array.ofDim[Split](numFeatures, numBins - 1)
-        val bins = Array.ofDim[Bin](numFeatures, numBins)
+        val splits = new Array[Array[Split]](numFeatures)
+        val bins = new Array[Array[Bin]](numFeatures)
         // Find all splits.
         // Iterate over all features.
         var featureIndex = 0
         while (featureIndex < numFeatures) {
-          // Check whether the feature is continuous.
-          val isFeatureContinuous = metadata.isContinuous(featureIndex)
-          if (isFeatureContinuous) {
+          val numSplits = metadata.numSplits(featureIndex)
+          val numBins = metadata.numBins(featureIndex)
+          if (metadata.isContinuous(featureIndex)) {
+            val numSamples = sampledInput.length
+            splits(featureIndex) = new Array[Split](numSplits)
+            bins(featureIndex) = new Array[Bin](numBins)
             val featureSamples = => lp.features(featureIndex)).sorted
-            val stride: Double = numSamples.toDouble / numBins
+            val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex)
             logDebug("stride = " + stride)
-            for (index <- 0 until numBins - 1) {
-              val sampleIndex = index * stride.toInt
+            for (splitIndex <- 0 until numSplits) {
+              val sampleIndex = splitIndex * stride.toInt
               // Set threshold halfway in between 2 samples.
               val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0
-              val split = new Split(featureIndex, threshold, Continuous, List())
-              splits(featureIndex)(index) = split
+              splits(featureIndex)(splitIndex) =
+                new Split(featureIndex, threshold, Continuous, List())
-          } else { // Categorical feature
-            val featureCategories = metadata.featureArity(featureIndex)
-            // Use different bin/split calculation strategy for categorical features in multiclass
-            // classification that satisfy the space constraint.
+            bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
+              splits(featureIndex)(0), Continuous, Double.MinValue)
+            for (splitIndex <- 1 until numSplits) {
+              bins(featureIndex)(splitIndex) =
+                new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
+                  Continuous, Double.MinValue)
+            }
+            bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
+              new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
+          } else {
+            // Categorical feature
+            val featureArity = metadata.featureArity(featureIndex)
             if (metadata.isUnordered(featureIndex)) {
-              // 2^(maxFeatureValue- 1) - 1 combinations
-              var index = 0
-              while (index < (1 << featureCategories - 1) - 1) {
-                val categories: List[Double]
-                  = extractMultiClassCategories(index + 1, featureCategories)
-                splits(featureIndex)(index)
-                  = new Split(featureIndex, Double.MinValue, Categorical, categories)
-                bins(featureIndex)(index) = {
-                  if (index == 0) {
+              // TODO: The second half of the bins are unused.  Actually, we could just use
+              //       splits and not build bins for unordered features.  That should be part of
+              //       a later PR since it will require changing other code (using splits instead
+              //       of bins in a few places).
+              // Unordered features
+              //   2^(maxFeatureValue - 1) - 1 combinations
+              splits(featureIndex) = new Array[Split](numSplits)
+              bins(featureIndex) = new Array[Bin](numBins)
+              var splitIndex = 0
+              while (splitIndex < numSplits) {
+                val categories: List[Double] =
+                  extractMultiClassCategories(splitIndex + 1, featureArity)
+                splits(featureIndex)(splitIndex) =
+                  new Split(featureIndex, Double.MinValue, Categorical, categories)
+                bins(featureIndex)(splitIndex) = {
+                  if (splitIndex == 0) {
                     new Bin(
                       new DummyCategoricalSplit(featureIndex, Categorical),
@@ -1395,96 +1026,24 @@ object DecisionTree extends Serializable with Logging {
                   } else {
                     new Bin(
-                      splits(featureIndex)(index - 1),
-                      splits(featureIndex)(index),
+                      splits(featureIndex)(splitIndex - 1),
+                      splits(featureIndex)(splitIndex),
-                index += 1
-              }
-            } else { // ordered feature
-              /* For a given categorical feature, use a subsample of the data
-               * to choose how to arrange possible splits.
-               * This examines each category and computes a centroid.
-               * These centroids are later used to sort the possible splits.
-               * centroidForCategories is a mapping: category (for the given feature) --> centroid
-               */
-              val centroidForCategories = {
-                if (isMulticlass) {
-                  // For categorical variables in multiclass classification,
-                  // each bin is a category. The bins are sorted and they
-                  // are ordered by calculating the impurity of their corresponding labels.
-         => (lp.features(featureIndex), lp.label))
-                   .groupBy(_._1)
-                   .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
-                   .map(x => (x._1, x._2.values.toArray))
-                   .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum)))
-                } else { // regression or binary classification
-                  // For categorical variables in regression and binary classification,
-                  // each bin is a category. The bins are sorted and they
-                  // are ordered by calculating the centroid of their corresponding labels.
-         => (lp.features(featureIndex), lp.label))
-                    .groupBy(_._1)
-                    .mapValues(x => /
-                }
-              }
-              logDebug("centroid for categories = " + centroidForCategories.mkString(","))
-              // Check for missing categorical variables and putting them last in the sorted list.
-              val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
-              for (i <- 0 until featureCategories) {
-                if (centroidForCategories.contains(i)) {
-                  fullCentroidForCategories(i) = centroidForCategories(i)
-                } else {
-                  fullCentroidForCategories(i) = Double.MaxValue
-                }
-              }
-              // bins sorted by centroids
-              val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
-              logDebug("centroid for categorical variable = " + categoriesSortedByCentroid)
-              var categoriesForSplit = List[Double]()
-              categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
-                case ((key, value), index) =>
-                  categoriesForSplit = key :: categoriesForSplit
-                  splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue,
-                    Categorical, categoriesForSplit)
-                  bins(featureIndex)(index) = {
-                    if (index == 0) {
-                      new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
-                        splits(featureIndex)(0), Categorical, key)
-                    } else {
-                      new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
-                        Categorical, key)
-                    }
-                  }
+                splitIndex += 1
+            } else {
+              // Ordered features
+              //   Bins correspond to feature values, so we do not need to compute splits or bins
+              //   beforehand.  Splits are constructed as needed during training.
+              splits(featureIndex) = new Array[Split](0)
+              bins(featureIndex) = new Array[Bin](0)
           featureIndex += 1
-        // Find all bins.
-        featureIndex = 0
-        while (featureIndex < numFeatures) {
-          val isFeatureContinuous = metadata.isContinuous(featureIndex)
-          if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
-            bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
-              splits(featureIndex)(0), Continuous, Double.MinValue)
-            for (index <- 1 until numBins - 1) {
-              val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
-                Continuous, Double.MinValue)
-              bins(featureIndex)(index) = bin
-            }
-            bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2),
-              new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
-          }
-          featureIndex += 1
-        }
         (splits, bins)
       case MinMax =>
         throw new UnsupportedOperationException("minmax not supported yet.")

To unsubscribe, e-mail:
For additional commands, e-mail: