You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2017/11/10 11:17:31 UTC

spark git commit: [SPARK-20199][ML] : Provided featureSubsetStrategy to GBTClassifier and GBTRegressor

Repository: spark
Updated Branches:
  refs/heads/master 28ab5bf59 -> 9b9827759


[SPARK-20199][ML] : Provided featureSubsetStrategy to GBTClassifier and GBTRegressor

## What changes were proposed in this pull request?

(Provided featureSubset Strategy to GBTClassifier
a) Moved featureSubsetStrategy to TreeEnsembleParams
b)  Changed GBTClassifier to pass featureSubsetStrategy
val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy))

## How was this patch tested?
a) Tested GradientBoostedTreeClassifierExample by adding .setFeatureSubsetStrategy with GBTClassifier

b)Added test cases in GBTClassifierSuite and GBTRegressorSuite

Author: Pralabh Kumar <pr...@gmail.com>

Closes #18118 from pralabhkumar/develop.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9b982775
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9b982775
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9b982775

Branch: refs/heads/master
Commit: 9b9827759af2ca3eea146a6032f9165f640ce152
Parents: 28ab5bf
Author: Pralabh Kumar <pr...@gmail.com>
Authored: Fri Nov 10 13:17:25 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Fri Nov 10 13:17:25 2017 +0200

----------------------------------------------------------------------
 .../GradientBoostedTreeClassifierExample.scala  |  1 +
 .../spark/ml/classification/GBTClassifier.scala |  9 ++-
 .../classification/RandomForestClassifier.scala |  2 +-
 .../ml/regression/DecisionTreeRegressor.scala   |  8 +-
 .../spark/ml/regression/GBTRegressor.scala      |  9 ++-
 .../ml/regression/RandomForestRegressor.scala   |  2 +-
 .../ml/tree/impl/DecisionTreeMetadata.scala     |  4 +-
 .../ml/tree/impl/GradientBoostedTrees.scala     | 25 +++---
 .../org/apache/spark/ml/tree/treeParams.scala   | 82 ++++++++++----------
 .../spark/mllib/tree/GradientBoostedTrees.scala |  4 +-
 .../apache/spark/mllib/tree/RandomForest.scala  |  2 +-
 .../ml/classification/GBTClassifierSuite.scala  | 29 +++++++
 .../spark/ml/regression/GBTRegressorSuite.scala | 29 +++++++
 .../tree/impl/GradientBoostedTreesSuite.scala   |  4 +-
 14 files changed, 146 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala
index 9a39acf..3656773 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala
@@ -59,6 +59,7 @@ object GradientBoostedTreeClassifierExample {
       .setLabelCol("indexedLabel")
       .setFeaturesCol("indexedFeatures")
       .setMaxIter(10)
+      .setFeatureSubsetStrategy("auto")
 
     // Convert indexed labels back to original labels.
     val labelConverter = new IndexToString()

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 3da809c..f11bc1d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -135,6 +135,11 @@ class GBTClassifier @Since("1.4.0") (
   @Since("1.4.0")
   override def setStepSize(value: Double): this.type = set(stepSize, value)
 
+  /** @group setParam */
+  @Since("2.3.0")
+  override def setFeatureSubsetStrategy(value: String): this.type =
+    set(featureSubsetStrategy, value)
+
   // Parameters from GBTClassifierParams:
 
   /** @group setParam */
@@ -167,12 +172,12 @@ class GBTClassifier @Since("1.4.0") (
     val instr = Instrumentation.create(this, oldDataset)
     instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
       maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
-      seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
+      seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
     instr.logNumFeatures(numFeatures)
     instr.logNumClasses(numClasses)
 
     val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
-      $(seed))
+      $(seed), $(featureSubsetStrategy))
     val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
     instr.logSuccess(m)
     m

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index ab4c235..78a4972 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -158,7 +158,7 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi
   /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
   @Since("1.4.0")
   final val supportedFeatureSubsetStrategies: Array[String] =
-    RandomForestParams.supportedFeatureSubsetStrategies
+    TreeEnsembleParams.supportedFeatureSubsetStrategies
 
   @Since("2.0.0")
   override def load(path: String): RandomForestClassifier = super.load(path)

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 01c5cc1..0291a57 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -117,12 +117,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
   }
 
   /** (private[ml]) Train a decision tree on an RDD */
