You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/09/16 02:43:31 UTC

git commit: [SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGain params to example and Python API

Repository: spark
Updated Branches:
  refs/heads/master 983d6a9c4 -> fdb302f49


[SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGain params to example and Python API

Added minInstancesPerNode, minInfoGain params to:
* DecisionTreeRunner.scala example
* Python API (tree.py)

Also:
* Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements"
* small style fixes

CC: mengxr

Author: qiping.lqp <qi...@alibaba-inc.com>
Author: Joseph K. Bradley <jo...@gmail.com>
Author: chouqin <li...@gmail.com>

Closes #2349 from jkbradley/chouqin-dt-preprune and squashes the following commits:

61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy.
a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes
e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune
f1d11d1 [chouqin] fix typo
c7ebaf1 [chouqin] fix typo
39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test
c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py
0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree
d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1
efcc736 [qiping.lqp] fix bug
10b8012 [qiping.lqp] fix style
6728fad [qiping.lqp] minor fix: remove empty lines
bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune
cadd569 [qiping.lqp] add api docs
46b891f [qiping.lqp] fix bug
e72c7e4 [qiping.lqp] add comments
845c6fa [qiping.lqp] fix style
f195e83 [qiping.lqp] fix style
987cbf4 [qiping.lqp] fix bug
ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain
ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fdb302f4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fdb302f4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fdb302f4

Branch: refs/heads/master
Commit: fdb302f49c021227026909bdcdade7496059013f
Parents: 983d6a9
Author: qiping.lqp <qi...@alibaba-inc.com>
Authored: Mon Sep 15 17:43:26 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Sep 15 17:43:26 2014 -0700

----------------------------------------------------------------------
 .../spark/examples/mllib/DecisionTreeRunner.scala   | 13 ++++++++++++-
 .../spark/mllib/api/python/PythonMLLibAPI.scala     |  8 ++++++--
 .../org/apache/spark/mllib/tree/DecisionTree.scala  |  4 ++--
 .../spark/mllib/tree/configuration/Strategy.scala   |  2 ++
 .../org/apache/spark/mllib/tree/model/Predict.scala |  6 +-----
 .../apache/spark/mllib/tree/DecisionTreeSuite.scala |  4 ++--
 python/pyspark/mllib/tree.py                        | 16 ++++++++++++----
 7 files changed, 37 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 72c3ab4..4683e6e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -55,6 +55,8 @@ object DecisionTreeRunner {
       maxDepth: Int = 5,
       impurity: ImpurityType = Gini,
       maxBins: Int = 32,
+      minInstancesPerNode: Int = 1,
+      minInfoGain: Double = 0.0,
       fracTest: Double = 0.2)
 
   def main(args: Array[String]) {
@@ -75,6 +77,13 @@ object DecisionTreeRunner {
       opt[Int]("maxBins")
         .text(s"max number of bins, default: ${defaultParams.maxBins}")
         .action((x, c) => c.copy(maxBins = x))
+      opt[Int]("minInstancesPerNode")
+        .text(s"min number of instances required at child nodes to create the parent split," +
+        s" default: ${defaultParams.minInstancesPerNode}")
+        .action((x, c) => c.copy(minInstancesPerNode = x))
+      opt[Double]("minInfoGain")
+        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+        .action((x, c) => c.copy(minInfoGain = x))
       opt[Double]("fracTest")
         .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
         .action((x, c) => c.copy(fracTest = x))
@@ -179,7 +188,9 @@ object DecisionTreeRunner {
           impurity = impurityCalculator,
           maxDepth = params.maxDepth,
           maxBins = params.maxBins,
-          numClassesForClassification = numClasses)
+          numClassesForClassification = numClasses,
+          minInstancesPerNode = params.minInstancesPerNode,
+          minInfoGain = params.minInfoGain)
     val model = DecisionTree.train(training, strategy)
 
     println(model)

http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 4343124..fa0fa69 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable {
       categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
       impurityStr: String,
       maxDepth: Int,
-      maxBins: Int): DecisionTreeModel = {
+      maxBins: Int,
+      minInstancesPerNode: Int,
+      minInfoGain: Double): DecisionTreeModel = {
 
     val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
 
@@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable {
       maxDepth = maxDepth,
       numClassesForClassification = numClasses,
       maxBins = maxBins,
-      categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)
+      categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
+      minInstancesPerNode = minInstancesPerNode,
+      minInfoGain = minInfoGain)
 
     DecisionTree.train(data, strategy)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 56bb881..c7f2576 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -389,7 +389,7 @@ object DecisionTree extends Serializable with Logging {
       var groupIndex = 0
       var doneTraining = true
       while (groupIndex < numGroups) {
-        val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
+        val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
           topNode, splits, bins, timer, numGroups, groupIndex)
         doneTraining = doneTraining && doneTrainingGroup
         groupIndex += 1
@@ -898,7 +898,7 @@ object DecisionTree extends Serializable with Logging {
       }
     }.maxBy(_._2.gain)
 
-    require(predict.isDefined, "must calculate predict for each node")
+    assert(predict.isDefined, "must calculate predict for each node")
 
     (bestSplit, bestSplitStats, predict.get)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 31d1e8a..caaccbf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -77,6 +77,8 @@ class Strategy (
   }
   require(minInstancesPerNode >= 1,
     s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+  require(maxMemoryInMB <= 10240,
+    s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
 
   val isMulticlassClassification =
     algo == Classification && numClassesForClassification > 2

http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index 6fac2be..d8476b5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -17,18 +17,14 @@
 
 package org.apache.spark.mllib.tree.model
 
-import org.apache.spark.annotation.DeveloperApi
-
 /**
- * :: DeveloperApi ::
  * Predicted value for a node
  * @param predict predicted value
  * @param prob probability of the label (classification only)
  */
-@DeveloperApi
 private[tree] class Predict(
     val predict: Double,
-    val prob: Double = 0.0) extends Serializable{
+    val prob: Double = 0.0) extends Serializable {
 
   override def toString = {
     "predict = %f, prob = %f".format(predict, prob)

http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 1bd7ea0..2b2e579 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -714,8 +714,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(gain == InformationGainStats.invalidInformationGainStats)
   }
 
-  test("don't choose split that doesn't satisfy min instance per node requirements") {
-    // if a split doesn't satisfy min instances per node requirements,
+  test("do not choose split that does not satisfy min instance per node requirements") {
+    // if a split does not satisfy min instances per node requirements,
     // this split is invalid, even though the information gain of split is large.
     val arr = new Array[LabeledPoint](4)
     arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))

http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/python/pyspark/mllib/tree.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index ccc000a..5b13ab6 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -138,7 +138,8 @@ class DecisionTree(object):
 
     @staticmethod
     def trainClassifier(data, numClasses, categoricalFeaturesInfo,
-                        impurity="gini", maxDepth=5, maxBins=32):
+                        impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
+                        minInfoGain=0.0):
         """
         Train a DecisionTreeModel for classification.
 
@@ -154,6 +155,9 @@ class DecisionTree(object):
                          E.g., depth 0 means 1 leaf node.
                          Depth 1 means 1 internal node + 2 leaf nodes.
         :param maxBins: Number of bins used for finding splits at each node.
+        :param minInstancesPerNode: Min number of instances required at child nodes to create
+                                    the parent split
+        :param minInfoGain: Min info gain required to create a split
         :return: DecisionTreeModel
         """
         sc = data.context
@@ -164,13 +168,14 @@ class DecisionTree(object):
         model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
             dataBytes._jrdd, "classification",
             numClasses, categoricalFeaturesInfoJMap,
-            impurity, maxDepth, maxBins)
+            impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
         dataBytes.unpersist()
         return DecisionTreeModel(sc, model)
 
     @staticmethod
     def trainRegressor(data, categoricalFeaturesInfo,
-                       impurity="variance", maxDepth=5, maxBins=32):
+                       impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
+                       minInfoGain=0.0):
         """
         Train a DecisionTreeModel for regression.
 
@@ -185,6 +190,9 @@ class DecisionTree(object):
                          E.g., depth 0 means 1 leaf node.
                          Depth 1 means 1 internal node + 2 leaf nodes.
         :param maxBins: Number of bins used for finding splits at each node.
+        :param minInstancesPerNode: Min number of instances required at child nodes to create
+                                    the parent split
+        :param minInfoGain: Min info gain required to create a split
         :return: DecisionTreeModel
         """
         sc = data.context
@@ -195,7 +203,7 @@ class DecisionTree(object):
         model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
             dataBytes._jrdd, "regression",
             0, categoricalFeaturesInfoJMap,
-            impurity, maxDepth, maxBins)
+            impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
         dataBytes.unpersist()
         return DecisionTreeModel(sc, model)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org