You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/04/25 21:27:24 UTC

[1/2] spark git commit: [SPARK-6113] [ML] Tree ensembles for Pipelines API

Repository: spark
Updated Branches:
  refs/heads/master a61d65fc8 -> a7160c4e3


http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
new file mode 100644
index 0000000..2171ef3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams}
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest  Random Forest]] learning algorithm for regression.
+ * It supports both continuous and categorical features.
+ */
+@AlphaComponent
+final class RandomForestRegressor
+  extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
+  with RandomForestParams with TreeRegressorParams {
+
+  // Override parameter setters from parent trait for Java API compatibility.
+
+  // Parameters from TreeRegressorParams:
+
+  override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+  override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+  override def setMinInstancesPerNode(value: Int): this.type =
+    super.setMinInstancesPerNode(value)
+
+  override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+  override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+  override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+  override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+  override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+  // Parameters from TreeEnsembleParams:
+
+  override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+  override def setSeed(value: Long): this.type = super.setSeed(value)
+
+  // Parameters from RandomForestParams:
+
+  override def setNumTrees(value: Int): this.type = super.setNumTrees(value)
+
+  override def setFeatureSubsetStrategy(value: String): this.type =
+    super.setFeatureSubsetStrategy(value)
+
+  override protected def train(
+      dataset: DataFrame,
+      paramMap: ParamMap): RandomForestRegressionModel = {
+    val categoricalFeatures: Map[Int, Int] =
+      MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+    val strategy =
+      super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
+    val oldModel = OldRandomForest.trainRegressor(
+      oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
+    RandomForestRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+  }
+}
+
+object RandomForestRegressor {
+  /** Accessor for supported impurity settings: variance */
+  final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+
+  /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
+  final val supportedFeatureSubsetStrategies: Array[String] =
+    RandomForestParams.supportedFeatureSubsetStrategies
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest  Random Forest]] model for regression.
+ * It supports both continuous and categorical features.
+ * @param _trees  Decision trees in the ensemble.
+ */
+@AlphaComponent
+final class RandomForestRegressionModel private[ml] (
+    override val parent: RandomForestRegressor,
+    override val fittingParamMap: ParamMap,
+    private val _trees: Array[DecisionTreeRegressionModel])
+  extends PredictionModel[Vector, RandomForestRegressionModel]
+  with TreeEnsembleModel with Serializable {
+
+  require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
+
+  override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+  // Note: We may add support for weights (based on tree performance) later on.
+  private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+
+  override def treeWeights: Array[Double] = _treeWeights
+
+  override protected def predict(features: Vector): Double = {
+    // TODO: Override transform() to broadcast model.  SPARK-7127
+    // TODO: When we add a generic Bagging class, handle transform there.  SPARK-7128
+    // Predict average of tree predictions.
+    // Ignore the weights since all are 1.0 for now.
+    _trees.map(_.rootNode.predict(features)).sum / numTrees
+  }
+
+  override protected def copy(): RandomForestRegressionModel = {
+    val m = new RandomForestRegressionModel(parent, fittingParamMap, _trees)
+    Params.inheritValues(this.extractParamMap(), this, m)
+    m
+  }
+
+  override def toString: String = {
+    s"RandomForestRegressionModel with $numTrees trees"
+  }
+
+  /** (private[ml]) Convert to a model in the old API */
+  private[ml] def toOld: OldRandomForestModel = {
+    new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
+  }
+}
+
+private[ml] object RandomForestRegressionModel {
+
+  /** (private[ml]) Convert a model from the old API */
+  def fromOld(
+      oldModel: OldRandomForestModel,
+      parent: RandomForestRegressor,
+      fittingParamMap: ParamMap,
+      categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = {
+    require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
+      s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
+    val newTrees = oldModel.trees.map { tree =>
+      // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+      DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+    }
+    new RandomForestRegressionModel(parent, fittingParamMap, newTrees)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index d6e2203..d2dec0c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -28,9 +28,9 @@ import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformation
 sealed abstract class Node extends Serializable {
 
   // TODO: Add aggregate stats (once available).  This will happen after we move the DecisionTree
-  //       code into the new API and deprecate the old API.
+  //       code into the new API and deprecate the old API.  SPARK-3727
 
-  /** Prediction this node makes (or would make, if it is an internal node) */
+  /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */
   def prediction: Double
 
   /** Impurity measure at this node (for training data) */
@@ -194,7 +194,7 @@ private object InternalNode {
           s"$featureStr > ${contSplit.threshold}"
         }
       case catSplit: CategoricalSplit =>
-        val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}")
+        val categoriesStr = catSplit.leftCategories.mkString("{", ",", "}")
         if (left) {
           s"$featureStr in $categoriesStr"
         } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index 708c769..90f1d05 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -44,7 +44,7 @@ private[tree] object Split {
     oldSplit.featureType match {
       case OldFeatureType.Categorical =>
         new CategoricalSplit(featureIndex = oldSplit.feature,
-          leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature))
+          _leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature))
       case OldFeatureType.Continuous =>
         new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold)
     }
@@ -54,30 +54,30 @@ private[tree] object Split {
 /**
  * Split which tests a categorical feature.
  * @param featureIndex  Index of the feature to test
- * @param leftCategories  If the feature value is in this set of categories, then the split goes
- *                        left. Otherwise, it goes right.
+ * @param _leftCategories  If the feature value is in this set of categories, then the split goes
+ *                         left. Otherwise, it goes right.
  * @param numCategories  Number of categories for this feature.
  */
 final class CategoricalSplit private[ml] (
     override val featureIndex: Int,
-    leftCategories: Array[Double],
+    _leftCategories: Array[Double],
     private val numCategories: Int)
   extends Split {
 
-  require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
-    s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}")
+  require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
+    s" (should be in range [0, $numCategories)): ${_leftCategories.mkString(",")}")
 
   /**
    * If true, then "categories" is the set of categories for splitting to the left, and vice versa.
    */
-  private val isLeft: Boolean = leftCategories.length <= numCategories / 2
+  private val isLeft: Boolean = _leftCategories.length <= numCategories / 2
 
   /** Set of categories determining the splitting rule, along with [[isLeft]]. */
   private val categories: Set[Double] = {
     if (isLeft) {
-      leftCategories.toSet
+      _leftCategories.toSet
     } else {
-      setComplement(leftCategories.toSet)
+      setComplement(_leftCategories.toSet)
     }
   }
 
@@ -107,13 +107,13 @@ final class CategoricalSplit private[ml] (
   }
 
   /** Get sorted categories which split to the left */
-  def getLeftCategories: Array[Double] = {
+  def leftCategories: Array[Double] = {
     val cats = if (isLeft) categories else setComplement(categories)
     cats.toArray.sorted
   }
 
   /** Get sorted categories which split to the right */
-  def getRightCategories: Array[Double] = {
+  def rightCategories: Array[Double] = {
     val cats = if (isLeft) setComplement(categories) else categories
     cats.toArray.sorted
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 8e3bc38..1929f9d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,18 +17,13 @@
 
 package org.apache.spark.ml.tree
 
-import org.apache.spark.annotation.AlphaComponent
-
 
 /**
- * :: AlphaComponent ::
- *
  * Abstraction for Decision Tree models.
  *
- * TODO: Add support for predicting probabilities and raw predictions
+ * TODO: Add support for predicting probabilities and raw predictions  SPARK-3727
  */
-@AlphaComponent
-trait DecisionTreeModel {
+private[ml] trait DecisionTreeModel {
 
   /** Root of the decision tree */
   def rootNode: Node
@@ -58,3 +53,40 @@ trait DecisionTreeModel {
     header + rootNode.subtreeToString(2)
   }
 }
+
+/**
+ * Abstraction for models which are ensembles of decision trees
+ *
+ * TODO: Add support for predicting probabilities and raw predictions  SPARK-3727
+ */
+private[ml] trait TreeEnsembleModel {
+
+  // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of
+  //       DecisionTreeModel.
+
+  /** Trees in this ensemble. Warning: These have null parent Estimators. */
+  def trees: Array[DecisionTreeModel]
+
+  /** Weights for each tree, zippable with [[trees]] */
+  def treeWeights: Array[Double]
+
+  /** Summary of the model */
+  override def toString: String = {
+    // Implementing classes should generally override this method to be more descriptive.
+    s"TreeEnsembleModel with $numTrees trees"
+  }
+
+  /** Full description of model */
+  def toDebugString: String = {
+    val header = toString + "\n"
+    header + trees.zip(treeWeights).zipWithIndex.map { case ((tree, weight), treeIndex) =>
+      s"  Tree $treeIndex (weight $weight):\n" + tree.rootNode.subtreeToString(4)
+    }.fold("")(_ + _)
+  }
+
+  /** Number of trees in ensemble */
+  val numTrees: Int = trees.length
+
+  /** Total number of nodes, summed over all trees in the ensemble. */
+  lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 43b8787..60f25e5 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.ml.classification;
 
-import java.io.File;
 import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
@@ -32,7 +31,6 @@ import org.apache.spark.ml.impl.TreeTests;
 import org.apache.spark.mllib.classification.LogisticRegressionSuite;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.DataFrame;
-import org.apache.spark.util.Utils;
 
 
 public class JavaDecisionTreeClassifierSuite implements Serializable {
@@ -57,7 +55,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
     double B = -1.5;
 
     JavaRDD<LabeledPoint> data = sc.parallelize(
-        LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+      LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
     Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
     DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
 
@@ -71,8 +69,8 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
       .setCacheNodeIds(false)
       .setCheckpointInterval(10)
       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
-    for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) {
-      dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]);
+    for (String impurity: DecisionTreeClassifier.supportedImpurities()) {
+      dt.setImpurity(impurity);
     }
     DecisionTreeClassificationModel model = dt.fit(dataFrame);
 
@@ -82,7 +80,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
     model.toDebugString();
 
     /*
-    // TODO: Add test once save/load are implemented.
+    // TODO: Add test once save/load are implemented.  SPARK-6725
     File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
     String path = tempDir.toURI().toString();
     try {

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
new file mode 100644
index 0000000..3c69467
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaGBTClassifierSuite implements Serializable {
+
+  private transient JavaSparkContext sc;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaGBTClassifierSuite");
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  @Test
+  public void runDT() {
+    int nPoints = 20;
+    double A = 2.0;
+    double B = -1.5;
+
+    JavaRDD<LabeledPoint> data = sc.parallelize(
+      LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+    Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+    // This tests setters. Training with various options is tested in Scala.
+    GBTClassifier rf = new GBTClassifier()
+      .setMaxDepth(2)
+      .setMaxBins(10)
+      .setMinInstancesPerNode(5)
+      .setMinInfoGain(0.0)
+      .setMaxMemoryInMB(256)
+      .setCacheNodeIds(false)
+      .setCheckpointInterval(10)
+      .setSubsamplingRate(1.0)
+      .setSeed(1234)
+      .setMaxIter(3)
+      .setStepSize(0.1)
+      .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+    for (String lossType: GBTClassifier.supportedLossTypes()) {
+      rf.setLossType(lossType);
+    }
+    GBTClassificationModel model = rf.fit(dataFrame);
+
+    model.transform(dataFrame);
+    model.totalNumNodes();
+    model.toDebugString();
+    model.trees();
+    model.treeWeights();
+
+    /*
+    // TODO: Add test once save/load are implemented.  SPARK-6725
+    File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+    String path = tempDir.toURI().toString();
+    try {
+      model3.save(sc.sc(), path);
+      GBTClassificationModel sameModel = GBTClassificationModel.load(sc.sc(), path);
+      TreeTests.checkEqual(model3, sameModel);
+    } finally {
+      Utils.deleteRecursively(tempDir);
+    }
+    */
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
new file mode 100644
index 0000000..32d0b38
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaRandomForestClassifierSuite implements Serializable {
+
+  private transient JavaSparkContext sc;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite");
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  @Test
+  public void runDT() {
+    int nPoints = 20;
+    double A = 2.0;
+    double B = -1.5;
+
+    JavaRDD<LabeledPoint> data = sc.parallelize(
+      LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+    Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+    // This tests setters. Training with various options is tested in Scala.
+    RandomForestClassifier rf = new RandomForestClassifier()
+      .setMaxDepth(2)
+      .setMaxBins(10)
+      .setMinInstancesPerNode(5)
+      .setMinInfoGain(0.0)
+      .setMaxMemoryInMB(256)
+      .setCacheNodeIds(false)
+      .setCheckpointInterval(10)
+      .setSubsamplingRate(1.0)
+      .setSeed(1234)
+      .setNumTrees(3)
+      .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+    for (String impurity: RandomForestClassifier.supportedImpurities()) {
+      rf.setImpurity(impurity);
+    }
+    for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
+      rf.setFeatureSubsetStrategy(featureSubsetStrategy);
+    }
+    RandomForestClassificationModel model = rf.fit(dataFrame);
+
+    model.transform(dataFrame);
+    model.totalNumNodes();
+    model.toDebugString();
+    model.trees();
+    model.treeWeights();
+
+    /*
+    // TODO: Add test once save/load are implemented.  SPARK-6725
+    File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+    String path = tempDir.toURI().toString();
+    try {
+      model3.save(sc.sc(), path);
+      RandomForestClassificationModel sameModel =
+          RandomForestClassificationModel.load(sc.sc(), path);
+      TreeTests.checkEqual(model3, sameModel);
+    } finally {
+      Utils.deleteRecursively(tempDir);
+    }
+    */
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index a3a3390..71b0418 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.ml.regression;
 
-import java.io.File;
 import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
@@ -32,7 +31,6 @@ import org.apache.spark.ml.impl.TreeTests;
 import org.apache.spark.mllib.classification.LogisticRegressionSuite;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.DataFrame;
-import org.apache.spark.util.Utils;
 
 
 public class JavaDecisionTreeRegressorSuite implements Serializable {
@@ -57,22 +55,22 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
     double B = -1.5;
 
     JavaRDD<LabeledPoint> data = sc.parallelize(
-        LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+      LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
     Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
     DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
 
     // This tests setters. Training with various options is tested in Scala.
     DecisionTreeRegressor dt = new DecisionTreeRegressor()
-        .setMaxDepth(2)
-        .setMaxBins(10)
-        .setMinInstancesPerNode(5)
-        .setMinInfoGain(0.0)
-        .setMaxMemoryInMB(256)
-        .setCacheNodeIds(false)
-        .setCheckpointInterval(10)
-        .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
-    for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) {
-      dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]);
+      .setMaxDepth(2)
+      .setMaxBins(10)
+      .setMinInstancesPerNode(5)
+      .setMinInfoGain(0.0)
+      .setMaxMemoryInMB(256)
+      .setCacheNodeIds(false)
+      .setCheckpointInterval(10)
+      .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+    for (String impurity: DecisionTreeRegressor.supportedImpurities()) {
+      dt.setImpurity(impurity);
     }
     DecisionTreeRegressionModel model = dt.fit(dataFrame);
 
@@ -82,7 +80,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
     model.toDebugString();
 
     /*
-    // TODO: Add test once save/load are implemented.
+    // TODO: Add test once save/load are implemented.   SPARK-6725
     File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
     String path = tempDir.toURI().toString();
     try {

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
new file mode 100644
index 0000000..fc8c13d
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaGBTRegressorSuite implements Serializable {
+
+  private transient JavaSparkContext sc;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaGBTRegressorSuite");
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  @Test
+  public void runDT() {
+    int nPoints = 20;
+    double A = 2.0;
+    double B = -1.5;
+
+    JavaRDD<LabeledPoint> data = sc.parallelize(
+      LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+    Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+
+    GBTRegressor rf = new GBTRegressor()
+      .setMaxDepth(2)
+      .setMaxBins(10)
+      .setMinInstancesPerNode(5)
+      .setMinInfoGain(0.0)
+      .setMaxMemoryInMB(256)
+      .setCacheNodeIds(false)
+      .setCheckpointInterval(10)
+      .setSubsamplingRate(1.0)
+      .setSeed(1234)
+      .setMaxIter(3)
+      .setStepSize(0.1)
+      .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+    for (String lossType: GBTRegressor.supportedLossTypes()) {
+      rf.setLossType(lossType);
+    }
+    GBTRegressionModel model = rf.fit(dataFrame);
+
+    model.transform(dataFrame);
+    model.totalNumNodes();
+    model.toDebugString();
+    model.trees();
+    model.treeWeights();
+
+    /*
+    // TODO: Add test once save/load are implemented.  SPARK-6725
+    File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+    String path = tempDir.toURI().toString();
+    try {
+      model2.save(sc.sc(), path);
+      GBTRegressionModel sameModel = GBTRegressionModel.load(sc.sc(), path);
+      TreeTests.checkEqual(model2, sameModel);
+    } finally {
+      Utils.deleteRecursively(tempDir);
+    }
+    */
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
new file mode 100644
index 0000000..e306eba
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaRandomForestRegressorSuite implements Serializable {
+
+  private transient JavaSparkContext sc;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite");
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  @Test
+  public void runDT() {
+    int nPoints = 20;
+    double A = 2.0;
+    double B = -1.5;
+
+    JavaRDD<LabeledPoint> data = sc.parallelize(
+      LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+    Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+
+    // This tests setters. Training with various options is tested in Scala.
+    RandomForestRegressor rf = new RandomForestRegressor()
+      .setMaxDepth(2)
+      .setMaxBins(10)
+      .setMinInstancesPerNode(5)
+      .setMinInfoGain(0.0)
+      .setMaxMemoryInMB(256)
+      .setCacheNodeIds(false)
+      .setCheckpointInterval(10)
+      .setSubsamplingRate(1.0)
+      .setSeed(1234)
+      .setNumTrees(3)
+      .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+    for (String impurity: RandomForestRegressor.supportedImpurities()) {
+      rf.setImpurity(impurity);
+    }
+    for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
+      rf.setFeatureSubsetStrategy(featureSubsetStrategy);
+    }
+    RandomForestRegressionModel model = rf.fit(dataFrame);
+
+    model.transform(dataFrame);
+    model.totalNumNodes();
+    model.toDebugString();
+    model.trees();
+    model.treeWeights();
+
+    /*
+    // TODO: Add test once save/load are implemented.   SPARK-6725
+    File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+    String path = tempDir.toURI().toString();
+    try {
+      model2.save(sc.sc(), path);
+      RandomForestRegressionModel sameModel = RandomForestRegressionModel.load(sc.sc(), path);
+      TreeTests.checkEqual(model2, sameModel);
+    } finally {
+      Utils.deleteRecursively(tempDir);
+    }
+    */
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index af88595..9b31ade 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -230,7 +230,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
 
-  // TODO: Reinstate test once save/load are implemented
+  // TODO: Reinstate test once save/load are implemented   SPARK-6725
   /*
   test("model save/load") {
     val tempDir = Utils.createTempDir()

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
new file mode 100644
index 0000000..e6ccc2c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[GBTClassifier]].
+ */
+class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext {
+
+  import GBTClassifierSuite.compareAPIs
+
+  // Combinations for estimators, learning rates and subsamplingRate
+  private val testCombinations =
+    Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
+
+  private var data: RDD[LabeledPoint] = _
+  private var trainData: RDD[LabeledPoint] = _
+  private var validationData: RDD[LabeledPoint] = _
+
+  override def beforeAll() {
+    super.beforeAll()
+    data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2)
+    trainData =
+      sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2)
+    validationData =
+      sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
+  }
+
+  test("Binary classification with continuous features: Log Loss") {
+    val categoricalFeatures = Map.empty[Int, Int]
+    testCombinations.foreach {
+      case (maxIter, learningRate, subsamplingRate) =>
+        val gbt = new GBTClassifier()
+          .setMaxDepth(2)
+          .setSubsamplingRate(subsamplingRate)
+          .setLossType("logistic")
+          .setMaxIter(maxIter)
+          .setStepSize(learningRate)
+        compareAPIs(data, None, gbt, categoricalFeatures)
+    }
+  }
+
+  // TODO: Reinstate test once runWithValidation is implemented   SPARK-7132
+  /*
+  test("runWithValidation stops early and performs better on a validation dataset") {
+    val categoricalFeatures = Map.empty[Int, Int]
+    // Set maxIter large enough so that it stops early.
+    val maxIter = 20
+    GBTClassifier.supportedLossTypes.foreach { loss =>
+      val gbt = new GBTClassifier()
+        .setMaxIter(maxIter)
+        .setMaxDepth(2)
+        .setLossType(loss)
+        .setValidationTol(0.0)
+      compareAPIs(trainData, None, gbt, categoricalFeatures)
+      compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures)
+    }
+  }
+  */
+
+  /////////////////////////////////////////////////////////////////////////////
+  // Tests of model save/load
+  /////////////////////////////////////////////////////////////////////////////
+
+  // TODO: Reinstate test once save/load are implemented  SPARK-6725
+  /*
+  test("model save/load") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
+    val treeWeights = Array(0.1, 0.3, 1.1)
+    val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights)
+    val newModel = GBTClassificationModel.fromOld(oldModel)
+
+    // Save model, load it back, and compare.
+    try {
+      newModel.save(sc, path)
+      val sameNewModel = GBTClassificationModel.load(sc, path)
+      TreeTests.checkEqual(newModel, sameNewModel)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+  */
+}
+
+private object GBTClassifierSuite {
+
+  /**
+   * Train 2 models on the given dataset, one using the old API and one using the new API.
+   * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+   */
+  def compareAPIs(
+      data: RDD[LabeledPoint],
+      validationData: Option[RDD[LabeledPoint]],
+      gbt: GBTClassifier,
+      categoricalFeatures: Map[Int, Int]): Unit = {
+    val oldBoostingStrategy =
+      gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
+    val oldGBT = new OldGBT(oldBoostingStrategy)
+    val oldModel = oldGBT.run(data)
+    val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
+    val newModel = gbt.fit(newData)
+    // Use parent, fittingParamMap from newTree since these are not checked anyways.
+    val oldModelAsNew = GBTClassificationModel.fromOld(oldModel, newModel.parent,
+      newModel.fittingParamMap, categoricalFeatures)
+    TreeTests.checkEqual(oldModelAsNew, newModel)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
new file mode 100644
index 0000000..ed41a96
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[RandomForestClassifier]].
+ */
+class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext {
+
+  import RandomForestClassifierSuite.compareAPIs
+
+  private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
+  private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _
+
+  override def beforeAll() {
+    super.beforeAll()
+    orderedLabeledPoints50_1000 =
+      sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000))
+    orderedLabeledPoints5_20 =
+      sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20))
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+  // Tests calling train()
+  /////////////////////////////////////////////////////////////////////////////
+
+  def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier) {
+    val categoricalFeatures = Map.empty[Int, Int]
+    val numClasses = 2
+    val newRF = rf
+      .setImpurity("Gini")
+      .setMaxDepth(2)
+      .setNumTrees(1)
+      .setFeatureSubsetStrategy("auto")
+      .setSeed(123)
+    compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses)
+  }
+
+  test("Binary classification with continuous features:" +
+    " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+    val rf = new RandomForestClassifier()
+    binaryClassificationTestWithContinuousFeatures(rf)
+  }
+
+  test("Binary classification with continuous features and node Id cache:" +
+    " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+    val rf = new RandomForestClassifier()
+      .setCacheNodeIds(true)
+    binaryClassificationTestWithContinuousFeatures(rf)
+  }
+
+  test("alternating categorical and continuous features with multiclass labels to test indexing") {
+    val arr = Array(
+      LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)),
+      LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0))
+    )
+    val rdd = sc.parallelize(arr)
+    val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4)
+    val numClasses = 3
+
+    val rf = new RandomForestClassifier()
+      .setImpurity("Gini")
+      .setMaxDepth(5)
+      .setNumTrees(2)
+      .setFeatureSubsetStrategy("sqrt")
+      .setSeed(12345)
+    compareAPIs(rdd, rf, categoricalFeatures, numClasses)
+  }
+
+  test("subsampling rate in RandomForest"){
+    val rdd = orderedLabeledPoints5_20
+    val categoricalFeatures = Map.empty[Int, Int]
+    val numClasses = 2
+
+    val rf1 = new RandomForestClassifier()
+      .setImpurity("Gini")
+      .setMaxDepth(2)
+      .setCacheNodeIds(true)
+      .setNumTrees(3)
+      .setFeatureSubsetStrategy("auto")
+      .setSeed(123)
+    compareAPIs(rdd, rf1, categoricalFeatures, numClasses)
+
+    val rf2 = rf1.setSubsamplingRate(0.5)
+    compareAPIs(rdd, rf2, categoricalFeatures, numClasses)
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+  // Tests of model save/load
+  /////////////////////////////////////////////////////////////////////////////
+
+  // TODO: Reinstate test once save/load are implemented  SPARK-6725
+  /*
+  test("model save/load") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    val trees =
+      Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray
+    val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees)
+    val newModel = RandomForestClassificationModel.fromOld(oldModel)
+
+    // Save model, load it back, and compare.
+    try {
+      newModel.save(sc, path)
+      val sameNewModel = RandomForestClassificationModel.load(sc, path)
+      TreeTests.checkEqual(newModel, sameNewModel)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+  */
+}
+
+private object RandomForestClassifierSuite {
+
+  /**
+   * Train 2 models on the given dataset, one using the old API and one using the new API.
+   * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+   */
+  def compareAPIs(
+      data: RDD[LabeledPoint],
+      rf: RandomForestClassifier,
+      categoricalFeatures: Map[Int, Int],
+      numClasses: Int): Unit = {
+    val oldStrategy =
+      rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity)
+    val oldModel = OldRandomForest.trainClassifier(
+      data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
+    val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+    val newModel = rf.fit(newData)
+    // Use parent, fittingParamMap from newTree since these are not checked anyways.
+    val oldModelAsNew = RandomForestClassificationModel.fromOld(oldModel, newModel.parent,
+      newModel.fittingParamMap, categoricalFeatures)
+    TreeTests.checkEqual(oldModelAsNew, newModel)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 2e57d4c..1505ad8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -23,8 +23,7 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
-import org.apache.spark.ml.impl.tree._
-import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node}
+import org.apache.spark.ml.tree._
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{SQLContext, DataFrame}
@@ -111,22 +110,19 @@ private[ml] object TreeTests extends FunSuite {
     }
   }
 
