You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/09/16 02:43:31 UTC
git commit: [SPARK-3516] [mllib] DecisionTree: Add
minInstancesPerNode, minInfoGain params to example and Python API
Repository: spark
Updated Branches:
refs/heads/master 983d6a9c4 -> fdb302f49
[SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGain params to example and Python API
Added minInstancesPerNode, minInfoGain params to:
* DecisionTreeRunner.scala example
* Python API (tree.py)
Also:
* Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements"
* small style fixes
CC: mengxr
Author: qiping.lqp <qi...@alibaba-inc.com>
Author: Joseph K. Bradley <jo...@gmail.com>
Author: chouqin <li...@gmail.com>
Closes #2349 from jkbradley/chouqin-dt-preprune and squashes the following commits:
61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy.
a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes
e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune
f1d11d1 [chouqin] fix typo
c7ebaf1 [chouqin] fix typo
39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test
c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py
0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree
d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1
efcc736 [qiping.lqp] fix bug
10b8012 [qiping.lqp] fix style
6728fad [qiping.lqp] minor fix: remove empty lines
bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune
cadd569 [qiping.lqp] add api docs
46b891f [qiping.lqp] fix bug
e72c7e4 [qiping.lqp] add comments
845c6fa [qiping.lqp] fix style
f195e83 [qiping.lqp] fix style
987cbf4 [qiping.lqp] fix bug
ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain
ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fdb302f4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fdb302f4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fdb302f4
Branch: refs/heads/master
Commit: fdb302f49c021227026909bdcdade7496059013f
Parents: 983d6a9
Author: qiping.lqp <qi...@alibaba-inc.com>
Authored: Mon Sep 15 17:43:26 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Sep 15 17:43:26 2014 -0700
----------------------------------------------------------------------
.../spark/examples/mllib/DecisionTreeRunner.scala | 13 ++++++++++++-
.../spark/mllib/api/python/PythonMLLibAPI.scala | 8 ++++++--
.../org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++--
.../spark/mllib/tree/configuration/Strategy.scala | 2 ++
.../org/apache/spark/mllib/tree/model/Predict.scala | 6 +-----
.../apache/spark/mllib/tree/DecisionTreeSuite.scala | 4 ++--
python/pyspark/mllib/tree.py | 16 ++++++++++++----
7 files changed, 37 insertions(+), 16 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 72c3ab4..4683e6e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -55,6 +55,8 @@ object DecisionTreeRunner {
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
fracTest: Double = 0.2)
def main(args: Array[String]) {
@@ -75,6 +77,13 @@ object DecisionTreeRunner {
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
opt[Double]("fracTest")
.text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
@@ -179,7 +188,9 @@ object DecisionTreeRunner {
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
- numClassesForClassification = numClasses)
+ numClassesForClassification = numClasses,
+ minInstancesPerNode = params.minInstancesPerNode,
+ minInfoGain = params.minInfoGain)
val model = DecisionTree.train(training, strategy)
println(model)
http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 4343124..fa0fa69 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable {
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
impurityStr: String,
maxDepth: Int,
- maxBins: Int): DecisionTreeModel = {
+ maxBins: Int,
+ minInstancesPerNode: Int,
+ minInfoGain: Double): DecisionTreeModel = {
val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
@@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable {
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
- categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)
+ categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
+ minInstancesPerNode = minInstancesPerNode,
+ minInfoGain = minInfoGain)
DecisionTree.train(data, strategy)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 56bb881..c7f2576 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -389,7 +389,7 @@ object DecisionTree extends Serializable with Logging {
var groupIndex = 0
var doneTraining = true
while (groupIndex < numGroups) {
- val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
+ val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
topNode, splits, bins, timer, numGroups, groupIndex)
doneTraining = doneTraining && doneTrainingGroup
groupIndex += 1
@@ -898,7 +898,7 @@ object DecisionTree extends Serializable with Logging {
}
}.maxBy(_._2.gain)
- require(predict.isDefined, "must calculate predict for each node")
+ assert(predict.isDefined, "must calculate predict for each node")
(bestSplit, bestSplitStats, predict.get)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 31d1e8a..caaccbf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -77,6 +77,8 @@ class Strategy (
}
require(minInstancesPerNode >= 1,
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+ require(maxMemoryInMB <= 10240,
+ s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
val isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index 6fac2be..d8476b5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -17,18 +17,14 @@
package org.apache.spark.mllib.tree.model
-import org.apache.spark.annotation.DeveloperApi
-
/**
- * :: DeveloperApi ::
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
-@DeveloperApi
private[tree] class Predict(
val predict: Double,
- val prob: Double = 0.0) extends Serializable{
+ val prob: Double = 0.0) extends Serializable {
override def toString = {
"predict = %f, prob = %f".format(predict, prob)
http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 1bd7ea0..2b2e579 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -714,8 +714,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(gain == InformationGainStats.invalidInformationGainStats)
}
- test("don't choose split that doesn't satisfy min instance per node requirements") {
- // if a split doesn't satisfy min instances per node requirements,
+ test("do not choose split that does not satisfy min instance per node requirements") {
+ // if a split does not satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
http://git-wip-us.apache.org/repos/asf/spark/blob/fdb302f4/python/pyspark/mllib/tree.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index ccc000a..5b13ab6 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -138,7 +138,8 @@ class DecisionTree(object):
@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
- impurity="gini", maxDepth=5, maxBins=32):
+ impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
+ minInfoGain=0.0):
"""
Train a DecisionTreeModel for classification.
@@ -154,6 +155,9 @@ class DecisionTree(object):
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
+ :param minInstancesPerNode: Min number of instances required at child nodes to create
+ the parent split
+ :param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
@@ -164,13 +168,14 @@ class DecisionTree(object):
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "classification",
numClasses, categoricalFeaturesInfoJMap,
- impurity, maxDepth, maxBins)
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)
@staticmethod
def trainRegressor(data, categoricalFeaturesInfo,
- impurity="variance", maxDepth=5, maxBins=32):
+ impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
+ minInfoGain=0.0):
"""
Train a DecisionTreeModel for regression.
@@ -185,6 +190,9 @@ class DecisionTree(object):
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
+ :param minInstancesPerNode: Min number of instances required at child nodes to create
+ the parent split
+ :param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
@@ -195,7 +203,7 @@ class DecisionTree(object):
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "regression",
0, categoricalFeaturesInfoJMap,
- impurity, maxDepth, maxBins)
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org