-  private[ml] def train(data: RDD[LabeledPoint],
-      oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
+  private[ml] def train(
+      data: RDD[LabeledPoint],
+      oldStrategy: OldStrategy,
+      featureSubsetStrategy: String): DecisionTreeRegressionModel = {
     val instr = Instrumentation.create(this, data)
     instr.logParams(params: _*)
 
-    val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
+    val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy,
       seed = $(seed), instr = Some(instr), parentUID = Some(uid))
 
     val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 08d175c..f41d15b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -140,6 +140,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
   @Since("1.4.0")
   def setLossType(value: String): this.type = set(lossType, value)
 
+  /** @group setParam */
+  @Since("2.3.0")
+  override def setFeatureSubsetStrategy(value: String): this.type =
+    set(featureSubsetStrategy, value)
+
   override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -150,11 +155,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
     val instr = Instrumentation.create(this, oldDataset)
     instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
       maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
-      seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
+      seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
     instr.logNumFeatures(numFeatures)
 
     val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
-      $(seed))
+      $(seed), $(featureSubsetStrategy))
     val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
     instr.logSuccess(m)
     m

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index a58da50..200b234 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -149,7 +149,7 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor
   /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
   @Since("1.4.0")
   final val supportedFeatureSubsetStrategies: Array[String] =
-    RandomForestParams.supportedFeatureSubsetStrategies
+    TreeEnsembleParams.supportedFeatureSubsetStrategies
 
   @Since("2.0.0")
   override def load(path: String): RandomForestRegressor = super.load(path)

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 8a9dcb4..53189e0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -22,7 +22,7 @@ import scala.util.Try
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.feature.LabeledPoint
-import org.apache.spark.ml.tree.RandomForestParams
+import org.apache.spark.ml.tree.TreeEnsembleParams
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
 import org.apache.spark.mllib.tree.configuration.Strategy
@@ -200,7 +200,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
             Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match {
               case Some(value) => math.ceil(value * numFeatures).toInt
               case _ => throw new IllegalArgumentException(s"Supported values:" +
-                s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
+                s" ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
                 s" (0.0-1.0], [1-n].")
             }
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index e32447a..bd8c9af 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -42,16 +42,18 @@ private[spark] object GradientBoostedTrees extends Logging {
   def run(
       input: RDD[LabeledPoint],
       boostingStrategy: OldBoostingStrategy,
-      seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+      seed: Long,
+      featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
     val algo = boostingStrategy.treeStrategy.algo
     algo match {
       case OldAlgo.Regression =>
-        GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
+        GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false,
+          seed, featureSubsetStrategy)
       case OldAlgo.Classification =>
         // Map labels to -1, +1 so binary classification can be treated as regression.
         val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
         GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
-          seed)
+          seed, featureSubsetStrategy)
       case _ =>
         throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
     }