-  // TODO: Reinstate after adding ensembles
   /**
    * Check if the two models are exactly the same.
    * If the models are not equal, this throws an exception.
    */
-  /*
   def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = {
     try {
-      a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) =>
+      a.trees.zip(b.trees).foreach { case (treeA, treeB) =>
         TreeTests.checkEqual(treeA, treeB)
       }
-      assert(a.getTreeWeights === b.getTreeWeights)
+      assert(a.treeWeights === b.treeWeights)
     } catch {
       case ex: Exception => throw new AssertionError(
         "checkEqual failed since the two tree ensembles were not identical")
     }
   }
-  */
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 0b40fe3..c87a171 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -66,7 +66,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
 
-  // TODO: test("model save/load")
+  // TODO: test("model save/load")   SPARK-6725
 }
 
 private[ml] object DecisionTreeRegressorSuite extends FunSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
new file mode 100644
index 0000000..4aec369
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[GBTRegressor]].
+ */
+class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext {
+
+  import GBTRegressorSuite.compareAPIs
+
+  // Combinations for estimators, learning rates and subsamplingRate
+  private val testCombinations =
+    Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
+
+  private var data: RDD[LabeledPoint] = _
+  private var trainData: RDD[LabeledPoint] = _
+  private var validationData: RDD[LabeledPoint] = _
+
+  override def beforeAll() {
+    super.beforeAll()
+    data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2)
+    trainData =
+      sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2)
+    validationData =
+      sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
+  }
+
+  test("Regression with continuous features: SquaredError") {
+    val categoricalFeatures = Map.empty[Int, Int]
+    GBTRegressor.supportedLossTypes.foreach { loss =>
+      testCombinations.foreach {
+        case (maxIter, learningRate, subsamplingRate) =>
+          val gbt = new GBTRegressor()
+            .setMaxDepth(2)
+            .setSubsamplingRate(subsamplingRate)
+            .setLossType(loss)
+            .setMaxIter(maxIter)
+            .setStepSize(learningRate)
+          compareAPIs(data, None, gbt, categoricalFeatures)
+      }
+    }
+  }
+
+  // TODO: Reinstate test once runWithValidation is implemented  SPARK-7132
+  /*
+  test("runWithValidation stops early and performs better on a validation dataset") {
+    val categoricalFeatures = Map.empty[Int, Int]
+    // Set maxIter large enough so that it stops early.
+    val maxIter = 20
+    GBTRegressor.supportedLossTypes.foreach { loss =>
+      val gbt = new GBTRegressor()
+        .setMaxIter(maxIter)
+        .setMaxDepth(2)
+        .setLossType(loss)
+        .setValidationTol(0.0)
+      compareAPIs(trainData, None, gbt, categoricalFeatures)
+      compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures)
+    }
+  }
+  */
+
+  /////////////////////////////////////////////////////////////////////////////
+  // Tests of model save/load
+  /////////////////////////////////////////////////////////////////////////////
+
+  // TODO: Reinstate test once save/load are implemented  SPARK-6725
+  /*
+  test("model save/load") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
+    val treeWeights = Array(0.1, 0.3, 1.1)
+    val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights)
+    val newModel = GBTRegressionModel.fromOld(oldModel)
+
+    // Save model, load it back, and compare.
+    try {
+      newModel.save(sc, path)
+      val sameNewModel = GBTRegressionModel.load(sc, path)
+      TreeTests.checkEqual(newModel, sameNewModel)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+  */
+}
+
+private object GBTRegressorSuite {
+
+  /**
+   * Train 2 models on the given dataset, one using the old API and one using the new API.
+   * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+   */
+  def compareAPIs(
+      data: RDD[LabeledPoint],
+      validationData: Option[RDD[LabeledPoint]],
+      gbt: GBTRegressor,
+      categoricalFeatures: Map[Int, Int]): Unit = {
+    val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
+    val oldGBT = new OldGBT(oldBoostingStrategy)
+    val oldModel = oldGBT.run(data)
+    val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+    val newModel = gbt.fit(newData)
+    // Use parent, fittingParamMap from newTree since these are not checked anyways.
+    val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent,
+      newModel.fittingParamMap, categoricalFeatures)
+    TreeTests.checkEqual(oldModelAsNew, newModel)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
new file mode 100644
index 0000000..c6dc1cc
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[RandomForestRegressor]].
+ */
+class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext {
+
+  import RandomForestRegressorSuite.compareAPIs
+
+  private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
+
+  override def beforeAll() {
+    super.beforeAll()
+    orderedLabeledPoints50_1000 =
+      sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000))
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+  // Tests calling train()
+  /////////////////////////////////////////////////////////////////////////////
+
+  def regressionTestWithContinuousFeatures(rf: RandomForestRegressor) {
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val newRF = rf
+      .setImpurity("variance")
+      .setMaxDepth(2)
+      .setMaxBins(10)
+      .setNumTrees(1)
+      .setFeatureSubsetStrategy("auto")
+      .setSeed(123)
+    compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeaturesInfo)
+  }
+
+  test("Regression with continuous features:" +
+    " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+    val rf = new RandomForestRegressor()
+    regressionTestWithContinuousFeatures(rf)
+  }
+
+  test("Regression with continuous features and node Id cache :" +
+    " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+    val rf = new RandomForestRegressor()
+      .setCacheNodeIds(true)
+    regressionTestWithContinuousFeatures(rf)
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+  // Tests of model save/load
+  /////////////////////////////////////////////////////////////////////////////
+
+  // TODO: Reinstate test once save/load are implemented  SPARK-6725
+  /*
+  test("model save/load") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
+    val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees)
+    val newModel = RandomForestRegressionModel.fromOld(oldModel)
+
+    // Save model, load it back, and compare.
+    try {
+      newModel.save(sc, path)
+      val sameNewModel = RandomForestRegressionModel.load(sc, path)
+      TreeTests.checkEqual(newModel, sameNewModel)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+  */
+}
+
+private object RandomForestRegressorSuite extends FunSuite {
+
+  /**
+   * Train 2 models on the given dataset, one using the old API and one using the new API.
+   * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+   */
+  def compareAPIs(
+      data: RDD[LabeledPoint],
+      rf: RandomForestRegressor,
+      categoricalFeatures: Map[Int, Int]): Unit = {
+    val oldStrategy =
+      rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity)
+    val oldModel = OldRandomForest.trainRegressor(
+      data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
+    val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+    val newModel = rf.fit(newData)
+    // Use parent, fittingParamMap from newTree since these are not checked anyways.
+    val oldModelAsNew = RandomForestRegressionModel.fromOld(oldModel, newModel.parent,
+      newModel.fittingParamMap, categoricalFeatures)
+    TreeTests.checkEqual(oldModelAsNew, newModel)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 249b8ea..ce983eb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -998,7 +998,7 @@ object DecisionTreeSuite extends FunSuite {
         node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical,
           categories = List(0.0, 1.0)))
     }
