You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/07/17 07:27:03 UTC

[2/2] spark git commit: [SPARK-7131] [ML] Copy Decision Tree, Random Forest impl to spark.ml

[SPARK-7131] [ML] Copy Decision Tree, Random Forest impl to spark.ml

This PR copies the RandomForest implementation from spark.mllib to spark.ml.  Note that this includes the DecisionTree implementation, but not the GradientBoostedTrees one (which will come later).

I essentially copied a minimal amount of code to spark.ml, removed the use of bins (and only used splits), and modified code only as much as necessary to get it to compile.  The spark.ml implementation still uses some spark.mllib classes (privately), which can be moved in future PRs.

This refactoring will be helpful in extending the node representation to include more information, such as class probabilities.

Specifically:
* Copied code from spark.mllib to spark.ml:
  * mllib.tree.DecisionTree, mllib.tree.RandomForest copied to ml.tree.impl.RandomForest (main implementation)
  * NodeIdCache (needed to use splits instead of bins)
  * TreePoint (use splits instead of bins)
* Added ml.tree.LearningNode used in RandomForest training (needed vars)
* Removed bins from implementation, and only used splits
* Small fix in JavaDecisionTreeRegressorSuite

CC: mengxr  manishamde  codedeft chouqin

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #7294 from jkbradley/dt-move-impl and squashes the following commits:

48749be [Joseph K. Bradley] cleanups based on code review, mostly style
bea9703 [Joseph K. Bradley] scala style fixes.  added some scala doc
4e6d2a4 [Joseph K. Bradley] removed unnecessary use of copyValues, setParent for trees
9a4d721 [Joseph K. Bradley] cleanups. removed InfoGainStats from ml, using old one for now.
836e7d4 [Joseph K. Bradley] Fixed test suite failures
bd5e063 [Joseph K. Bradley] fixed bucketizing issue
0df3759 [Joseph K. Bradley] Need to remove use of Bucketizer
d5224a9 [Joseph K. Bradley] modified tree and forest to use moved impl
cc01823 [Joseph K. Bradley] still editing RF to get it to work
19143fb [Joseph K. Bradley] More progress, but not done yet.  Rebased with master after 1.4 release.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/322d286b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/322d286b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/322d286b

Branch: refs/heads/master
Commit: 322d286bb7773389ed07df96290e427b21c775bd
Parents: f893955
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Thu Jul 16 22:26:59 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Jul 16 22:26:59 2015 -0700

----------------------------------------------------------------------
 .../classification/DecisionTreeClassifier.scala |   13 +-
 .../classification/RandomForestClassifier.scala |   16 +-
 .../ml/regression/DecisionTreeRegressor.scala   |   13 +-
 .../ml/regression/RandomForestRegressor.scala   |   15 +-
 .../scala/org/apache/spark/ml/tree/Node.scala   |  129 ++
 .../scala/org/apache/spark/ml/tree/Split.scala  |   30 +-
 .../apache/spark/ml/tree/impl/NodeIdCache.scala |  194 +++
 .../spark/ml/tree/impl/RandomForest.scala       | 1132 ++++++++++++++++++
 .../apache/spark/ml/tree/impl/TreePoint.scala   |  134 +++
 .../spark/mllib/tree/impl/BaggedPoint.scala     |   10 +-
 .../mllib/tree/impl/DTStatsAggregator.scala     |    2 +-
 .../mllib/tree/impl/DecisionTreeMetadata.scala  |    4 +-
 .../spark/mllib/tree/impl/NodeIdCache.scala     |    4 +-
 .../spark/mllib/tree/impl/TimeTracker.scala     |    2 +-
 .../spark/mllib/tree/impl/TreePoint.scala       |    4 +-
 .../spark/mllib/tree/impurity/Impurity.scala    |    4 +-
 .../mllib/tree/model/InformationGainStats.scala |    2 +-
 .../JavaDecisionTreeRegressorSuite.java         |    2 +-
 18 files changed, 1678 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/322d286b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 2dc1824..36fe1bd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
+import org.apache.spark.ml.tree.impl.RandomForest
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
 import org.apache.spark.rdd.RDD
@@ -75,8 +75,9 @@ final class DecisionTreeClassifier(override val uid: String)
     }
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
     val strategy = getOldStrategy(categoricalFeatures, numClasses)
-    val oldModel = OldDecisionTree.train(oldDataset, strategy)
-    DecisionTreeClassificationModel.fromOld(oldModel, this, categoricalFeatures)
+    val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
+      seed = 0L, parentUID = Some(uid))
+    trees.head.asInstanceOf[DecisionTreeClassificationModel]
   }
 
   /** (private[ml]) Create a Strategy instance to use with the old API. */