@@ -73,11 +75,13 @@ private[spark] object GradientBoostedTrees extends Logging {
       input: RDD[LabeledPoint],
       validationInput: RDD[LabeledPoint],
       boostingStrategy: OldBoostingStrategy,
-      seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+      seed: Long,
+      featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
     val algo = boostingStrategy.treeStrategy.algo
     algo match {
       case OldAlgo.Regression =>
-        GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
+        GradientBoostedTrees.boost(input, validationInput, boostingStrategy,
+          validate = true, seed, featureSubsetStrategy)
       case OldAlgo.Classification =>
         // Map labels to -1, +1 so binary classification can be treated as regression.
         val remappedInput = input.map(
@@ -85,7 +89,7 @@ private[spark] object GradientBoostedTrees extends Logging {
         val remappedValidationInput = validationInput.map(
           x => new LabeledPoint((x.label * 2) - 1, x.features))
         GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
-          validate = true, seed)
+          validate = true, seed, featureSubsetStrategy)
       case _ =>
         throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
     }
@@ -245,7 +249,8 @@ private[spark] object GradientBoostedTrees extends Logging {
       validationInput: RDD[LabeledPoint],
       boostingStrategy: OldBoostingStrategy,
       validate: Boolean,
-      seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+      seed: Long,
+      featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
     val timer = new TimeTracker()
     timer.start("total")
     timer.start("init")
@@ -258,6 +263,7 @@ private[spark] object GradientBoostedTrees extends Logging {
     val baseLearnerWeights = new Array[Double](numIterations)
     val loss = boostingStrategy.loss
     val learningRate = boostingStrategy.learningRate
+
     // Prepare strategy for individual trees, which use regression with variance impurity.
     val treeStrategy = boostingStrategy.treeStrategy.copy
     val validationTol = boostingStrategy.validationTol
@@ -288,7 +294,7 @@ private[spark] object GradientBoostedTrees extends Logging {
     // Initialize tree
     timer.start("building tree 0")
     val firstTree = new DecisionTreeRegressor().setSeed(seed)
-    val firstTreeModel = firstTree.train(input, treeStrategy)
+    val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy)
     val firstTreeWeight = 1.0
     baseLearners(0) = firstTreeModel
     baseLearnerWeights(0) = firstTreeWeight
@@ -319,8 +325,9 @@ private[spark] object GradientBoostedTrees extends Logging {
       logDebug("###################################################")
       logDebug("Gradient boosting tree iteration " + m)
       logDebug("###################################################")
+
       val dt = new DecisionTreeRegressor().setSeed(seed + m)
-      val model = dt.train(data, treeStrategy)
+      val model = dt.train(data, treeStrategy, featureSubsetStrategy)
       timer.stop(s"building tree $m")
       // Update partial model
       baseLearners(m) = model

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 47079d9..81b6222 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -320,6 +320,12 @@ private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
   }
 }
 
+private[spark] object TreeEnsembleParams {
+  // These options should be lowercase.
+  final val supportedFeatureSubsetStrategies: Array[String] =
+    Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
+}
+
 /**
  * Parameters for Decision Tree-based ensemble algorithms.
  *
@@ -359,38 +365,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
       oldImpurity: OldImpurity): OldStrategy = {
     super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
   }
-}
-
-/**
- * Parameters for Random Forest algorithms.
- */
-private[ml] trait RandomForestParams extends TreeEnsembleParams {
-
-  /**
-   * Number of trees to train (>= 1).
-   * If 1, then no bootstrapping is used.  If > 1, then bootstrapping is done.
-   * TODO: Change to always do bootstrapping (simpler).  SPARK-7130
-   * (default = 20)
-   *
-   * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
-   * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms
-   * are a bit different.
-   * @group param
-   */
-  final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
-    ParamValidators.gtEq(1))
-
-  setDefault(numTrees -> 20)
-
-  /**
-   * @deprecated This method is deprecated and will be removed in 3.0.0.
-   * @group setParam
-   */
-  @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
-  def setNumTrees(value: Int): this.type = set(numTrees, value)
-
-  /** @group getParam */
-  final def getNumTrees: Int = $(numTrees)
 
   /**
    * The number of features to consider for splits at each tree node.
@@ -420,10 +394,10 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
    */
   final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
     "The number of features to consider for splits at each tree node." +
-      s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
+      s" Supported options: ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
       s", (0.0-1.0], [1-n].",
     (value: String) =>
-      RandomForestParams.supportedFeatureSubsetStrategies.contains(
+      TreeEnsembleParams.supportedFeatureSubsetStrategies.contains(
         value.toLowerCase(Locale.ROOT))
       || Try(value.toInt).filter(_ > 0).isSuccess
       || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
@@ -431,7 +405,7 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
   setDefault(featureSubsetStrategy -> "auto")
 
   /**
-   * @deprecated This method is deprecated and will be removed in 3.0.0.
+   * @deprecated This method is deprecated and will be removed in 3.0.0
    * @group setParam
    */
   @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
@@ -441,10 +415,38 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
   final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT)
 }
 
-private[spark] object RandomForestParams {
-  // These options should be lowercase.
-  final val supportedFeatureSubsetStrategies: Array[String] =
-    Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
+
+
+/**
+ * Parameters for Random Forest algorithms.
+ */
+private[ml] trait RandomForestParams extends TreeEnsembleParams {
+
+  /**
+   * Number of trees to train (>= 1).
+   * If 1, then no bootstrapping is used.  If > 1, then bootstrapping is done.
+   * TODO: Change to always do bootstrapping (simpler).  SPARK-7130
+   * (default = 20)
+   *
+   * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
+   * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms
+   * are a bit different.
+   * @group param
+   */
+  final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+    ParamValidators.gtEq(1))
+
+  setDefault(numTrees -> 20)
+
+  /**
+   * @deprecated This method is deprecated and will be removed in 3.0.0.
+   * @group setParam
+   */
+  @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
+  def setNumTrees(value: Int): this.type = set(numTrees, value)
+
+  /** @group getParam */
+  final def getNumTrees: Int = $(numTrees)
 }
 
 private[ml] trait RandomForestClassifierParams
@@ -497,6 +499,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
 
   setDefault(maxIter -> 20, stepSize -> 0.1)
 
+  setDefault(featureSubsetStrategy -> "all")
+
   /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
   private[ml] def getOldBoostingStrategy(
       categoricalFeatures: Map[Int, Int],

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index df2c1b0..d24d8da 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -69,7 +69,7 @@ class GradientBoostedTrees private[spark] (
     val algo = boostingStrategy.treeStrategy.algo
     val (trees, treeWeights) = NewGBT.run(input.map { point =>
       NewLabeledPoint(point.label, point.features.asML)
-    }, boostingStrategy, seed.toLong)
+    }, boostingStrategy, seed.toLong, "all")
     new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
   }
 
