You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2017/11/10 11:17:31 UTC
spark git commit: [SPARK-20199][ML] : Provided featureSubsetStrategy
to GBTClassifier and GBTRegressor
Repository: spark
Updated Branches:
refs/heads/master 28ab5bf59 -> 9b9827759
[SPARK-20199][ML] : Provided featureSubsetStrategy to GBTClassifier and GBTRegressor
## What changes were proposed in this pull request?
(Provided featureSubset Strategy to GBTClassifier
a) Moved featureSubsetStrategy to TreeEnsembleParams
b) Changed GBTClassifier to pass featureSubsetStrategy
val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy))
## How was this patch tested?
a) Tested GradientBoostedTreeClassifierExample by adding .setFeatureSubsetStrategy with GBTClassifier
b)Added test cases in GBTClassifierSuite and GBTRegressorSuite
Author: Pralabh Kumar <pr...@gmail.com>
Closes #18118 from pralabhkumar/develop.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9b982775
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9b982775
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9b982775
Branch: refs/heads/master
Commit: 9b9827759af2ca3eea146a6032f9165f640ce152
Parents: 28ab5bf
Author: Pralabh Kumar <pr...@gmail.com>
Authored: Fri Nov 10 13:17:25 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Fri Nov 10 13:17:25 2017 +0200
----------------------------------------------------------------------
.../GradientBoostedTreeClassifierExample.scala | 1 +
.../spark/ml/classification/GBTClassifier.scala | 9 ++-
.../classification/RandomForestClassifier.scala | 2 +-
.../ml/regression/DecisionTreeRegressor.scala | 8 +-
.../spark/ml/regression/GBTRegressor.scala | 9 ++-
.../ml/regression/RandomForestRegressor.scala | 2 +-
.../ml/tree/impl/DecisionTreeMetadata.scala | 4 +-
.../ml/tree/impl/GradientBoostedTrees.scala | 25 +++---
.../org/apache/spark/ml/tree/treeParams.scala | 82 ++++++++++----------
.../spark/mllib/tree/GradientBoostedTrees.scala | 4 +-
.../apache/spark/mllib/tree/RandomForest.scala | 2 +-
.../ml/classification/GBTClassifierSuite.scala | 29 +++++++
.../spark/ml/regression/GBTRegressorSuite.scala | 29 +++++++
.../tree/impl/GradientBoostedTreesSuite.scala | 4 +-
14 files changed, 146 insertions(+), 64 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala
index 9a39acf..3656773 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala
@@ -59,6 +59,7 @@ object GradientBoostedTreeClassifierExample {
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setMaxIter(10)
+ .setFeatureSubsetStrategy("auto")
// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 3da809c..f11bc1d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -135,6 +135,11 @@ class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
override def setStepSize(value: Double): this.type = set(stepSize, value)
+ /** @group setParam */
+ @Since("2.3.0")
+ override def setFeatureSubsetStrategy(value: String): this.type =
+ set(featureSubsetStrategy, value)
+
// Parameters from GBTClassifierParams:
/** @group setParam */
@@ -167,12 +172,12 @@ class GBTClassifier @Since("1.4.0") (
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
- seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
+ seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
instr.logNumFeatures(numFeatures)
instr.logNumClasses(numClasses)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
- $(seed))
+ $(seed), $(featureSubsetStrategy))
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
instr.logSuccess(m)
m
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index ab4c235..78a4972 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -158,7 +158,7 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
- RandomForestParams.supportedFeatureSubsetStrategies
+ TreeEnsembleParams.supportedFeatureSubsetStrategies
@Since("2.0.0")
override def load(path: String): RandomForestClassifier = super.load(path)
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 01c5cc1..0291a57 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -117,12 +117,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
}
/** (private[ml]) Train a decision tree on an RDD */
- private[ml] def train(data: RDD[LabeledPoint],
- oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
+ private[ml] def train(
+ data: RDD[LabeledPoint],
+ oldStrategy: OldStrategy,
+ featureSubsetStrategy: String): DecisionTreeRegressionModel = {
val instr = Instrumentation.create(this, data)
instr.logParams(params: _*)
- val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
+ val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy,
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 08d175c..f41d15b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -140,6 +140,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
+ /** @group setParam */
+ @Since("2.3.0")
+ override def setFeatureSubsetStrategy(value: String): this.type =
+ set(featureSubsetStrategy, value)
+
override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -150,11 +155,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
- seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
+ seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
instr.logNumFeatures(numFeatures)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
- $(seed))
+ $(seed), $(featureSubsetStrategy))
val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
instr.logSuccess(m)
m
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index a58da50..200b234 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -149,7 +149,7 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
- RandomForestParams.supportedFeatureSubsetStrategies
+ TreeEnsembleParams.supportedFeatureSubsetStrategies
@Since("2.0.0")
override def load(path: String): RandomForestRegressor = super.load(path)
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 8a9dcb4..53189e0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -22,7 +22,7 @@ import scala.util.Try
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.LabeledPoint
-import org.apache.spark.ml.tree.RandomForestParams
+import org.apache.spark.ml.tree.TreeEnsembleParams
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.Strategy
@@ -200,7 +200,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match {
case Some(value) => math.ceil(value * numFeatures).toInt
case _ => throw new IllegalArgumentException(s"Supported values:" +
- s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
+ s" ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
s" (0.0-1.0], [1-n].")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index e32447a..bd8c9af 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -42,16 +42,18 @@ private[spark] object GradientBoostedTrees extends Logging {
def run(
input: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
- seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ seed: Long,
+ featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
- GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
+ GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false,
+ seed, featureSubsetStrategy)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
- seed)
+ seed, featureSubsetStrategy)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
}
@@ -73,11 +75,13 @@ private[spark] object GradientBoostedTrees extends Logging {
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
- seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ seed: Long,
+ featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
- GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
+ GradientBoostedTrees.boost(input, validationInput, boostingStrategy,
+ validate = true, seed, featureSubsetStrategy)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
@@ -85,7 +89,7 @@ private[spark] object GradientBoostedTrees extends Logging {
val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
- validate = true, seed)
+ validate = true, seed, featureSubsetStrategy)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@@ -245,7 +249,8 @@ private[spark] object GradientBoostedTrees extends Logging {
validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
validate: Boolean,
- seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ seed: Long,
+ featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
@@ -258,6 +263,7 @@ private[spark] object GradientBoostedTrees extends Logging {
val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
+
// Prepare strategy for individual trees, which use regression with variance impurity.
val treeStrategy = boostingStrategy.treeStrategy.copy
val validationTol = boostingStrategy.validationTol
@@ -288,7 +294,7 @@ private[spark] object GradientBoostedTrees extends Logging {
// Initialize tree
timer.start("building tree 0")
val firstTree = new DecisionTreeRegressor().setSeed(seed)
- val firstTreeModel = firstTree.train(input, treeStrategy)
+ val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
@@ -319,8 +325,9 @@ private[spark] object GradientBoostedTrees extends Logging {
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
+
val dt = new DecisionTreeRegressor().setSeed(seed + m)
- val model = dt.train(data, treeStrategy)
+ val model = dt.train(data, treeStrategy, featureSubsetStrategy)
timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 47079d9..81b6222 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -320,6 +320,12 @@ private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
}
}
+private[spark] object TreeEnsembleParams {
+ // These options should be lowercase.
+ final val supportedFeatureSubsetStrategies: Array[String] =
+ Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
+}
+
/**
* Parameters for Decision Tree-based ensemble algorithms.
*
@@ -359,38 +365,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
oldImpurity: OldImpurity): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
}
-}
-
-/**
- * Parameters for Random Forest algorithms.
- */
-private[ml] trait RandomForestParams extends TreeEnsembleParams {
-
- /**
- * Number of trees to train (>= 1).
- * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
- * TODO: Change to always do bootstrapping (simpler). SPARK-7130
- * (default = 20)
- *
- * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
- * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms
- * are a bit different.
- * @group param
- */
- final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
- ParamValidators.gtEq(1))
-
- setDefault(numTrees -> 20)
-
- /**
- * @deprecated This method is deprecated and will be removed in 3.0.0.
- * @group setParam
- */
- @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
- def setNumTrees(value: Int): this.type = set(numTrees, value)
-
- /** @group getParam */
- final def getNumTrees: Int = $(numTrees)
/**
* The number of features to consider for splits at each tree node.
@@ -420,10 +394,10 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
*/
final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
"The number of features to consider for splits at each tree node." +
- s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
+ s" Supported options: ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
s", (0.0-1.0], [1-n].",
(value: String) =>
- RandomForestParams.supportedFeatureSubsetStrategies.contains(
+ TreeEnsembleParams.supportedFeatureSubsetStrategies.contains(
value.toLowerCase(Locale.ROOT))
|| Try(value.toInt).filter(_ > 0).isSuccess
|| Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
@@ -431,7 +405,7 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
setDefault(featureSubsetStrategy -> "auto")
/**
- * @deprecated This method is deprecated and will be removed in 3.0.0.
+ * @deprecated This method is deprecated and will be removed in 3.0.0
* @group setParam
*/
@deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
@@ -441,10 +415,38 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT)
}
-private[spark] object RandomForestParams {
- // These options should be lowercase.
- final val supportedFeatureSubsetStrategies: Array[String] =
- Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
+
+
+/**
+ * Parameters for Random Forest algorithms.
+ */
+private[ml] trait RandomForestParams extends TreeEnsembleParams {
+
+ /**
+ * Number of trees to train (>= 1).
+ * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
+ * TODO: Change to always do bootstrapping (simpler). SPARK-7130
+ * (default = 20)
+ *
+ * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
+ * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms
+ * are a bit different.
+ * @group param
+ */
+ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+ ParamValidators.gtEq(1))
+
+ setDefault(numTrees -> 20)
+
+ /**
+ * @deprecated This method is deprecated and will be removed in 3.0.0.
+ * @group setParam
+ */
+ @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
+
+ /** @group getParam */
+ final def getNumTrees: Int = $(numTrees)
}
private[ml] trait RandomForestClassifierParams
@@ -497,6 +499,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
setDefault(maxIter -> 20, stepSize -> 0.1)
+ setDefault(featureSubsetStrategy -> "all")
+
/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
private[ml] def getOldBoostingStrategy(
categoricalFeatures: Map[Int, Int],
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index df2c1b0..d24d8da 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -69,7 +69,7 @@ class GradientBoostedTrees private[spark] (
val algo = boostingStrategy.treeStrategy.algo
val (trees, treeWeights) = NewGBT.run(input.map { point =>
NewLabeledPoint(point.label, point.features.asML)
- }, boostingStrategy, seed.toLong)
+ }, boostingStrategy, seed.toLong, "all")
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
@@ -101,7 +101,7 @@ class GradientBoostedTrees private[spark] (
NewLabeledPoint(point.label, point.features.asML)
}, validationInput.map { point =>
NewLabeledPoint(point.label, point.features.asML)
- }, boostingStrategy, seed.toLong)
+ }, boostingStrategy, seed.toLong, "all")
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index d1331a5..a8c5286 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -23,7 +23,7 @@ import scala.util.Try
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams}
+import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams}
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 8000143..978f89c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -83,6 +83,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
assert(gbt.getPredictionCol === "prediction")
assert(gbt.getRawPredictionCol === "rawPrediction")
assert(gbt.getProbabilityCol === "probability")
+ assert(gbt.getFeatureSubsetStrategy === "all")
val df = trainData.toDF()
val model = gbt.fit(df)
model.transform(df)
@@ -95,6 +96,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.getPredictionCol === "prediction")
assert(model.getRawPredictionCol === "rawPrediction")
assert(model.getProbabilityCol === "probability")
+ assert(model.getFeatureSubsetStrategy === "all")
assert(model.hasParent)
MLTestingUtils.checkCopyAndUids(gbt, model)
@@ -357,6 +359,33 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
}
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature subset strategy
+ /////////////////////////////////////////////////////////////////////////////
+ test("Tests of feature subset strategy") {
+ val numClasses = 2
+ val gbt = new GBTClassifier()
+ .setSeed(123)
+ .setMaxDepth(3)
+ .setMaxIter(5)
+ .setFeatureSubsetStrategy("all")
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+ val importances = gbt.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+
+ // GBT with different featureSubsetStrategy
+ val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
+ val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
+ val mostIF = importanceFeatures.argmax
+ assert(mostImportantFeature !== mostIF)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 2da25f7..ecbb571 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -166,6 +166,35 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
}
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature subset strategy
+ /////////////////////////////////////////////////////////////////////////////
+ test("Tests of feature subset strategy") {
+ val numClasses = 2
+ val gbt = new GBTRegressor()
+ .setMaxDepth(3)
+ .setMaxIter(5)
+ .setSeed(123)
+ .setFeatureSubsetStrategy("all")
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+ val importances = gbt.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+
+ // GBT with different featureSubsetStrategy
+ val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
+ val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
+ val mostIF = importanceFeatures.argmax
+ assert(mostImportantFeature !== mostIF)
+ }
+
+
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
index 4109a29..366d5ec 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
@@ -50,12 +50,12 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
val boostingStrategy =
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val (validateTrees, validateTreeWeights) = GradientBoostedTrees
- .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L)
+ .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L, "all")
val numTrees = validateTrees.length
assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset.
- val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L)
+ val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L, "all")
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org