You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/07/18 23:00:19 UTC

[1/2] [MLlib] SPARK-1536: multiclass classification support for decision tree

Repository: spark
Updated Branches:
  refs/heads/master 586e716e4 -> d88f6be44


http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index bcb1187..5961a61 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model.Filter
 import org.apache.spark.mllib.tree.model.Split
@@ -28,6 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
 
 class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
@@ -35,7 +35,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 100)
+    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(bins.length === 2)
@@ -51,6 +51,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       Classification,
       Gini,
       maxDepth = 3,
+      numClassesForClassification = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
@@ -130,8 +131,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       Classification,
       Gini,
       maxDepth = 3,
+      numClassesForClassification = 2,
       maxBins = 100,
-      categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+      categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
 
     // Check splits.
@@ -231,6 +233,162 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(1)(3) === null)
   }
 
+  test("extract categories from a number for multiclass classification") {
+    val l = DecisionTree.extractMultiClassCategories(13, 10)
+    assert(l.length === 3)
+    assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq)
+  }
+
+  test("split and bin calculations for unordered categorical variables with multiclass " +
+    "classification") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+    assert(arr.length === 1000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new Strategy(
+      Classification,
+      Gini,
+      maxDepth = 3,
+      numClassesForClassification = 100,
+      maxBins = 100,
+      categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+    // Expecting 2^2 - 1 = 3 bins/splits
+    assert(splits(0)(0).feature === 0)
+    assert(splits(0)(0).threshold === Double.MinValue)
+    assert(splits(0)(0).featureType === Categorical)
+    assert(splits(0)(0).categories.length === 1)
+    assert(splits(0)(0).categories.contains(0.0))
+    assert(splits(1)(0).feature === 1)
+    assert(splits(1)(0).threshold === Double.MinValue)
+    assert(splits(1)(0).featureType === Categorical)
+    assert(splits(1)(0).categories.length === 1)
+    assert(splits(1)(0).categories.contains(0.0))
+
+    assert(splits(0)(1).feature === 0)
+    assert(splits(0)(1).threshold === Double.MinValue)
+    assert(splits(0)(1).featureType === Categorical)
+    assert(splits(0)(1).categories.length === 1)
+    assert(splits(0)(1).categories.contains(1.0))
+    assert(splits(1)(1).feature === 1)
+    assert(splits(1)(1).threshold === Double.MinValue)
+    assert(splits(1)(1).featureType === Categorical)
+    assert(splits(1)(1).categories.length === 1)
+    assert(splits(1)(1).categories.contains(1.0))
+
+    assert(splits(0)(2).feature === 0)
+    assert(splits(0)(2).threshold === Double.MinValue)
+    assert(splits(0)(2).featureType === Categorical)
+    assert(splits(0)(2).categories.length === 2)
+    assert(splits(0)(2).categories.contains(0.0))
+    assert(splits(0)(2).categories.contains(1.0))
+    assert(splits(1)(2).feature === 1)
+    assert(splits(1)(2).threshold === Double.MinValue)
+    assert(splits(1)(2).featureType === Categorical)
+    assert(splits(1)(2).categories.length === 2)
+    assert(splits(1)(2).categories.contains(0.0))
+    assert(splits(1)(2).categories.contains(1.0))
+
+    assert(splits(0)(3) === null)
+    assert(splits(1)(3) === null)
+
+
+    // Check bins.
+
+    assert(bins(0)(0).category === Double.MinValue)
+    assert(bins(0)(0).lowSplit.categories.length === 0)
+    assert(bins(0)(0).highSplit.categories.length === 1)
+    assert(bins(0)(0).highSplit.categories.contains(0.0))
+    assert(bins(1)(0).category === Double.MinValue)
+    assert(bins(1)(0).lowSplit.categories.length === 0)
+    assert(bins(1)(0).highSplit.categories.length === 1)
+    assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+    assert(bins(0)(1).category === Double.MinValue)
+    assert(bins(0)(1).lowSplit.categories.length === 1)
+    assert(bins(0)(1).lowSplit.categories.contains(0.0))
+    assert(bins(0)(1).highSplit.categories.length === 1)
+    assert(bins(0)(1).highSplit.categories.contains(1.0))
+    assert(bins(1)(1).category === Double.MinValue)
+    assert(bins(1)(1).lowSplit.categories.length === 1)
+    assert(bins(1)(1).lowSplit.categories.contains(0.0))
+    assert(bins(1)(1).highSplit.categories.length === 1)
+    assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+    assert(bins(0)(2).category === Double.MinValue)
+    assert(bins(0)(2).lowSplit.categories.length === 1)
+    assert(bins(0)(2).lowSplit.categories.contains(1.0))
+    assert(bins(0)(2).highSplit.categories.length === 2)
+    assert(bins(0)(2).highSplit.categories.contains(1.0))
+    assert(bins(0)(2).highSplit.categories.contains(0.0))
+    assert(bins(1)(2).category === Double.MinValue)
+    assert(bins(1)(2).lowSplit.categories.length === 1)
+    assert(bins(1)(2).lowSplit.categories.contains(1.0))
+    assert(bins(1)(2).highSplit.categories.length === 2)
+    assert(bins(1)(2).highSplit.categories.contains(1.0))
+    assert(bins(1)(2).highSplit.categories.contains(0.0))
+
+    assert(bins(0)(3) === null)
+    assert(bins(1)(3) === null)
+
+  }
+
+  test("split and bin calculations for ordered categorical variables with multiclass " +
+    "classification") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+    assert(arr.length === 3000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new Strategy(
+      Classification,
+      Gini,
+      maxDepth = 3,
+      numClassesForClassification = 100,
+      maxBins = 100,
+      categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+    // 2^10 - 1 > 100, so categorical variables will be ordered
+
+    assert(splits(0)(0).feature === 0)
+    assert(splits(0)(0).threshold === Double.MinValue)
+    assert(splits(0)(0).featureType === Categorical)
+    assert(splits(0)(0).categories.length === 1)
+    assert(splits(0)(0).categories.contains(1.0))
+
+    assert(splits(0)(1).feature === 0)
+    assert(splits(0)(1).threshold === Double.MinValue)
+    assert(splits(0)(1).featureType === Categorical)
+    assert(splits(0)(1).categories.length === 2)
+    assert(splits(0)(1).categories.contains(2.0))
+
+    assert(splits(0)(2).feature === 0)
+    assert(splits(0)(2).threshold === Double.MinValue)
+    assert(splits(0)(2).featureType === Categorical)
+    assert(splits(0)(2).categories.length === 3)
+    assert(splits(0)(2).categories.contains(2.0))
+    assert(splits(0)(2).categories.contains(1.0))
+
+    assert(splits(0)(10) === null)
+    assert(splits(1)(10) === null)
+
+
+    // Check bins.
+
+    assert(bins(0)(0).category === 1.0)
+    assert(bins(0)(0).lowSplit.categories.length === 0)
+    assert(bins(0)(0).highSplit.categories.length === 1)
+    assert(bins(0)(0).highSplit.categories.contains(1.0))
+    assert(bins(0)(1).category === 2.0)
+    assert(bins(0)(1).lowSplit.categories.length === 1)
+    assert(bins(0)(1).highSplit.categories.length === 2)
+    assert(bins(0)(1).highSplit.categories.contains(1.0))
+    assert(bins(0)(1).highSplit.categories.contains(2.0))
+
+    assert(bins(0)(10) === null)
+
+  }
+
+
   test("classification stump with all categorical variables") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
@@ -238,6 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val strategy = new Strategy(
       Classification,
       Gini,
+      numClassesForClassification = 2,
       maxDepth = 3,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
@@ -253,8 +412,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val stats = bestSplits(0)._2
     assert(stats.gain > 0)
-    assert(stats.predict > 0.5)
-    assert(stats.predict < 0.7)
+    assert(stats.predict === 1)
+    assert(stats.prob == 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -280,8 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val stats = bestSplits(0)._2
     assert(stats.gain > 0)
-    assert(stats.predict > 0.5)
-    assert(stats.predict < 0.7)
+    assert(stats.predict == 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -289,7 +447,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 100)
+    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -312,7 +470,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 100)
+    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -336,7 +494,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 100)
+    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -360,7 +518,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 100)
+    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -380,11 +538,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.predict === 1)
   }
 