@@ -112,6 +113,12 @@ final class DecisionTreeClassificationModel private[ml] (
   require(rootNode != null,
     "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
 
+  /**
+   * Construct a decision tree classification model.
+   * @param rootNode  Root node of tree, with other nodes attached.
+   */
+  def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode)
+
   override protected def predict(features: Vector): Double = {
     rootNode.predict(features)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/322d286b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index d3c6749..490f04c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -20,13 +20,13 @@ package org.apache.spark.ml.classification
 import scala.collection.mutable
 
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.tree.impl.RandomForest
 import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
 import org.apache.spark.rdd.RDD
@@ -93,9 +93,10 @@ final class RandomForestClassifier(override val uid: String)
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
     val strategy =
       super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
-    val oldModel = OldRandomForest.trainClassifier(
-      oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
-    RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
+    val trees =
+      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
+        .map(_.asInstanceOf[DecisionTreeClassificationModel])
+    new RandomForestClassificationModel(trees)
   }
 
   override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
@@ -128,6 +129,13 @@ final class RandomForestClassificationModel private[ml] (
 
   require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
 
+  /**
+   * Construct a random forest classification model, with all trees weighted equally.
+   * @param trees  Component trees
+   */
+  def this(trees: Array[DecisionTreeClassificationModel]) =
+    this(Identifiable.randomUID("rfc"), trees)
+
   override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
 
   // Note: We may add support for weights (based on tree performance) later on.

http://git-wip-us.apache.org/repos/asf/spark/blob/322d286b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index be1f806..6f3340c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams}
+import org.apache.spark.ml.tree.impl.RandomForest
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
 import org.apache.spark.rdd.RDD
@@ -67,8 +67,9 @@ final class DecisionTreeRegressor(override val uid: String)
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
     val strategy = getOldStrategy(categoricalFeatures)
-    val oldModel = OldDecisionTree.train(oldDataset, strategy)
-    DecisionTreeRegressionModel.fromOld(oldModel, this, categoricalFeatures)
+    val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
+      seed = 0L, parentUID = Some(uid))
+    trees.head.asInstanceOf[DecisionTreeRegressionModel]
   }
 
   /** (private[ml]) Create a Strategy instance to use with the old API. */
@@ -102,6 +103,12 @@ final class DecisionTreeRegressionModel private[ml] (
   require(rootNode != null,
     "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
 
+  /**
+   * Construct a decision tree regression model.
+   * @param rootNode  Root node of tree, with other nodes attached.
+   */
+  def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
+
   override protected def predict(features: Vector): Double = {
     rootNode.predict(features)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/322d286b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 21c5906..5fd5c7c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
+import org.apache.spark.ml.tree.impl.RandomForest
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
 import org.apache.spark.rdd.RDD
@@ -82,9 +82,10 @@ final class RandomForestRegressor(override val uid: String)
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
     val strategy =
       super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
-    val oldModel = OldRandomForest.trainRegressor(
-      oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
-    RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
+    val trees =
+      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
+        .map(_.asInstanceOf[DecisionTreeRegressionModel])
+    new RandomForestRegressionModel(trees)
   }
 
   override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
@@ -115,6 +116,12 @@ final class RandomForestRegressionModel private[ml] (
 
   require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
 
+  /**
+   * Construct a random forest regression model, with all trees weighted equally.
+   * @param trees  Component trees
+   */
+  def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees)
+
   override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
 
   // Note: We may add support for weights (based on tree performance) later on.

http://git-wip-us.apache.org/repos/asf/spark/blob/322d286b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index 4242154..bbc2427 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -209,3 +209,132 @@ private object InternalNode {
     }
   }
 }
+
+/**
+ * Version of a node used in learning.  This uses vars so that we can modify nodes as we split the
+ * tree by adding children, etc.
+ *
+ * For now, we use node IDs.  These will be kept internal since we hope to remove node IDs
+ * in the future, or at least change the indexing (so that we can support much deeper trees).
+ *
+ * This node can either be:
+ *  - a leaf node, with leftChild, rightChild, split set to null, or
+ *  - an internal node, with all values set
+ *
+ * @param id  We currently use the same indexing as the old implementation in
+ *            [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
+ * @param predictionStats  Predicted label + class probability (for classification).
+ *                         We will later modify this to store aggregate statistics for labels
+ *                         to provide all class probabilities (for classification) and maybe a
+ *                         distribution (for regression).
+ * @param isLeaf  Indicates whether this node will definitely be a leaf in the learned tree,
+ *                so that we do not need to consider splitting it further.
+ * @param stats  Old structure for storing stats about information gain, prediction, etc.
+ *               This is legacy and will be modified in the future.
+ */
+private[tree] class LearningNode(
+    var id: Int,
+    var predictionStats: OldPredict,
+    var impurity: Double,
+    var leftChild: Option[LearningNode],
+    var rightChild: Option[LearningNode],
+    var split: Option[Split],
+    var isLeaf: Boolean,
+    var stats: Option[OldInformationGainStats]) extends Serializable {
+
+  /**
+   * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
+   */
+  def toNode: Node = {
+    if (leftChild.nonEmpty) {
+      assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty,
+        "Unknown error during Decision Tree learning.  Could not convert LearningNode to Node.")
+      new InternalNode(predictionStats.predict, impurity, stats.get.gain,
+        leftChild.get.toNode, rightChild.get.toNode, split.get)
+    } else {
+      new LeafNode(predictionStats.predict, impurity)
+    }
+  }
+
+}
+
+private[tree] object LearningNode {
+
+  /** Create a node with some of its fields set. */
+  def apply(
+      id: Int,
+      predictionStats: OldPredict,
+      impurity: Double,
+      isLeaf: Boolean): LearningNode = {
+    new LearningNode(id, predictionStats, impurity, None, None, None, false, None)
+  }
+
+  /** Create an empty node with the given node index.  Values must be set later on. */
+  def emptyNode(nodeIndex: Int): LearningNode = {
+    new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN,
+      None, None, None, false, None)
+  }
+
+  // The below indexing methods were copied from spark.mllib.tree.model.Node
+
+  /**
+   * Return the index of the left child of this node.
+   */
+  def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
+
+  /**
+   * Return the index of the right child of this node.
+   */
+  def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
+
+  /**
+   * Get the parent index of the given node, or 0 if it is the root.
+   */
+  def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1
+
+  /**
+   * Return the level of a tree which the given node is in.
+   */
+  def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) {
+    throw new IllegalArgumentException(s"0 is not a valid node index.")
+  } else {
+    java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex))
+  }
+
+  /**
+   * Returns true if this is a left child.
+   * Note: Returns false for the root.
+   */
+  def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0
+
+  /**
+   * Return the maximum number of nodes which can be in the given level of the tree.
+   * @param level  Level of tree (0 = root).
+   */
+  def maxNodesInLevel(level: Int): Int = 1 << level
+
+  /**
+   * Return the index of the first node in the given level.
+   * @param level  Level of tree (0 = root).
+   */
+  def startIndexInLevel(level: Int): Int = 1 << level
+
+  /**
+   * Traces down from a root node to get the node with the given node index.
+   * This assumes the node exists.
+   */
+  def getNode(nodeIndex: Int, rootNode: LearningNode): LearningNode = {
+    var tmpNode: LearningNode = rootNode
+    var levelsToGo = indexToLevel(nodeIndex)
+    while (levelsToGo > 0) {
+      if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
+        tmpNode = tmpNode.leftChild.asInstanceOf[LearningNode]
+      } else {
+        tmpNode = tmpNode.rightChild.asInstanceOf[LearningNode]
+      }
+      levelsToGo -= 1
+    }
+    tmpNode
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/322d286b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index 7acdeee..78199cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -34,9 +34,19 @@ sealed trait Split extends Serializable {
   /** Index of feature which this split tests */
   def featureIndex: Int
 
-  /** Return true (split to left) or false (split to right) */
+  /**
+   * Return true (split to left) or false (split to right).
+   * @param features  Vector of features (original values, not binned).
+   */
   private[ml] def shouldGoLeft(features: Vector): Boolean
 
+  /**
+   * Return true (split to left) or false (split to right).
+   * @param binnedFeature Binned feature value.
+   * @param splits All splits for the given feature.
+   */
+  private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean
+
   /** Convert to old Split format */
   private[tree] def toOld: OldSplit
 }
@@ -94,6 +104,14 @@ final class CategoricalSplit private[ml] (
     }
   }
 