-    // TODO: The information gain stats should be consistent with the same info stored in children.
+    // TODO: The information gain stats should be consistent with info in children: SPARK-7131
     node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2,
       leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6)))
     node
@@ -1006,9 +1006,9 @@ object DecisionTreeSuite extends FunSuite {
 
   /**
    * Create a tree model.  This is deterministic and contains a variety of node and feature types.
-   * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.)
+   * TODO: Update to be a correct tree (with matching probabilities, impurities, etc.): SPARK-7131
    */
-  private[mllib] def createModel(algo: Algo): DecisionTreeModel = {
+  private[spark] def createModel(algo: Algo): DecisionTreeModel = {
     val topNode = createInternalNode(id = 1, Continuous)
     val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
     val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/2] spark git commit: [SPARK-6113] [ML] Tree ensembles for Pipelines API

Posted by me...@apache.org.
[SPARK-6113] [ML] Tree ensembles for Pipelines API

This is a continuation of [https://github.com/apache/spark/pull/5530] (which was for Decision Trees), but for ensembles: Random Forests and Gradient-Boosted Trees.  Please refer to the JIRA [https://issues.apache.org/jira/browse/SPARK-6113], the design doc linked from the JIRA, and the previous PR linked above for design discussions.

This PR follows the example set by the previous PR for Decision Trees.  It includes a few cleanups to Decision Trees.

Note: There is one issue which will be addressed in a separate PR: Ensembles' component Models have no parent or fittingParamMap.  I plan to submit a separate PR which makes those values in Model be Options.  It does not matter much which PR gets merged first.

CC: mengxr manishamde codedeft chouqin

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #5626 from jkbradley/dt-api-ensembles and squashes the following commits:

729167a [Joseph K. Bradley] small cleanups based on code review
bbae2a2 [Joseph K. Bradley] Updated per all comments in code review
855aa9a [Joseph K. Bradley] scala style fix
ea3d901 [Joseph K. Bradley] Added GBT to spark.ml, with tests and examples
c0f30c1 [Joseph K. Bradley] Added random forests and test suites to spark.ml.  Not tested yet.  Need to add example as well
d045ebd [Joseph K. Bradley] some more updates, but far from done
ee1a10b [Joseph K. Bradley] Added files from old PR and did some initial updates.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a7160c4e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a7160c4e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a7160c4e

Branch: refs/heads/master
Commit: a7160c4e3aae22600d05e257d0b4d2428754b8ea
Parents: a61d65f
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Sat Apr 25 12:27:19 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sat Apr 25 12:27:19 2015 -0700

----------------------------------------------------------------------
 .../spark/examples/ml/DecisionTreeExample.scala | 139 ++++++-----
 .../apache/spark/examples/ml/GBTExample.scala   | 238 ++++++++++++++++++
 .../spark/examples/ml/RandomForestExample.scala | 248 ++++++++++++++++++
 .../mllib/GradientBoostedTreesRunner.scala      |   1 +
 .../main/scala/org/apache/spark/ml/Model.scala  |   2 +
 .../classification/DecisionTreeClassifier.scala |  24 +-
 .../spark/ml/classification/GBTClassifier.scala | 228 +++++++++++++++++
 .../classification/RandomForestClassifier.scala | 185 ++++++++++++++
 .../apache/spark/ml/impl/tree/treeParams.scala  | 249 ++++++++++++++++---
 .../ml/param/shared/SharedParamsCodeGen.scala   |   4 +-
 .../spark/ml/param/shared/sharedParams.scala    |  20 ++
 .../ml/regression/DecisionTreeRegressor.scala   |  14 +-
 .../spark/ml/regression/GBTRegressor.scala      | 218 ++++++++++++++++
 .../ml/regression/RandomForestRegressor.scala   | 167 +++++++++++++
 .../scala/org/apache/spark/ml/tree/Node.scala   |   6 +-
 .../scala/org/apache/spark/ml/tree/Split.scala  |  22 +-
 .../org/apache/spark/ml/tree/treeModels.scala   |  46 +++-
 .../JavaDecisionTreeClassifierSuite.java        |  10 +-
 .../classification/JavaGBTClassifierSuite.java  | 100 ++++++++
 .../JavaRandomForestClassifierSuite.java        | 103 ++++++++
 .../JavaDecisionTreeRegressorSuite.java         |  26 +-
 .../ml/regression/JavaGBTRegressorSuite.java    |  99 ++++++++
 .../JavaRandomForestRegressorSuite.java         | 102 ++++++++
 .../DecisionTreeClassifierSuite.scala           |   2 +-
 .../ml/classification/GBTClassifierSuite.scala  | 136 ++++++++++
 .../RandomForestClassifierSuite.scala           | 166 +++++++++++++
 .../org/apache/spark/ml/impl/TreeTests.scala    |  10 +-
 .../regression/DecisionTreeRegressorSuite.scala |   2 +-
 .../spark/ml/regression/GBTRegressorSuite.scala | 137 ++++++++++
 .../regression/RandomForestRegressorSuite.scala | 122 +++++++++
 .../spark/mllib/tree/DecisionTreeSuite.scala    |   6 +-
 31 files changed, 2658 insertions(+), 174 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 2cd515c..9002e99 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -22,10 +22,9 @@ import scala.language.reflectiveCalls
 
 import scopt.OptionParser
 
-import org.apache.spark.ml.tree.DecisionTreeModel
 import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.examples.mllib.AbstractParams
-import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer}
 import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
 import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
 import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
@@ -64,8 +63,6 @@ object DecisionTreeExample {
       maxBins: Int = 32,
       minInstancesPerNode: Int = 1,
       minInfoGain: Double = 0.0,
-      numTrees: Int = 1,
-      featureSubsetStrategy: String = "auto",
       fracTest: Double = 0.2,
       cacheNodeIds: Boolean = false,
       checkpointDir: Option[String] = None,
@@ -123,8 +120,8 @@ object DecisionTreeExample {
         .required()
         .action((x, c) => c.copy(input = x))
       checkConfig { params =>
-        if (params.fracTest < 0 || params.fracTest > 1) {
-          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+        if (params.fracTest < 0 || params.fracTest >= 1) {
+          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
         } else {
           success
         }
@@ -200,9 +197,18 @@ object DecisionTreeExample {
           throw new IllegalArgumentException("Algo ${params.algo} not supported.")
       }
     }
-    val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
+    val dataframes = splits.map(_.toDF()).map(labelsToStrings)
+    val training = dataframes(0).cache()
+    val test = dataframes(1).cache()
 
-    (dataframes(0), dataframes(1))
+    val numTraining = training.count()
+    val numTest = test.count()
+    val numFeatures = training.select("features").first().getAs[Vector](0).size
+    println("Loaded data:")
+    println(s"  numTraining = $numTraining, numTest = $numTest")
+    println(s"  numFeatures = $numFeatures")
+
+    (training, test)
   }
 
   def run(params: Params) {
@@ -217,13 +223,6 @@ object DecisionTreeExample {
     val (training: DataFrame, test: DataFrame) =
       loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
 
-    val numTraining = training.count()
-    val numTest = test.count()
-    val numFeatures = training.select("features").first().getAs[Vector](0).size
-    println("Loaded data:")
-    println(s"  numTraining = $numTraining, numTest = $numTest")
-    println(s"  numFeatures = $numFeatures")
-
     // Set up Pipeline
     val stages = new mutable.ArrayBuffer[PipelineStage]()
     // (1) For classification, re-index classes.
@@ -241,7 +240,7 @@ object DecisionTreeExample {
       .setOutputCol("indexedFeatures")
       .setMaxCategories(10)
     stages += featuresIndexer
-    // (3) Learn DecisionTree
+    // (3) Learn Decision Tree
     val dt = algo match {
       case "classification" =>
         new DecisionTreeClassifier()
@@ -275,62 +274,86 @@ object DecisionTreeExample {
     println(s"Training time: $elapsedTime seconds")
 
     // Get the trained Decision Tree from the fitted PipelineModel
-    val treeModel: DecisionTreeModel = algo match {
+    algo match {
       case "classification" =>
-        pipelineModel.getModel[DecisionTreeClassificationModel](
+        val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
           dt.asInstanceOf[DecisionTreeClassifier])
+        if (treeModel.numNodes < 20) {
+          println(treeModel.toDebugString) // Print full model.
+        } else {
+          println(treeModel) // Print model summary.
+        }
       case "regression" =>
-        pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor])
-      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
-    }
-    if (treeModel.numNodes < 20) {
-      println(treeModel.toDebugString) // Print full model.
-    } else {
-      println(treeModel) // Print model summary.
-    }
-
-    // Predict on training
-    val trainingFullPredictions = pipelineModel.transform(training).cache()
-    val trainingPredictions = trainingFullPredictions.select("prediction")
-      .map(_.getDouble(0))
-    val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0))
-    // Predict on test data
-    val testFullPredictions = pipelineModel.transform(test).cache()
-    val testPredictions = testFullPredictions.select("prediction")
-      .map(_.getDouble(0))
-    val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0))
-
-    // For classification, print number of classes for reference.
-    if (algo == "classification") {
-      val numClasses =
-        MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match {
-          case Some(n) => n
-          case None => throw new RuntimeException(
-            "DecisionTreeExample had unknown failure when indexing labels for classification.")
+        val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
+          dt.asInstanceOf[DecisionTreeRegressor])
+        if (treeModel.numNodes < 20) {
+          println(treeModel.toDebugString) // Print full model.
+        } else {
+          println(treeModel) // Print model summary.
         }
-      println(s"numClasses = $numClasses.")
+      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
     }
 
     // Evaluate model on training, test data
     algo match {
       case "classification" =>
-        val trainingAccuracy =
-          new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision
-        println(s"Train accuracy = $trainingAccuracy")
-        val testAccuracy =
-          new MulticlassMetrics(testPredictions.zip(testLabels)).precision
-        println(s"Test accuracy = $testAccuracy")
+        println("Training data results:")
+        evaluateClassificationModel(pipelineModel, training, labelColName)
+        println("Test data results:")
+        evaluateClassificationModel(pipelineModel, test, labelColName)
       case "regression" =>
-        val trainingRMSE =
-          new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError
-        println(s"Training root mean squared error (RMSE) = $trainingRMSE")
-        val testRMSE =
-          new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError
-        println(s"Test root mean squared error (RMSE) = $testRMSE")
+        println("Training data results:")
+        evaluateRegressionModel(pipelineModel, training, labelColName)
+        println("Test data results:")
+        evaluateRegressionModel(pipelineModel, test, labelColName)
       case _ =>
         throw new IllegalArgumentException("Algo ${params.algo} not supported.")
     }
 
     sc.stop()
   }
+
+  /**
+   * Evaluate the given ClassificationModel on data.  Print the results.
+   * @param model  Must fit ClassificationModel abstraction
+   * @param data  DataFrame with "prediction" and labelColName columns
+   * @param labelColName  Name of the labelCol parameter for the model
+   *
+   * TODO: Change model type to ClassificationModel once that API is public. SPARK-5995
+   */
+  private[ml] def evaluateClassificationModel(
+      model: Transformer,
+      data: DataFrame,
+      labelColName: String): Unit = {
+    val fullPredictions = model.transform(data).cache()
+    val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
+    val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
+    // Print number of classes for reference
+    val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match {
+      case Some(n) => n
+      case None => throw new RuntimeException(
+        "Unknown failure when indexing labels for classification.")
+    }
+    val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision
+    println(s"  Accuracy ($numClasses classes): $accuracy")
+  }
+
+  /**
+   * Evaluate the given RegressionModel on data.  Print the results.
+   * @param model  Must fit RegressionModel abstraction
+   * @param data  DataFrame with "prediction" and labelColName columns
+   * @param labelColName  Name of the labelCol parameter for the model
+   *
+   * TODO: Change model type to RegressionModel once that API is public. SPARK-5995
+   */
+  private[ml] def evaluateRegressionModel(
+      model: Transformer,
+      data: DataFrame,
+      labelColName: String): Unit = {
+    val fullPredictions = model.transform(data).cache()
+    val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
+    val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
+    val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError
+    println(s"  Root mean squared error (RMSE): $RMSE")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
new file mode 100644
index 0000000..5fccb14
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
+import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
+import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.GBTExample [options]
+ * }}}
+ * Decision Trees and ensembles can take a large amount of memory.  If the run-example command
+ * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.GBTExample --driver-memory 1g
+ *   [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object GBTExample {
+
+  case class Params(
+      input: String = null,
+      testInput: String = "",
+      dataFormat: String = "libsvm",
+      algo: String = "classification",
+      maxDepth: Int = 5,
+      maxBins: Int = 32,
+      minInstancesPerNode: Int = 1,
+      minInfoGain: Double = 0.0,
+      maxIter: Int = 10,
+      fracTest: Double = 0.2,
+      cacheNodeIds: Boolean = false,
+      checkpointDir: Option[String] = None,
+      checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+  def main(args: Array[String]) {
+    val defaultParams = Params()
+
+    val parser = new OptionParser[Params]("GBTExample") {
+      head("GBTExample: an example Gradient-Boosted Trees app.")
+      opt[String]("algo")
+        .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
+        .action((x, c) => c.copy(algo = x))
+      opt[Int]("maxDepth")
+        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+        .action((x, c) => c.copy(maxDepth = x))
+      opt[Int]("maxBins")
+        .text(s"max number of bins, default: ${defaultParams.maxBins}")
+        .action((x, c) => c.copy(maxBins = x))
+      opt[Int]("minInstancesPerNode")
+        .text(s"min number of instances required at child nodes to create the parent split," +
+        s" default: ${defaultParams.minInstancesPerNode}")
+        .action((x, c) => c.copy(minInstancesPerNode = x))
+      opt[Double]("minInfoGain")
+        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+        .action((x, c) => c.copy(minInfoGain = x))
+      opt[Int]("maxIter")
+        .text(s"number of trees in ensemble, default: ${defaultParams.maxIter}")
+        .action((x, c) => c.copy(maxIter = x))
+      opt[Double]("fracTest")
+        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
+        s"this option is ignored. default: ${defaultParams.fracTest}")
+        .action((x, c) => c.copy(fracTest = x))
+      opt[Boolean]("cacheNodeIds")
+        .text(s"whether to use node Id cache during training, " +
+        s"default: ${defaultParams.cacheNodeIds}")
+        .action((x, c) => c.copy(cacheNodeIds = x))
+      opt[String]("checkpointDir")
+        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+        s"default: ${
+          defaultParams.checkpointDir match {
+            case Some(strVal) => strVal
+            case None => "None"
+          }
+        }")
+        .action((x, c) => c.copy(checkpointDir = Some(x)))
+      opt[Int]("checkpointInterval")
+        .text(s"how often to checkpoint the node Id cache, " +
+        s"default: ${defaultParams.checkpointInterval}")
+        .action((x, c) => c.copy(checkpointInterval = x))
+      opt[String]("testInput")
+        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
+        s" default: ${defaultParams.testInput}")
+        .action((x, c) => c.copy(testInput = x))
+      opt[String]("<dataFormat>")
+        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+        .action((x, c) => c.copy(dataFormat = x))
+      arg[String]("<input>")
+        .text("input path to labeled examples")
+        .required()
+        .action((x, c) => c.copy(input = x))
+      checkConfig { params =>
+        if (params.fracTest < 0 || params.fracTest >= 1) {
+          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+        } else {
+          success
+        }
+      }
+    }
+
+    parser.parse(args, defaultParams).map { params =>
+      run(params)
+    }.getOrElse {
+      sys.exit(1)
+    }
+  }
+
+  def run(params: Params) {
+    val conf = new SparkConf().setAppName(s"GBTExample with $params")
+    val sc = new SparkContext(conf)
+    params.checkpointDir.foreach(sc.setCheckpointDir)
+    val algo = params.algo.toLowerCase
+
+    println(s"GBTExample with parameters:\n$params")
+
+    // Load training and test data and cache it.
+    val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+      params.dataFormat, params.testInput, algo, params.fracTest)
+
+    // Set up Pipeline
+    val stages = new mutable.ArrayBuffer[PipelineStage]()
+    // (1) For classification, re-index classes.
+    val labelColName = if (algo == "classification") "indexedLabel" else "label"
+    if (algo == "classification") {
+      val labelIndexer = new StringIndexer()
+        .setInputCol("labelString")
+        .setOutputCol(labelColName)
+      stages += labelIndexer
+    }
+    // (2) Identify categorical features using VectorIndexer.
+    //     Features with more than maxCategories values will be treated as continuous.
+    val featuresIndexer = new VectorIndexer()
+      .setInputCol("features")
+      .setOutputCol("indexedFeatures")
+      .setMaxCategories(10)
+    stages += featuresIndexer
+    // (3) Learn GBT
+    val dt = algo match {
+      case "classification" =>
+        new GBTClassifier()
+          .setFeaturesCol("indexedFeatures")
+          .setLabelCol(labelColName)
+          .setMaxDepth(params.maxDepth)
+          .setMaxBins(params.maxBins)
+          .setMinInstancesPerNode(params.minInstancesPerNode)
+          .setMinInfoGain(params.minInfoGain)
+          .setCacheNodeIds(params.cacheNodeIds)
+          .setCheckpointInterval(params.checkpointInterval)
+          .setMaxIter(params.maxIter)
+      case "regression" =>
+        new GBTRegressor()
+          .setFeaturesCol("indexedFeatures")
+          .setLabelCol(labelColName)
+          .setMaxDepth(params.maxDepth)
+          .setMaxBins(params.maxBins)
+          .setMinInstancesPerNode(params.minInstancesPerNode)
+          .setMinInfoGain(params.minInfoGain)
+          .setCacheNodeIds(params.cacheNodeIds)
+          .setCheckpointInterval(params.checkpointInterval)
+          .setMaxIter(params.maxIter)
+      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+    }
+    stages += dt
+    val pipeline = new Pipeline().setStages(stages.toArray)
+
+    // Fit the Pipeline
+    val startTime = System.nanoTime()
+    val pipelineModel = pipeline.fit(training)
+    val elapsedTime = (System.nanoTime() - startTime) / 1e9
+    println(s"Training time: $elapsedTime seconds")
+
+    // Get the trained GBT from the fitted PipelineModel
+    algo match {
+      case "classification" =>
+        val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier])
+        if (rfModel.totalNumNodes < 30) {
+          println(rfModel.toDebugString) // Print full model.
+        } else {
+          println(rfModel) // Print model summary.
+        }
+      case "regression" =>
+        val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor])
+        if (rfModel.totalNumNodes < 30) {
+          println(rfModel.toDebugString) // Print full model.
+        } else {
+          println(rfModel) // Print model summary.
+        }
+      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+    }
+
+    // Evaluate model on training, test data
+    algo match {
+      case "classification" =>
+        println("Training data results:")
+        DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
+        println("Test data results:")
+        DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
+      case "regression" =>
+        println("Training data results:")
+        DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
+        println("Test data results:")
+        DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
+      case _ =>
+        throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+    }
+
+    sc.stop()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
new file mode 100644
index 0000000..9b90932
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
+import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
+import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.RandomForestExample [options]
+ * }}}
+ * Decision Trees and ensembles can take a large amount of memory.  If the run-example command
+ * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.RandomForestExample --driver-memory 1g
+ *   [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object RandomForestExample {
+
+  case class Params(
+      input: String = null,
+      testInput: String = "",
+      dataFormat: String = "libsvm",
+      algo: String = "classification",
+      maxDepth: Int = 5,
+      maxBins: Int = 32,
+      minInstancesPerNode: Int = 1,
+      minInfoGain: Double = 0.0,
+      numTrees: Int = 10,
+      featureSubsetStrategy: String = "auto",
+      fracTest: Double = 0.2,
+      cacheNodeIds: Boolean = false,
+      checkpointDir: Option[String] = None,
+      checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+  def main(args: Array[String]) {
+    val defaultParams = Params()
+
+    val parser = new OptionParser[Params]("RandomForestExample") {
+      head("RandomForestExample: an example random forest app.")
+      opt[String]("algo")
+        .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
+        .action((x, c) => c.copy(algo = x))
+      opt[Int]("maxDepth")
+        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+        .action((x, c) => c.copy(maxDepth = x))
+      opt[Int]("maxBins")
+        .text(s"max number of bins, default: ${defaultParams.maxBins}")
+        .action((x, c) => c.copy(maxBins = x))
+      opt[Int]("minInstancesPerNode")
+        .text(s"min number of instances required at child nodes to create the parent split," +
+        s" default: ${defaultParams.minInstancesPerNode}")
+        .action((x, c) => c.copy(minInstancesPerNode = x))
+      opt[Double]("minInfoGain")
+        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+        .action((x, c) => c.copy(minInfoGain = x))
+      opt[Int]("numTrees")
+        .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}")
+        .action((x, c) => c.copy(numTrees = x))
+      opt[String]("featureSubsetStrategy")
+        .text(s"number of features to use per node (supported:" +
+        s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," +
+        s" default: ${defaultParams.numTrees}")
+        .action((x, c) => c.copy(featureSubsetStrategy = x))
+      opt[Double]("fracTest")
+        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
+        s"this option is ignored. default: ${defaultParams.fracTest}")
+        .action((x, c) => c.copy(fracTest = x))
+      opt[Boolean]("cacheNodeIds")
+        .text(s"whether to use node Id cache during training, " +
+        s"default: ${defaultParams.cacheNodeIds}")
+        .action((x, c) => c.copy(cacheNodeIds = x))
+      opt[String]("checkpointDir")
+        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+        s"default: ${
+          defaultParams.checkpointDir match {
+            case Some(strVal) => strVal
+            case None => "None"
+          }
+        }")
+        .action((x, c) => c.copy(checkpointDir = Some(x)))
+      opt[Int]("checkpointInterval")
+        .text(s"how often to checkpoint the node Id cache, " +
+        s"default: ${defaultParams.checkpointInterval}")
+        .action((x, c) => c.copy(checkpointInterval = x))
+      opt[String]("testInput")
+        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
+        s" default: ${defaultParams.testInput}")
+        .action((x, c) => c.copy(testInput = x))
+      opt[String]("<dataFormat>")
+        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+        .action((x, c) => c.copy(dataFormat = x))
+      arg[String]("<input>")
+        .text("input path to labeled examples")
+        .required()
+        .action((x, c) => c.copy(input = x))
+      checkConfig { params =>
+        if (params.fracTest < 0 || params.fracTest >= 1) {
+          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+        } else {
+          success
+        }
+      }
+    }
+
+    parser.parse(args, defaultParams).map { params =>
+      run(params)
+    }.getOrElse {
+      sys.exit(1)
+    }
+  }
+
+  def run(params: Params) {
+    val conf = new SparkConf().setAppName(s"RandomForestExample with $params")
+    val sc = new SparkContext(conf)
+    params.checkpointDir.foreach(sc.setCheckpointDir)
+    val algo = params.algo.toLowerCase
+
+    println(s"RandomForestExample with parameters:\n$params")
+
+    // Load training and test data and cache it.
+    val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+      params.dataFormat, params.testInput, algo, params.fracTest)
+
+    // Set up Pipeline
+    val stages = new mutable.ArrayBuffer[PipelineStage]()
+    // (1) For classification, re-index classes.
+    val labelColName = if (algo == "classification") "indexedLabel" else "label"
+    if (algo == "classification") {
+      val labelIndexer = new StringIndexer()
+        .setInputCol("labelString")
+        .setOutputCol(labelColName)
+      stages += labelIndexer
+    }
+    // (2) Identify categorical features using VectorIndexer.
+    //     Features with more than maxCategories values will be treated as continuous.
+    val featuresIndexer = new VectorIndexer()
+      .setInputCol("features")
+      .setOutputCol("indexedFeatures")
+      .setMaxCategories(10)
+    stages += featuresIndexer
+    // (3) Learn Random Forest
+    val dt = algo match {
+      case "classification" =>
+        new RandomForestClassifier()
+          .setFeaturesCol("indexedFeatures")
+          .setLabelCol(labelColName)
+          .setMaxDepth(params.maxDepth)
+          .setMaxBins(params.maxBins)
+          .setMinInstancesPerNode(params.minInstancesPerNode)
+          .setMinInfoGain(params.minInfoGain)
+          .setCacheNodeIds(params.cacheNodeIds)
+          .setCheckpointInterval(params.checkpointInterval)
+          .setFeatureSubsetStrategy(params.featureSubsetStrategy)
+          .setNumTrees(params.numTrees)
+      case "regression" =>
+        new RandomForestRegressor()
+          .setFeaturesCol("indexedFeatures")
+          .setLabelCol(labelColName)
+          .setMaxDepth(params.maxDepth)
+          .setMaxBins(params.maxBins)
+          .setMinInstancesPerNode(params.minInstancesPerNode)
+          .setMinInfoGain(params.minInfoGain)
+          .setCacheNodeIds(params.cacheNodeIds)
+          .setCheckpointInterval(params.checkpointInterval)
+          .setFeatureSubsetStrategy(params.featureSubsetStrategy)
+          .setNumTrees(params.numTrees)
+      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+    }
+    stages += dt
+    val pipeline = new Pipeline().setStages(stages.toArray)
+
+    // Fit the Pipeline
+    val startTime = System.nanoTime()
+    val pipelineModel = pipeline.fit(training)
+    val elapsedTime = (System.nanoTime() - startTime) / 1e9
+    println(s"Training time: $elapsedTime seconds")
+
+    // Get the trained Random Forest from the fitted PipelineModel
+    algo match {
+      case "classification" =>
+        val rfModel = pipelineModel.getModel[RandomForestClassificationModel](
+          dt.asInstanceOf[RandomForestClassifier])
+        if (rfModel.totalNumNodes < 30) {
+          println(rfModel.toDebugString) // Print full model.
+        } else {
+          println(rfModel) // Print model summary.
+        }
+      case "regression" =>
+        val rfModel = pipelineModel.getModel[RandomForestRegressionModel](
+          dt.asInstanceOf[RandomForestRegressor])
+        if (rfModel.totalNumNodes < 30) {
+          println(rfModel.toDebugString) // Print full model.
+        } else {
+          println(rfModel) // Print model summary.
+        }
+      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+    }
+
+    // Evaluate model on training, test data
+    algo match {
+      case "classification" =>
+        println("Training data results:")
+        DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
+        println("Test data results:")
+        DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
+      case "regression" =>
+        println("Training data results:")
+        DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
+        println("Test data results:")
+        DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
+      case _ =>
+        throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+    }
+
+    sc.stop()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
index 431ead8..0763a77 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.GradientBoostedTrees
 import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
 import org.apache.spark.util.Utils
 
