You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/10/09 10:37:29 UTC

git commit: [SPARK-3158][MLLIB]Avoid 1 extra aggregation for DecisionTree training

Repository: spark
Updated Branches:
  refs/heads/master 13cab5ba4 -> 14f222f7f


[SPARK-3158][MLLIB]Avoid 1 extra aggregation for DecisionTree training

Currently, the implementation does one unnecessary aggregation step. The aggregation step for level L (to choose splits) gives enough information to set the predictions of any leaf nodes at level L+1. We can use that info and skip the aggregation step for the last level of the tree (which only has leaf nodes).

### Implementation Details

Each node now has a `impurity` field and the `predict` is changed from type `Double` to type `Predict`(this can be used to compute predict probability in the future) When compute best splits for each node, we also compute impurity and predict for the child nodes, which is used to constructed newly allocated child nodes. So at level L, we have set impurity and predict for nodes at level L +1.
If level L+1 is the last level, then we can avoid aggregation. What's more, calculation of parent impurity in

Top nodes for each tree needs to be treated differently because we have to compute impurity and predict for them first. In `binsToBestSplit`, if current node is top node(level == 0), we calculate impurity and predict first.
after finding best split, top node's predict and impurity is set to the calculated value. Non-top nodes's impurity and predict are already calculated and don't need to be recalculated again. I have considered to add a initialization step to set top nodes' impurity and predict and then we can treat all nodes in the same way, but this will need a lot of duplication of code(all the code to do seq operation(BinSeqOp) needs to be duplicated), so I choose the current way.

 CC mengxr manishamde jkbradley, please help me review this, thanks.

Author: Qiping Li <li...@gmail.com>

Closes #2708 from chouqin/avoid-agg and squashes the following commits:

8e269ea [Qiping Li] adjust code and comments
eefeef1 [Qiping Li] adjust comments and check child nodes' impurity
c41b1b6 [Qiping Li] fix pyspark unit test
7ad7a71 [Qiping Li] fix unit test
822c912 [Qiping Li] add comments and unit test
e41d715 [Qiping Li] fix bug in test suite
6cc0333 [Qiping Li] SPARK-3158: Avoid 1 extra aggregation for DecisionTree training


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/14f222f7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/14f222f7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/14f222f7

Branch: refs/heads/master
Commit: 14f222f7f76cc93633aae27a94c0e556e289ec56
Parents: 13cab5b
Author: Qiping Li <li...@gmail.com>
Authored: Thu Oct 9 01:36:58 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Oct 9 01:36:58 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  |  97 ++++++++++++------
 .../mllib/tree/model/InformationGainStats.scala |   9 +-
 .../apache/spark/mllib/tree/model/Node.scala    |  37 +++++--
 .../spark/mllib/tree/DecisionTreeSuite.scala    | 102 +++++++++++++++++--
 4 files changed, 197 insertions(+), 48 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/14f222f7/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b311d10..03eeaa7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -532,6 +532,14 @@ object DecisionTree extends Serializable with Logging {
       Some(mutableNodeToFeatures.toMap)
     }
 
+    // array of nodes to train indexed by node index in group
+    val nodes = new Array[Node](numNodes)
+    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+      nodesForTree.foreach { node =>
+        nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
+      }
+    }
+
     // Calculate best splits for all nodes in the group
     timer.start("chooseSplits")
 
@@ -568,7 +576,7 @@ object DecisionTree extends Serializable with Logging {
 
           // find best split for each node
           val (split: Split, stats: InformationGainStats, predict: Predict) =
-            binsToBestSplit(aggStats, splits, featuresForNode)
+            binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
           (nodeIndex, (split, stats, predict))
         }.collectAsMap()
 
