You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/09/09 04:00:02 UTC

git commit: [SPARK-3443][MLLIB] update default values of tree:

Repository: spark
Updated Branches:
  refs/heads/master 7db53391f -> 50a4fa774


[SPARK-3443][MLLIB] update default values of tree:

Adjust the default values of decision tree, based on the memory requirement discussed in https://github.com/apache/spark/pull/2125 :

1. maxMemoryInMB: 128 -> 256
2. maxBins: 100 -> 32
3. maxDepth: 4 -> 5 (in some example code)

jkbradley

Author: Xiangrui Meng <me...@databricks.com>

Closes #2322 from mengxr/tree-defaults and squashes the following commits:

cda453a [Xiangrui Meng] fix tests
5900445 [Xiangrui Meng] update comments
8c81831 [Xiangrui Meng] update default values of tree:


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/50a4fa77
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/50a4fa77
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/50a4fa77

Branch: refs/heads/master
Commit: 50a4fa774a0e8a17d7743b33ce8941bf4041144d
Parents: 7db5339
Author: Xiangrui Meng <me...@databricks.com>
Authored: Mon Sep 8 18:59:57 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Sep 8 18:59:57 2014 -0700

----------------------------------------------------------------------
 docs/mllib-decision-tree.md                       | 16 ++++++++--------
 .../spark/examples/mllib/JavaDecisionTree.java    |  2 +-
 .../spark/examples/mllib/DecisionTreeRunner.scala |  4 ++--
 .../apache/spark/mllib/tree/DecisionTree.scala    |  8 ++++----
 .../spark/mllib/tree/configuration/Strategy.scala |  6 +++---
 .../spark/mllib/tree/DecisionTreeSuite.scala      | 18 ++++--------------
 python/pyspark/mllib/tree.py                      |  4 ++--
 7 files changed, 24 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/50a4fa77/docs/mllib-decision-tree.md
----------------------------------------------------------------------
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md
index 1166d9c..12a6afb 100644
--- a/docs/mllib-decision-tree.md
+++ b/docs/mllib-decision-tree.md
@@ -80,7 +80,7 @@ The ordered splits create "bins" and the maximum number of such
 bins can be specified using the `maxBins` parameter.
 
 Note that the number of bins cannot be greater than the number of instances `$N$` (a rare scenario
-since the default `maxBins` value is 100). The tree algorithm automatically reduces the number of
+since the default `maxBins` value is 32). The tree algorithm automatically reduces the number of
 bins if the condition is not satisfied.
 
 **Categorical features**
@@ -117,7 +117,7 @@ all nodes at each level of the tree. This could lead to high memory requirements
 of the tree, potentially leading to memory overflow errors. To alleviate this problem, a `maxMemoryInMB`
 training parameter specifies the maximum amount of memory at the workers (twice as much at the
 master) to be allocated to the histogram computation. The default value is conservatively chosen to
-be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements
+be 256 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements
 for a level-wise computation cross the `maxMemoryInMB` threshold, the node training tasks at each
 subsequent level are split into smaller tasks.
 
@@ -167,7 +167,7 @@ val numClasses = 2
 val categoricalFeaturesInfo = Map[Int, Int]()
 val impurity = "gini"
 val maxDepth = 5
-val maxBins = 100
+val maxBins = 32
 
 val model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity,
   maxDepth, maxBins)
@@ -213,7 +213,7 @@ Integer numClasses = 2;
 HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
 String impurity = "gini";
 Integer maxDepth = 5;
-Integer maxBins = 100;
+Integer maxBins = 32;
 
 // Train a DecisionTree model for classification.
 final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
@@ -250,7 +250,7 @@ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
 # Train a DecisionTree model.
 #  Empty categoricalFeaturesInfo indicates all features are continuous.
 model = DecisionTree.trainClassifier(data, numClasses=2, categoricalFeaturesInfo={},
-                                     impurity='gini', maxDepth=5, maxBins=100)
+                                     impurity='gini', maxDepth=5, maxBins=32)
 
 # Evaluate model on training instances and compute training error
 predictions = model.predict(data.map(lambda x: x.features))
@@ -293,7 +293,7 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache
 val categoricalFeaturesInfo = Map[Int, Int]()
 val impurity = "variance"
 val maxDepth = 5
-val maxBins = 100
+val maxBins = 32
 
 val model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity,
   maxDepth, maxBins)
@@ -338,7 +338,7 @@ JavaSparkContext sc = new JavaSparkContext(sparkConf);
 HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
 String impurity = "variance";
 Integer maxDepth = 5;
-Integer maxBins = 100;
+Integer maxBins = 32;
 
 // Train a DecisionTree model.
 final DecisionTreeModel model = DecisionTree.trainRegressor(data,
@@ -380,7 +380,7 @@ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
 # Train a DecisionTree model.
 #  Empty categoricalFeaturesInfo indicates all features are continuous.
 model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo={},
-                                    impurity='variance', maxDepth=5, maxBins=100)
+                                    impurity='variance', maxDepth=5, maxBins=32)
 
 # Evaluate model on training instances and compute training error
 predictions = model.predict(data.map(lambda x: x.features))