+
 /**
  * An example runner for Gradient Boosting using decision trees as weak learners. Run with
  * {{{

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/Model.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index cae5082..a491bc7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -30,11 +30,13 @@ import org.apache.spark.ml.param.ParamMap
 abstract class Model[M <: Model[M]] extends Transformer {
   /**
    * The parent estimator that produced this model.
+   * Note: For ensembles' component Models, this value can be null.
    */
   val parent: Estimator[M]
 
   /**
    * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model.
+   * Note: For ensembles' component Models, this value can be null.
    */
   val fittingParamMap: ParamMap
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 3855e39..ee2a8dc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -43,8 +43,7 @@ import org.apache.spark.sql.DataFrame
 @AlphaComponent
 final class DecisionTreeClassifier
   extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
-  with DecisionTreeParams
-  with TreeClassifierParams {
+  with DecisionTreeParams with TreeClassifierParams {
 
   // Override parameter setters from parent trait for Java API compatibility.
 
@@ -59,11 +58,9 @@ final class DecisionTreeClassifier
 
   override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
 
-  override def setCacheNodeIds(value: Boolean): this.type =
-    super.setCacheNodeIds(value)
+  override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
 
-  override def setCheckpointInterval(value: Int): this.type =
-    super.setCheckpointInterval(value)
+  override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
 
   override def setImpurity(value: String): this.type = super.setImpurity(value)
 
@@ -75,8 +72,9 @@ final class DecisionTreeClassifier
     val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
       case Some(n: Int) => n
       case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
-        s" with invalid label column, without the number of classes specified.")
-        // TODO: Automatically index labels.
+        s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+        " specified. See StringIndexer.")
+        // TODO: Automatically index labels: SPARK-7126
     }
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
     val strategy = getOldStrategy(categoricalFeatures, numClasses)