-  test("test second level node building with/without groups") {
+  test("second level node building with/without groups") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 100)
+    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -426,6 +584,82 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
   }
 
+  test("stump with categorical variables for multiclass classification") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+    assert(bestSplit.feature === 0)
+    assert(bestSplit.categories.length === 1)
+    assert(bestSplit.categories.contains(1))
+    assert(bestSplit.featureType === Categorical)
+  }
+
+  test("stump with continuous variables for multiclass classification") {
+    val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+      numClassesForClassification = 3)
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+
+    assert(bestSplit.feature === 1)
+    assert(bestSplit.featureType === Continuous)
+    assert(bestSplit.threshold > 1980)
+    assert(bestSplit.threshold < 2020)
+
+  }
+
+  test("stump with continuous + categorical variables for multiclass classification") {
+    val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+
+    assert(bestSplit.feature === 1)
+    assert(bestSplit.featureType === Continuous)
+    assert(bestSplit.threshold > 1980)
+    assert(bestSplit.threshold < 2020)
+  }
+
+  test("stump with categorical variables for ordered multiclass classification") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+    assert(bestSplit.feature === 0)
+    assert(bestSplit.categories.length === 1)
+    assert(bestSplit.categories.contains(1.0))
+    assert(bestSplit.featureType === Categorical)
+  }
+
+
 }
 
 object DecisionTreeSuite {
@@ -473,4 +707,47 @@ object DecisionTreeSuite {
     }
     arr
   }