+  override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = {
+    if (isLeft) {
+      categories.contains(binnedFeature.toDouble)
+    } else {
+      !categories.contains(binnedFeature.toDouble)
+    }
+  }
+
   override def equals(o: Any): Boolean = {
     o match {
       case other: CategoricalSplit => featureIndex == other.featureIndex &&
@@ -144,6 +162,16 @@ final class ContinuousSplit private[ml] (override val featureIndex: Int, val thr
     features(featureIndex) <= threshold
   }
 
+  override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = {
+    if (binnedFeature == splits.length) {
+      // > last split, so split right
+      false
+    } else {
+      val featureValueUpperBound = splits(binnedFeature).asInstanceOf[ContinuousSplit].threshold
+      featureValueUpperBound <= threshold
+    }
+  }
+
   override def equals(o: Any): Boolean = {
     o match {
       case other: ContinuousSplit =>

http://git-wip-us.apache.org/repos/asf/spark/blob/322d286b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
new file mode 100644
index 0000000..488e8e4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import java.io.IOException
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.tree.{LearningNode, Split}
+import org.apache.spark.mllib.tree.impl.BaggedPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This is used by the node id cache to find the child id that a data point would belong to.
+ * @param split Split information.
+ * @param nodeIndex The current node index of a data point that this will update.
+ */
+private[tree] case class NodeIndexUpdater(split: Split, nodeIndex: Int) {
+
+  /**
+   * Determine a child node index based on the feature value and the split.
+   * @param binnedFeature Binned feature value.
+   * @param splits Split information to convert the bin indices to approximate feature values.
+   * @return Child node index to update to.
+   */
+  def updateNodeIndex(binnedFeature: Int, splits: Array[Split]): Int = {
+    if (split.shouldGoLeft(binnedFeature, splits)) {
+      LearningNode.leftChildIndex(nodeIndex)
+    } else {
+      LearningNode.rightChildIndex(nodeIndex)
+    }
+  }
+}
+
+/**
+ * Each TreePoint belongs to a particular node per tree.
+ * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
+ * in each tree. Initially, values should all be 1 for root node.
+ * The nodeIdsForInstances RDD needs to be updated at each iteration.
+ * @param nodeIdsForInstances The initial values in the cache
+ *                           (should be an Array of all 1's (meaning the root nodes)).
+ * @param checkpointInterval The checkpointing interval
+ *                           (how often should the cache be checkpointed.).
+ */
+private[spark] class NodeIdCache(
+  var nodeIdsForInstances: RDD[Array[Int]],
+  val checkpointInterval: Int) extends Logging {
+
+  // Keep a reference to a previous node Ids for instances.
+  // Because we will keep on re-persisting updated node Ids,
+  // we want to unpersist the previous RDD.
+  private var prevNodeIdsForInstances: RDD[Array[Int]] = null
+
+  // To keep track of the past checkpointed RDDs.
+  private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
+  private var rddUpdateCount = 0
+
+  // Indicates whether we can checkpoint
+  private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty
+
+  // FileSystem instance for deleting checkpoints as needed
+  private val fs = FileSystem.get(nodeIdsForInstances.sparkContext.hadoopConfiguration)
+
+  /**
+   * Update the node index values in the cache.
+   * This updates the RDD and its lineage.
+   * TODO: Passing bin information to executors seems unnecessary and costly.
+   * @param data The RDD of training rows.
+   * @param nodeIdUpdaters A map of node index updaters.
+   *                       The key is the indices of nodes that we want to update.
+   * @param splits  Split information needed to find child node indices.
+   */
+  def updateNodeIndices(
+      data: RDD[BaggedPoint[TreePoint]],
+      nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
+      splits: Array[Array[Split]]): Unit = {
+    if (prevNodeIdsForInstances != null) {
+      // Unpersist the previous one if one exists.
+      prevNodeIdsForInstances.unpersist()
+    }
+
+    prevNodeIdsForInstances = nodeIdsForInstances
+    nodeIdsForInstances = data.zip(nodeIdsForInstances).map { case (point, ids) =>
+      var treeId = 0
+      while (treeId < nodeIdUpdaters.length) {
+        val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(ids(treeId), null)
+        if (nodeIdUpdater != null) {
+          val featureIndex = nodeIdUpdater.split.featureIndex
+          val newNodeIndex = nodeIdUpdater.updateNodeIndex(
+            binnedFeature = point.datum.binnedFeatures(featureIndex),
+            splits = splits(featureIndex))
+          ids(treeId) = newNodeIndex
+        }
+        treeId += 1
+      }
+      ids
+    }
+
+    // Keep on persisting new ones.
+    nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
+    rddUpdateCount += 1
+
+    // Handle checkpointing if the directory is not None.
+    if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) {
+      // Let's see if we can delete previous checkpoints.
+      var canDelete = true
+      while (checkpointQueue.size > 1 && canDelete) {
+        // We can delete the oldest checkpoint iff
+        // the next checkpoint actually exists in the file system.
+        if (checkpointQueue(1).getCheckpointFile.isDefined) {
+          val old = checkpointQueue.dequeue()
+          // Since the old checkpoint is not deleted by Spark, we'll manually delete it here.
+          try {
+            fs.delete(new Path(old.getCheckpointFile.get), true)
+          } catch {
+            case e: IOException =>
+              logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
+                s" file: ${old.getCheckpointFile.get}")
+          }
+        } else {
+          canDelete = false
+        }
+      }
+
+      nodeIdsForInstances.checkpoint()
+      checkpointQueue.enqueue(nodeIdsForInstances)
+    }
+  }
+
+  /**
+   * Call this after training is finished to delete any remaining checkpoints.
+   */
+  def deleteAllCheckpoints(): Unit = {
+    while (checkpointQueue.nonEmpty) {
+      val old = checkpointQueue.dequeue()
+      if (old.getCheckpointFile.isDefined) {
+        try {
+          fs.delete(new Path(old.getCheckpointFile.get), true)
+        } catch {
+          case e: IOException =>
+            logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
+              s" file: ${old.getCheckpointFile.get}")
+        }
+      }
+    }
+  }
+  if (prevNodeIdsForInstances != null) {
+    // Unpersist the previous one if one exists.
+    prevNodeIdsForInstances.unpersist()
+  }
+}
+
+@DeveloperApi
+private[spark] object NodeIdCache {
+  /**
+   * Initialize the node Id cache with initial node Id values.
+   * @param data The RDD of training rows.
+   * @param numTrees The number of trees that we want to create cache for.
+   * @param checkpointInterval The checkpointing interval
+   *                           (how often should the cache be checkpointed.).
+   * @param initVal The initial values in the cache.
+   * @return A node Id cache containing an RDD of initial root node Indices.
+   */
+  def init(
+      data: RDD[BaggedPoint[TreePoint]],
+      numTrees: Int,
+      checkpointInterval: Int,
+      initVal: Int = 1): NodeIdCache = {
+    new NodeIdCache(
+      data.map(_ => Array.fill[Int](numTrees)(initVal)),
+      checkpointInterval)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/322d286b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
new file mode 100644
index 0000000..15b56bd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -0,0 +1,1132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import java.io.IOException
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.Logging
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
+  TimeTracker}
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
+import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
+
+
+private[ml] object RandomForest extends Logging {
+
+  /**
+   * Train a random forest.
+   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+   * @return an unweighted set of trees
+   */
+  def run(
+      input: RDD[LabeledPoint],
+      strategy: OldStrategy,
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      seed: Long,
+      parentUID: Option[String] = None): Array[DecisionTreeModel] = {
+
+    val timer = new TimeTracker()
+
+    timer.start("total")
+
+    timer.start("init")
+
+    val retaggedInput = input.retag(classOf[LabeledPoint])
+    val metadata =
+      DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
+    logDebug("algo = " + strategy.algo)
+    logDebug("numTrees = " + numTrees)
+    logDebug("seed = " + seed)
+    logDebug("maxBins = " + metadata.maxBins)
+    logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
+    logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
+    logDebug("subsamplingRate = " + strategy.subsamplingRate)
+
+    // Find the splits and the corresponding bins (interval between the splits) using a sample
+    // of the input data.
+    timer.start("findSplitsBins")
+    val splits = findSplits(retaggedInput, metadata)
+    timer.stop("findSplitsBins")
+    logDebug("numBins: feature: number of bins")
+    logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+      s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+    }.mkString("\n"))
+
+    // Bin feature values (TreePoint representation).
+    // Cache input RDD for speedup during multiple passes.
+    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata)
+
+    val withReplacement = numTrees > 1
+
+    val baggedInput = BaggedPoint
+      .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)
+      .persist(StorageLevel.MEMORY_AND_DISK)
+
+    // depth of the decision tree
+    val maxDepth = strategy.maxDepth
+    require(maxDepth <= 30,
+      s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+
+    // Max memory usage for aggregates
+    // TODO: Calculate memory usage more precisely.
+    val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
+    logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
+    val maxMemoryPerNode = {
+      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+        // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
+        Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
+          .take(metadata.numFeaturesPerNode).map(_._2))
+      } else {
+        None
+      }
+      RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
+    }
+    require(maxMemoryPerNode <= maxMemoryUsage,
+      s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
+        " which is too small for the given features." +
+        s"  Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
+
+    timer.stop("init")
+
+    /*
+     * The main idea here is to perform group-wise training of the decision tree nodes thus
+     * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
+     * Each data sample is handled by a particular node (or it reaches a leaf and is not used
+     * in lower levels).
+     */
+
+    // Create an RDD of node Id cache.
+    // At first, all the rows belong to the root nodes (node Id == 1).
+    val nodeIdCache = if (strategy.useNodeIdCache) {
+      Some(NodeIdCache.init(
+        data = baggedInput,
+        numTrees = numTrees,
+        checkpointInterval = strategy.checkpointInterval,
+        initVal = 1))
+    } else {
+      None
+    }
+
+    // FIFO queue of nodes to train: (treeIndex, node)
+    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+
+    val rng = new Random()
+    rng.setSeed(seed)
+
+    // Allocate and queue root nodes.
+    val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
+    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
+
+    while (nodeQueue.nonEmpty) {
+      // Collect some nodes to split, and choose features for each node (if subsampling).
+      // Each group of nodes may come from one or multiple trees, and at multiple levels.
+      val (nodesForGroup, treeToNodeToIndexInfo) =
+        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+      // Sanity check (should never occur):
+      assert(nodesForGroup.nonEmpty,
+        s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")
+
+      // Choose node splits, and enqueue new nodes as needed.
+      timer.start("findBestSplits")
+      RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
+        treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
+      timer.stop("findBestSplits")
+    }
+
+    baggedInput.unpersist()
+
+    timer.stop("total")
+
+    logInfo("Internal timing for DecisionTree:")
+    logInfo(s"$timer")
+
+    // Delete any remaining checkpoints used for node Id cache.
+    if (nodeIdCache.nonEmpty) {
+      try {
+        nodeIdCache.get.deleteAllCheckpoints()
+      } catch {
+        case e: IOException =>
+          logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
+      }
+    }
+
+    parentUID match {
+      case Some(uid) =>
+        if (strategy.algo == OldAlgo.Classification) {
+          topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode))
+        } else {
+          topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
+        }
+      case None =>
+        if (strategy.algo == OldAlgo.Classification) {
+          topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode))
+        } else {
+          topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
+        }
+    }
+  }
+
+  /**
+   * Get the node index corresponding to this data point.
+   * This function mimics prediction, passing an example from the root node down to a leaf
+   * or unsplit node; that node's index is returned.
+   *
+   * @param node  Node in tree from which to classify the given data point.
+   * @param binnedFeatures  Binned feature vector for data point.
+   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+   * @return  Leaf index if the data point reaches a leaf.
+   *          Otherwise, last node reachable in tree matching this example.
+   *          Note: This is the global node index, i.e., the index used in the tree.
+   *                This index is different from the index used during training a particular
+   *                group of nodes on one call to [[findBestSplits()]].
+   */
+  private def predictNodeIndex(
+      node: LearningNode,
+      binnedFeatures: Array[Int],
+      splits: Array[Array[Split]]): Int = {
+    if (node.isLeaf || node.split.isEmpty) {
+      node.id
+    } else {
+      val split = node.split.get
+      val featureIndex = split.featureIndex
+      val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
+      if (node.leftChild.isEmpty) {
+        // Not yet split. Return index from next layer of nodes to train
+        if (splitLeft) {
+          LearningNode.leftChildIndex(node.id)
+        } else {
+          LearningNode.rightChildIndex(node.id)
+        }
+      } else {
+        if (splitLeft) {
+          predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
+        } else {
+          predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
+        }
+      }
+    }
+  }
+
+  /**
+   * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
+   *
+   * For ordered features, a single bin is updated.
+   * For unordered features, bins correspond to subsets of categories; either the left or right bin
+   * for each subset is updated.
+   *
+   * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
+   *             each (feature, bin).
+   * @param treePoint  Data point being aggregated.
+   * @param splits possible splits indexed (numFeatures)(numSplits)
+   * @param unorderedFeatures  Set of indices of unordered features.
+   * @param instanceWeight  Weight (importance) of instance in dataset.
+   */
+  private def mixedBinSeqOp(
+      agg: DTStatsAggregator,
+      treePoint: TreePoint,
+      splits: Array[Array[Split]],
+      unorderedFeatures: Set[Int],
+      instanceWeight: Double,
+      featuresForNode: Option[Array[Int]]): Unit = {
+    val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
+      // Use subsampled features
+      featuresForNode.get.length
+    } else {
+      // Use all features
+      agg.metadata.numFeatures
+    }
+    // Iterate over features.
+    var featureIndexIdx = 0
+    while (featureIndexIdx < numFeaturesPerNode) {
+      val featureIndex = if (featuresForNode.nonEmpty) {
+        featuresForNode.get.apply(featureIndexIdx)
+      } else {
+        featureIndexIdx
+      }
+      if (unorderedFeatures.contains(featureIndex)) {
+        // Unordered feature
+        val featureValue = treePoint.binnedFeatures(featureIndex)
+        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
+          agg.getLeftRightFeatureOffsets(featureIndexIdx)
+        // Update the left or right bin for each split.
+        val numSplits = agg.metadata.numSplits(featureIndex)
+        val featureSplits = splits(featureIndex)
+        var splitIndex = 0
+        while (splitIndex < numSplits) {
+          if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
+            agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
+          } else {
+            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
+          }
+          splitIndex += 1
+        }
+      } else {
+        // Ordered feature
+        val binIndex = treePoint.binnedFeatures(featureIndex)
+        agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
+      }
+      featureIndexIdx += 1
+    }
+  }
+
+  /**
+   * Helper for binSeqOp, for regression and for classification with only ordered features.
+   *
+   * For each feature, the sufficient statistics of one bin are updated.
+   *
+   * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
+   *             each (feature, bin).
+   * @param treePoint  Data point being aggregated.
+   * @param instanceWeight  Weight (importance) of instance in dataset.
+   */
+  private def orderedBinSeqOp(
+      agg: DTStatsAggregator,
+      treePoint: TreePoint,
+      instanceWeight: Double,
+      featuresForNode: Option[Array[Int]]): Unit = {
+    val label = treePoint.label
+
+    // Iterate over features.
+    if (featuresForNode.nonEmpty) {
+      // Use subsampled features
+      var featureIndexIdx = 0
+      while (featureIndexIdx < featuresForNode.get.length) {
+        val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
+        agg.update(featureIndexIdx, binIndex, label, instanceWeight)
+        featureIndexIdx += 1
+      }
+    } else {
+      // Use all features
+      val numFeatures = agg.metadata.numFeatures
+      var featureIndex = 0
+      while (featureIndex < numFeatures) {
+        val binIndex = treePoint.binnedFeatures(featureIndex)
+        agg.update(featureIndex, binIndex, label, instanceWeight)
+        featureIndex += 1
+      }
+    }
+  }
+
+  /**
+   * Given a group of nodes, this finds the best split for each node.
+   *
+   * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
+   * @param metadata Learning and dataset metadata
+   * @param topNodes Root node for each tree.  Used for matching instances with nodes.
+   * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
+   * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
+   *                              where nodeIndexInfo stores the index in the group and the
+   *                              feature subsets (if using feature subsets).
+   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+   * @param nodeQueue  Queue of nodes to split, with values (treeIndex, node).
+   *                   Updated with new non-leaf nodes which are created.
+   * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
+   *                    each value in the array is the data point's node Id
+   *                    for a corresponding tree. This is used to prevent the need
+   *                    to pass the entire tree to the executors during
+   *                    the node stat aggregation phase.
+   */
+  private[tree] def findBestSplits(
+      input: RDD[BaggedPoint[TreePoint]],
+      metadata: DecisionTreeMetadata,
+      topNodes: Array[LearningNode],
+      nodesForGroup: Map[Int, Array[LearningNode]],
+      treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
+      splits: Array[Array[Split]],
+      nodeQueue: mutable.Queue[(Int, LearningNode)],
+      timer: TimeTracker = new TimeTracker,
+      nodeIdCache: Option[NodeIdCache] = None): Unit = {
+
+    /*
+     * The high-level descriptions of the best split optimizations are noted here.
+     *
+     * *Group-wise training*
+     * We perform bin calculations for groups of nodes to reduce the number of
+     * passes over the data.  Each iteration requires more computation and storage,
+     * but saves several iterations over the data.
+     *
+     * *Bin-wise computation*
+     * We use a bin-wise best split computation strategy instead of a straightforward best split
+     * computation strategy. Instead of analyzing each sample for contribution to the left/right
+     * child node impurity of every split, we first categorize each feature of a sample into a
+     * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
+     * to calculate information gain for each split.
+     *
+     * *Aggregation over partitions*
+     * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
+     * the number of splits in advance. Thus, we store the aggregates (at the appropriate
+     * indices) in a single array for all bins and rely upon the RDD aggregate method to
+     * drastically reduce the communication overhead.
+     */
+
+    // numNodes:  Number of nodes in this group
+    val numNodes = nodesForGroup.values.map(_.length).sum
+    logDebug("numNodes = " + numNodes)
+    logDebug("numFeatures = " + metadata.numFeatures)
+    logDebug("numClasses = " + metadata.numClasses)
+    logDebug("isMulticlass = " + metadata.isMulticlass)
+    logDebug("isMulticlassWithCategoricalFeatures = " +
+      metadata.isMulticlassWithCategoricalFeatures)
+    logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
+
+    /**
+     * Performs a sequential aggregation over a partition for a particular tree and node.
+     *
+     * For each feature, the aggregate sufficient statistics are updated for the relevant
+     * bins.
+     *
+     * @param treeIndex Index of the tree that we want to perform aggregation for.
+     * @param nodeInfo The node info for the tree node.
+     * @param agg Array storing aggregate calculation, with a set of sufficient statistics
+     *            for each (node, feature, bin).
+     * @param baggedPoint Data point being aggregated.
+     */
+    def nodeBinSeqOp(
+        treeIndex: Int,
+        nodeInfo: NodeIndexInfo,
+        agg: Array[DTStatsAggregator],
+        baggedPoint: BaggedPoint[TreePoint]): Unit = {
+      if (nodeInfo != null) {
+        val aggNodeIndex = nodeInfo.nodeIndexInGroup
+        val featuresForNode = nodeInfo.featureSubset
+        val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+        if (metadata.unorderedFeatures.isEmpty) {
+          orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
+        } else {
+          mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
+            metadata.unorderedFeatures, instanceWeight, featuresForNode)
+        }
+      }
+    }
+
+    /**
+     * Performs a sequential aggregation over a partition.
+     *
+     * Each data point contributes to one node. For each feature,
+     * the aggregate sufficient statistics are updated for the relevant bins.
+     *
+     * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
+     *             each (node, feature, bin).
+     * @param baggedPoint   Data point being aggregated.
+     * @return  agg
+     */
+    def binSeqOp(
+        agg: Array[DTStatsAggregator],
+        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
+      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+        val nodeIndex =
+          predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits)
+        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
+      }
+      agg
+    }
+
+    /**
+     * Do the same thing as binSeqOp, but with nodeIdCache.
+     */
+    def binSeqOpWithNodeIdCache(
+        agg: Array[DTStatsAggregator],
+        dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
+      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+        val baggedPoint = dataPoint._1
+        val nodeIdCache = dataPoint._2
+        val nodeIndex = nodeIdCache(treeIndex)
+        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
+      }
+
+      agg
+    }
+
+    /**
+     * Get node index in group --> features indices map,
+     * which is a short cut to find feature indices for a node given node index in group.
+     */
+    def getNodeToFeatures(
+        treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = {
+      if (!metadata.subsamplingFeatures) {
+        None
+      } else {
+        val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
+        treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
+          nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
+            assert(nodeIndexInfo.featureSubset.isDefined)
+            mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
+          }
+        }
+        Some(mutableNodeToFeatures.toMap)
+      }
+    }
+
+    // array of nodes to train indexed by node index in group
+    val nodes = new Array[LearningNode](numNodes)
+    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+      nodesForTree.foreach { node =>
+        nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
+      }
+    }
+
+    // Calculate best splits for all nodes in the group
+    timer.start("chooseSplits")
+
+    // In each partition, iterate all instances and compute aggregate stats for each node,
+    // yield an (nodeIndex, nodeAggregateStats) pair for each node.
+    // After a `reduceByKey` operation,
+    // stats of a node will be shuffled to a particular partition and be combined together,
+    // then best splits for nodes are found there.
+    // Finally, only best Splits for nodes are collected to driver to construct decision tree.
+    val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
+    val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
+
+    val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
+      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
+        // Construct a nodeStatsAggregators array to hold node aggregate stats,
+        // each node will have a nodeStatsAggregator
+        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+            Some(nodeToFeatures(nodeIndex))
+          }
+          new DTStatsAggregator(metadata, featuresForNode)
+        }
+
+        // iterator all instances in current partition and update aggregate stats
+        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
+
+        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+        // which can be combined with other partition using `reduceByKey`
+        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+      }
+    } else {
+      input.mapPartitions { points =>
+        // Construct a nodeStatsAggregators array to hold node aggregate stats,
+        // each node will have a nodeStatsAggregator
+        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+            Some(nodeToFeatures(nodeIndex))
+          }
+          new DTStatsAggregator(metadata, featuresForNode)
+        }
+
+        // iterator all instances in current partition and update aggregate stats
+        points.foreach(binSeqOp(nodeStatsAggregators, _))
+
+        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+        // which can be combined with other partition using `reduceByKey`
+        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+      }
+    }
+
+    val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {
+      case (nodeIndex, aggStats) =>
+        val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+          Some(nodeToFeatures(nodeIndex))
+        }
+
+        // find best split for each node
+        val (split: Split, stats: InformationGainStats, predict: Predict) =
+          binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
+        (nodeIndex, (split, stats, predict))
+    }.collectAsMap()
+
+    timer.stop("chooseSplits")
+
+    val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
+      Array.fill[mutable.Map[Int, NodeIndexUpdater]](
+        metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
+    } else {
+      null
+    }
+    // Iterate over all nodes in this group.
+    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+      nodesForTree.foreach { node =>
+        val nodeIndex = node.id
+        val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
+        val aggNodeIndex = nodeInfo.nodeIndexInGroup
+        val (split: Split, stats: InformationGainStats, predict: Predict) =
+          nodeToBestSplits(aggNodeIndex)
+        logDebug("best split = " + split)
+
+        // Extract info for this node.  Create children if not leaf.
+        val isLeaf =
+          (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
+        node.predictionStats = predict
+        node.isLeaf = isLeaf
+        node.stats = Some(stats)
+        node.impurity = stats.impurity
+        logDebug("Node = " + node)
+
+        if (!isLeaf) {
+          node.split = Some(split)
+          val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+          val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
+          val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+          node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
+            stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
+          node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
+            stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+
+          if (nodeIdCache.nonEmpty) {
+            val nodeIndexUpdater = NodeIndexUpdater(
+              split = split,
+              nodeIndex = nodeIndex)
+            nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
+          }
+
+          // enqueue left child and right child if they are not leaves
+          if (!leftChildIsLeaf) {
+            nodeQueue.enqueue((treeIndex, node.leftChild.get))
+          }
+          if (!rightChildIsLeaf) {
+            nodeQueue.enqueue((treeIndex, node.rightChild.get))
+          }
+
+          logDebug("leftChildIndex = " + node.leftChild.get.id +
+            ", impurity = " + stats.leftImpurity)
+          logDebug("rightChildIndex = " + node.rightChild.get.id +
+            ", impurity = " + stats.rightImpurity)
+        }
+      }
+    }
+
+    if (nodeIdCache.nonEmpty) {
+      // Update the cache if needed.
+      nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, splits)
+    }
+  }
+
+  /**
+   * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
+   * @param leftImpurityCalculator left node aggregates for this (feature, split)
+   * @param rightImpurityCalculator right node aggregate for this (feature, split)
+   * @return information gain and statistics for split
+   */
+  private def calculateGainForSplit(
+      leftImpurityCalculator: ImpurityCalculator,
+      rightImpurityCalculator: ImpurityCalculator,
+      metadata: DecisionTreeMetadata,
+      impurity: Double): InformationGainStats = {
+    val leftCount = leftImpurityCalculator.count
+    val rightCount = rightImpurityCalculator.count
+
+    // If left child or right child doesn't satisfy minimum instances per node,
+    // then this split is invalid, return invalid information gain stats.
+    if ((leftCount < metadata.minInstancesPerNode) ||
+      (rightCount < metadata.minInstancesPerNode)) {
+      return InformationGainStats.invalidInformationGainStats
+    }
+
+    val totalCount = leftCount + rightCount
+
+    val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
+    val rightImpurity = rightImpurityCalculator.calculate()
+
+    val leftWeight = leftCount / totalCount.toDouble
+    val rightWeight = rightCount / totalCount.toDouble
+
+    val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+
+    // if information gain doesn't satisfy minimum information gain,
+    // then this split is invalid, return invalid information gain stats.
+    if (gain < metadata.minInfoGain) {
+      return InformationGainStats.invalidInformationGainStats
+    }
+
+    // calculate left and right predict
+    val leftPredict = calculatePredict(leftImpurityCalculator)
+    val rightPredict = calculatePredict(rightImpurityCalculator)
+
+    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
+      leftPredict, rightPredict)
+  }
+
+  private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
+    val predict = impurityCalculator.predict
+    val prob = impurityCalculator.prob(predict)
+    new Predict(predict, prob)
+  }
+
+  /**
+   * Calculate predict value for current node, given stats of any split.
+   * Note that this function is called only once for each node.
+   * @param leftImpurityCalculator left node aggregates for a split
+   * @param rightImpurityCalculator right node aggregates for a split
+   * @return predict value and impurity for current node
+   */
+  private def calculatePredictImpurity(
+      leftImpurityCalculator: ImpurityCalculator,
+      rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
+    val parentNodeAgg = leftImpurityCalculator.copy
+    parentNodeAgg.add(rightImpurityCalculator)
+    val predict = calculatePredict(parentNodeAgg)
+    val impurity = parentNodeAgg.calculate()
+
+    (predict, impurity)
+  }
+
+  /**
+   * Find the best split for a node.
+   * @param binAggregates Bin statistics.
+   * @return tuple for best split: (Split, information gain, prediction at node)
+   */
+  private def binsToBestSplit(
+      binAggregates: DTStatsAggregator,
+      splits: Array[Array[Split]],
+      featuresForNode: Option[Array[Int]],
+      node: LearningNode): (Split, InformationGainStats, Predict) = {
+
+    // Calculate prediction and impurity if current node is top node
+    val level = LearningNode.indexToLevel(node.id)
+    var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) {
+      None
+    } else {
+      Some((node.predictionStats, node.impurity))
+    }
+
+    // For each (feature, split), calculate the gain, and select the best (feature, split).
+    val (bestSplit, bestSplitStats) =
+      Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
+        val featureIndex = if (featuresForNode.nonEmpty) {
+          featuresForNode.get.apply(featureIndexIdx)
+        } else {
+          featureIndexIdx
+        }
+        val numSplits = binAggregates.metadata.numSplits(featureIndex)
+        if (binAggregates.metadata.isContinuous(featureIndex)) {
+          // Cumulative sum (scanLeft) of bin statistics.
+          // Afterwards, binAggregates for a bin is the sum of aggregates for
+          // that bin + all preceding bins.
+          val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+          var splitIndex = 0
+          while (splitIndex < numSplits) {
+            binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+            splitIndex += 1
+          }
+          // Find best split.
+          val (bestFeatureSplitIndex, bestFeatureGainStats) =
+            Range(0, numSplits).map { case splitIdx =>
+              val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+              val rightChildStats =
+                binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
+              rightChildStats.subtract(leftChildStats)
+              predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
+                calculatePredictImpurity(leftChildStats, rightChildStats)))
+              val gainStats = calculateGainForSplit(leftChildStats,
+                rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
+              (splitIdx, gainStats)
+            }.maxBy(_._2.gain)
+          (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+        } else if (binAggregates.metadata.isUnordered(featureIndex)) {
+          // Unordered categorical feature
+          val (leftChildOffset, rightChildOffset) =
+            binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+          val (bestFeatureSplitIndex, bestFeatureGainStats) =
+            Range(0, numSplits).map { splitIndex =>
+              val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
+              val rightChildStats =
+                binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+              predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
+                calculatePredictImpurity(leftChildStats, rightChildStats)))
+              val gainStats = calculateGainForSplit(leftChildStats,
+                rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
+              (splitIndex, gainStats)
+            }.maxBy(_._2.gain)
+          (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+        } else {
+          // Ordered categorical feature
+          val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+          val numCategories = binAggregates.metadata.numBins(featureIndex)
+
+          /* Each bin is one category (feature value).
+           * The bins are ordered based on centroidForCategories, and this ordering determines which
+           * splits are considered.  (With K categories, we consider K - 1 possible splits.)
+           *
+           * centroidForCategories is a list: (category, centroid)
+           */
+          val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
+            // For categorical variables in multiclass classification,
+            // the bins are ordered by the impurity of their corresponding labels.
+            Range(0, numCategories).map { case featureValue =>
+              val categoryStats =
+                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+              val centroid = if (categoryStats.count != 0) {
+                categoryStats.calculate()
+              } else {
+                Double.MaxValue
+              }
+              (featureValue, centroid)
+            }
+          } else { // regression or binary classification
+            // For categorical variables in regression and binary classification,
+            // the bins are ordered by the centroid of their corresponding labels.
+            Range(0, numCategories).map { case featureValue =>
+              val categoryStats =
+                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+              val centroid = if (categoryStats.count != 0) {
+                categoryStats.predict
+              } else {
+                Double.MaxValue
+              }
+              (featureValue, centroid)
+            }
+          }
+
+          logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
+
+          // bins sorted by centroids
+          val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
+
+          logDebug("Sorted centroids for categorical variable = " +
+            categoriesSortedByCentroid.mkString(","))
+
+          // Cumulative sum (scanLeft) of bin statistics.
+          // Afterwards, binAggregates for a bin is the sum of aggregates for
+          // that bin + all preceding bins.
+          var splitIndex = 0
+          while (splitIndex < numSplits) {
+            val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+            val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+            binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
+            splitIndex += 1
+          }
+          // lastCategory = index of bin with total aggregates for this (node, feature)
+          val lastCategory = categoriesSortedByCentroid.last._1
+          // Find best split.
+          val (bestFeatureSplitIndex, bestFeatureGainStats) =
+            Range(0, numSplits).map { splitIndex =>
+              val featureValue = categoriesSortedByCentroid(splitIndex)._1
+              val leftChildStats =
+                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+              val rightChildStats =
+                binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
+              rightChildStats.subtract(leftChildStats)
+              predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
+                calculatePredictImpurity(leftChildStats, rightChildStats)))
+              val gainStats = calculateGainForSplit(leftChildStats,
+                rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
+              (splitIndex, gainStats)
+            }.maxBy(_._2.gain)
+          val categoriesForSplit =
+            categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
+          val bestFeatureSplit =
+            new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
+          (bestFeatureSplit, bestFeatureGainStats)
+        }
+      }.maxBy(_._2.gain)
+
+    (bestSplit, bestSplitStats, predictionAndImpurity.get._1)
+  }
+
+  /**
+   * Returns splits and bins for decision tree calculation.
+   * Continuous and categorical features are handled differently.
+   *
+   * Continuous features:
+   *   For each feature, there are numBins - 1 possible splits representing the possible binary
+   *   decisions at each node in the tree.
+   *   This finds locations (feature values) for splits using a subsample of the data.
+   *
+   * Categorical features:
+   *   For each feature, there is 1 bin per split.
+   *   Splits and bins are handled in 2 ways:
+   *   (a) "unordered features"
+   *       For multiclass classification with a low-arity feature
+   *       (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
+   *       the feature is split based on subsets of categories.
+   *   (b) "ordered features"
+   *       For regression and binary classification,
+   *       and for multiclass classification with a high-arity feature,
+   *       there is one bin per category.
+   *
+   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+   * @param metadata Learning and dataset metadata
+   * @return A tuple of (splits, bins).
+   *         Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
+   *          of size (numFeatures, numSplits).
+   *         Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
+   *          of size (numFeatures, numBins).
+   */
+  protected[tree] def findSplits(
+      input: RDD[LabeledPoint],
+      metadata: DecisionTreeMetadata): Array[Array[Split]] = {
+
+    logDebug("isMulticlass = " + metadata.isMulticlass)
+
+    val numFeatures = metadata.numFeatures
+
+    // Sample the input only if there are continuous features.
+    val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
+    val sampledInput = if (hasContinuousFeatures) {
+      // Calculate the number of samples for approximate quantile calculation.
+      val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
+      val fraction = if (requiredSamples < metadata.numExamples) {
+        requiredSamples.toDouble / metadata.numExamples
+      } else {
+        1.0
+      }
+      logDebug("fraction of data used for calculating quantiles = " + fraction)
+      input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect()
+    } else {
+      new Array[LabeledPoint](0)
+    }
+
+    val splits = new Array[Array[Split]](numFeatures)
+
+    // Find all splits.
+    // Iterate over all features.
+    var featureIndex = 0
+    while (featureIndex < numFeatures) {
+      if (metadata.isContinuous(featureIndex)) {
+        val featureSamples = sampledInput.map(_.features(featureIndex))
+        val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)
+
+        val numSplits = featureSplits.length
+        logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
+        splits(featureIndex) = new Array[Split](numSplits)
+
+        var splitIndex = 0
+        while (splitIndex < numSplits) {
+          val threshold = featureSplits(splitIndex)
+          splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold)
+          splitIndex += 1
+        }
+      } else {
+        // Categorical feature
+        if (metadata.isUnordered(featureIndex)) {
+          val numSplits = metadata.numSplits(featureIndex)
+          val featureArity = metadata.featureArity(featureIndex)
+          // TODO: Use an implicit representation mapping each category to a subset of indices.
+          //       I.e., track indices such that we can calculate the set of bins for which
+          //       feature value x splits to the left.
+          // Unordered features
+          // 2^(maxFeatureValue - 1) - 1 combinations
+          splits(featureIndex) = new Array[Split](numSplits)
+          var splitIndex = 0
+          while (splitIndex < numSplits) {
+            val categories: List[Double] =
+              extractMultiClassCategories(splitIndex + 1, featureArity)
+            splits(featureIndex)(splitIndex) =
+              new CategoricalSplit(featureIndex, categories.toArray, featureArity)
+            splitIndex += 1
+          }
+        } else {
+          // Ordered features
+          //   Bins correspond to feature values, so we do not need to compute splits or bins
+          //   beforehand.  Splits are constructed as needed during training.
+          splits(featureIndex) = new Array[Split](0)
+        }
+      }
+      featureIndex += 1
+    }
+    splits
+  }
+
+  /**
+   * Nested method to extract list of eligible categories given an index. It extracts the
+   * position of ones in a binary representation of the input. If binary
+   * representation of an number is 01101 (13), the output list should (3.0, 2.0,
+   * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
+   */
+  private[tree] def extractMultiClassCategories(
+      input: Int,
+      maxFeatureValue: Int): List[Double] = {
+    var categories = List[Double]()
+    var j = 0
+    var bitShiftedInput = input
+    while (j < maxFeatureValue) {
+      if (bitShiftedInput % 2 != 0) {
+        // updating the list of categories.
+        categories = j.toDouble :: categories
+      }
+      // Right shift by one
+      bitShiftedInput = bitShiftedInput >> 1
+      j += 1
+    }
+    categories
+  }
+
+  /**
+   * Find splits for a continuous feature
+   * NOTE: Returned number of splits is set based on `featureSamples` and
+   *       could be different from the specified `numSplits`.
+   *       The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+   * @param featureSamples feature values of each sample
+   * @param metadata decision tree metadata
+   *                 NOTE: `metadata.numbins` will be changed accordingly
+   *                       if there are not enough splits to be found
+   * @param featureIndex feature index to find splits
+   * @return array of splits
+   */
+  private[tree] def findSplitsForContinuousFeature(
+      featureSamples: Array[Double],
+      metadata: DecisionTreeMetadata,
+      featureIndex: Int): Array[Double] = {
+    require(metadata.isContinuous(featureIndex),
+      "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+    val splits = {
+      val numSplits = metadata.numSplits(featureIndex)
+
+      // get count for each distinct value
+      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
+        m + ((x, m.getOrElse(x, 0) + 1))
+      }
+      // sort distinct values
+      val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+      // if possible splits is not enough or just enough, just return all possible splits
+      val possibleSplits = valueCounts.length
+      if (possibleSplits <= numSplits) {
+        valueCounts.map(_._1)
+      } else {
+        // stride between splits
+        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+        logDebug("stride = " + stride)
+
+        // iterate `valueCount` to find splits
+        val splitsBuilder = mutable.ArrayBuilder.make[Double]
+        var index = 1
+        // currentCount: sum of counts of values that have been visited
+        var currentCount = valueCounts(0)._2
+        // targetCount: target value for `currentCount`.
+        // If `currentCount` is closest value to `targetCount`,
+        // then current value is a split threshold.
+        // After finding a split threshold, `targetCount` is added by stride.
+        var targetCount = stride
+        while (index < valueCounts.length) {
+          val previousCount = currentCount
+          currentCount += valueCounts(index)._2
+          val previousGap = math.abs(previousCount - targetCount)
+          val currentGap = math.abs(currentCount - targetCount)
+          // If adding count of current value to currentCount
+          // makes the gap between currentCount and targetCount smaller,
+          // previous value is a split threshold.
+          if (previousGap < currentGap) {
+            splitsBuilder += valueCounts(index - 1)._1
+            targetCount += stride
+          }
+          index += 1
+        }
+
+        splitsBuilder.result()
+      }
+    }
+
+    // TODO: Do not fail; just ignore the useless feature.
+    assert(splits.length > 0,
+      s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
+        "  Please remove this feature and then try again.")
+    // set number of splits accordingly
+    metadata.setNumSplits(featureIndex, splits.length)
+
+    splits
+  }
+
+  private[tree] class NodeIndexInfo(
+      val nodeIndexInGroup: Int,
+      val featureSubset: Option[Array[Int]]) extends Serializable
+
+  /**
+   * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
+   * This tracks the memory usage for aggregates and stops adding nodes when too much memory
+   * will be needed; this allows an adaptive number of nodes since different nodes may require
+   * different amounts of memory (if featureSubsetStrategy is not "all").
+   *
+   * @param nodeQueue  Queue of nodes to split.
+   * @param maxMemoryUsage  Bound on size of aggregate statistics.
+   * @return  (nodesForGroup, treeToNodeToIndexInfo).
+   *          nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+   *
+   *          treeToNodeToIndexInfo holds indices selected features for each node:
+   *            treeIndex --> (global) node index --> (node index in group, feature indices).
+   *          The (global) node index is the index in the tree; the node index in group is the
+   *           index in [0, numNodesInGroup) of the node in this group.
+   *          The feature indices are None if not subsampling features.
+   */
+  private[tree] def selectNodesToSplit(
+      nodeQueue: mutable.Queue[(Int, LearningNode)],
+      maxMemoryUsage: Long,
+      metadata: DecisionTreeMetadata,
+      rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
+    // Collect some nodes to split:
+    //  nodesForGroup(treeIndex) = nodes to split
+    val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]()
+    val mutableTreeToNodeToIndexInfo =
+      new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
+    var memUsage: Long = 0L
+    var numNodesInGroup = 0
+    while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
+      val (treeIndex, node) = nodeQueue.head
+      // Choose subset of features for node (if subsampling).
+      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+        Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+          metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)
+      } else {
+        None
+      }
+      // Check if enough memory remains to add this node to the group.
+      val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
+      if (memUsage + nodeMemUsage <= maxMemoryUsage) {
+        nodeQueue.dequeue()
+        mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
+          node
+        mutableTreeToNodeToIndexInfo
+          .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
+          = new NodeIndexInfo(numNodesInGroup, featureSubset)
+      }
+      numNodesInGroup += 1
+      memUsage += nodeMemUsage
+    }
+    // Convert mutable maps to immutable ones.
+    val nodesForGroup: Map[Int, Array[LearningNode]] =
+      mutableNodesForGroup.mapValues(_.toArray).toMap
+    val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
+    (nodesForGroup, treeToNodeToIndexInfo)
+  }
+
+  /**
+   * Get the number of values to be stored for this node in the bin aggregates.
+   * @param featureSubset  Indices of features which may be split at this node.
+   *                       If None, then use all features.
+   */
+  private def aggregateSizeForNode(
+      metadata: DecisionTreeMetadata,
+      featureSubset: Option[Array[Int]]): Long = {
+    val totalBins = if (featureSubset.nonEmpty) {
+      featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
+    } else {
+      metadata.numBins.map(_.toLong).sum
+    }
+    if (metadata.isClassification) {
+      metadata.numClasses * totalBins
+    } else {
+      3 * totalBins
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/322d286b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
new file mode 100644
index 0000000..9fa27e5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.ml.tree.{ContinuousSplit, Split}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * Internal representation of LabeledPoint for DecisionTree.
+ * This bins feature values based on a subsampled of data as follows:
+ *  (a) Continuous features are binned into ranges.
+ *  (b) Unordered categorical features are binned based on subsets of feature values.
+ *      "Unordered categorical features" are categorical features with low arity used in
+ *      multiclass classification.
+ *  (c) Ordered categorical features are binned based on feature values.
+ *      "Ordered categorical features" are categorical features with high arity,
+ *      or any categorical feature used in regression or binary classification.
+ *
+ * @param label  Label from LabeledPoint
+ * @param binnedFeatures  Binned feature values.
+ *                        Same length as LabeledPoint.features, but values are bin indices.
+ */
+private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
+  extends Serializable {
+}
+
+private[spark] object TreePoint {
+
+  /**
+   * Convert an input dataset into its TreePoint representation,
+   * binning feature values in preparation for DecisionTree training.
+   * @param input     Input dataset.
+   * @param splits    Splits for features, of size (numFeatures, numSplits).
+   * @param metadata  Learning and dataset metadata
+   * @return  TreePoint dataset representation
+   */
+  def convertToTreeRDD(
+      input: RDD[LabeledPoint],
+      splits: Array[Array[Split]],
+      metadata: DecisionTreeMetadata): RDD[TreePoint] = {
+    // Construct arrays for featureArity for efficiency in the inner loop.
+    val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
+    var featureIndex = 0
+    while (featureIndex < metadata.numFeatures) {
+      featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
+      featureIndex += 1
+    }
+    val thresholds: Array[Array[Double]] = featureArity.zipWithIndex.map { case (arity, idx) =>
+      if (arity == 0) {
+        splits(idx).map(_.asInstanceOf[ContinuousSplit].threshold)
+      } else {
+        Array.empty[Double]
+      }
+    }
+    input.map { x =>
+      TreePoint.labeledPointToTreePoint(x, thresholds, featureArity)
+    }
+  }
+
+  /**
+   * Convert one LabeledPoint into its TreePoint representation.
+   * @param thresholds  For each feature, split thresholds for continuous features,
+   *                    empty for categorical features.
+   * @param featureArity  Array indexed by feature, with value 0 for continuous and numCategories
+   *                      for categorical features.
+   */
+  private def labeledPointToTreePoint(
+      labeledPoint: LabeledPoint,
+      thresholds: Array[Array[Double]],
+      featureArity: Array[Int]): TreePoint = {
+    val numFeatures = labeledPoint.features.size
+    val arr = new Array[Int](numFeatures)
+    var featureIndex = 0
+    while (featureIndex < numFeatures) {
+      arr(featureIndex) =
+        findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
+      featureIndex += 1
+    }
+    new TreePoint(labeledPoint.label, arr)
+  }
+
+  /**
+   * Find discretized value for one (labeledPoint, feature).
+   *
+   * NOTE: We cannot use Bucketizer since it handles split thresholds differently than the old
+   *       (mllib) tree API.  We want to maintain the same behavior as the old tree API.
+   *
+   * @param featureArity  0 for continuous features; number of categories for categorical features.
+   */
+  private def findBin(
+      featureIndex: Int,
+      labeledPoint: LabeledPoint,
+      featureArity: Int,
+      thresholds: Array[Double]): Int = {
+    val featureValue = labeledPoint.features(featureIndex)
+
+    if (featureArity == 0) {
+      val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
+      if (idx >= 0) {
+        idx
+      } else {
+        -idx - 1
+      }
+    } else {
+      // Categorical feature bins are indexed by feature values.
+      if (featureValue < 0 || featureValue >= featureArity) {
+        throw new IllegalArgumentException(
+          s"DecisionTree given invalid data:" +
+            s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
+            s" but a data point gives it value $featureValue.\n" +
+            "  Bad data point: " + labeledPoint.toString)
+      }
+      featureValue.toInt
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/322d286b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
index 089010c..572815d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
@@ -38,10 +38,10 @@ import org.apache.spark.util.random.XORShiftRandom
  * TODO: This does not currently support (Double) weighted instances.  Once MLlib has weighted
  *       dataset support, update.  (We store subsampleWeights as Double for this future extension.)
  */
-private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
+private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
   extends Serializable
 
-private[tree] object BaggedPoint {
+private[spark] object BaggedPoint {
 
   /**
    * Convert an input dataset into its BaggedPoint representation,
@@ -60,7 +60,7 @@ private[tree] object BaggedPoint {
       subsamplingRate: Double,
       numSubsamples: Int,
       withReplacement: Boolean,
-      seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
+      seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
     if (withReplacement) {
       convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
     } else {
@@ -76,7 +76,7 @@ private[tree] object BaggedPoint {
       input: RDD[Datum],
       subsamplingRate: Double,
       numSubsamples: Int,
-      seed: Int): RDD[BaggedPoint[Datum]] = {
+      seed: Long): RDD[BaggedPoint[Datum]] = {
     input.mapPartitionsWithIndex { (partitionIndex, instances) =>
       // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
       val rng = new XORShiftRandom
@@ -100,7 +100,7 @@ private[tree] object BaggedPoint {
       input: RDD[Datum],
       subsample: Double,
       numSubsamples: Int,
-      seed: Int): RDD[BaggedPoint[Datum]] = {
+      seed: Long): RDD[BaggedPoint[Datum]] = {
     input.mapPartitionsWithIndex { (partitionIndex, instances) =>
       // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
       val poisson = new PoissonDistribution(subsample)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org