@@ -85,18 +83,16 @@ final class DecisionTreeClassifier
   }
 
   /** (private[ml]) Create a Strategy instance to use with the old API. */
-  override private[ml] def getOldStrategy(
+  private[ml] def getOldStrategy(
       categoricalFeatures: Map[Int, Int],
       numClasses: Int): OldStrategy = {
-    val strategy = super.getOldStrategy(categoricalFeatures, numClasses)
-    strategy.algo = OldAlgo.Classification
-    strategy.setImpurity(getOldImpurity)
-    strategy
+    super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
+      subsamplingRate = 1.0)
   }
 }
 
 object DecisionTreeClassifier {
-  /** Accessor for supported impurities */
+  /** Accessor for supported impurities: entropy, gini */
   final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
new file mode 100644
index 0000000..d2e052f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Param, Params, ParamMap}
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.loss.{Loss => OldLoss, LogLoss => OldLogLoss}
+import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * learning algorithm for classification.
+ * It supports binary labels, as well as both continuous and categorical features.
+ * Note: Multiclass labels are not currently supported.
+ */
+@AlphaComponent
+final class GBTClassifier
+  extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
+  with GBTParams with TreeClassifierParams with Logging {
+
+  // Override parameter setters from parent trait for Java API compatibility.
+
+  // Parameters from TreeClassifierParams:
+
+  override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+  override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+  override def setMinInstancesPerNode(value: Int): this.type =
+    super.setMinInstancesPerNode(value)
+
+  override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+  override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+  override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+  override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+  /**
+   * The impurity setting is ignored for GBT models.
+   * Individual trees are built using impurity "Variance."
+   */
+  override def setImpurity(value: String): this.type = {
+    logWarning("GBTClassifier.setImpurity should NOT be used")
+    this
+  }
+
+  // Parameters from TreeEnsembleParams:
+
+  override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+  override def setSeed(value: Long): this.type = {
+    logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
+    super.setSeed(value)
+  }
+
+  // Parameters from GBTParams:
+
+  override def setMaxIter(value: Int): this.type = super.setMaxIter(value)
+
+  override def setStepSize(value: Double): this.type = super.setStepSize(value)
+
+  // Parameters for GBTClassifier:
+
+  /**
+   * Loss function which GBT tries to minimize. (case-insensitive)
+   * Supported: "logistic"
+   * (default = logistic)
+   * @group param
+   */
+  val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+    " tries to minimize (case-insensitive). Supported options:" +
+    s" ${GBTClassifier.supportedLossTypes.mkString(", ")}")
+
+  setDefault(lossType -> "logistic")
+
+  /** @group setParam */
+  def setLossType(value: String): this.type = {
+    val lossStr = value.toLowerCase
+    require(GBTClassifier.supportedLossTypes.contains(lossStr), "GBTClassifier was given bad loss" +
+      s" type: $value. Supported options: ${GBTClassifier.supportedLossTypes.mkString(", ")}")
+    set(lossType, lossStr)
+    this
+  }
+
+  /** @group getParam */
+  def getLossType: String = getOrDefault(lossType)
+
+  /** (private[ml]) Convert new loss to old loss. */
+  override private[ml] def getOldLossType: OldLoss = {
+    getLossType match {
+      case "logistic" => OldLogLoss
+      case _ =>
+        // Should never happen because of check in setter method.
+        throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
+    }
+  }
+
+  override protected def train(
+      dataset: DataFrame,
+      paramMap: ParamMap): GBTClassificationModel = {
+    val categoricalFeatures: Map[Int, Int] =
+      MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+    val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+      case Some(n: Int) => n
+      case None => throw new IllegalArgumentException("GBTClassifier was given input" +
+        s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+        " specified. See StringIndexer.")
+      // TODO: Automatically index labels: SPARK-7126
+    }
+    require(numClasses == 2,
+      s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
+    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+    val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
+    val oldGBT = new OldGBT(boostingStrategy)
+    val oldModel = oldGBT.run(oldDataset)
+    GBTClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+  }
+}
+
+object GBTClassifier {
+  // The losses below should be lowercase.
+  /** Accessor for supported loss settings: logistic */
+  final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * model for classification.
+ * It supports binary labels, as well as both continuous and categorical features.
+ * Note: Multiclass labels are not currently supported.
+ * @param _trees  Decision trees in the ensemble.
+ * @param _treeWeights  Weights for the decision trees in the ensemble.
+ */
+@AlphaComponent
+final class GBTClassificationModel(
+    override val parent: GBTClassifier,
+    override val fittingParamMap: ParamMap,
+    private val _trees: Array[DecisionTreeRegressionModel],
+    private val _treeWeights: Array[Double])
+  extends PredictionModel[Vector, GBTClassificationModel]
+  with TreeEnsembleModel with Serializable {
+
+  require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.")
+  require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
+    s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
+
+  override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+  override def treeWeights: Array[Double] = _treeWeights
+
+  override protected def predict(features: Vector): Double = {
+    // TODO: Override transform() to broadcast model: SPARK-7127
+    // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
+    // Classifies by thresholding sum of weighted tree predictions
+    val treePredictions = _trees.map(_.rootNode.predict(features))
+    val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+    if (prediction > 0.0) 1.0 else 0.0
+  }
+
+  override protected def copy(): GBTClassificationModel = {
+    val m = new GBTClassificationModel(parent, fittingParamMap, _trees, _treeWeights)
+    Params.inheritValues(this.extractParamMap(), this, m)
+    m
+  }
+
+  override def toString: String = {
+    s"GBTClassificationModel with $numTrees trees"
+  }
+
+  /** (private[ml]) Convert to a model in the old API */
+  private[ml] def toOld: OldGBTModel = {
+    new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
+  }
+}
+
+private[ml] object GBTClassificationModel {
+
+  /** (private[ml]) Convert a model from the old API */
+  def fromOld(
+      oldModel: OldGBTModel,
+      parent: GBTClassifier,
+      fittingParamMap: ParamMap,
+      categoricalFeatures: Map[Int, Int]): GBTClassificationModel = {
+    require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
+      s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
+    val newTrees = oldModel.trees.map { tree =>
+      // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+      DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+    }
+    new GBTClassificationModel(parent, fittingParamMap, newTrees, oldModel.treeWeights)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
new file mode 100644
index 0000000..cfd6508
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest  Random Forest]] learning algorithm for
+ * classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class RandomForestClassifier
+  extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
+  with RandomForestParams with TreeClassifierParams {
+
+  // Override parameter setters from parent trait for Java API compatibility.
+
+  // Parameters from TreeClassifierParams:
+
+  override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+  override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+  override def setMinInstancesPerNode(value: Int): this.type =
+    super.setMinInstancesPerNode(value)
+
+  override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+  override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+  override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+  override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+  override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+  // Parameters from TreeEnsembleParams:
+
+  override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+  override def setSeed(value: Long): this.type = super.setSeed(value)
+
+  // Parameters from RandomForestParams:
+
+  override def setNumTrees(value: Int): this.type = super.setNumTrees(value)
+
+  override def setFeatureSubsetStrategy(value: String): this.type =
+    super.setFeatureSubsetStrategy(value)
+
+  override protected def train(
+      dataset: DataFrame,
+      paramMap: ParamMap): RandomForestClassificationModel = {
+    val categoricalFeatures: Map[Int, Int] =
+      MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+    val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+      case Some(n: Int) => n
+      case None => throw new IllegalArgumentException("RandomForestClassifier was given input" +
+        s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+        " specified. See StringIndexer.")
+      // TODO: Automatically index labels: SPARK-7126
+    }
+    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+    val strategy =
+      super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
+    val oldModel = OldRandomForest.trainClassifier(
+      oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
+    RandomForestClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+  }
+}
+
+object RandomForestClassifier {
+  /** Accessor for supported impurity settings: entropy, gini */
+  final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+
+  /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
+  final val supportedFeatureSubsetStrategies: Array[String] =
+    RandomForestParams.supportedFeatureSubsetStrategies
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest  Random Forest]] model for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ * @param _trees  Decision trees in the ensemble.
+ *               Warning: These have null parents.
+ */
+@AlphaComponent
+final class RandomForestClassificationModel private[ml] (
+    override val parent: RandomForestClassifier,
+    override val fittingParamMap: ParamMap,
+    private val _trees: Array[DecisionTreeClassificationModel])
+  extends PredictionModel[Vector, RandomForestClassificationModel]
+  with TreeEnsembleModel with Serializable {
+
+  require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
+
+  override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+  // Note: We may add support for weights (based on tree performance) later on.
+  private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+
+  override def treeWeights: Array[Double] = _treeWeights
+
+  override protected def predict(features: Vector): Double = {
+    // TODO: Override transform() to broadcast model.  SPARK-7127
+    // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
+    // Classifies using majority votes.
+    // Ignore the weights since all are 1.0 for now.
+    val votes = mutable.Map.empty[Int, Double]
+    _trees.view.foreach { tree =>
+      val prediction = tree.rootNode.predict(features).toInt
+      votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight
+    }
+    votes.maxBy(_._2)._1
+  }
+
+  override protected def copy(): RandomForestClassificationModel = {
+    val m = new RandomForestClassificationModel(parent, fittingParamMap, _trees)
+    Params.inheritValues(this.extractParamMap(), this, m)
+    m
+  }
+
+  override def toString: String = {
+    s"RandomForestClassificationModel with $numTrees trees"
+  }
+
+  /** (private[ml]) Convert to a model in the old API */
+  private[ml] def toOld: OldRandomForestModel = {
+    new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
+  }
+}
+
+private[ml] object RandomForestClassificationModel {
+
+  /** (private[ml]) Convert a model from the old API */
+  def fromOld(
+      oldModel: OldRandomForestModel,
+      parent: RandomForestClassifier,
+      fittingParamMap: ParamMap,
+      categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
+    require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
+      s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
+    val newTrees = oldModel.trees.map { tree =>
+      // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+      DecisionTreeClassificationModel.fromOld(tree, null, null, categoricalFeatures)
+    }
+    new RandomForestClassificationModel(parent, fittingParamMap, newTrees)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
index eb2609f..ab6281b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -20,9 +20,12 @@ package org.apache.spark.ml.impl.tree
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.impl.estimator.PredictorParams
 import org.apache.spark.ml.param._
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.ml.param.shared.{HasSeed, HasMaxIter}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo,
+  BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
   Impurity => OldImpurity, Variance => OldVariance}
+import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
 
 
 /**
@@ -117,79 +120,68 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
   def setMaxDepth(value: Int): this.type = {
     require(value >= 0, s"maxDepth parameter must be >= 0.  Given bad value: $value")
     set(maxDepth, value)
-    this
   }
 
   /** @group getParam */
-  def getMaxDepth: Int = getOrDefault(maxDepth)
+  final def getMaxDepth: Int = getOrDefault(maxDepth)
 
   /** @group setParam */
   def setMaxBins(value: Int): this.type = {
     require(value >= 2, s"maxBins parameter must be >= 2.  Given bad value: $value")
     set(maxBins, value)
-    this
   }
 
   /** @group getParam */
-  def getMaxBins: Int = getOrDefault(maxBins)
+  final def getMaxBins: Int = getOrDefault(maxBins)
 
   /** @group setParam */
   def setMinInstancesPerNode(value: Int): this.type = {
     require(value >= 1, s"minInstancesPerNode parameter must be >= 1.  Given bad value: $value")
     set(minInstancesPerNode, value)
-    this
   }
 
   /** @group getParam */
-  def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
+  final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
 
   /** @group setParam */
-  def setMinInfoGain(value: Double): this.type = {
-    set(minInfoGain, value)
-    this
-  }
+  def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
 
   /** @group getParam */
-  def getMinInfoGain: Double = getOrDefault(minInfoGain)
+  final def getMinInfoGain: Double = getOrDefault(minInfoGain)
 
   /** @group expertSetParam */
   def setMaxMemoryInMB(value: Int): this.type = {
     require(value > 0, s"maxMemoryInMB parameter must be > 0.  Given bad value: $value")
     set(maxMemoryInMB, value)
-    this
   }
 
   /** @group expertGetParam */
-  def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
+  final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
 
   /** @group expertSetParam */
-  def setCacheNodeIds(value: Boolean): this.type = {
-    set(cacheNodeIds, value)
-    this
-  }
+  def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
 
   /** @group expertGetParam */
-  def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
+  final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
 
   /** @group expertSetParam */
   def setCheckpointInterval(value: Int): this.type = {
     require(value >= 1, s"checkpointInterval parameter must be >= 1.  Given bad value: $value")
     set(checkpointInterval, value)
-    this
   }
 
   /** @group expertGetParam */
-  def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
+  final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
 
-  /**
-   * Create a Strategy instance to use with the old API.
-   * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0,
-   *       the default for single trees).
-   */
+  /** (private[ml]) Create a Strategy instance to use with the old API. */
   private[ml] def getOldStrategy(
       categoricalFeatures: Map[Int, Int],
-      numClasses: Int): OldStrategy = {
-    val strategy = OldStrategy.defaultStategy(OldAlgo.Classification)
+      numClasses: Int,
+      oldAlgo: OldAlgo.Algo,
+      oldImpurity: OldImpurity,
+      subsamplingRate: Double): OldStrategy = {
+    val strategy = OldStrategy.defaultStategy(oldAlgo)
+    strategy.impurity = oldImpurity
     strategy.checkpointInterval = getCheckpointInterval
     strategy.maxBins = getMaxBins
     strategy.maxDepth = getMaxDepth
@@ -199,13 +191,13 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
     strategy.useNodeIdCache = getCacheNodeIds
     strategy.numClasses = numClasses
     strategy.categoricalFeaturesInfo = categoricalFeatures
-    strategy.subsamplingRate = 1.0 // default for individual trees
+    strategy.subsamplingRate = subsamplingRate
     strategy
   }
 }
 
 /**
- * (private trait) Parameters for Decision Tree-based classification algorithms.
+ * Parameters for Decision Tree-based classification algorithms.
  */
 private[ml] trait TreeClassifierParams extends Params {
 
@@ -215,7 +207,7 @@ private[ml] trait TreeClassifierParams extends Params {
    * (default = gini)
    * @group param
    */
-  val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+  final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
     " information gain calculation (case-insensitive). Supported options:" +
     s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
 
@@ -228,11 +220,10 @@ private[ml] trait TreeClassifierParams extends Params {
       s"Tree-based classifier was given unrecognized impurity: $value." +
       s"  Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
     set(impurity, impurityStr)
-    this
   }
 
   /** @group getParam */
-  def getImpurity: String = getOrDefault(impurity)
+  final def getImpurity: String = getOrDefault(impurity)
 
   /** Convert new impurity to old impurity. */
   private[ml] def getOldImpurity: OldImpurity = {
@@ -249,11 +240,11 @@ private[ml] trait TreeClassifierParams extends Params {
 
 private[ml] object TreeClassifierParams {
   // These options should be lowercase.
-  val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+  final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
 }
 
 /**
- * (private trait) Parameters for Decision Tree-based regression algorithms.
+ * Parameters for Decision Tree-based regression algorithms.
  */
 private[ml] trait TreeRegressorParams extends Params {
 
@@ -263,7 +254,7 @@ private[ml] trait TreeRegressorParams extends Params {
    * (default = variance)
    * @group param
    */
-  val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+  final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
     " information gain calculation (case-insensitive). Supported options:" +
     s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
 
@@ -276,11 +267,10 @@ private[ml] trait TreeRegressorParams extends Params {
       s"Tree-based regressor was given unrecognized impurity: $value." +
         s"  Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
     set(impurity, impurityStr)
-    this
   }
 
   /** @group getParam */
-  def getImpurity: String = getOrDefault(impurity)
+  final def getImpurity: String = getOrDefault(impurity)
 
   /** Convert new impurity to old impurity. */
   private[ml] def getOldImpurity: OldImpurity = {
@@ -296,5 +286,186 @@ private[ml] trait TreeRegressorParams extends Params {
 
 private[ml] object TreeRegressorParams {
   // These options should be lowercase.
-  val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+  final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based ensemble algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
+
+  /**
+   * Fraction of the training data used for learning each decision tree.
+   * (default = 1.0)
+   * @group param
+   */
+  final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
+    "Fraction of the training data used for learning each decision tree.")
+
+  setDefault(subsamplingRate -> 1.0)
+
+  /** @group setParam */
+  def setSubsamplingRate(value: Double): this.type = {
+    require(value > 0.0 && value <= 1.0,
+      s"Subsampling rate must be in range (0,1]. Bad rate: $value")
+    set(subsamplingRate, value)
+  }
+
+  /** @group getParam */
+  final def getSubsamplingRate: Double = getOrDefault(subsamplingRate)
+
+  /** @group setParam */
+  def setSeed(value: Long): this.type = set(seed, value)
+
+  /**
+   * Create a Strategy instance to use with the old API.
+   * NOTE: The caller should set impurity and seed.
+   */
+  private[ml] def getOldStrategy(
+      categoricalFeatures: Map[Int, Int],
+      numClasses: Int,
+      oldAlgo: OldAlgo.Algo,
+      oldImpurity: OldImpurity): OldStrategy = {
+    super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
+  }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Random Forest algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait RandomForestParams extends TreeEnsembleParams {
+
+  /**
+   * Number of trees to train (>= 1).
+   * If 1, then no bootstrapping is used.  If > 1, then bootstrapping is done.
+   * TODO: Change to always do bootstrapping (simpler).  SPARK-7130
+   * (default = 20)
+   * @group param
+   */
+  final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)")
+
+  /**
+   * The number of features to consider for splits at each tree node.
+   * Supported options:
+   *  - "auto": Choose automatically for task:
+   *            If numTrees == 1, set to "all."
+   *            If numTrees > 1 (forest), set to "sqrt" for classification and
+   *              to "onethird" for regression.
+   *  - "all": use all features
+   *  - "onethird": use 1/3 of the features
+   *  - "sqrt": use sqrt(number of features)
+   *  - "log2": use log2(number of features)
+   * (default = "auto")
+   *
+   * These various settings are based on the following references:
+   *  - log2: tested in Breiman (2001)
+   *  - sqrt: recommended by Breiman manual for random forests
+   *  - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
+   *    package.
+   * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf  Breiman (2001)]]
+   * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf  Breiman manual for
+   *     random forests]]
+   *
+   * @group param
+   */
+  final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
+    "The number of features to consider for splits at each tree node." +
+      s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}")
+
+  setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
+
+  /** @group setParam */
+  def setNumTrees(value: Int): this.type = {
+    require(value >= 1, s"Random Forest numTrees parameter cannot be $value; it must be >= 1.")
+    set(numTrees, value)
+  }
+
+  /** @group getParam */
+  final def getNumTrees: Int = getOrDefault(numTrees)
+
+  /** @group setParam */
+  def setFeatureSubsetStrategy(value: String): this.type = {
+    val strategyStr = value.toLowerCase
+    require(RandomForestParams.supportedFeatureSubsetStrategies.contains(strategyStr),
+      s"RandomForestParams was given unrecognized featureSubsetStrategy: $value. Supported" +
+        s" options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}")
+    set(featureSubsetStrategy, strategyStr)
+  }
+
+  /** @group getParam */
+  final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy)
+}
+
+private[ml] object RandomForestParams {
+  // These options should be lowercase.
+  final val supportedFeatureSubsetStrategies: Array[String] =
+    Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Gradient-Boosted Tree algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
+
+  /**
+   * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
+   * estimator.
+   * (default = 0.1)
+   * @group param
+   */
+  final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
+    " learning rate) in interval (0, 1] for shrinking the contribution of each estimator")
+
+  /* TODO: Add this doc when we add this param.  SPARK-7132
+   * Threshold for stopping early when runWithValidation is used.
+   * If the error rate on the validation input changes by less than the validationTol,
+   * then learning will stop early (before [[numIterations]]).
+   * This parameter is ignored when run is used.
+   * (default = 1e-5)
+   * @group param
+   */
+  // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
+  // validationTol -> 1e-5
+
+  setDefault(maxIter -> 20, stepSize -> 0.1)
+
+  /** @group setParam */
+  def setMaxIter(value: Int): this.type = {
+    require(value >= 1, s"Gradient Boosting maxIter parameter cannot be $value; it must be >= 1.")
+    set(maxIter, value)
+  }
+
+  /** @group setParam */
+  def setStepSize(value: Double): this.type = {
+    require(value > 0.0 && value <= 1.0,
+      s"GBT given invalid step size ($value).  Value should be in (0,1].")
+    set(stepSize, value)
+  }
+
+  /** @group getParam */
+  final def getStepSize: Double = getOrDefault(stepSize)
+
+  /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
+  private[ml] def getOldBoostingStrategy(
+      categoricalFeatures: Map[Int, Int],
+      oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
+    val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
+    // NOTE: The old API does not support "seed" so we ignore it.
+    new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
+  }
+
+  /** Get old Gradient Boosting Loss type */
+  private[ml] def getOldLossType: OldLoss
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 95d7e64..e88c487 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -45,7 +45,8 @@ private[shared] object SharedParamsCodeGen {
       ParamDesc[Array[String]]("inputCols", "input column names"),
       ParamDesc[String]("outputCol", "output column name"),
       ParamDesc[Int]("checkpointInterval", "checkpoint interval"),
-      ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")))
+      ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
+      ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")))
 
     val code = genSharedParams(params)
     val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
@@ -154,6 +155,7 @@ private[shared] object SharedParamsCodeGen {
         |
         |import org.apache.spark.annotation.DeveloperApi
         |import org.apache.spark.ml.param._
+        |import org.apache.spark.util.Utils
         |
         |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
         |

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 72b08bf..a860b88 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.param.shared
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.param._
+import org.apache.spark.util.Utils
 
 // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
 
@@ -256,4 +257,23 @@ trait HasFitIntercept extends Params {
   /** @group getParam */
   final def getFitIntercept: Boolean = getOrDefault(fitIntercept)
 }
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param seed (default: Utils.random.nextLong()).
+ */
+@DeveloperApi
+trait HasSeed extends Params {
+
+  /**
+   * Param for random seed.
+   * @group param
+   */
+  final val seed: LongParam = new LongParam(this, "seed", "random seed")
+
+  setDefault(seed, Utils.random.nextLong())
+
+  /** @group getParam */
+  final def getSeed: Long = getOrDefault(seed)
+}
 // scalastyle:on

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 49a8b77..756725a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -42,8 +42,7 @@ import org.apache.spark.sql.DataFrame
 @AlphaComponent
 final class DecisionTreeRegressor
   extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
-  with DecisionTreeParams
-  with TreeRegressorParams {
+  with DecisionTreeParams with TreeRegressorParams {
 
   // Override parameter setters from parent trait for Java API compatibility.
 
@@ -60,8 +59,7 @@ final class DecisionTreeRegressor
 
   override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
 
-  override def setCheckpointInterval(value: Int): this.type =
-    super.setCheckpointInterval(value)
+  override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
 
   override def setImpurity(value: String): this.type = super.setImpurity(value)
 
@@ -78,15 +76,13 @@ final class DecisionTreeRegressor
 
   /** (private[ml]) Create a Strategy instance to use with the old API. */
   private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
-    val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0)
-    strategy.algo = OldAlgo.Regression
-    strategy.setImpurity(getOldImpurity)
-    strategy
+    super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
+      subsamplingRate = 1.0)
   }
 }
 
 object DecisionTreeRegressor {
-  /** Accessor for supported impurities */
+  /** Accessor for supported impurities: variance */
   final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a7160c4e/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
new file mode 100644
index 0000000..c784cf3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap, Param}
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
+  SquaredError => OldSquaredError}
+import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * learning algorithm for regression.
+ * It supports both continuous and categorical features.
+ */
+@AlphaComponent
+final class GBTRegressor
+  extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
+  with GBTParams with TreeRegressorParams with Logging {
+
+  // Override parameter setters from parent trait for Java API compatibility.
+
+  // Parameters from TreeRegressorParams:
+
+  override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+  override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+  override def setMinInstancesPerNode(value: Int): this.type =
+    super.setMinInstancesPerNode(value)
+
+  override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+  override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+  override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+  override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+  /**
+   * The impurity setting is ignored for GBT models.
+   * Individual trees are built using impurity "Variance."
+   */
+  override def setImpurity(value: String): this.type = {
+    logWarning("GBTRegressor.setImpurity should NOT be used")
+    this
+  }
+
+  // Parameters from TreeEnsembleParams:
+
+  override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+  override def setSeed(value: Long): this.type = {
+    logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
+    super.setSeed(value)
+  }
+
+  // Parameters from GBTParams:
+
+  override def setMaxIter(value: Int): this.type = super.setMaxIter(value)
+
+  override def setStepSize(value: Double): this.type = super.setStepSize(value)
+
+  // Parameters for GBTRegressor:
+
+  /**
+   * Loss function which GBT tries to minimize. (case-insensitive)
+   * Supported: "squared" (L2) and "absolute" (L1)
+   * (default = squared)
+   * @group param
+   */
+  val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+    " tries to minimize (case-insensitive). Supported options:" +
+    s" ${GBTRegressor.supportedLossTypes.mkString(", ")}")
+
+  setDefault(lossType -> "squared")
+
+  /** @group setParam */
+  def setLossType(value: String): this.type = {
+    val lossStr = value.toLowerCase
+    require(GBTRegressor.supportedLossTypes.contains(lossStr), "GBTRegressor was given bad loss" +
+      s" type: $value. Supported options: ${GBTRegressor.supportedLossTypes.mkString(", ")}")
+    set(lossType, lossStr)
+    this
+  }
+
+  /** @group getParam */
+  def getLossType: String = getOrDefault(lossType)
+
+  /** (private[ml]) Convert new loss to old loss. */
+  override private[ml] def getOldLossType: OldLoss = {
+    getLossType match {
+      case "squared" => OldSquaredError
+      case "absolute" => OldAbsoluteError
+      case _ =>
+        // Should never happen because of check in setter method.
+        throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
+    }
+  }
+
+  override protected def train(
+      dataset: DataFrame,
+      paramMap: ParamMap): GBTRegressionModel = {
+    val categoricalFeatures: Map[Int, Int] =
+      MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+    val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
+    val oldGBT = new OldGBT(boostingStrategy)
+    val oldModel = oldGBT.run(oldDataset)
+    GBTRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+  }
+}
+
+object GBTRegressor {
+  // The losses below should be lowercase.
+  /** Accessor for supported loss settings: squared (L2), absolute (L1) */
+  final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * model for regression.
+ * It supports both continuous and categorical features.
+ * @param _trees  Decision trees in the ensemble.
+ * @param _treeWeights  Weights for the decision trees in the ensemble.
+ */
+@AlphaComponent
+final class GBTRegressionModel(
+    override val parent: GBTRegressor,
+    override val fittingParamMap: ParamMap,
+    private val _trees: Array[DecisionTreeRegressionModel],
+    private val _treeWeights: Array[Double])
+  extends PredictionModel[Vector, GBTRegressionModel]
+  with TreeEnsembleModel with Serializable {
+
+  require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.")
+  require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
+    s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
+
+  override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+  override def treeWeights: Array[Double] = _treeWeights
+
+  override protected def predict(features: Vector): Double = {
+    // TODO: Override transform() to broadcast model. SPARK-7127
+    // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
+    // Classifies by thresholding sum of weighted tree predictions
+    val treePredictions = _trees.map(_.rootNode.predict(features))
+    val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+    if (prediction > 0.0) 1.0 else 0.0
+  }
+
+  override protected def copy(): GBTRegressionModel = {
+    val m = new GBTRegressionModel(parent, fittingParamMap, _trees, _treeWeights)
+    Params.inheritValues(this.extractParamMap(), this, m)
+    m
+  }
+
+  override def toString: String = {
+    s"GBTRegressionModel with $numTrees trees"
+  }
+
+  /** (private[ml]) Convert to a model in the old API */
+  private[ml] def toOld: OldGBTModel = {
+    new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
+  }
+}
+
+private[ml] object GBTRegressionModel {
+
+  /** (private[ml]) Convert a model from the old API */
+  def fromOld(
+      oldModel: OldGBTModel,
+      parent: GBTRegressor,
+      fittingParamMap: ParamMap,
+      categoricalFeatures: Map[Int, Int]): GBTRegressionModel = {
+    require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
+      s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
+    val newTrees = oldModel.trees.map { tree =>
+      // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+      DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+    }
+    new GBTRegressionModel(parent, fittingParamMap, newTrees, oldModel.treeWeights)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org