+
+  def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = {
+    val arr = new Array[LabeledPoint](3000)
+    for (i <- 0 until 3000) {
+      if (i < 1000) {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+      } else if (i < 2000) {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
+      } else {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+      }
+    }
+    arr
+  }
+
+  def generateContinuousDataPointsForMulticlass(): Array[LabeledPoint] = {
+    val arr = new Array[LabeledPoint](3000)
+    for (i <- 0 until 3000) {
+      if (i < 2000) {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, i))
+      } else {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, i))
+      }
+    }
+    arr
+  }
+
+  def generateCategoricalDataPointsForMulticlassForOrderedFeatures():
+    Array[LabeledPoint] = {
+    val arr = new Array[LabeledPoint](3000)
+    for (i <- 0 until 3000) {
+      if (i < 1000) {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+      } else if (i < 2000) {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
+      } else {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0))
+      }
+    }
+    arr
+  }
+
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 3487f7c..e0f433b 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -82,7 +82,15 @@ object MimaExcludes {
       MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++
       MimaBuild.excludeSparkClass("storage.Values") ++
       MimaBuild.excludeSparkClass("storage.Entry") ++
-      MimaBuild.excludeSparkClass("storage.MemoryStore$Entry")
+      MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++
+      Seq(
+        ProblemFilters.exclude[IncompatibleMethTypeProblem](
+          "org.apache.spark.mllib.tree.impurity.Gini.calculate"),
+        ProblemFilters.exclude[IncompatibleMethTypeProblem](
+          "org.apache.spark.mllib.tree.impurity.Entropy.calculate"),
+        ProblemFilters.exclude[IncompatibleMethTypeProblem](
+          "org.apache.spark.mllib.tree.impurity.Variance.calculate")
+      )
     case v if v.startsWith("1.0") =>
       Seq(
         MimaBuild.excludeSparkPackage("api.java"),


[2/2] git commit: [MLlib] SPARK-1536: multiclass classification support for decision tree

Posted by me...@apache.org.
[MLlib] SPARK-1536: multiclass classification support for decision tree

The ability to perform multiclass classification is a big advantage for using decision trees and was a highly requested feature for mllib. This pull request adds multiclass classification support to the MLlib decision tree. It also adds sample weights support using WeightedLabeledPoint class for handling unbalanced datasets during classification. It will also support algorithms such as AdaBoost which requires instances to be weighted.

It handles the special case where the categorical variables cannot be ordered for multiclass classification and thus the optimizations used for speeding up binary classification cannot be directly used for multiclass classification with categorical variables. More specifically, for m categories in a categorical feature, it analyses all the ```2^(m-1) - 1``` categorical splits provided that #splits are less than the maxBins provided in the input. This condition will not be met for features with large number of categories -- using decision trees is not recommended for such datasets in general since the categorical features are favored over continuous features. Moreover, the user can use a combination of tricks (increasing bin size of the tree algorithms, use binary encoding for categorical features or use one-vs-all classification strategy) to avoid these constraints.

The new code is accompanied by unit tests and has also been tested on the iris and covtype datasets.

cc: mengxr, etrain, hirakendu, atalwalkar, srowen

Author: Manish Amde <ma...@gmail.com>
Author: manishamde <ma...@gmail.com>
Author: Evan Sparks <sp...@cs.berkeley.edu>

Closes #886 from manishamde/multiclass and squashes the following commits:

26f8acc [Manish Amde] another attempt at fixing mima
c5b2d04 [Manish Amde] more MIMA fixes
1ce7212 [Manish Amde] change problem filter for mima
10fdd82 [Manish Amde] fixing MIMA excludes
e1c970d [Manish Amde] merged master
abf2901 [Manish Amde] adding classes to MimaExcludes.scala
45e767a [Manish Amde] adding developer api annotation for overriden methods
c8428c4 [Manish Amde] fixing weird multiline bug
afced16 [Manish Amde] removed label weights support
2d85a48 [Manish Amde] minor: fixed scalastyle issues reprise
4e85f2c [Manish Amde] minor: fixed scalastyle issues
b2ae41f [Manish Amde] minor: scalastyle
e4c1321 [Manish Amde] using while loop for regression histograms
d75ac32 [Manish Amde] removed WeightedLabeledPoint from this PR
0fecd38 [Manish Amde] minor: add newline to EOF
2061cf5 [Manish Amde] merged from master
06b1690 [Manish Amde] fixed off-by-one error in bin to split conversion
9cc3e31 [Manish Amde] added implicit conversion import
5c1b2ca [Manish Amde] doc for PointConverter class
485eaae [Manish Amde] implicit conversion from LabeledPoint to WeightedLabeledPoint
3d7f911 [Manish Amde] updated doc
8e44ab8 [Manish Amde] updated doc
adc7315 [Manish Amde] support ordered categorical splits for multiclass classification
e3e8843 [Manish Amde] minor code formatting
23d4268 [Manish Amde] minor: another minor code style
34ee7b9 [Manish Amde] minor: code style
237762d [Manish Amde] renaming functions
12e6d0a [Manish Amde] minor: removing line in doc
9a90c93 [Manish Amde] Merge branch 'master' into multiclass
1892a2c [Manish Amde] tests and use multiclass binaggregate length when atleast one categorical feature is present
f5f6b83 [Manish Amde] multiclass for continous variables
8cfd3b6 [Manish Amde] working for categorical multiclass classification
828ff16 [Manish Amde] added categorical variable test
bce835f [Manish Amde] code cleanup
7e5f08c [Manish Amde] minor doc
1dd2735 [Manish Amde] bin search logic for multiclass
f16a9bb [Manish Amde] fixing while loop
d811425 [Manish Amde] multiclass bin aggregate logic
ab5cb21 [Manish Amde] multiclass logic
d8e4a11 [Manish Amde] sample weights
ed5a2df [Manish Amde] fixed classification requirements
d012be7 [Manish Amde] fixed while loop
18d2835 [Manish Amde] changing default values for num classes
6b912dc [Manish Amde] added numclasses to tree runner, predict logic for multiclass, add multiclass option to train
75f2bfc [Manish Amde] minor code style fix
e547151 [Manish Amde] minor modifications
34549d0 [Manish Amde] fixing error during merge
098e8c5 [Manish Amde] merged master
e006f9d [Manish Amde] changing variable names
5c78e1a [Manish Amde] added multiclass support
6c7af22 [Manish Amde] prepared for multiclass without breaking binary classification
46e06ee [Manish Amde] minor mods
3f85a17 [Manish Amde] tests for multiclass classification
4d5f70c [Manish Amde] added multiclass support for find splits bins
46f909c [Manish Amde] todo for multiclass support
455bea9 [Manish Amde] fixed tests
14aea48 [Manish Amde] changing instance format to weighted labeled point
a1a6e09 [Manish Amde] added weighted point class
968ca9d [Manish Amde] merged master
7fc9545 [Manish Amde] added docs
ce004a1 [Manish Amde] minor formatting
b27ad2c [Manish Amde] formatting
426bb28 [Manish Amde] programming guide blurb
8053fed [Manish Amde] more formatting
5eca9e4 [Manish Amde] grammar
4731cda [Manish Amde] formatting
5e82202 [Manish Amde] added documentation, fixed off by 1 error in max level calculation
cbd9f14 [Manish Amde] modified scala.math to math
dad9652 [Manish Amde] removed unused imports
e0426ee [Manish Amde] renamed parameter
718506b [Manish Amde] added unit test
1517155 [Manish Amde] updated documentation
9dbdabe [Manish Amde] merge from master
719d009 [Manish Amde] updating user documentation
fecf89a [manishamde] Merge pull request #6 from etrain/deep_tree
0287772 [Evan Sparks] Fixing scalastyle issue.
2f1e093 [Manish Amde] minor: added doc for maxMemory parameter
2f6072c [manishamde] Merge pull request #5 from etrain/deep_tree
abc5a23 [Evan Sparks] Parameterizing max memory.
50b143a [Manish Amde] adding support for very deep trees


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d88f6be4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d88f6be4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d88f6be4

Branch: refs/heads/master
Commit: d88f6be446e263251c446441c9ce7f5b11216909
Parents: 586e716
Author: Manish Amde <ma...@gmail.com>
Authored: Fri Jul 18 14:00:13 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri Jul 18 14:00:13 2014 -0700

----------------------------------------------------------------------
 docs/mllib-decision-tree.md                     |   8 +-
 .../examples/mllib/DecisionTreeRunner.scala     |  21 +-
 .../apache/spark/mllib/tree/DecisionTree.scala  | 732 ++++++++++++++-----
 .../mllib/tree/configuration/Strategy.scala     |  12 +-
 .../spark/mllib/tree/impurity/Entropy.scala     |  36 +-
 .../apache/spark/mllib/tree/impurity/Gini.scala |  33 +-
 .../spark/mllib/tree/impurity/Impurity.scala    |   8 +-
 .../spark/mllib/tree/impurity/Variance.scala    |  11 +-
 .../org/apache/spark/mllib/tree/model/Bin.scala |   2 +-
 .../mllib/tree/model/InformationGainStats.scala |   8 +-
 .../spark/mllib/tree/DecisionTreeSuite.scala    | 303 +++++++-
 project/MimaExcludes.scala                      |  10 +-
 12 files changed, 926 insertions(+), 258 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/docs/mllib-decision-tree.md
----------------------------------------------------------------------
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md
index 9cd7685..9cbd880 100644
--- a/docs/mllib-decision-tree.md
+++ b/docs/mllib-decision-tree.md
@@ -77,15 +77,17 @@ bins if the condition is not satisfied.
 
 **Categorical features**
 
-For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for
-binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the
+For `$M$` categorical feature values, one could come up with `$2^(M-1)-1$` split candidates. For
+binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the
 categorical feature values by the proportion of labels falling in one of the two classes (see
 Section 9.2.4 in
 [Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for
 details). For example, for a binary classification problem with one categorical feature with three
 categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical
 features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B
-and A , B \| C where \| denotes the split.
+and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification
+when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value
+is used for ordering.
 
 ### Stopping rule
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index b3cc361..43f13fe 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -49,6 +49,7 @@ object DecisionTreeRunner {
   case class Params(
       input: String = null,
       algo: Algo = Classification,
+      numClassesForClassification: Int = 2,
       maxDepth: Int = 5,
       impurity: ImpurityType = Gini,
       maxBins: Int = 100)
@@ -68,6 +69,10 @@ object DecisionTreeRunner {
       opt[Int]("maxDepth")
         .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
         .action((x, c) => c.copy(maxDepth = x))
+      opt[Int]("numClassesForClassification")
+        .text(s"number of classes for classification, "
+          + s"default: ${defaultParams.numClassesForClassification}")
+        .action((x, c) => c.copy(numClassesForClassification = x))
       opt[Int]("maxBins")
         .text(s"max number of bins, default: ${defaultParams.maxBins}")
         .action((x, c) => c.copy(maxBins = x))
@@ -118,7 +123,13 @@ object DecisionTreeRunner {
       case Variance => impurity.Variance
     }
 
-    val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins)
+    val strategy
+      = new Strategy(
+          algo = params.algo,
+          impurity = impurityCalculator,
+          maxDepth = params.maxDepth,
+          maxBins = params.maxBins,
+          numClassesForClassification = params.numClassesForClassification)
     val model = DecisionTree.train(training, strategy)
 
     if (params.algo == Classification) {
@@ -139,12 +150,8 @@ object DecisionTreeRunner {
    */
   private def accuracyScore(
       model: DecisionTreeModel,
-      data: RDD[LabeledPoint],
-      threshold: Double = 0.5): Double = {
-    def predictedValue(features: Vector): Double = {
-      if (model.predict(features) < threshold) 0.0 else 1.0
-    }
-    val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
+      data: RDD[LabeledPoint]): Double = {
+    val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
     val count = data.count()
     correctCount.toDouble / count
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 74d5d7b..ad32e3f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -77,11 +77,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     // Max memory usage for aggregates
     val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
     logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
-    val numElementsPerNode =
-      strategy.algo match {
-        case Classification => 2 * numBins * numFeatures
-        case Regression => 3 * numBins * numFeatures
-      }
+    val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins,
+      strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures,
+      strategy.algo)
 
     logDebug("numElementsPerNode = " + numElementsPerNode)
     val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
@@ -109,8 +107,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       logDebug("#####################################")
 
       // Find best split for all nodes at a level.
-      val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
-        level, filters, splits, bins, maxLevelForSingleGroup)
+      val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities,
+        strategy, level, filters, splits, bins, maxLevelForSingleGroup)
 
       for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
         // Extract info for nodes at the current level.
@@ -212,7 +210,7 @@ object DecisionTree extends Serializable with Logging {
    * @return a DecisionTreeModel that can be used for prediction
   */
   def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
-    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+    new DecisionTree(strategy).train(input)
   }
 
   /**
@@ -233,10 +231,33 @@ object DecisionTree extends Serializable with Logging {
       algo: Algo,
       impurity: Impurity,
       maxDepth: Int): DecisionTreeModel = {
-    val strategy = new Strategy(algo,impurity,maxDepth)
-    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+    val strategy = new Strategy(algo, impurity, maxDepth)
+    new DecisionTree(strategy).train(input)
   }
 
+  /**
+   * Method to train a decision tree model where the instances are represented as an RDD of
+   * (label, features) pairs. The method supports binary classification and regression. For the
+   * binary classification, the label for each instance should either be 0 or 1 to denote the two
+   * classes.
+   *
+   * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
+   *              training data
+   * @param algo algorithm, classification or regression
+   * @param impurity impurity criterion used for information gain calculation
+   * @param maxDepth maxDepth maximum depth of the tree
+   * @param numClassesForClassification number of classes for classification. Default value of 2.
+   * @return a DecisionTreeModel that can be used for prediction
+   */
+  def train(
+      input: RDD[LabeledPoint],
+      algo: Algo,
+      impurity: Impurity,
+      maxDepth: Int,
+      numClassesForClassification: Int): DecisionTreeModel = {
+    val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
+    new DecisionTree(strategy).train(input)
+  }
 
   /**
    * Method to train a decision tree model where the instances are represented as an RDD of
@@ -250,6 +271,7 @@ object DecisionTree extends Serializable with Logging {
    * @param algo classification or regression
    * @param impurity criterion used for information gain calculation
    * @param maxDepth  maximum depth of the tree
+   * @param numClassesForClassification number of classes for classification. Default value of 2.
    * @param maxBins maximum number of bins used for splitting features
    * @param quantileCalculationStrategy  algorithm for calculating quantiles
    * @param categoricalFeaturesInfo A map storing information about the categorical variables and
@@ -264,12 +286,13 @@ object DecisionTree extends Serializable with Logging {
       algo: Algo,
       impurity: Impurity,
       maxDepth: Int,
+      numClassesForClassification: Int,
       maxBins: Int,
       quantileCalculationStrategy: QuantileStrategy,
       categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
-    val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
-      categoricalFeaturesInfo)
-    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+    val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
+      quantileCalculationStrategy, categoricalFeaturesInfo)
+    new DecisionTree(strategy).train(input)
   }
 
   private val InvalidBinIndex = -1
@@ -381,6 +404,14 @@ object DecisionTree extends Serializable with Logging {
     logDebug("numFeatures = " + numFeatures)
     val numBins = bins(0).length
     logDebug("numBins = " + numBins)
+    val numClasses = strategy.numClassesForClassification
+    logDebug("numClasses = " + numClasses)
+    val isMulticlassClassification = strategy.isMulticlassClassification
+    logDebug("isMulticlassClassification = " + isMulticlassClassification)
+    val isMulticlassClassificationWithCategoricalFeatures
+      = strategy.isMulticlassWithCategoricalFeatures
+    logDebug("isMultiClassWithCategoricalFeatures = " +
+      isMulticlassClassificationWithCategoricalFeatures)
 
     // shift when more than one group is used at deep tree level
     val groupShift = numNodes * groupIndex
@@ -436,10 +467,8 @@ object DecisionTree extends Serializable with Logging {
     /**
      * Find bin for one feature.
      */
-    def findBin(
-        featureIndex: Int,
-        labeledPoint: LabeledPoint,
-        isFeatureContinuous: Boolean): Int = {
+    def findBin(featureIndex: Int, labeledPoint: LabeledPoint,
+        isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = {
       val binForFeatures = bins(featureIndex)
       val feature = labeledPoint.features(featureIndex)
 
@@ -468,16 +497,27 @@ object DecisionTree extends Serializable with Logging {
       }
 
       /**
+       * Sequential search helper method to find bin for categorical feature in multiclass
+       * classification. The category is returned since each category can belong to multiple
+       * splits. The actual left/right child allocation per split is performed in the
+       * sequential phase of the bin aggregate operation.
+       */
+      def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
+        labeledPoint.features(featureIndex).toInt
+      }
+
+      /**
        * Sequential search helper method to find bin for categorical feature.
        */
-      def sequentialBinSearchForCategoricalFeature(): Int = {
-        val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
+      def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = {
+        val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+        val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
         var binIndex = 0
         while (binIndex < numCategoricalBins) {
           val bin = bins(featureIndex)(binIndex)
-          val category = bin.category
+          val categories = bin.highSplit.categories
           val features = labeledPoint.features
-          if (category == features(featureIndex)) {
+          if (categories.contains(features(featureIndex))) {
             return binIndex
           }
           binIndex += 1
@@ -494,7 +534,13 @@ object DecisionTree extends Serializable with Logging {
         binIndex
       } else {
         // Perform sequential search to find bin for categorical features.
-        val binIndex = sequentialBinSearchForCategoricalFeature()
+        val binIndex = {
+          if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+            sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
+          } else {
+            sequentialBinSearchForOrderedCategoricalFeatureInClassification()
+          }
+        }
         if (binIndex == -1){
           throw new UnknownError("no bin was found for categorical variable.")
         }
@@ -506,13 +552,16 @@ object DecisionTree extends Serializable with Logging {
      * Finds bins for all nodes (and all features) at a given level.
      * For l nodes, k features the storage is as follows:
      * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
-     * where b_ij is an integer between 0 and numBins - 1.
+     * where b_ij is an integer between 0 and numBins - 1 for regressions and binary
+     * classification and the categorical feature value in  multiclass classification.
      * Invalid sample is denoted by noting bin for feature 1 as -1.
      */
     def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
       // Calculate bin index and label per feature per node.
       val arr = new Array[Double](1 + (numFeatures * numNodes))
+      // First element of the array is the label of the instance.
       arr(0) = labeledPoint.label
+      // Iterate over nodes.
       var nodeIndex = 0
       while (nodeIndex < numNodes) {
         val parentFilters = findParentFilters(nodeIndex)
@@ -525,8 +574,19 @@ object DecisionTree extends Serializable with Logging {
         } else {
           var featureIndex = 0
           while (featureIndex < numFeatures) {
-            val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
-            arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous)
+            val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex)
+            val isFeatureContinuous = featureInfo.isEmpty
+            if (isFeatureContinuous) {
+              arr(shift + featureIndex)
+                = findBin(featureIndex, labeledPoint, isFeatureContinuous, false)
+            } else {
+              val featureCategories = featureInfo.get
+              val isSpaceSufficientForAllCategoricalSplits
+                = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+              arr(shift + featureIndex)
+                = findBin(featureIndex, labeledPoint, isFeatureContinuous,
+                isSpaceSufficientForAllCategoricalSplits)
+            }
             featureIndex += 1
           }
         }
@@ -535,18 +595,61 @@ object DecisionTree extends Serializable with Logging {
       arr
     }
 
+     // Find feature bins for all nodes at a level.
+    val binMappedRDD = input.map(x => findBinsForLevel(x))
+
+    def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int,
+        label: Double, featureIndex: Int) = {
+
+      // Find the bin index for this feature.
+      val arrShift = 1 + numFeatures * nodeIndex
+      val arrIndex = arrShift + featureIndex
+      // Update the left or right count for one bin.
+      val aggShift = numClasses * numBins * numFeatures * nodeIndex
+      val aggIndex
+        = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
+      val labelInt = label.toInt
+      agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
+    }
+
+    def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double],
+        label: Double, agg: Array[Double], rightChildShift: Int) = {
+      // Find the bin index for this feature.
+      val arrShift = 1 + numFeatures * nodeIndex
+      val arrIndex = arrShift + featureIndex
+      // Update the left or right count for one bin.
+      val aggShift = numClasses * numBins * numFeatures * nodeIndex
+      val aggIndex
+        = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
+      // Find all matching bins and increment their values
+      val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+      val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
+      var binIndex = 0
+      while (binIndex < numCategoricalBins) {
+        val labelInt = label.toInt
+        if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) {
+          agg(aggIndex + binIndex)
+            = agg(aggIndex + binIndex) + 1
+        } else {
+          agg(rightChildShift + aggIndex + binIndex)
+            = agg(rightChildShift + aggIndex + binIndex) + 1
+        }
+        binIndex += 1
+      }
+    }
+
     /**
      * Performs a sequential aggregation over a partition for classification. For l nodes,
      * k features, either the left count or the right count of one of the p bins is
      * incremented based upon whether the feature is classified as 0 or 1.
      *
      * @param agg Array[Double] storing aggregate calculation of size
-     *            2 * numSplits * numFeatures*numNodes for classification
+     *            numClasses * numSplits * numFeatures*numNodes for classification
      * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
      * @return Array[Double] storing aggregate calculation of size
      *         2 * numSplits * numFeatures * numNodes for classification
      */
-    def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+    def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
       // Iterate over all nodes.
       var nodeIndex = 0
       while (nodeIndex < numNodes) {
@@ -559,15 +662,52 @@ object DecisionTree extends Serializable with Logging {
           // Iterate over all features.
           var featureIndex = 0
           while (featureIndex < numFeatures) {
-            // Find the bin index for this feature.
-            val arrShift = 1 + numFeatures * nodeIndex
-            val arrIndex = arrShift + featureIndex
-            // Update the left or right count for one bin.
-            val aggShift = 2 * numBins * numFeatures * nodeIndex
-            val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
-            label match {
-              case 0.0 => agg(aggIndex) = agg(aggIndex) + 1
-              case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1
+            updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
+            featureIndex += 1
+          }
+        }
+        nodeIndex += 1
+      }
+    }
+
+    /**
+     * Performs a sequential aggregation over a partition for classification. For l nodes,
+     * k features, either the left count or the right count of one of the p bins is
+     * incremented based upon whether the feature is classified as 0 or 1.
+     *
+     * @param agg Array[Double] storing aggregate calculation of size
+     *            numClasses * numSplits * numFeatures*numNodes for classification
+     * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
+     * @return Array[Double] storing aggregate calculation of size
+     *         2 * numClasses * numSplits * numFeatures * numNodes for classification
+     */
+    def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
+      // Iterate over all nodes.
+      var nodeIndex = 0
+      while (nodeIndex < numNodes) {
+        // Check whether the instance was valid for this nodeIndex.
+        val validSignalIndex = 1 + numFeatures * nodeIndex
+        val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
+        if (isSampleValidForNode) {
+          val rightChildShift = numClasses * numBins * numFeatures * numNodes
+          // actual class label
+          val label = arr(0)
+          // Iterate over all features.
+          var featureIndex = 0
+          while (featureIndex < numFeatures) {
+            val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+            if (isFeatureContinuous) {
+              updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
+            } else {
+              val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+              val isSpaceSufficientForAllCategoricalSplits
+                = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+              if (isSpaceSufficientForAllCategoricalSplits) {
+                updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg,
+                  rightChildShift)
+              } else {
+                updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
+              }
             }
             featureIndex += 1
           }
@@ -586,7 +726,7 @@ object DecisionTree extends Serializable with Logging {
      * @return Array[Double] storing aggregate calculation of size
      *         3 * numSplits * numFeatures * numNodes for regression
      */
-    def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+    def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
       // Iterate over all nodes.
       var nodeIndex = 0
       while (nodeIndex < numNodes) {
@@ -620,17 +760,20 @@ object DecisionTree extends Serializable with Logging {
      */
     def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
       strategy.algo match {
-        case Classification => classificationBinSeqOp(arr, agg)
+        case Classification =>
+          if(isMulticlassClassificationWithCategoricalFeatures) {
+            unorderedClassificationBinSeqOp(arr, agg)
+          } else {
+            orderedClassificationBinSeqOp(arr, agg)
+          }
         case Regression => regressionBinSeqOp(arr, agg)
       }
       agg
     }
 
     // Calculate bin aggregate length for classification or regression.
-    val binAggregateLength = strategy.algo match {
-      case Classification => 2 * numBins * numFeatures * numNodes
-      case Regression =>  3 * numBins * numFeatures * numNodes
-    }
+    val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses,
+        isMulticlassClassificationWithCategoricalFeatures, strategy.algo)
     logDebug("binAggregateLength = " + binAggregateLength)
 
     /**
@@ -649,9 +792,6 @@ object DecisionTree extends Serializable with Logging {
       combinedAggregate
     }
 
-    // Find feature bins for all nodes at a level.
-    val binMappedRDD = input.map(x => findBinsForLevel(x))
-
     // Calculate bin aggregates.
     val binAggregates = {
       binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
@@ -668,42 +808,55 @@ object DecisionTree extends Serializable with Logging {
      * @return information gain and statistics for all splits
      */
     def calculateGainForSplit(
-        leftNodeAgg: Array[Array[Double]],
+        leftNodeAgg: Array[Array[Array[Double]]],
         featureIndex: Int,
         splitIndex: Int,
-        rightNodeAgg: Array[Array[Double]],
+        rightNodeAgg: Array[Array[Array[Double]]],
         topImpurity: Double): InformationGainStats = {
       strategy.algo match {
         case Classification =>
-          val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
-          val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1)
-          val leftCount = left0Count + left1Count
-
-          val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex)
-          val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1)
-          val rightCount = right0Count + right1Count
+          var classIndex = 0
+          val leftCounts: Array[Double] = new Array[Double](numClasses)
+          val rightCounts: Array[Double] = new Array[Double](numClasses)
+          var leftTotalCount = 0.0
+          var rightTotalCount = 0.0
+          while (classIndex < numClasses) {
+            val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex)
+            val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex)
+            leftCounts(classIndex) = leftClassCount
+            leftTotalCount += leftClassCount
+            rightCounts(classIndex) = rightClassCount
+            rightTotalCount += rightClassCount
+            classIndex += 1
+          }
 
           val impurity = {
             if (level > 0) {
               topImpurity
             } else {
               // Calculate impurity for root node.
-              strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
+              val rootNodeCounts = new Array[Double](numClasses)
+              var classIndex = 0
+              while (classIndex < numClasses) {
+                rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex)
+                classIndex += 1
+              }
+              strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
             }
           }
 
-          if (leftCount == 0) {
-            return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1)
+          if (leftTotalCount == 0) {
+            return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1)
           }
-          if (rightCount == 0) {
-            return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0)
+          if (rightTotalCount == 0) {
+            return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1)
           }
 
-          val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
-          val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
+          val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount)
+          val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount)
 