http://git-wip-us.apache.org/repos/asf/spark/blob/50a4fa77/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
index e4468e8..1f82e3f 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
@@ -63,7 +63,7 @@ public final class JavaDecisionTree {
     HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
     String impurity = "gini";
     Integer maxDepth = 5;
-    Integer maxBins = 100;
+    Integer maxBins = 32;
 
     // Train a DecisionTree model for classification.
     final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,

http://git-wip-us.apache.org/repos/asf/spark/blob/50a4fa77/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index cf3d2cc..72c3ab4 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -52,9 +52,9 @@ object DecisionTreeRunner {
       input: String = null,
       dataFormat: String = "libsvm",
       algo: Algo = Classification,
-      maxDepth: Int = 4,
+      maxDepth: Int = 5,
       impurity: ImpurityType = Gini,
-      maxBins: Int = 100,
+      maxBins: Int = 32,
       fracTest: Double = 0.2)
 
   def main(args: Array[String]) {

http://git-wip-us.apache.org/repos/asf/spark/blob/50a4fa77/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index dd766c1..d1309b2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -330,9 +330,9 @@ object DecisionTree extends Serializable with Logging {
    *                 Supported values: "gini" (recommended) or "entropy".
    * @param maxDepth Maximum depth of the tree.
    *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
-   *                  (suggested value: 4)
+   *                  (suggested value: 5)
    * @param maxBins maximum number of bins used for splitting features
-   *                 (suggested value: 100)
+   *                 (suggested value: 32)
    * @return DecisionTreeModel that can be used for prediction
    */
   def trainClassifier(
@@ -374,9 +374,9 @@ object DecisionTree extends Serializable with Logging {
    *                 Supported values: "variance".
    * @param maxDepth Maximum depth of the tree.
    *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
-   *                  (suggested value: 4)
+   *                  (suggested value: 5)
    * @param maxBins maximum number of bins used for splitting features
-   *                 (suggested value: 100)
+   *                 (suggested value: 32)
    * @return DecisionTreeModel that can be used for prediction
    */
   def trainRegressor(

http://git-wip-us.apache.org/repos/asf/spark/blob/50a4fa77/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index cfc8192..23f74d5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -50,7 +50,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
  *                                1, 2, ... , k-1. It's important to note that features are
  *                                zero-indexed.
  * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
- *                      128 MB.
+ *                      256 MB.
  */
 @Experimental
 class Strategy (
@@ -58,10 +58,10 @@ class Strategy (
     val impurity: Impurity,
     val maxDepth: Int,
     val numClassesForClassification: Int = 2,
-    val maxBins: Int = 100,
+    val maxBins: Int = 32,
     val quantileCalculationStrategy: QuantileStrategy = Sort,
     val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
-    val maxMemoryInMB: Int = 128) extends Serializable {
+    val maxMemoryInMB: Int = 256) extends Serializable {
 
   if (algo == Classification) {
     require(numClassesForClassification >= 2)

http://git-wip-us.apache.org/repos/asf/spark/blob/50a4fa77/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 8e556c9..69482f2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -31,7 +31,6 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
 import org.apache.spark.mllib.util.LocalSparkContext
 
-
 class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
   def validateClassifier(
@@ -353,8 +352,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins.length === 2)
     assert(bins(0).length === 100)
-    assert(splits(0).length === 99)
-    assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
     val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
@@ -381,8 +378,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins.length === 2)
     assert(bins(0).length === 100)
-    assert(splits(0).length === 99)
-    assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
     val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
@@ -410,8 +405,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins.length === 2)
     assert(bins(0).length === 100)
-    assert(splits(0).length === 99)
-    assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
     val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
@@ -439,8 +432,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins.length === 2)
     assert(bins(0).length === 100)
-    assert(splits(0).length === 99)
-    assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
     val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
@@ -464,8 +455,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins.length === 2)
     assert(bins(0).length === 100)
-    assert(splits(0).length === 99)
-    assert(bins(0).length === 100)
 
     // Train a 1-node model
     val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
@@ -600,7 +589,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
-      numClassesForClassification = 3)
+      numClassesForClassification = 3, maxBins = 100)
     assert(strategy.isMulticlassClassification)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
 
@@ -626,7 +615,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
-      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
+      numClassesForClassification = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
     assert(strategy.isMulticlassClassification)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
     assert(metadata.isUnordered(featureIndex = 0))
@@ -652,7 +641,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
-      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+      numClassesForClassification = 3, maxBins = 100,
+      categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
     assert(strategy.isMulticlassClassification)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
     assert(!metadata.isUnordered(featureIndex = 0))

http://git-wip-us.apache.org/repos/asf/spark/blob/50a4fa77/python/pyspark/mllib/tree.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index a2fade6..ccc000a 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -138,7 +138,7 @@ class DecisionTree(object):
 
     @staticmethod
     def trainClassifier(data, numClasses, categoricalFeaturesInfo,
-                        impurity="gini", maxDepth=4, maxBins=100):
+                        impurity="gini", maxDepth=5, maxBins=32):
         """
         Train a DecisionTreeModel for classification.
 
@@ -170,7 +170,7 @@ class DecisionTree(object):
 
     @staticmethod
     def trainRegressor(data, categoricalFeaturesInfo,
-                       impurity="variance", maxDepth=4, maxBins=100):
+                       impurity="variance", maxDepth=5, maxBins=32):
         """
         Train a DecisionTreeModel for regression.
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org