@@ -587,17 +595,30 @@ object DecisionTree extends Serializable with Logging {
         // Extract info for this node.  Create children if not leaf.
         val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
         assert(node.id == nodeIndex)
-        node.predict = predict.predict
+        node.predict = predict
         node.isLeaf = isLeaf
         node.stats = Some(stats)
+        node.impurity = stats.impurity
         logDebug("Node = " + node)
 
         if (!isLeaf) {
           node.split = Some(split)
-          node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex)))
-          node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex)))
-          nodeQueue.enqueue((treeIndex, node.leftNode.get))
-          nodeQueue.enqueue((treeIndex, node.rightNode.get))
+          val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+          val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
+          val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+          node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),
+            stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
+          node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
+            stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+
+          // enqueue left child and right child if they are not leaves
+          if (!leftChildIsLeaf) {
+            nodeQueue.enqueue((treeIndex, node.leftNode.get))
+          }
+          if (!rightChildIsLeaf) {
+            nodeQueue.enqueue((treeIndex, node.rightNode.get))
+          }
+
           logDebug("leftChildIndex = " + node.leftNode.get.id +
             ", impurity = " + stats.leftImpurity)
           logDebug("rightChildIndex = " + node.rightNode.get.id +
@@ -617,7 +638,8 @@ object DecisionTree extends Serializable with Logging {
   private def calculateGainForSplit(
       leftImpurityCalculator: ImpurityCalculator,
       rightImpurityCalculator: ImpurityCalculator,
-      metadata: DecisionTreeMetadata): InformationGainStats = {
+      metadata: DecisionTreeMetadata,
+      impurity: Double): InformationGainStats = {
     val leftCount = leftImpurityCalculator.count
     val rightCount = rightImpurityCalculator.count
 
@@ -630,11 +652,6 @@ object DecisionTree extends Serializable with Logging {
 
     val totalCount = leftCount + rightCount
 
-    val parentNodeAgg = leftImpurityCalculator.copy
-    parentNodeAgg.add(rightImpurityCalculator)
-
-    val impurity = parentNodeAgg.calculate()
-
     val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
     val rightImpurity = rightImpurityCalculator.calculate()
 
@@ -649,7 +666,18 @@ object DecisionTree extends Serializable with Logging {
       return InformationGainStats.invalidInformationGainStats
     }
 
-    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
+    // calculate left and right predict
+    val leftPredict = calculatePredict(leftImpurityCalculator)
+    val rightPredict = calculatePredict(rightImpurityCalculator)
+
+    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
+      leftPredict, rightPredict)
+  }
+
+  private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
+    val predict = impurityCalculator.predict
+    val prob = impurityCalculator.prob(predict)
+    new Predict(predict, prob)
   }
 
   /**
@@ -657,17 +685,17 @@ object DecisionTree extends Serializable with Logging {
    * Note that this function is called only once for each node.
    * @param leftImpurityCalculator left node aggregates for a split
    * @param rightImpurityCalculator right node aggregates for a split
-   * @return predict value for current node
+   * @return predict value and impurity for current node
    */
-  private def calculatePredict(
+  private def calculatePredictImpurity(
       leftImpurityCalculator: ImpurityCalculator,
-      rightImpurityCalculator: ImpurityCalculator): Predict =  {
+      rightImpurityCalculator: ImpurityCalculator): (Predict, Double) =  {
     val parentNodeAgg = leftImpurityCalculator.copy
     parentNodeAgg.add(rightImpurityCalculator)
-    val predict = parentNodeAgg.predict
-    val prob = parentNodeAgg.prob(predict)
+    val predict = calculatePredict(parentNodeAgg)
+    val impurity = parentNodeAgg.calculate()
 
-    new Predict(predict, prob)
+    (predict, impurity)
   }
 
   /**
@@ -678,10 +706,16 @@ object DecisionTree extends Serializable with Logging {
   private def binsToBestSplit(
       binAggregates: DTStatsAggregator,
       splits: Array[Array[Split]],
-      featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
+      featuresForNode: Option[Array[Int]],
+      node: Node): (Split, InformationGainStats, Predict) = {
 
-    // calculate predict only once
-    var predict: Option[Predict] = None
+    // calculate predict and impurity if current node is top node
+    val level = Node.indexToLevel(node.id)
+    var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
+      None
+    } else {
+      Some((node.predict, node.impurity))
+    }
 
     // For each (feature, split), calculate the gain, and select the best (feature, split).
     val (bestSplit, bestSplitStats) =
@@ -708,9 +742,10 @@ object DecisionTree extends Serializable with Logging {
             val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
             val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
             rightChildStats.subtract(leftChildStats)
-            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+            predictWithImpurity = Some(predictWithImpurity.getOrElse(
+              calculatePredictImpurity(leftChildStats, rightChildStats)))
             val gainStats = calculateGainForSplit(leftChildStats,
-              rightChildStats, binAggregates.metadata)
+              rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
             (splitIdx, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -722,9 +757,10 @@ object DecisionTree extends Serializable with Logging {
           Range(0, numSplits).map { splitIndex =>
             val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
             val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
-            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+            predictWithImpurity = Some(predictWithImpurity.getOrElse(
+              calculatePredictImpurity(leftChildStats, rightChildStats)))
             val gainStats = calculateGainForSplit(leftChildStats,
-              rightChildStats, binAggregates.metadata)
+              rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -794,9 +830,10 @@ object DecisionTree extends Serializable with Logging {
             val rightChildStats =
               binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
             rightChildStats.subtract(leftChildStats)
-            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+            predictWithImpurity = Some(predictWithImpurity.getOrElse(
+              calculatePredictImpurity(leftChildStats, rightChildStats)))
             val gainStats = calculateGainForSplit(leftChildStats,
-              rightChildStats, binAggregates.metadata)
+              rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         val categoriesForSplit =
@@ -807,9 +844,7 @@ object DecisionTree extends Serializable with Logging {
       }
     }.maxBy(_._2.gain)
 
-    assert(predict.isDefined, "must calculate predict for each node")
-
-    (bestSplit, bestSplitStats, predict.get)
+    (bestSplit, bestSplitStats, predictWithImpurity.get._1)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/14f222f7/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index a89e71e..9a50ecb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi
  * @param impurity current node impurity
  * @param leftImpurity left node impurity
  * @param rightImpurity right node impurity
+ * @param leftPredict left node predict
+ * @param rightPredict right node predict
  */
 @DeveloperApi
 class InformationGainStats(
     val gain: Double,
     val impurity: Double,
     val leftImpurity: Double,
-    val rightImpurity: Double) extends Serializable {
+    val rightImpurity: Double,
+    val leftPredict: Predict,
+    val rightPredict: Predict) extends Serializable {
 
   override def toString = {
     "gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
@@ -58,5 +62,6 @@ private[tree] object InformationGainStats {
    * denote that current split doesn't satisfies minimum info gain or
    * minimum number of instances per node.
    */
-  val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
+  val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
+    new Predict(0.0, 0.0), new Predict(0.0, 0.0))
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/14f222f7/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 56c3e25..2179da8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector
  *
  * @param id integer node id, from 1
  * @param predict predicted value at the node
- * @param isLeaf whether the leaf is a node
+ * @param impurity current node impurity
+ * @param isLeaf whether the node is a leaf
  * @param split split to calculate left and right nodes
  * @param leftNode  left child
  * @param rightNode right child
@@ -41,7 +42,8 @@ import org.apache.spark.mllib.linalg.Vector
 @DeveloperApi
 class Node (
     val id: Int,
-    var predict: Double,
+    var predict: Predict,
+    var impurity: Double,
     var isLeaf: Boolean,
     var split: Option[Split],
     var leftNode: Option[Node],
@@ -49,7 +51,7 @@ class Node (
     var stats: Option[InformationGainStats]) extends Serializable with Logging {
 
   override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
-    "split = " + split + ", stats = " + stats
+    "impurity =  " + impurity + "split = " + split + ", stats = " + stats
 
   /**
    * build the left node and right nodes if not leaf
@@ -62,6 +64,7 @@ class Node (
     logDebug("id = " + id + ", split = " + split)
     logDebug("stats = " + stats)
     logDebug("predict = " + predict)
+    logDebug("impurity = " + impurity)
     if (!isLeaf) {
       leftNode = Some(nodes(Node.leftChildIndex(id)))
       rightNode = Some(nodes(Node.rightChildIndex(id)))
@@ -77,7 +80,7 @@ class Node (
    */
   def predict(features: Vector) : Double = {
     if (isLeaf) {
-      predict
+      predict.predict
     } else{
       if (split.get.featureType == Continuous) {
         if (features(split.get.feature) <= split.get.threshold) {
@@ -109,7 +112,7 @@ class Node (
     } else {
       Some(rightNode.get.deepCopy())
     }
-    new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
+    new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
   }
 
   /**
@@ -154,7 +157,7 @@ class Node (
     }
     val prefix: String = " " * indentFactor
     if (isLeaf) {
-      prefix + s"Predict: $predict\n"
+      prefix + s"Predict: ${predict.predict}\n"
     } else {
       prefix + s"If ${splitToString(split.get, left=true)}\n" +
         leftNode.get.subtreeToString(indentFactor + 1) +
@@ -170,7 +173,27 @@ private[tree] object Node {
   /**
    * Return a node with the given node id (but nothing else set).
    */
-  def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None)
+  def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0,
+    false, None, None, None, None)
+
+  /**
+   * Construct a node with nodeIndex, predict, impurity and isLeaf parameters.
+   * This is used in `DecisionTree.findBestSplits` to construct child nodes
+   * after finding the best splits for parent nodes.
+   * Other fields are set at next level.
+   * @param nodeIndex integer node id, from 1
+   * @param predict predicted value at the node
+   * @param impurity current node impurity
+   * @param isLeaf whether the node is a leaf
+   * @return new node instance
+   */
+  def apply(
+      nodeIndex: Int,
+      predict: Predict,
+      impurity: Double,
+      isLeaf: Boolean): Node = {
+    new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None)
+  }
 
   /**
    * Return the index of the left child of this node.

http://git-wip-us.apache.org/repos/asf/spark/blob/14f222f7/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a48ed71..98a72b0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -253,7 +253,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val stats = rootNode.stats.get
     assert(stats.gain > 0)
-    assert(rootNode.predict === 1)
+    assert(rootNode.predict.predict === 1)
     assert(stats.impurity > 0.2)
   }
 
@@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val stats = rootNode.stats.get
     assert(stats.gain > 0)
-    assert(rootNode.predict === 0.6)
+    assert(rootNode.predict.predict === 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -352,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(stats.gain === 0)
     assert(stats.leftImpurity === 0)
     assert(stats.rightImpurity === 0)
-    assert(rootNode.predict === 1)
+    assert(rootNode.predict.predict === 1)
   }
 
   test("Binary classification stump with fixed label 0 for Entropy") {
@@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(stats.gain === 0)
     assert(stats.leftImpurity === 0)
     assert(stats.rightImpurity === 0)
-    assert(rootNode.predict === 0)
+    assert(rootNode.predict.predict === 0)
   }
 
   test("Binary classification stump with fixed label 1 for Entropy") {
@@ -402,7 +402,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(stats.gain === 0)
     assert(stats.leftImpurity === 0)
     assert(stats.rightImpurity === 0)
-    assert(rootNode.predict === 1)
+    assert(rootNode.predict.predict === 1)
   }
 
   test("Second level node building with vs. without groups") {
@@ -471,7 +471,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       assert(stats1.impurity === stats2.impurity)
       assert(stats1.leftImpurity === stats2.leftImpurity)
       assert(stats1.rightImpurity === stats2.rightImpurity)
-      assert(children1(i).predict === children2(i).predict)
+      assert(children1(i).predict.predict === children2(i).predict.predict)
     }
   }
 
@@ -646,7 +646,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val model = DecisionTree.train(rdd, strategy)
     assert(model.topNode.isLeaf)
-    assert(model.topNode.predict == 0.0)
+    assert(model.topNode.predict.predict == 0.0)
     val predicts = rdd.map(p => model.predict(p.features)).collect()
     predicts.foreach { predict =>
       assert(predict == 0.0)
@@ -693,7 +693,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val model = DecisionTree.train(input, strategy)
     assert(model.topNode.isLeaf)
-    assert(model.topNode.predict == 0.0)
+    assert(model.topNode.predict.predict == 0.0)
     val predicts = input.map(p => model.predict(p.features)).collect()
     predicts.foreach { predict =>
       assert(predict == 0.0)
@@ -705,6 +705,92 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val gain = rootNode.stats.get
     assert(gain == InformationGainStats.invalidInformationGainStats)
   }
+
+  test("Avoid aggregation on the last level") {
+    val arr = new Array[LabeledPoint](4)
+    arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
+    arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
+    arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
+    arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
+    val input = sc.parallelize(arr)
+
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+      numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+    val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+    val topNode = Node.emptyNode(nodeIndex = 1)
+    assert(topNode.predict.predict === Double.MinValue)
+    assert(topNode.impurity === -1.0)
+    assert(topNode.isLeaf === false)
+
+    val nodesForGroup = Map((0, Array(topNode)))
+    val treeToNodeToIndexInfo = Map((0, Map(
+      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+      )))
+    val nodeQueue = new mutable.Queue[(Int, Node)]()
+    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+    // don't enqueue leaf nodes into node queue
+    assert(nodeQueue.isEmpty)
+
+    // set impurity and predict for topNode
+    assert(topNode.predict.predict !== Double.MinValue)
+    assert(topNode.impurity !== -1.0)
+
+    // set impurity and predict for child nodes
+    assert(topNode.leftNode.get.predict.predict === 0.0)
+    assert(topNode.rightNode.get.predict.predict === 1.0)
+    assert(topNode.leftNode.get.impurity === 0.0)
+    assert(topNode.rightNode.get.impurity === 0.0)
+  }
+
+  test("Avoid aggregation if impurity is 0.0") {
+    val arr = new Array[LabeledPoint](4)
+    arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
+    arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
+    arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
+    arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
+    val input = sc.parallelize(arr)
+
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+      numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+    val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+    val topNode = Node.emptyNode(nodeIndex = 1)
+    assert(topNode.predict.predict === Double.MinValue)
+    assert(topNode.impurity === -1.0)
+    assert(topNode.isLeaf === false)
+
+    val nodesForGroup = Map((0, Array(topNode)))
+    val treeToNodeToIndexInfo = Map((0, Map(
+      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+    )))
+    val nodeQueue = new mutable.Queue[(Int, Node)]()
+    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+    // don't enqueue a node into node queue if its impurity is 0.0
+    assert(nodeQueue.isEmpty)
+
+    // set impurity and predict for topNode
+    assert(topNode.predict.predict !== Double.MinValue)
+    assert(topNode.impurity !== -1.0)
+
+    // set impurity and predict for child nodes
+    assert(topNode.leftNode.get.predict.predict === 0.0)
+    assert(topNode.rightNode.get.predict.predict === 1.0)
+    assert(topNode.leftNode.get.impurity === 0.0)
+    assert(topNode.rightNode.get.impurity === 0.0)
+  }
 }
 
 object DecisionTreeSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org