-          val leftWeight = leftCount.toDouble / (leftCount + rightCount)
-          val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+          val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount)
+          val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount)
 
           val gain = {
             if (level > 0) {
@@ -713,17 +866,34 @@ object DecisionTree extends Serializable with Logging {
             }
           }
 
-          val predict = (left1Count + right1Count) / (leftCount + rightCount)
+          val totalCount = leftTotalCount + rightTotalCount
 
-          new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+          // Sum of count for each label
+          val leftRightCounts: Array[Double]
+            = leftCounts.zip(rightCounts)
+              .map{case (leftCount, rightCount) => leftCount + rightCount}
+
+          def indexOfLargestArrayElement(array: Array[Double]): Int = {
+            val result = array.foldLeft(-1, Double.MinValue, 0) {
+              case ((maxIndex, maxValue, currentIndex), currentValue) =>
+                if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1)
+                else (maxIndex, maxValue, currentIndex + 1)
+            }
+            if (result._1 < 0) 0 else result._1
+          }
+
+          val predict = indexOfLargestArrayElement(leftRightCounts)
+          val prob = leftRightCounts(predict) / totalCount
+
+          new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
         case Regression =>
-          val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
-          val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1)
-          val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2)
+          val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
+          val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
+          val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2)
 
-          val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex)
-          val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1)
-          val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2)
+          val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0)
+          val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1)
+          val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2)
 
           val impurity = {
             if (level > 0) {
@@ -768,104 +938,149 @@ object DecisionTree extends Serializable with Logging {
     /**
      * Extracts left and right split aggregates.
      * @param binData Array[Double] of size 2*numFeatures*numSplits
-     * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double],
-     *         Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
+     * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\],
+     *         Array[Array[Array[Double\]\]\]) where each array is of size(numFeature,
+     *         (numBins - 1), numClasses)
      */
     def extractLeftRightNodeAggregates(
-        binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
+        binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
+
+
+      def findAggForOrderedFeatureClassification(
+          leftNodeAgg: Array[Array[Array[Double]]],
+          rightNodeAgg: Array[Array[Array[Double]]],
+          featureIndex: Int) {
+
+        // shift for this featureIndex
+        val shift = numClasses * featureIndex * numBins
+
+        var classIndex = 0
+        while (classIndex < numClasses) {
+          // left node aggregate for the lowest split
+          leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex)
+          // right node aggregate for the highest split
+          rightNodeAgg(featureIndex)(numBins - 2)(classIndex)
+            = binData(shift + (numClasses * (numBins - 1)) + classIndex)
+          classIndex += 1
+        }
+
+        // Iterate over all splits.
+        var splitIndex = 1
+        while (splitIndex < numBins - 1) {
+          // calculating left node aggregate for a split as a sum of left node aggregate of a
+          // lower split and the left bin aggregate of a bin where the split is a high split
+          var innerClassIndex = 0
+          while (innerClassIndex < numClasses) {
+            leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex)
+              = binData(shift + numClasses * splitIndex + innerClassIndex) +
+                leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex)
+            rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) =
+              binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) +
+                rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex)
+            innerClassIndex += 1
+          }
+          splitIndex += 1
+        }
+      }
+
+      def findAggForUnorderedFeatureClassification(
+          leftNodeAgg: Array[Array[Array[Double]]],
+          rightNodeAgg: Array[Array[Array[Double]]],
+          featureIndex: Int) {
+
+        val rightChildShift = numClasses * numBins * numFeatures
+        var splitIndex = 0
+        while (splitIndex < numBins - 1) {
+          var classIndex = 0
+          while (classIndex < numClasses) {
+            // shift for this featureIndex
+            val shift = numClasses * featureIndex * numBins + splitIndex * numClasses
+            val leftBinValue = binData(shift + classIndex)
+            val rightBinValue = binData(rightChildShift + shift + classIndex)
+            leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue
+            rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue
+            classIndex += 1
+          }
+          splitIndex += 1
+        }
+      }
+
+      def findAggForRegression(
+          leftNodeAgg: Array[Array[Array[Double]]],
+          rightNodeAgg: Array[Array[Array[Double]]],
+          featureIndex: Int) {
+
+        // shift for this featureIndex
+        val shift = 3 * featureIndex * numBins
+        // left node aggregate for the lowest split
+        leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
+        leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)
+        leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2)
+
+        // right node aggregate for the highest split
+        rightNodeAgg(featureIndex)(numBins - 2)(0) =
+          binData(shift + (3 * (numBins - 1)))
+        rightNodeAgg(featureIndex)(numBins - 2)(1) =
+          binData(shift + (3 * (numBins - 1)) + 1)
+        rightNodeAgg(featureIndex)(numBins - 2)(2) =
+          binData(shift + (3 * (numBins - 1)) + 2)
+
+        // Iterate over all splits.
+        var splitIndex = 1
+        while (splitIndex < numBins - 1) {
+          var i = 0 // index for regression histograms
+          while (i < 3) { // count, sum, sum^2
+            // calculating left node aggregate for a split as a sum of left node aggregate of a
+            // lower split and the left bin aggregate of a bin where the split is a high split
+            leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) +
+              leftNodeAgg(featureIndex)(splitIndex - 1)(i)
+            // calculating right node aggregate for a split as a sum of right node aggregate of a
+            // higher split and the right bin aggregate of a bin where the split is a low split
+            rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) =
+              binData(shift + (3 * (numBins - 1 - splitIndex) + i)) +
+                rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i)
+            i += 1
+          }
+          splitIndex += 1
+        }
+      }
+
       strategy.algo match {
         case Classification =>
           // Initialize left and right split aggregates.
-          val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
-          val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
-          // Iterate over all features.
+          val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
+          val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
           var featureIndex = 0
           while (featureIndex < numFeatures) {
-            // shift for this featureIndex
-            val shift = 2 * featureIndex * numBins
-
-            // left node aggregate for the lowest split
-            leftNodeAgg(featureIndex)(0) = binData(shift + 0)
-            leftNodeAgg(featureIndex)(1) = binData(shift + 1)
-
-            // right node aggregate for the highest split
-            rightNodeAgg(featureIndex)(2 * (numBins - 2))
-              = binData(shift + (2 * (numBins - 1)))
-            rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1)
-              = binData(shift + (2 * (numBins - 1)) + 1)
-
-            // Iterate over all splits.
-            var splitIndex = 1
-            while (splitIndex < numBins - 1) {
-              // calculating left node aggregate for a split as a sum of left node aggregate of a
-              // lower split and the left bin aggregate of a bin where the split is a high split
-              leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) +
-                leftNodeAgg(featureIndex)(2 * splitIndex - 2)
-              leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) +
-                leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
-
-              // calculating right node aggregate for a split as a sum of right node aggregate of a
-              // higher split and the right bin aggregate of a bin where the split is a low split
-              rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
-                binData(shift + (2 *(numBins - 1 - splitIndex))) +
-                rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
-              rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
-                binData(shift + (2* (numBins - 1 - splitIndex) + 1)) +
-                  rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
-
-              splitIndex += 1
+            if (isMulticlassClassificationWithCategoricalFeatures){
+              val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+              if (isFeatureContinuous) {
+                findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+              } else {
+                val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+                val isSpaceSufficientForAllCategoricalSplits
+                  = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+                if (isSpaceSufficientForAllCategoricalSplits) {
+                  findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+                } else {
+                  findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+                }
+              }
+            } else {
+              findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
             }
             featureIndex += 1
           }