@@ -101,7 +101,7 @@ class GradientBoostedTrees private[spark] (
       NewLabeledPoint(point.label, point.features.asML)
     }, validationInput.map { point =>
       NewLabeledPoint(point.label, point.features.asML)
-    }, boostingStrategy, seed.toLong)
+    }, boostingStrategy, seed.toLong, "all")
     new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index d1331a5..a8c5286 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -23,7 +23,7 @@ import scala.util.Try
 import org.apache.spark.annotation.Since
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams}
+import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams}
 import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 8000143..978f89c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -83,6 +83,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
     assert(gbt.getPredictionCol === "prediction")
     assert(gbt.getRawPredictionCol === "rawPrediction")
     assert(gbt.getProbabilityCol === "probability")
+    assert(gbt.getFeatureSubsetStrategy === "all")
     val df = trainData.toDF()
     val model = gbt.fit(df)
     model.transform(df)
@@ -95,6 +96,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
     assert(model.getPredictionCol === "prediction")
     assert(model.getRawPredictionCol === "rawPrediction")
     assert(model.getProbabilityCol === "probability")
+    assert(model.getFeatureSubsetStrategy === "all")
     assert(model.hasParent)
 
     MLTestingUtils.checkCopyAndUids(gbt, model)
@@ -357,6 +359,33 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
   }
 
   /////////////////////////////////////////////////////////////////////////////
+  // Tests of feature subset  strategy
+  /////////////////////////////////////////////////////////////////////////////
+  test("Tests of feature subset strategy") {
+    val numClasses = 2
+    val gbt = new GBTClassifier()
+      .setSeed(123)
+      .setMaxDepth(3)
+      .setMaxIter(5)
+      .setFeatureSubsetStrategy("all")
+
+    // In this data, feature 1 is very important.
+    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+    val categoricalFeatures = Map.empty[Int, Int]
+    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+    val importances = gbt.fit(df).featureImportances
+    val mostImportantFeature = importances.argmax
+    assert(mostImportantFeature === 1)
+
+    // GBT with different featureSubsetStrategy
+    val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
+    val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
+    val mostIF = importanceFeatures.argmax
+    assert(mostImportantFeature !== mostIF)
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 2da25f7..ecbb571 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -166,6 +166,35 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
   }
 
   /////////////////////////////////////////////////////////////////////////////
+  // Tests of feature subset strategy
+  /////////////////////////////////////////////////////////////////////////////
+  test("Tests of feature subset strategy") {
+    val numClasses = 2
+    val gbt = new GBTRegressor()
+      .setMaxDepth(3)
+      .setMaxIter(5)
+      .setSeed(123)
+      .setFeatureSubsetStrategy("all")
+
+    // In this data, feature 1 is very important.
+    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+    val categoricalFeatures = Map.empty[Int, Int]
+    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+    val importances = gbt.fit(df).featureImportances
+    val mostImportantFeature = importances.argmax
+    assert(mostImportantFeature === 1)
+
+    // GBT with different featureSubsetStrategy
+    val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
+    val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
+    val mostIF = importanceFeatures.argmax
+    assert(mostImportantFeature !== mostIF)
+  }
+
+
+
+  /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
index 4109a29..366d5ec 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
@@ -50,12 +50,12 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
       val boostingStrategy =
         new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
       val (validateTrees, validateTreeWeights) = GradientBoostedTrees
-        .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L)
+        .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L, "all")
       val numTrees = validateTrees.length
       assert(numTrees !== numIterations)
 
       // Test that it performs better on the validation dataset.
-      val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L)
+      val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L, "all")
       val (errorWithoutValidation, errorWithValidation) = {
         if (algo == Classification) {
           val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org