+
           (leftNodeAgg, rightNodeAgg)
         case Regression =>
           // Initialize left and right split aggregates.
-          val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
-          val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
+          val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
+          val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
           // Iterate over all features.
           var featureIndex = 0
           while (featureIndex < numFeatures) {
-            // shift for this featureIndex
-            val shift = 3 * featureIndex * numBins
-            // left node aggregate for the lowest split
-            leftNodeAgg(featureIndex)(0) = binData(shift + 0)
-            leftNodeAgg(featureIndex)(1) = binData(shift + 1)
-            leftNodeAgg(featureIndex)(2) = binData(shift + 2)
-
-            // right node aggregate for the highest split
-            rightNodeAgg(featureIndex)(3 * (numBins - 2)) =
-              binData(shift + (3 * (numBins - 1)))
-            rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) =
-              binData(shift + (3 * (numBins - 1)) + 1)
-            rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) =
-              binData(shift + (3 * (numBins - 1)) + 2)
-
-            // Iterate over all splits.
-            var splitIndex = 1
-            while (splitIndex < numBins - 1) {
-              // calculating left node aggregate for a split as a sum of left node aggregate of a
-              // lower split and the left bin aggregate of a bin where the split is a high split
-              leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) +
-                leftNodeAgg(featureIndex)(3 * splitIndex - 3)
-              leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) +
-                leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1)
-              leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) +
-                leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
-
-              // calculating right node aggregate for a split as a sum of right node aggregate of a
-              // higher split and the right bin aggregate of a bin where the split is a low split
-              rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
-                binData(shift + (3 * (numBins - 1 - splitIndex))) +
-                  rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
-              rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
-                binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) +
-                  rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
-              rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
-                binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) +
-                  rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
-
-              splitIndex += 1
-            }
+            findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
             featureIndex += 1
           }
           (leftNodeAgg, rightNodeAgg)
@@ -876,8 +1091,8 @@ object DecisionTree extends Serializable with Logging {
      * Calculates information gain for all nodes splits.
      */
     def calculateGainsForAllNodeSplits(
-        leftNodeAgg: Array[Array[Double]],
-        rightNodeAgg: Array[Array[Double]],
+        leftNodeAgg: Array[Array[Array[Double]]],
+        rightNodeAgg: Array[Array[Array[Double]]],
         nodeImpurity: Double): Array[Array[InformationGainStats]] = {
       val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
 
@@ -918,7 +1133,22 @@ object DecisionTree extends Serializable with Logging {
         while (featureIndex < numFeatures) {
           // Iterate over all splits.
           var splitIndex = 0
-          while (splitIndex < numBins - 1) {
+          val maxSplitIndex : Double = {
+            val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+            if (isFeatureContinuous) {
+              numBins - 1
+            } else { // Categorical feature
+              val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+              val isSpaceSufficientForAllCategoricalSplits
+                = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+              if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+                math.pow(2.0, featureCategories - 1).toInt - 1
+              } else { // Binary classification
+                featureCategories
+              }
+            }
+          }
+          while (splitIndex < maxSplitIndex) {
             val gainStats = gains(featureIndex)(splitIndex)
             if (gainStats.gain > bestGainStats.gain) {
               bestGainStats = gainStats
@@ -944,9 +1174,23 @@ object DecisionTree extends Serializable with Logging {
     def getBinDataForNode(node: Int): Array[Double] = {
       strategy.algo match {
         case Classification =>
-          val shift = 2 * node * numBins * numFeatures
-          val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
-          binsForNode
+          if (isMulticlassClassificationWithCategoricalFeatures) {
+            val shift = numClasses * node * numBins * numFeatures
+            val rightChildShift = numClasses * numBins * numFeatures * numNodes
+            val binsForNode = {
+              val leftChildData
+                = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+              val rightChildData
+              = binAggregates.slice(rightChildShift + shift,
+                rightChildShift + shift + numClasses * numBins * numFeatures)
+              leftChildData ++ rightChildData
+            }
+            binsForNode
+          } else {
+            val shift = numClasses * node * numBins * numFeatures
+            val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+            binsForNode
+          }
         case Regression =>
           val shift = 3 * node * numBins * numFeatures
           val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
@@ -963,14 +1207,26 @@ object DecisionTree extends Serializable with Logging {
       val binsForNode: Array[Double] = getBinDataForNode(node)
       logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
       val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
-      logDebug("node impurity = " + parentNodeImpurity)
+      logDebug("parent node impurity = " + parentNodeImpurity)
       bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
       node += 1
     }
-
     bestSplits
   }
 
+  private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int,
+      isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = {
+    algo match {
+      case Classification =>
+        if (isMulticlassClassificationWithCategoricalFeatures) {
+          2 * numClasses * numBins * numFeatures
+        } else {
+          numClasses * numBins * numFeatures
+        }
+      case Regression => 3 * numBins * numFeatures
+    }
+  }
+
   /**
    * Returns split and bins for decision tree calculation.
    * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
@@ -992,17 +1248,23 @@ object DecisionTree extends Serializable with Logging {
     val maxBins = strategy.maxBins
     val numBins = if (maxBins <= count) maxBins else count.toInt
     logDebug("numBins = " + numBins)
+    val isMulticlassClassification = strategy.isMulticlassClassification
+    logDebug("isMulticlassClassification = " + isMulticlassClassification)
+
 
     /*
-     * TODO: Add a require statement ensuring #bins is always greater than the categories.
+     * Ensure #bins is always greater than the categories. For multiclass classification,
+     * #bins should be greater than 2^(maxCategories - 1) - 1.
      * It's a limitation of the current implementation but a reasonable trade-off since features
      * with large number of categories get favored over continuous features.
      */
     if (strategy.categoricalFeaturesInfo.size > 0) {
       val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
-      require(numBins >= maxCategoriesForFeatures)
+      require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
+        "in categorical features")
     }
 
+
     // Calculate the number of sample for approximate quantile calculation.
     val requiredSamples = numBins*numBins
     val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
@@ -1036,48 +1298,93 @@ object DecisionTree extends Serializable with Logging {
               val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List())
               splits(featureIndex)(index) = split
             }
-          } else {
-            val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
-            require(maxFeatureValue < numBins, "number of categories should be less than number " +
-              "of bins")
-
-            // For categorical variables, each bin is a category. The bins are sorted and they
-            // are ordered by calculating the centroid of their corresponding labels.
-            val centroidForCategories =
-              sampledInput.map(lp => (lp.features(featureIndex),lp.label))
-                .groupBy(_._1)
-                .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
-
-            // Check for missing categorical variables and putting them last in the sorted list.
-            val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
-            for (i <- 0 until maxFeatureValue) {
-              if (centroidForCategories.contains(i)) {
-                fullCentroidForCategories(i) = centroidForCategories(i)
-              } else {
-                fullCentroidForCategories(i) = Double.MaxValue
-              }
-            }
-
-            // bins sorted by centroids
-            val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
-
-            logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)
-
-            var categoriesForSplit = List[Double]()
-            categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
-              case ((key, value), index) =>
-                categoriesForSplit = key :: categoriesForSplit
-                splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical,
-                  categoriesForSplit)
+          } else { // Categorical feature
+            val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+            val isSpaceSufficientForAllCategoricalSplits
+              = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+
+            // Use different bin/split calculation strategy for categorical features in multiclass
+            // classification that satisfy the space constraint
+            if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+              // 2^(maxFeatureValue- 1) - 1 combinations
+              var index = 0
+              while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
+                val categories: List[Double]
+                  = extractMultiClassCategories(index + 1, featureCategories)
+                splits(featureIndex)(index)
+                  = new Split(featureIndex, Double.MinValue, Categorical, categories)
                 bins(featureIndex)(index) = {
                   if (index == 0) {
-                    new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
-                      splits(featureIndex)(0), Categorical, key)
+                    new Bin(
+                      new DummyCategoricalSplit(featureIndex, Categorical),
+                      splits(featureIndex)(0),
+                      Categorical,
+                      Double.MinValue)
                   } else {
-                    new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
-                      Categorical, key)
+                    new Bin(
+                      splits(featureIndex)(index - 1),
+                      splits(featureIndex)(index),
+                      Categorical,
+                      Double.MinValue)
                   }
                 }
+                index += 1
+              }
+            } else {
+
+              val centroidForCategories = {
+                if (isMulticlassClassification) {
+                  // For categorical variables in multiclass classification,
+                  // each bin is a category. The bins are sorted and they
+                  // are ordered by calculating the impurity of their corresponding labels.
+                  sampledInput.map(lp => (lp.features(featureIndex), lp.label))
+                   .groupBy(_._1)
+                   .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
+                   .map(x => (x._1, x._2.values.toArray))
+                   .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum)))
+                } else { // regression or binary classification
+                  // For categorical variables in regression and binary classification,
+                  // each bin is a category. The bins are sorted and they
+                  // are ordered by calculating the centroid of their corresponding labels.
+                  sampledInput.map(lp => (lp.features(featureIndex), lp.label))
+                    .groupBy(_._1)
+                    .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
+                }
+              }
+
+              logDebug("centriod for categories = " + centroidForCategories.mkString(","))
+
+              // Check for missing categorical variables and putting them last in the sorted list.
+              val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
+              for (i <- 0 until featureCategories) {
+                if (centroidForCategories.contains(i)) {
+                  fullCentroidForCategories(i) = centroidForCategories(i)
+                } else {
+                  fullCentroidForCategories(i) = Double.MaxValue
+                }
+              }
+
+              // bins sorted by centroids
+              val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
+
+              logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)
+
+              var categoriesForSplit = List[Double]()
+              categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
+                case ((key, value), index) =>
+                  categoriesForSplit = key :: categoriesForSplit
+                  splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue,
+                    Categorical, categoriesForSplit)
+                  bins(featureIndex)(index) = {
+                    if (index == 0) {
+                      new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
+                        splits(featureIndex)(0), Categorical, key)
+                    } else {
+                      new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
+                        Categorical, key)
+                    }
+                  }
+              }
             }
           }
           featureIndex += 1
@@ -1107,4 +1414,29 @@ object DecisionTree extends Serializable with Logging {
         throw new UnsupportedOperationException("approximate histogram not supported yet.")
     }
   }
+
+  /**
+   * Nested method to extract list of eligible categories given an index. It extracts the
+   * position of ones in a binary representation of the input. If binary
+   * representation of an number is 01101 (13), the output list should (3.0, 2.0,
+   * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
+   */
+  private[tree] def extractMultiClassCategories(
+      input: Int,
+      maxFeatureValue: Int): List[Double] = {
+    var categories = List[Double]()
+    var j = 0
+    var bitShiftedInput = input
+    while (j < maxFeatureValue) {
+      if (bitShiftedInput % 2 != 0) {
+        // updating the list of categories.
+        categories = j.toDouble :: categories
+      }
+      // Right shift by one
+      bitShiftedInput = bitShiftedInput >> 1
+      j += 1
+    }
+    categories
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 1b505fd..7c027ac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
  * @param algo classification or regression
  * @param impurity criterion used for information gain calculation
  * @param maxDepth maximum depth of the tree
+ * @param numClassesForClassification number of classes for classification. Default value is 2
+ *                                    leads to binary classification
  * @param maxBins maximum number of bins used for splitting features
  * @param quantileCalculationStrategy algorithm for calculating quantiles
  * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
@@ -44,7 +46,15 @@ class Strategy (
     val algo: Algo,
     val impurity: Impurity,
     val maxDepth: Int,
+    val numClassesForClassification: Int = 2,
     val maxBins: Int = 100,
     val quantileCalculationStrategy: QuantileStrategy = Sort,
     val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
-    val maxMemoryInMB: Int = 128) extends Serializable
+    val maxMemoryInMB: Int = 128) extends Serializable {
+
+  require(numClassesForClassification >= 2)
+  val isMulticlassClassification = numClassesForClassification > 2
+  val isMulticlassWithCategoricalFeatures
+    = isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 60f43e9..a0e2d91 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -31,23 +31,35 @@ object Entropy extends Impurity {
 
   /**
    * :: DeveloperApi ::
-   * entropy calculation
-   * @param c0 count of instances with label 0
-   * @param c1 count of instances with label 1
-   * @return entropy value
+   * information calculation for multiclass classification
+   * @param counts Array[Double] with counts for each label
+   * @param totalCount sum of counts for all labels
+   * @return information value
    */
   @DeveloperApi
-  override def calculate(c0: Double, c1: Double): Double = {
-    if (c0 == 0 || c1 == 0) {
-      0
-    } else {
-      val total = c0 + c1
-      val f0 = c0 / total
-      val f1 = c1 / total
-      -(f0 * log2(f0)) - (f1 * log2(f1))
+  override def calculate(counts: Array[Double], totalCount: Double): Double = {
+    val numClasses = counts.length
+    var impurity = 0.0
+    var classIndex = 0
+    while (classIndex < numClasses) {
+      val classCount = counts(classIndex)
+      if (classCount != 0) {
+        val freq = classCount / totalCount
+        impurity -= freq * log2(freq)
+      }
+      classIndex += 1
     }
+    impurity
   }
 
+  /**
+   * :: DeveloperApi ::
+   * variance calculation
+   * @param count number of instances
+   * @param sum sum of labels
+   * @param sumSquares summation of squares of the labels
+   */
+  @DeveloperApi
   override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
     throw new UnsupportedOperationException("Entropy.calculate")
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index c51d76d..48144b5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -30,23 +30,32 @@ object Gini extends Impurity {
 
   /**
    * :: DeveloperApi ::
-   * Gini coefficient calculation
-   * @param c0 count of instances with label 0
-   * @param c1 count of instances with label 1
-   * @return Gini coefficient value
+   * information calculation for multiclass classification
+   * @param counts Array[Double] with counts for each label
+   * @param totalCount sum of counts for all labels
+   * @return information value
    */
   @DeveloperApi
-  override def calculate(c0: Double, c1: Double): Double = {
-    if (c0 == 0 || c1 == 0) {
-      0
-    } else {
-      val total = c0 + c1
-      val f0 = c0 / total
-      val f1 = c1 / total
-      1 - f0 * f0 - f1 * f1
+  override def calculate(counts: Array[Double], totalCount: Double): Double = {
+    val numClasses = counts.length
+    var impurity = 1.0
+    var classIndex = 0
+    while (classIndex < numClasses) {
+      val freq = counts(classIndex) / totalCount
+      impurity -= freq * freq
+      classIndex += 1
     }
+    impurity
   }
 
+  /**
+   * :: DeveloperApi ::
+   * variance calculation
+   * @param count number of instances
+   * @param sum sum of labels
+   * @param sumSquares summation of squares of the labels
+   */
+  @DeveloperApi
   override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
     throw new UnsupportedOperationException("Gini.calculate")
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 8eab247..7b2a932 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -28,13 +28,13 @@ trait Impurity extends Serializable {
 
   /**
    * :: DeveloperApi ::
-   * information calculation for binary classification
-   * @param c0 count of instances with label 0
-   * @param c1 count of instances with label 1
+   * information calculation for multiclass classification
+   * @param counts Array[Double] with counts for each label
+   * @param totalCount sum of counts for all labels
    * @return information value
    */
   @DeveloperApi
-  def calculate(c0 : Double, c1 : Double): Double
+  def calculate(counts: Array[Double], totalCount: Double): Double
 
   /**
    * :: DeveloperApi ::

http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 47d0712..97149a9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -25,7 +25,16 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
  */
 @Experimental
 object Variance extends Impurity {
-   override def calculate(c0: Double, c1: Double): Double =
+
+  /**
+   * :: DeveloperApi ::
+   * information calculation for multiclass classification
+   * @param counts Array[Double] with counts for each label
+   * @param totalCount sum of counts for all labels
+   * @return information value
+   */
+  @DeveloperApi
+  override def calculate(counts: Array[Double], totalCount: Double): Double =
      throw new UnsupportedOperationException("Variance.calculate")
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
index 2d71e1e..c89c1e3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
  * @param highSplit signifying the upper threshold for the continuous feature to be
  *                 accepted in the bin
  * @param featureType type of feature -- categorical or continuous
- * @param category categorical label value accepted in the bin
+ * @param category categorical label value accepted in the bin for binary classification
  */
 private[tree]
 case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)

http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index cc8a24c..fb12298 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -27,6 +27,7 @@ import org.apache.spark.annotation.DeveloperApi
  * @param leftImpurity left node impurity
  * @param rightImpurity right node impurity
  * @param predict predicted value
+ * @param prob probability of the label (classification only)
  */
 @DeveloperApi
 class InformationGainStats(
@@ -34,10 +35,11 @@ class InformationGainStats(
     val impurity: Double,
     val leftImpurity: Double,
     val rightImpurity: Double,
-    val predict: Double) extends Serializable {
+    val predict: Double,
+    val prob: Double = 0.0) extends Serializable {
 
   override def toString = {
-    "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
-      .format(gain, impurity, leftImpurity, rightImpurity, predict)
+    "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
+      .format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
   }
 }